How to use update_samples method in autotest

Best Python code snippet using autotest_python

solver.py

Source:solver.py Github

copy

Full Screen

1import torch2import numpy as np3import os4import json5import matplotlib.pyplot as plt6from utils.dataloader import torcs_dataset,synthetic_example7from utils.utils import print_n_txt8from core.net import Decoder,Policy,Encoder9from core.transformer.encoder import Transformer_Encoder10from core.loss import INFONCE_loss,BT_loss,mdn_loss,recon_loss11import torch.nn.functional as F12class SOLVER():13 def __init__(self,args):14 self.device = 'cuda'15 self.args= args16 self.state_only=True17 self.path = './res/{}/'.format(args.id)18 try:19 os.mkdir(self.path)20 except:21 pass22 self.lr = args.lr23 self.load_dataset()24 self.load_model()25 def load_dataset(self):26 if self.args.data == 'torcs':27 self.state_samples = 5028 dataset = torcs_dataset(num_traj=self.args.num_traj,state_samples=self.state_samples)29 self.traj_dim = self.args.num_traj*3130 self.a_dim = 231 self.s_dim = 2932 self.dim = 3133 self.z_dim = 534 self.update_samples = 535 elif self.args.data == 'syn':36 self.state_samples = 4037 dataset = synthetic_example(num_traj=self.args.num_traj,fixed_len=self.args.fixed_len,38 state_samples=self.state_samples)39 self.a_dim = 240 self.s_dim = 241 self.z_dim = 442 if self.state_only:43 self.dim = 244 else:45 self.dim=446 self.traj_dim = self.args.num_traj*self.dim47 self.update_samples = 548 self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size,shuffle=True)49 def load_model(self):50 if self.args.encoder_base == 'mlp':51 self.encoder = Encoder(x_dim=int(self.args.num_traj*self.dim),z_dim=self.z_dim).to(self.device)52 elif self.args.encoder_base == 'transformer':53 self.encoder = Transformer_Encoder(x_dim=self.dim,z_dim=self.z_dim).to(self.device)54 self.policy = Policy(s_dim=self.s_dim,a_dim=self.a_dim,z_dim=self.z_dim).to(self.device)55 # self.model = MODEL(self.args,x_dim=int(self.args.num_traj*self.dim),s_dim=self.s_dim,a_dim=self.a_dim,56 # z_dim=self.z_dim,k=5,sig_max=None).to(self.device)57 if self.args.policy == 'mlp':58 self.weight = [1,1]59 elif self.args.policy == 'mdn':60 self.weight = [1,0.01]61 self.policy.init_param()62 self.encoder.init_param()63 if self.args.recon_loss:64 self.decoder = Decoder(x_dim=int(self.args.num_traj*self.dim),z_dim=5).to(self.device)65 def train(self):66 with open(self.path+'config.json','w') as jf:67 json.dump(self.args.__dict__, jf, indent=2)68 f = open(self.path+'logs.txt', 'w') 69 self.loss = []70 self.ploss = []71 self.eloss = []72 if self.args.loss=='simclr':73 ecriterion = INFONCE_loss74 elif self.args.loss == 'BT':75 ecriterion = BT_loss76 if self.args.policy == 'mlp':77 pcriterion = F.mse_loss78 elif self.args.policy == 'mdn':79 pcriterion = mdn_loss80 flag=081 #optimizer = torch.optim.Adam(self.model.parameters(), self.lr,weight_decay=1e-4)82 poptimizer = torch.optim.Adam(self.policy.parameters(), self.lr,weight_decay=1e-4)83 eoptimizer = torch.optim.Adam(self.encoder.parameters(), self.lr,weight_decay=1e-4)84 for e in range(200):85 total_loss = 086 ploss = 087 eloss = 088 for traj_1,traj_2,state_1,action_1,state_2,action_2,_,traj_len in self.dataloader:89 if self.args.encoder_base == 'mlp':90 traj_1 = traj_1.view(-1,self.traj_dim)91 traj_2 = traj_2.view(-1,self.traj_dim)92 sampled_z_1 = self.encoder(traj_1.to(self.device))93 sampled_z_2 = self.encoder(traj_2.to(self.device))94 else:95 traj_1 = traj_1.view(-1,self.args.num_traj,self.dim)96 traj_2 = traj_2.view(-1,self.args.num_traj,self.dim)97 # traj_1 = torch.cat((torch.zeros(traj_1.size(0),1,self.dim),traj_1),dim=1)98 # traj_2 = torch.cat((torch.zeros(traj_1.size(0),1,self.dim),traj_2),dim=1)99 sampled_z_1 = self.encoder(traj_1.to(self.device))#[:,0,:]100 sampled_z_2 = self.encoder(traj_2.to(self.device))#[:,0,:]101 index = traj_len.unsqueeze(-1).unsqueeze(-1).repeat(1,self.args.num_traj,self.z_dim).to(self.device)102 sampled_z_1 = torch.gather(sampled_z_1,dim=1,index=index)[:,0,:]103 sampled_z_2 = torch.gather(sampled_z_2,dim=1,index=index)[:,0,:]104 encoder_loss = ecriterion(sampled_z_1,sampled_z_2) # encoder loss105 if torch.sum(torch.isnan(sampled_z_1)).item()>0:106 flag=1107 break108 eoptimizer.zero_grad()109 encoder_loss.backward()110 eoptimizer.step()111 112 piter = self.state_samples//self.update_samples113 for p in range(piter):114 sampled_z_1_s = sampled_z_1.detach().unsqueeze(1).repeat(1,self.update_samples,1).to(self.device)115 sampled_z_2_s = sampled_z_2.detach().unsqueeze(1).repeat(1,self.update_samples,1).to(self.device)116 state_1_s = state_1[:,p*self.update_samples:(p+1)*self.update_samples,:].to(self.device)117 state_2_s = state_2[:,p*self.update_samples:(p+1)*self.update_samples,:].to(self.device)118 action_1_s = action_1[:,p*self.update_samples:(p+1)*self.update_samples,:].to(self.device)119 #action_1_s = action_1_s.view(-1,self.s_dim)120 action_2_s = action_2[:,p*self.update_samples:(p+1)*self.update_samples,:].to(self.device)121 input_1 = torch.cat((sampled_z_1_s,state_1_s),dim=-1).view(-1,self.s_dim+self.z_dim)122 pred_1 = self.policy(input_1)123 # print(pred_1[0,:],action_1_s[0,0,:])124 pred_2 = self.policy(torch.cat((sampled_z_2_s,state_2_s),dim=-1).view(-1,self.s_dim+self.z_dim))125 policy_loss = pcriterion(pred_1,action_1_s.view(-1,self.a_dim).to(self.device)) \126 + pcriterion(pred_2,action_2_s.view(-1,self.a_dim).to(self.device)) # Policy Loss127 poptimizer.zero_grad()128 policy_loss.backward()129 poptimizer.step()130 loss = self.weight[0]*encoder_loss + self.weight[1]*policy_loss/self.update_samples131 if self.args.recon_loss:132 recon_1 = self.decoder(sampled_z_1)133 recon_2 = self.decoder(sampled_z_2)134 loss += 0.1*(recon_loss(traj_1,recon_1) + recon_loss(traj_2,recon_2))135 total_loss += loss136 eloss += encoder_loss137 ploss += policy_loss/self.update_samples138 if flag==1:139 break140 total_loss /= len(self.dataloader)141 eloss /= len(self.dataloader)142 ploss /= len(self.dataloader)143 strtemp = ("Epoch: %d loss: %.3f encoder loss: %.3f policy loss: %.3f")%(e,total_loss,eloss,ploss)144 print_n_txt(_chars=strtemp, _f=f)145 self.loss.append(total_loss)146 self.eloss.append(eloss)147 self.ploss.append(ploss)148 torch.save(self.policy.state_dict(),self.path+'policy_final.pt')149 torch.save(self.encoder.state_dict(),self.path+'encoder_final.pt')150 def plot_loss(self):151 plt.figure(figsize=(10,4))152 plt.subplot(1,3,1)153 plt.title("Total Loss")154 plt.xlabel("Epochs")155 plt.plot(self.loss)156 plt.subplot(1,3,2)157 plt.title("Encoder Loss")158 plt.xlabel("Epochs")159 plt.plot(self.eloss)160 plt.subplot(1,3,3)161 plt.title("Policy Loss")162 plt.xlabel("Epochs")163 plt.plot(self.ploss)...

Full Screen

Full Screen

test_databus.py

Source:test_databus.py Github

copy

Full Screen

...10 meta = ChannelMeta(name='RPM')11 sample.channel_metas = [meta]12 sample.samples = [SampleValue(1234, meta)]13 14 dataBus.update_samples(sample)15 dataBus.notify_listeners(None)16 17 value = dataBus.getData('RPM')18 self.assertEqual(value, 1234)19 listenerVal0 = None20 def test_listener(self):21 22 def listener(value):23 self.listenerVal0 = value24 sample = Sample()25 26 meta = ChannelMeta(name='RPM')27 sample.channel_metas = [meta]28 sample.samples = [SampleValue(1111, meta)]29 dataBus = DataBus()30 dataBus.addChannelListener('RPM', listener)31 dataBus.update_samples(sample)32 dataBus.notify_listeners(None)33 self.assertEqual(self.listenerVal0, 1111)34 35 listenerVal1 = None36 listenerVal2 = None37 def test_multiple_listeners(self):38 def listener1(value):39 self.listenerVal1 = value40 def listener2(value):41 self.listenerVal2 = value42 43 dataBus = DataBus()44 dataBus.addChannelListener('RPM', listener1)45 dataBus.addChannelListener('RPM', listener2)46 47 sample = Sample()48 meta = ChannelMeta(name='RPM')49 sample.channel_metas = [meta]50 sample.samples = [SampleValue(1111, meta)]51 52 dataBus.update_samples(sample)53 dataBus.notify_listeners(None)54 self.assertEqual(self.listenerVal1, 1111)55 self.assertEqual(self.listenerVal2, 1111)56 57 listenerVal3 = None58 listenerVal4 = None59 def test_mixed_listeners(self):60 def listener3(value):61 self.listenerVal3 = value62 def listener4(value):63 self.listenerVal4 = value64 65 66 sample = Sample()67 metaRpm = ChannelMeta(name='RPM')68 metaEngineTemp = ChannelMeta(name='EngineTemp')69 sample.channel_metas = [metaRpm, metaEngineTemp]70 sample.samples = [SampleValue(1111, metaRpm)]71 72 dataBus = DataBus()73 dataBus.addChannelListener('RPM', listener3)74 dataBus.addChannelListener('EngineTemp', listener4)75 dataBus.update_samples(sample)76 dataBus.notify_listeners(None)77 #ensure we don't set the wrong listener78 self.assertEqual(self.listenerVal3, 1111)79 self.assertEqual(self.listenerVal4, None)80 81 sample.samples = [SampleValue(1111, metaRpm), SampleValue(199, metaEngineTemp)]82 83 dataBus.update_samples(sample)84 dataBus.notify_listeners(None)85 #ensure we don't affect unrelated channels86 self.assertEqual(self.listenerVal3, 1111)87 self.assertEqual(self.listenerVal4, 199)88 89 def test_no_listener(self):90 sample = Sample()91 meta = ChannelMeta(name='EngineTemp')92 sample.channel_metas = [meta]93 sample.samples = [SampleValue(200, meta)]94 95 dataBus = DataBus()96 dataBus.update_samples(sample)97 dataBus.notify_listeners(None)98 #no listener for this channel, should not cause an error99 100 channelMeta = None101 def test_meta_listener(self):102 dataBus = DataBus()103 104 def metaListener(channel):105 self.channelMeta = channel106 metas = ChannelMetaCollection()107 metas.channel_metas = [ChannelMeta(name='RPM')]108 dataBus.addMetaListener(metaListener)109 dataBus.update_channel_meta(metas)110 dataBus.notify_listeners(None)...

Full Screen

Full Screen

06zmq-pub-sub.py

Source:06zmq-pub-sub.py Github

copy

Full Screen

1from handofcats import as_subcommand2import os3import sys4import zmq5import subprocess6from tinyrpc.protocols.jsonrpc import JSONRPCProtocol7@as_subcommand8def run(*, endpoint: str = "tcp://127.0.0.1:5001"):9 sp = subprocess.Popen([sys.executable, __file__, "server", "--endpoint", endpoint])10 cp = subprocess.Popen([sys.executable, __file__, "client", "--endpoint", endpoint])11 cp.wait()12 sp.terminate()13 print("ok")14@as_subcommand15def server(*, endpoint: str):16 import time17 from random import randrange18 context = zmq.Context()19 publisher = context.socket(zmq.PUB)20 publisher.bind(endpoint)21 message_id = 022 while True:23 zipcode = randrange(10001, 10010)24 temprature = randrange(0, 215) - 8025 relhumidity = randrange(0, 50) + 1026 update = "%05d %d %d %d" % (zipcode, temprature, relhumidity, message_id)27 message_id += 128 print(os.getpid(), update)29 time.sleep(0.5)30 publisher.send(update.encode("utf-8"))31@as_subcommand32def client(*, endpoint: str, zipfilter="10001 "):33 context = zmq.Context()34 subscriber = context.socket(zmq.SUB)35 subscriber.connect(endpoint)36 # receive only message with zipcode being 1000137 subscriber.setsockopt(zmq.SUBSCRIBE, zipfilter.encode("utf-8"))38 update_samples = 1039 for updates in range(update_samples):40 message = subscriber.recv()41 zipcode, temprature, relhumidity, message_id = message.split()42 print(43 "zip:%s, temp:%s, relh:%s, id:%s"44 % (zipcode, temprature, relhumidity, message_id)45 )46 total_temp = float(temprature)47 print(48 "average temprature for zipcode '%s' was '%f'"49 % (zipfilter, total_temp / update_samples)50 )...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run autotest automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful