Best Python code snippet using fMBT_python
fista_para.py
Source:fista_para.py  
1# -*- coding: utf-8 -*-2"""3Created on Tue Mar 12 13:59:42 201945@author: 619956"""7#import pandas8#import tensorflow as tf9import foa_image as foai10import foa_convolution as foac11import foa_saliency as foas12import numpy as np13import scipy.io as scio14import cv215#from data import *16import torch17import time18#import matplotlib.pyplot as plt19#from mpl_toolkits.mplot3d import Axes3D202122#def dpcn_detect(input_img,A,B,C,d_x1,d_u1,batch_size):23#    img = np.zeros([batch_size,20,20,1])24#    img[:,:,:,0] =input_img25#    imgout = np.zeros([batch_size,2,2,144])26#    for kkk in range(batch_size):27#        for iii in range(6,20,8):28#            for jjj in range(6,20,8):29#                imgout[kkk,int((iii-6)/8),int((jjj-6)/8),:]=(img[kkk,iii-6:iii+6,jjj-6:jjj+6,0]).flatten()30#    imgout = np.transpose(imgout,[3,1,2,0])31#    32#    x_hat = np.zeros([d_x1,2,2,batch_size])33#    cause1,state_cell1 = fista(x_hat,C,B,A,imgout,0,d_x1,d_u1,batch_size)34#    return cause1        35def dpcn_detect(input_img,A,B,C,d_x1,d_u1,batch_size):3637    #img = torch.zeros([batch_size,20,20,1])#batch_size38    img = input_img.unsqueeze(-1)39    #imgout = torch.zeros([batch_size,2,2,144])40    41    imgout = img.unfold(1, 12, 8).unfold(2, 12, 8).squeeze()42    imgout = imgout.permute(3,4,1,2,0)43    44    imgout = imgout.reshape([144,-1])45    46    x_hat = torch.zeros([d_x1,2,2,batch_size])4748    49    cause1 = fista(x_hat,C,B,A,imgout,0,d_x1,d_u1,batch_size)5051    return cause1        52def pre_process(im):5354    mn = im.mean()5556    sd = im.std()5758   5960    k = np.ones([5,5])6162    k = k/k.sum()6364   6566    im = (im-mn)/sd6768   6970    lmn = cv2.filter2D(im,-1,k)7172    lmnsq = cv2.filter2D(im**2,-1,k)7374    lvar = lmnsq - lmn**27576    lvar[lvar<0] = 07778    lstd = np.sqrt(lvar)7980    lstd[lstd<1] = 18182   8384    im = im - lmn8586    im = im / lstd8788    return im        89        90        91        92        93        94        95        96        97        98        99        100#def fista(x_hat,C,B,A,imgout,lamda1,d_x1,d_u,batch_size):101#    state_cell = []102#    for rl in range(2):103#        for cl in range(2):104#            d= {'mu':0.01/d_x1,'L': np.ones([1,batch_size]),'eta':2 , 'tk' :1,'x_hat':x_hat[:,rl,cl,:],'xk':x_hat[:,rl,cl,:],'z_k':x_hat[:,rl,cl,:],'tk_1':1,'CTy':np.matmul(np.transpose(C),imgout[:,rl,cl,:]),'y':imgout[:,rl,cl,:],'xk_1':x_hat[:,rl,cl,:]}105#            106#            state_cell.append(d)107#    CTC = np.matmul(np.transpose(C),C)108#    sumx1 = np.abs(state_cell[0]['xk'])+np.abs(state_cell[1]['xk'])+np.abs(state_cell[2]['xk'])+np.abs(state_cell[3]['xk'])109#    #sumx1 = np.maximum(np.maximum(np.maximum(np.abs(state_cell[0]['xk']),np.abs(state_cell[1]['xk'])),np.abs(state_cell[2]['xk'])),np.abs(state_cell[3]['xk']))110#    keep_going = 1;111#    cause= {'L': np.ones([1,batch_size]),'eta':1.5 , 'tk' :1,'uk':np.zeros([d_u,batch_size]),'uk_1':np.zeros([d_u,batch_size]),'z_k':np.zeros([d_u,batch_size]),'tk_1':1,'x':sumx1,'gamma_bar':1e-3,'gamma':0.5}112#    gama1output = (1+np.exp(-np.matmul(B,cause['uk'])))/2113#    keep_going = 1;114#    nz_x=np.zeros([4,d_x1,batch_size])115#    for reg in range(4):116#        nz_x[reg,:,:] = (np.abs(state_cell[reg]['xk']) >= 2.2204e-15)117#    maxprotect = 0118#    while keep_going:119#        for reg in range(4):120##            alpha = (state_cell[reg]['z_k']-state_cell[reg]['x_hat'])/state_cell[reg]['mu']#(d_x1,4)121##            alpha[alpha > 1] =1;122##            alpha[alpha < -1] = -1123#            grad_zk =(np.matmul(CTC,state_cell[reg]['z_k']) - state_cell[reg]['CTy'])#+ lamda1*alpha;124#            const = 0.5*np.sum((np.matmul(C,state_cell[reg]['z_k'])-state_cell[reg]['y'])**2,0)# + lamda1*np.sum(np.abs(state_cell[reg]['z_k'] - state_cell[reg]['x_hat']),0);125#            stop_linesearch = np.zeros([batch_size,1])126#            protect = 0127#            128#            while np.sum(stop_linesearch) != batch_size:129#                gk=state_cell[reg]['z_k'] - grad_zk/state_cell[reg]['L']130#                state_cell[reg]['xk'] = np.sign(gk)*np.maximum((np.abs(gk)-gama1output/state_cell[reg]['L']),0)131#                132#                #lossx.append(0.5*np.sum((np.matmul(C,state_cell[reg]['xk'])-state_cell[reg]['y'])**2) + lamda1*np.sum(np.abs(state_cell[reg]['xk'] - state_cell[reg]['x_hat']))+0*1/batch_size*np.sum(gama1output*cause['x']))133#                #lossx.append(np.sum(const))134#                temp1 = 0.5*np.sum((state_cell[reg]['y'] - np.matmul(C,state_cell[reg]['xk']))**2,0)# + lamda1*np.sum(abs(state_cell[reg]['xk'] - state_cell[reg]['x_hat']),0) #(y-x)**2 x(144,5)-->(1,5)135#                temp2 = const + np.sum((state_cell[reg]['xk'] - state_cell[reg]['z_k'])*grad_zk,0) + np.sum((state_cell[reg]['L']/2),0)*np.sum(((state_cell[reg]['xk'] - state_cell[reg]['z_k']))**2,0)#(1,5)+(1,5).*sum((x-x')**2)136#                indx = (temp1<= temp2) 137#                stop_linesearch[indx] = True;138#                decay = np.ones([batch_size,1])139#                decay = (1-stop_linesearch)*state_cell[reg]['eta']140#                decay[decay==0] = 1141#                state_cell[reg]['L'] = np.transpose(decay)*state_cell[reg]['L']142#                protect += 1 143#                if protect>5:144#                    break145#                #print(count)146#147#        148#                149#                #state_cell[reg]['L'][indx] = self.eta*state_cell[reg]['L'](~indx);150#            state_cell[reg]['tk_1'] = (1 + np.sqrt(4*state_cell[reg]['tk']**2 + 1))/2151#            state_cell[reg]['z_k'] = state_cell[reg]['xk'] + (state_cell[reg]['tk'] - 1)/(state_cell[reg]['tk_1'])*(state_cell[reg]['xk'] - state_cell[reg]['xk_1']);152#            state_cell[reg]['xk_1'] = state_cell[reg]['xk']153#            state_cell[reg]['tk'] = state_cell[reg]['tk_1'];154#            X = state_cell[reg]['xk']155#        cause['x']=np.abs(state_cell[0]['xk'])+np.abs(state_cell[1]['xk'])+np.abs(state_cell[2]['xk'])+np.abs(state_cell[3]['xk'])156#        #sess.run(u1_int, feed_dict={u1_p_r:cause['z_k'],})157#        exp_func = np.exp(-np.matmul(B,cause['z_k']))/2 #np.array(sess.run(gama1t))[:,0,0,:]-0.5#exp(-Para.B*self.z_k)/2;158#             159#        grad_zk = -np.matmul(np.transpose(B),(exp_func*cause['x']))#'*(exp_func.*self.x);160#        const = np.sum(cause['x']*np.exp(-np.matmul(B,cause['z_k']))/2,0)#sum(alpha*(self.x.*exp(-Para.B*self.z_k)/2),1);161#        stop_linesearch = np.zeros([batch_size,1]);162#        protect = 0163#        while np.sum(stop_linesearch) != batch_size:164#            gk=cause['z_k'] - grad_zk/cause['L']165#            cause['uk'] = np.sign(gk)*np.maximum((np.abs(gk)-cause['gamma']/cause['L']),0)166#            167#            #lossu.append(0*0.5*np.sum((np.matmul(C,state_cell[reg]['xk'])-state_cell[reg]['y'])**2,0) + 0*lamda1*np.sum(np.abs(state_cell[reg]['xk'] - state_cell[reg]['x_hat']),0)+1/batch_size*np.sum(gama1output*cause['x']))168#            #lossu.append(np.sum(const))169#            #sess.run(u1_int, feed_dict={u1_p_r:cause['uk'],})170#            temp1 = np.sum(cause['x']*np.exp(-np.matmul(B,cause['uk']))/2,0)171#            temp2 = const+np.sum((cause['uk'] - cause['z_k'])*grad_zk,0)+ np.sum((cause['L']/2),0)*np.sum(((cause['uk'] - cause['z_k']))**2,0)172#            indx = (temp1<= temp2) 173#            stop_linesearch[indx] = True;174#            decay = np.ones([batch_size,1])175#            decay = (1-stop_linesearch)*cause['eta']176#            decay[decay==0] = 1177#            cause['L'] = np.transpose(decay)*cause['L']178#            protect += 1179#            if protect>5:180#                break181#182# 183#        cause['gamma'] = np.maximum(0.99*cause['gamma'],cause['gamma_bar']);184#        cause['tk_1'] = (1 + np.sqrt(4*cause['tk']**2 + 1))/2185#        cause['z_k'] = cause['uk'] + (cause['tk'] - 1)/(cause['tk_1'])*(cause['uk'] - cause['uk_1']);186#        cause['uk_1'] = cause['uk']187#        cause['tk'] = cause['tk_1'];188#            189#        gama1output = (1+np.exp(-np.matmul(B,cause['uk'])))/2190#       191#            192#        193#        194#        #keep_going -= 1195#        #print(keep_going)196#        #U = cause['uk']197# 198#        nz_x_prev = nz_x199#        nz_x=np.zeros([4,d_x1,batch_size])200#        for reg in range(4):201#            nz_x[reg,:,:] = (np.abs(state_cell[reg]['xk']) >= 2.2204e-15)202#            203#        num_changes_active = 0204#        for reg in range(4):205#            num_changes_active = num_changes_active+np.sum(nz_x[reg,:,:] != nz_x_prev[reg,:,:])206#        num_nz_x = np.sum(nz_x)207#        if num_nz_x >= 1:208#            criterionActiveSet = num_changes_active / num_nz_x209#            keep_going = (criterionActiveSet > 0.01)210#        maxprotect += 1211#        if maxprotect >=20:212#            keep_going=0213#            214#    return cause,state_cell  215def fista(x_hat,C,B,A,imgout,lamda1,d_x1,d_u,batch_size):216    217    state_cell = []218    for rl in range(1):219        for cl in range(1):220            d= {'mu':0.01/d_x1,'L': torch.ones([1,int(batch_size*4)]),'eta':2 , 'tk' :1,'x_hat':torch.zeros([d_x1,int(batch_size*4)]),'xk':torch.zeros([d_x1,int(batch_size*4)]),'z_k':torch.zeros([d_x1,int(batch_size*4)]),'tk_1':1,'CTy':torch.mm(C.T,imgout),'y':imgout,'xk_1':torch.zeros([d_x1,int(batch_size*4)])}221            222            state_cell.append(d)223    224    225    CTC = torch.mm(C.T,C)226    sumx1 = torch.abs(state_cell[0]['xk'])+torch.abs(state_cell[0]['xk'])+torch.abs(state_cell[0]['xk'])+torch.abs(state_cell[0]['xk'])227    228    229    cause= {'L': torch.ones([1,batch_size]),'eta':1.5 , 'tk' :1,'uk':torch.zeros([d_u,batch_size]),'uk_1':torch.zeros([d_u,batch_size]),'z_k':torch.zeros([d_u,batch_size]),'tk_1':1,'x':sumx1,'gamma_bar':1e-3,'gamma':0.5}230    gama1output = torch.zeros([d_x1,2,2,batch_size])231    for rl in range(2):232        for cl in range(2):233            gama1output[:,rl,cl,:] = (1+torch.exp(-torch.mm(B,cause['uk'])))/2234    gama1output = gama1output.reshape([d_x1,-1])235    keep_going = 1236    nz_x=torch.zeros([4,d_x1,batch_size])237    for reg in range(4):238        nz_x[reg,:,:] = (state_cell[0]['xk'].reshape([d_x1,4,batch_size])[:,reg,:] >= 2.2204e-15)239    maxprotect = 0240241    while keep_going:242        243        for reg in range(1):#4244            245#            alpha = (state_cell[reg]['z_k']-state_cell[reg]['x_hat'])/state_cell[reg]['mu']#(d_x1,4)246#            alpha[alpha > 1] =1;247#            alpha[alpha < -1] = -1248            grad_zk =(torch.mm(CTC,state_cell[reg]['z_k']) - state_cell[reg]['CTy'])#+ lamda1*alpha;249            const = 0.5*torch.sum((torch.mm(C,state_cell[reg]['z_k'])-state_cell[reg]['y'])**2,dim = 0)# + lamda1*torch.sum(torch.abs(state_cell[reg]['z_k'] - state_cell[reg]['x_hat']),dim = 0);250            stop_linesearch = torch.zeros([int(4*batch_size),1])251            protect = 0252            #print(reg)253            254            while torch.sum(stop_linesearch) != int(batch_size*4):255                gk=state_cell[reg]['z_k'] - grad_zk/state_cell[reg]['L']256                #mks, _ = torch.max((torch.abs(gk)-gama1output/state_cell[reg]['L']),dim = 0)257                mks = torch.abs(gk)-gama1output/state_cell[reg]['L']258                mks[mks<0] = 0259260                state_cell[reg]['xk'] = torch.sign(gk)*mks261                262#                lsx = 0.5*torch.sum((torch.mm(C,state_cell[reg]['xk'])-state_cell[reg]['y'])**2)# + lamda1*torch.sum(torch.abs(state_cell[reg]['xk'] - state_cell[reg]['x_hat']))+0*1/batch_size*torch.sum(gama1output*cause['x'])263#                lossx.append(np.asscalar(torch.sum(lsx).cpu().numpy()))264                temp1 = 0.5*torch.sum((state_cell[reg]['y'] - torch.mm(C,state_cell[reg]['xk']))**2,dim = 0)# + lamda1*torch.sum(abs(state_cell[reg]['xk'] - state_cell[reg]['x_hat']),dim = 0) #(y-x)**2 x(144,5)-->(1,5)265                temp2 = const + torch.sum((state_cell[reg]['xk'] - state_cell[reg]['z_k'])*grad_zk,0) + torch.sum((state_cell[reg]['L']/2),dim = 0)*torch.sum(((state_cell[reg]['xk'] - state_cell[reg]['z_k']))**2,dim = 0)#(1,5)+(1,5).*sum((x-x')**2)266                indx = (temp1<= temp2)267                stop_linesearch[indx] = True268                decay = torch.ones([int(batch_size*4),1])269                decay = (1-stop_linesearch)*state_cell[reg]['eta']270                decay[decay==0] = 1271                state_cell[reg]['L'] = (decay.T)*state_cell[reg]['L']272                protect += 1273                if protect>5:274                    break 275                #print('xk')276            state_cell[reg]['tk_1'] = (1 + np.sqrt(4*state_cell[reg]['tk']**2 + 1))/2277            state_cell[reg]['z_k'] = state_cell[reg]['xk'] + (state_cell[reg]['tk'] - 1)/(state_cell[reg]['tk_1'])*(state_cell[reg]['xk'] - state_cell[reg]['xk_1']);278            state_cell[reg]['xk_1'] = state_cell[reg]['xk']279            state_cell[reg]['tk'] = state_cell[reg]['tk_1']280281                282        stc = state_cell[0]['xk'].reshape([d_x1,4,batch_size])283        cause['x']=torch.abs(stc[:,0,:])+torch.abs(stc[:,1,:])+torch.abs(stc[:,2,:])+torch.abs(stc[:,3,:])          284        285        exp_func = torch.exp(-torch.mm(B,cause['z_k']))/2286        grad_zk = -torch.mm(B.T,(exp_func*cause['x']))287        const = torch.sum(cause['x']*torch.exp(-torch.mm(B,cause['z_k']))/2,dim = 0)288        stop_linesearch = torch.zeros([batch_size,1])289        protect = 0290291        292        while torch.sum(stop_linesearch) != batch_size:293            #print('uk')294            gk=cause['z_k'] - grad_zk/cause['L']295            #mks, _ = torch.max((torch.abs(gk)-cause['gamma']/cause['L']),dim = 0)296            297            mks = torch.abs(gk)-cause['gamma']/cause['L']298            mks[mks<0] = 0299        300            cause['uk'] = torch.sign(gk) * mks301            302            303#            lsu = (1/batch_size*torch.sum(gama1output*cause['x']))304#            lossu.append(np.asscalar(torch.sum(const).cpu().numpy()))305            306            temp1 = torch.sum(cause['x']*torch.exp(-torch.mm(B,cause['uk']))/2,dim = 0)307            temp2 = const+torch.sum((cause['uk'] - cause['z_k'])*grad_zk,dim = 0)+ torch.sum((cause['L']/2),dim = 0)*torch.sum(((cause['uk'] - cause['z_k']))**2,dim =0)308            indx = (temp1<= temp2) 309            stop_linesearch[indx] = True;310            decay = torch.ones([batch_size,1])311            decay = (1-stop_linesearch)*cause['eta']312            decay[decay==0] = 1313            cause['L'] = (decay.T)*cause['L']314            protect += 1315            if protect>5:316                break 317            318        cause['gamma'] = np.maximum(0.99*cause['gamma'],cause['gamma_bar']);319        cause['tk_1'] = (1 + np.sqrt(4*cause['tk']**2 + 1))/2320        cause['z_k'] = cause['uk'] + (cause['tk'] - 1)/(cause['tk_1'])*(cause['uk'] - cause['uk_1']);321        cause['uk_1'] = cause['uk']322        cause['tk'] = cause['tk_1'];323                324        gama1output = torch.zeros([d_x1,2,2,batch_size])325        for rl in range(2):326            for cl in range(2):327                gama1output[:,rl,cl,:] = (1+torch.exp(-torch.mm(B,cause['uk'])))/2328        gama1output = gama1output.reshape([d_x1,-1])329        nz_x_prev = nz_x330        nz_x=torch.zeros([4,d_x1,batch_size])331        for reg in range(4):332            nz_x[reg,:,:] = (torch.abs(state_cell[0]['xk'].reshape([d_x1,4,batch_size])[:,reg,:]) >= 2.2204e-15)333        num_changes_active = 0334        for reg in range(4):335            num_changes_active = num_changes_active+torch.sum(nz_x[reg,:,:] != nz_x_prev[reg,:,:])336        num_nz_x = torch.sum(nz_x)337        if num_nz_x >= 1:338            criterionActiveSet = num_changes_active / num_nz_x339            keep_going = (criterionActiveSet > 0.2)340        maxprotect += 1341        if maxprotect >=20:342            keep_going=0343    return cause344345346def Gamma(test_image,k=np.array([1, 20, 1, 30, 1, 35], dtype=float),mu = np.array([4, 4, 4, 4, 4, 4], dtype=float)):347    test_image = foai.ImageObject(test_image)348    foveation_prior = foac.matlab_style_gauss2D(test_image.modified.shape, 300)349    kernel = foac.gamma_kernel(test_image,k = k,mu = mu)350    image_height = 240351    image_width = 256352    rankCount = 8  # Number of maps scans to run353    img_sz = 32354    processed_patch = np.empty((1, rankCount, img_sz, img_sz, 3))355    processed_location = np.empty((1, rankCount, 2))356    357    foac.convolution(test_image, kernel, foveation_prior)358    map1 = foas.salience_scan(test_image, rankCount=rankCount)359    processed_patch = (test_image.patch).astype(np.uint8)360    processed_location = (test_image.location).astype(np.int)361    #map1 = (test_image.salience_map).astype(np.int)362    return processed_patch,processed_location, map1363364#scio.savemat('state21.mat',mdict={'state': processed_patch})365366def DPCN_score(processed_patch,d_x1,d_u1,d_input1,batch_size,A,B,C,ind_ac):367    length = len(processed_patch)368    state = processed_patch/255369    state = 0.2989 * state[:,:,:,0] + 0.5870 * state[:,:,:,1] + 0.1140 * state[:,:,:,2]370    for i in range(length):371        state[i,:,:] = pre_process(state[i,:,:])372    img_sz = 32373    AA = np.zeros([length,img_sz-20,20,20,img_sz-20])374    for ss in range(length):375        for i in range(img_sz-20):376            for j in range(img_sz-20):377                AA[ss,i,:,:,j] = state[ss,i:i+20,j:j+20]378    AA = np.transpose(AA,[0,2,3,1,4])379    AA = np.reshape(AA,[length,20,20,-1])      380    AA = np.transpose(AA,[0,3,1,2])381    AA = np.reshape(AA,[-1,20,20])382#    output = np.zeros([d_u1,1,batch_size])383    384    for i in range(1):385        #print(i)386        #input_img = AA[:,:,:,i]387        388        input_img = torch.Tensor(AA)  389        390        U = dpcn_detect(input_img,A,B,C,d_x1,d_u1,int(length*batch_size))391392#        output[:,i,:] = U['uk']393    output = np.reshape(U['uk'].cpu().numpy(),[d_u1,length,batch_size])394395    #output = np.reshape(output,[60,12,12])396    output[ind_ac,:,:] = 0397    output[output<1.2] = 0398    a = np.sum(output,0)399    a = np.reshape(a,[length,batch_size])400401    return a402403404def thresholding(count_patch,processed_patch,mario):405    if mario and (np.where(count_patch!=0))[0]!=[]:406        407        pt = (np.where(count_patch==np.max(count_patch)))[0]408    else:409        pt = (np.where(count_patch!=0))[0]410    #411    ptt = np.zeros([5,])412    ptt[0:len(pt)] = pt[0:5]413    ind = True414    if len(pt) == 0:415        ind = False416    417    if len(pt) <5 and len(pt)>0:418        for i in range(5-len(pt)):419            ptt[-(i+1)] = pt[-1]420    421        422    ptt = ptt.astype(np.int64)        423    #ptt = np.concatenate([ptt])       424    425    426#    batch = (np.flip(processed_patch[ptt,:,:,:],-1)/255).astype(np.float32) #color channel for different api427#    batch = batch[:,2:30,2:30,:]  428#    batch_concat = np.concatenate([batch,batch_anchor],0)429    return ptt,ind,len(pt)430431432433def scores_locations(img,A,B,C,ind_ac,processed_patch,processed_location,mario = False):434#    start = time.time()435    test_image = img436    if mario:437        ppp = []438        for i in range(len(processed_patch)):439            ppp.append(cv2.flip(processed_patch[i],1))440        ppp = np.array(ppp)441    #processed_patch,processed_location = Gamma(test_image)442#    scores_patch = np.zeros([8,])#rank443    #count_patch = np.zeros([8,])#rank444    445    446    447    for i in range(1):#rank448        a = DPCN_score(processed_patch,150,60,144,int(12*12),A,B,C,ind_ac)449        scores_patch = np.sum(a,1)450        count_patch = (np.sum(a,1)>0).astype(np.int)451        452    if mario:453        for i in range(1):#rank454            a_f = DPCN_score(ppp,150,60,144,int(12*12),A,B,C,ind_ac)455            scores_patch_f = np.sum(a_f,1)456            count_patch_f = (np.sum(a,1)>0).astype(np.int)457            458        a = np.maximum(a,a_f)459        scores_patch = np.maximum(scores_patch,scores_patch_f)460        count_patch = np.maximum(count_patch,count_patch_f)461462    ptt,ind,num_object = thresholding(scores_patch,processed_patch,mario)463#    stop = time.time()464#    print("Salience Map Generation: ", stop - start, " seconds")465    return (processed_patch[ptt,:])[0:num_object,:],(processed_location[ptt,:])[0:num_object,:],num_object,ind,a466467468    469470471def labeling(batch_concat,batch_size_c,net,device,num_object):472    oringinal_data = IICDataset(ip = batch_concat,transform = 1)473    474            475    feed_x = torch.ones((batch_size_c,3,28,28), dtype=torch.float)476    for s in range(batch_size_c):477        sample = oringinal_data[s]   478        feed_x[s,:,:,:] = sample['image']479    feed_x.to(device)480    latent_x,_ =net(feed_x)481    lab_a = (latent_x.cpu().detach().numpy())482    483    return lab_a[0:num_object]484485486487def TK(location_p,location):488    #_,location,_ = Gamma(frame)489    490    location_r = np.zeros(location_p.shape)491    for i in range(np.size(location_p,0)):492        dis_list = []493        for j in range(len(location)):494            dis = (location[j][0]-location_p[i][0])**2+(location[j][1]-location_p[i][1])**2495            dis_list.append(dis)496        dis_list = np.array(dis_list)497        index = np.where(dis_list == np.min(dis_list))[0][0]498        location_r[i,:] = location[index]499    500    return location_r501502503def Merge(processed_location3):504    processed_location3_1 = np.zeros(processed_location3.shape)505    biaoji = np.arange(len(processed_location3))506    for i in range(len(processed_location3)):507        for j in range(len(processed_location3)-1-i):508            if np.abs(processed_location3[i,1] - processed_location3[j,1])<5:509                processed_location3[j,0] =min(processed_location3[j,0],processed_location3[i,0])510                biaoji[i] = 100511    512    return processed_location3[np.where(biaoji!=100)]513514515def truth_matching(img_gray,template):516517#template = img[177:189,114:143,:]518    # print("GT")519    w, h = template.shape[::-1]520    521    res = cv2.matchTemplate(img_gray,template,cv2.TM_CCOEFF_NORMED)522    523    524    threshold = 0.7525    return np.where( res >= threshold)526527# to accurately detect mario vs blocks528def truth_matching_mario(img_gray,template):529530#template = img[177:189,114:143,:]531    # print("GT")532    w, h = template.shape[::-1]533    534    res = cv2.matchTemplate(img_gray,template,cv2.TM_CCOEFF_NORMED)535    536    537    threshold = 0.75538    return np.where( res >= threshold)539540if __name__ == "__main__":
...corrCUDA.py
Source:corrCUDA.py  
1import numpy as np2import accelerate.cuda.blas as blas3import accelerate.cuda.fft as ft4from numba import cuda5def corr_td_single (x1,x2):6    c_12 = blas.dot(x1,x2)7    return c_128def best_grid_size(size, tpb):9    bpg = np.ceil(np.array(size, dtype=np.float) / tpb).astype(np.int).tolist()10    return tuple(bpg)11@cuda.jit('void(float32[:], float32[:])')12def mult_inplace(img, resp):13    i = cuda.grid(1)14    img[i] *= resp[i]15def corr_FD(x1,x2):16    threadperblock = 32, 817    blockpergrid = best_grid_size(tuple(reversed(x1.shape)), threadperblock)18    print('kernel config: %s x %s' % (blockpergrid, threadperblock))19    # Trigger initialization the cuFFT system.20    # This takes significant time for small dataset.21    # We should not be including the time wasted here22    #ft.FFTPlan(shape=x1.shape, itype=np.float32, otype=np.complex64)23    X1 = x1.astype(np.float32)24    X2 = x2.astype(np.float32)25    stream1 = cuda.stream()26    stream2 = cuda.stream()27    fftplan1 = ft.FFTPlan(shape=x1.shape, itype=np.float32,28                       otype=np.complex64, stream=stream1)29    fftplan2 = ft.FFTPlan(shape=x2.shape, itype=np.float32,30                       otype=np.complex64, stream=stream2)31    # pagelock memory32    with cuda.pinned(X1, X2):33        # We can overlap the transfer of response_complex with the forward FFT34        # on image_complex.35        d_X1 = cuda.to_device(X1, stream=stream1)36        d_X2 = cuda.to_device(X2, stream=stream2)37        fftplan1.forward(d_X1, out=d_X1)38        fftplan2.forward(d_X2, out=d_X2)39        print ('d_X1 is ',np.shape(d_X1),type(d_X1),np.max(d_X1))40        print ('d_X2 is ',np.shape(d_X2),type(d_X2),np.max(d_X2))41        stream2.synchronize()42        mult_inplace[blockpergrid, threadperblock, stream1](d_X1, d_X2)43        fftplan1.inverse(d_X1, out=d_X1)44        # implicitly synchronizes the streams45        c = d_X1.copy_to_host().real / np.prod(x1.shape)...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
