Best Python code snippet using fMBT_python
test_instance_actions.py
Source:test_instance_actions.py  
1# Copyright 2016 IBM Corp.2#3#    Licensed under the Apache License, Version 2.0 (the "License"); you may4#    not use this file except in compliance with the License. You may obtain5#    a copy of the License at6#7#         http://www.apache.org/licenses/LICENSE-2.08#9#    Unless required by applicable law or agreed to in writing, software10#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT11#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the12#    License for the specific language governing permissions and limitations13#    under the License.14import mock15from oslo_policy import policy as oslo_policy16from nova import exception17from nova import policy18from nova import test19from nova.tests import fixtures as nova_fixtures20from nova.tests.functional.api import client21from nova.tests.functional import fixtures as func_fixtures22from nova.tests.functional import integrated_helpers23class InstanceActionsTestV2(integrated_helpers._IntegratedTestBase):24    """Tests Instance Actions API"""25    def test_get_instance_actions(self):26        server = self._create_server()27        actions = self.api.get_instance_actions(server['id'])28        self.assertEqual('create', actions[0]['action'])29    def test_get_instance_actions_deleted(self):30        server = self._create_server()31        self._delete_server(server)32        self.assertRaises(client.OpenStackApiNotFoundException,33                          self.api.get_instance_actions,34                          server['id'])35class InstanceActionsTestV21(InstanceActionsTestV2):36    api_major_version = 'v2.1'37class InstanceActionsTestV221(InstanceActionsTestV21):38    microversion = '2.21'39    def setUp(self):40        super(InstanceActionsTestV221, self).setUp()41        self.api.microversion = self.microversion42    def test_get_instance_actions_deleted(self):43        server = self._create_server()44        self._delete_server(server)45        actions = self.api.get_instance_actions(server['id'])46        self.assertEqual('delete', actions[0]['action'])47        self.assertEqual('create', actions[1]['action'])48class HypervisorError(Exception):49    """This is just used to make sure the exception type is in the events."""50    pass51class InstanceActionEventFaultsTestCase(52    test.TestCase, integrated_helpers.InstanceHelperMixin):53    """Tests for the instance action event details reporting from the API"""54    def setUp(self):55        super(InstanceActionEventFaultsTestCase, self).setUp()56        # Setup the standard fixtures.57        self.useFixture(nova_fixtures.GlanceFixture(self))58        self.useFixture(nova_fixtures.NeutronFixture(self))59        self.useFixture(func_fixtures.PlacementFixture())60        self.useFixture(nova_fixtures.RealPolicyFixture())61        # Start the compute services.62        self.start_service('conductor')63        self.start_service('scheduler')64        self.compute = self.start_service('compute')65        api_fixture = self.useFixture(nova_fixtures.OSAPIFixture(66            api_version='v2.1'))67        self.api = api_fixture.api68        self.admin_api = api_fixture.admin_api69    def _set_policy_rules(self, overwrite=True):70        rules = {'os_compute_api:os-instance-actions:show': '',71                 'os_compute_api:os-instance-actions:events:details':72                     'project_id:%(project_id)s'}73        policy.set_rules(oslo_policy.Rules.from_dict(rules),74                         overwrite=overwrite)75    def test_instance_action_event_details_non_nova_exception(self):76        """Creates a server using the non-admin user, then reboot it which77        will generate a non-NovaException fault and put the instance into78        ERROR status. Then checks that fault details are visible.79        """80        # Create the server with the non-admin user.81        server = self._build_server(82            networks=[{'port': nova_fixtures.NeutronFixture.port_1['id']}])83        server = self.api.post_server({'server': server})84        server = self._wait_for_state_change(server, 'ACTIVE')85        # Stop the server before rebooting it so that after the driver.reboot86        # method raises an exception, the fake driver does not report the87        # instance power state as running - that will make the compute manager88        # set the instance vm_state to error.89        self.api.post_server_action(server['id'], {'os-stop': None})90        server = self._wait_for_state_change(server, 'SHUTOFF')91        # Stub out the compute driver reboot method to raise a non-nova92        # exception to simulate some error from the underlying hypervisor93        # which in this case we are going to say has sensitive content.94        error_msg = 'sensitive info'95        with mock.patch.object(96                self.compute.manager.driver, 'reboot',97                side_effect=HypervisorError(error_msg)) as mock_reboot:98            reboot_request = {'reboot': {'type': 'HARD'}}99            self.api.post_server_action(server['id'], reboot_request)100            # In this case we wait for the status to change to ERROR using101            # the non-admin user so we can assert the fault details. We also102            # wait for the task_state to be None since the wrap_instance_fault103            # decorator runs before the reverts_task_state decorator so we will104            # be sure the fault is set on the server.105            server = self._wait_for_server_parameter(106                server, {'status': 'ERROR', 'OS-EXT-STS:task_state': None},107                api=self.api)108            mock_reboot.assert_called_once()109        self._set_policy_rules(overwrite=False)110        server_id = server['id']111        # Calls GET on the server actions and verifies that the reboot112        # action expected in the response.113        response = self.api.api_get('/servers/%s/os-instance-actions' %114                                    server_id)115        server_actions = response.body['instanceActions']116        for actions in server_actions:117            if actions['action'] == 'reboot':118                reboot_request_id = actions['request_id']119        # non admin shows instance actions details and verifies the 'details'120        # in the action events via 'request_id', since microversion 2.51 that121        # we can show events, but in microversion 2.84 that we can show122        # 'details' for non-admin.123        self.api.microversion = '2.84'124        action_events_response = self.api.api_get(125            '/servers/%s/os-instance-actions/%s' % (server_id,126                                                    reboot_request_id))127        reboot_action = action_events_response.body['instanceAction']128        # Since reboot action failed, the 'message' property in reboot action129        # should be 'Error', otherwise it's None.130        self.assertEqual('Error', reboot_action['message'])131        reboot_action_events = reboot_action['events']132        # The instance action events from the non-admin user API response133        # should not have 'traceback' in it.134        self.assertNotIn('traceback', reboot_action_events[0])135        # And the sensitive details from the non-nova exception should not be136        # in the details.137        self.assertIn('details', reboot_action_events[0])138        self.assertNotIn(error_msg, reboot_action_events[0]['details'])139        # The exception type class name should be in the details.140        self.assertIn('HypervisorError', reboot_action_events[0]['details'])141        # Get the server fault details for the admin user.142        self.admin_api.microversion = '2.84'143        action_events_response = self.admin_api.api_get(144            '/servers/%s/os-instance-actions/%s' % (server_id,145                                                    reboot_request_id))146        reboot_action = action_events_response.body['instanceAction']147        self.assertEqual('Error', reboot_action['message'])148        reboot_action_events = reboot_action['events']149        # The admin can see the fault details which includes the traceback,150        # and make sure the traceback is there by looking for part of it.151        self.assertIn('traceback', reboot_action_events[0])152        self.assertIn('in reboot_instance',153                      reboot_action_events[0]['traceback'])154        # The exception type class name should be in the details for the admin155        # user as well since the fault handling code cannot distinguish who156        # is going to see the message so it only sets class name.157        self.assertIn('HypervisorError', reboot_action_events[0]['details'])158    def test_instance_action_event_details_with_nova_exception(self):159        """Creates a server using the non-admin user, then reboot it which160        will generate a nova exception fault and put the instance into161        ERROR status. Then checks that fault details are visible.162        """163        # Create the server with the non-admin user.164        server = self._build_server(165            networks=[{'port': nova_fixtures.NeutronFixture.port_1['id']}])166        server = self.api.post_server({'server': server})167        server = self._wait_for_state_change(server, 'ACTIVE')168        # Stop the server before rebooting it so that after the driver.reboot169        # method raises an exception, the fake driver does not report the170        # instance power state as running - that will make the compute manager171        # set the instance vm_state to error.172        self.api.post_server_action(server['id'], {'os-stop': None})173        server = self._wait_for_state_change(server, 'SHUTOFF')174        # Stub out the compute driver reboot method to raise a nova175        # exception 'InstanceRebootFailure' to simulate some error.176        exc_reason = 'reboot failure'177        with mock.patch.object(178                self.compute.manager.driver, 'reboot',179                side_effect=exception.InstanceRebootFailure(reason=exc_reason)180            ) as mock_reboot:181            reboot_request = {'reboot': {'type': 'HARD'}}182            self.api.post_server_action(server['id'], reboot_request)183            # In this case we wait for the status to change to ERROR using184            # the non-admin user so we can assert the fault details. We also185            # wait for the task_state to be None since the wrap_instance_fault186            # decorator runs before the reverts_task_state decorator so we will187            # be sure the fault is set on the server.188            server = self._wait_for_server_parameter(189                server, {'status': 'ERROR', 'OS-EXT-STS:task_state': None},190                api=self.api)191            mock_reboot.assert_called_once()192        self._set_policy_rules(overwrite=False)193        server_id = server['id']194        # Calls GET on the server actions and verifies that the reboot195        # action expected in the response.196        response = self.api.api_get('/servers/%s/os-instance-actions' %197                                    server_id)198        server_actions = response.body['instanceActions']199        for actions in server_actions:200            if actions['action'] == 'reboot':201                reboot_request_id = actions['request_id']202        # non admin shows instance actions details and verifies the 'details'203        # in the action events via 'request_id', since microversion 2.51 that204        # we can show events, but in microversion 2.84 that we can show205        # 'details' for non-admin.206        self.api.microversion = '2.84'207        action_events_response = self.api.api_get(208            '/servers/%s/os-instance-actions/%s' % (server_id,209                                                    reboot_request_id))210        reboot_action = action_events_response.body['instanceAction']211        # Since reboot action failed, the 'message' property in reboot action212        # should be 'Error', otherwise it's None.213        self.assertEqual('Error', reboot_action['message'])214        reboot_action_events = reboot_action['events']215        # The instance action events from the non-admin user API response216        # should not have 'traceback' in it.217        self.assertNotIn('traceback', reboot_action_events[0])218        # The nova exception format message should be in the details.219        self.assertIn('details', reboot_action_events[0])220        self.assertIn(exc_reason, reboot_action_events[0]['details'])221        # Get the server fault details for the admin user.222        self.admin_api.microversion = '2.84'223        action_events_response = self.admin_api.api_get(224            '/servers/%s/os-instance-actions/%s' % (server_id,225                                                    reboot_request_id))226        reboot_action = action_events_response.body['instanceAction']227        self.assertEqual('Error', reboot_action['message'])228        reboot_action_events = reboot_action['events']229        # The admin can see the fault details which includes the traceback,230        # and make sure the traceback is there by looking for part of it.231        self.assertIn('traceback', reboot_action_events[0])232        self.assertIn('in reboot_instance',233                      reboot_action_events[0]['traceback'])234        # The nova exception format message should be in the details....Service.py
Source:Service.py  
1import re2import time3import datetime4from Method import CPU, Memory, Alarm, Other5from Dao import monitor_SQL6class Service(CPU, Memory, Alarm, Other):7    """https://www.jianshu.com/p/a64ad351ebb28    https://blog.csdn.net/specter11235/article/details/89198032"""9    def __init__(self, ip, port, username, password, dbhost, dbport, dbname, dbuser, dbpwd):10        super(Service, self).__init__()11        self.ip = ip12        self.port = port13        self.username = username14        self.password = password15        self.dbhost = dbhost16        self.dbport = dbport17        self.dbname = dbname18        self.dbuser = dbuser19        self.dbpwd = dbpwd20        """the information of Database connection"""21        self.conn = monitor_SQL(host=dbhost, port=dbport, database=dbname, user=dbuser, passwd=dbpwd)22        """initial reboot times information"""23    def initial_uptime(self):24        # clean up_time table25        self.conn.cleanUpTime(self.ip)26        free, card1, card2, clock = self.uptime()27        # Assign the initial value of system startup time which is the time of script startup28        # to facilitate the subsequent calculation of system restart times29        self.conn.addToUpTime(SECONDS=int(free[0]), IP=self.ip, REBOOT=0, TIME=clock,30                              CARD1_REBOOT_TIMES_BEGIN=card1[0], CARD1_REBOOT_INFO=0,31                              CARD2_REBOOT_TIMES_BEGIN=card2[0], CARD2_REBOOT_INFO=0, CARD1_REBOOT_TIMES=0,32                              CARD2_REBOOT_TIMES=0)33    def read_uptime(self, last_uptime, card1_reboot_times_begin, card2_reboot_times_begin, INTERVAL):34        # uptime35        free, card1, card2, clock = self.uptime()36        seconds = free[0]37        # system and card reboot times38        card1_reboot = int(card1[0]) - int(card1_reboot_times_begin)39        card2_reboot = int(card2[0]) - int(card2_reboot_times_begin)40        Info = self.conn.selectFromUpTime(self.ip)[0]41        reboot_times = int(Info['REBOOT_TIMES'])42        # the system reboot43        if (int(seconds) < int(last_uptime) and int(seconds) < INTERVAL):44            reboot_times += 145            #   when system reboot, the two cards don't start at the same time, so if don't check one more time,46            # the reboot times number will be so big, because the CARD_REBOOT_TIMES_BEGIN has become zero and47            # the "card_reboot = card - card_reboot_times_begin"48            time.sleep(10)49            free, card1, card2, clock = self.uptime()50            seconds = free[0]51            card1_reboot_times_begin = card1[0]52            card2_reboot_times_begin = card2[0]53            # update database54            self.conn.addToUpTime(SECONDS=int(seconds), IP=self.ip, REBOOT=reboot_times, TIME=clock,55                                  CARD1_REBOOT_TIMES=0,56                                  CARD1_REBOOT_INFO=0,57                                  CARD2_REBOOT_TIMES=0,58                                  CARD2_REBOOT_INFO=0,59                                  CARD1_REBOOT_TIMES_BEGIN=card1_reboot_times_begin,60                                  CARD2_REBOOT_TIMES_BEGIN=card2_reboot_times_begin)61            return reboot_times, 0, 0, int(seconds)62        # the system don't reboot but cards do63        elif int(card1[0]) != 0 and int(card2[0]) != 0 and (64                int(card1[0]) > int(card1_reboot_times_begin) or int(card2[0]) > int(card2_reboot_times_begin)):65            self.conn.addToUpTime(SECONDS=int(seconds), IP=self.ip, REBOOT=reboot_times, TIME=clock,66                                  CARD1_REBOOT_TIMES=card1_reboot,67                                  CARD1_REBOOT_INFO=card1[1],68                                  CARD2_REBOOT_TIMES=card2_reboot,69                                  CARD2_REBOOT_INFO=card2[1],70                                  CARD1_REBOOT_TIMES_BEGIN=card1_reboot_times_begin,71                                  CARD2_REBOOT_TIMES_BEGIN=card2_reboot_times_begin)72        else:73            pass74        # system reboot times, card1 reboot times, card2 reboot times75        return reboot_times, card1_reboot, card2_reboot, int(seconds)76    def disk_space_check(self):77        usage, clock = self.disk_space()78        self.conn.addToDiskTotalUse(USE=usage, IP=self.ip, TIME=clock)79        return usage80    # only use in mem_check81    def mem_check_judge(self, TIME, memory, name):82        if TIME == 0 and memory == 0:83            return84        preTIME = self.conn.selectFromMemCheck(self.ip, name)[-1]['TIME']85        if preTIME is None:86            self.conn.addToMemCheck(name, self.ip, TIME, memory)87        else:88            preTIME = preTIME.strftime("%Y-%m-%d %H:%M:%S")89            TIME1 = datetime.datetime.strptime(TIME, "%Y-%m-%d %H:%M:%S")90            preTIME = datetime.datetime.strptime(preTIME, "%Y-%m-%d %H:%M:%S")91            # convert to timestamp to compare92            TIME1 = time.mktime(TIME1.timetuple())93            preTIME = time.mktime(preTIME.timetuple())94            if preTIME < TIME1:95                self.conn.addToMemCheck(name, self.ip, TIME, memory)96            else:97                return98    def mem_check(self, lmd_begin, pon_begin):99        TIME, lmd_memory = self.mem_check_log("lmd")100        self.mem_check_judge(TIME, lmd_memory, "lmd")101        TIME, pon_memory = self.mem_check_log("ponmgrd")102        self.mem_check_judge(TIME, pon_memory, "ponmgrd")103        return float(lmd_memory), float(pon_memory)104        # return (float(lmd_memory) - float(lmd_begin)), (float(pon_memory) - float(pon_begin))105    def read_cpu(self):106        # total memory usage107        free, clock = self.cpu_cal()108        cpu_usage = free[0]109        self.conn.addToCpuTotalUse(CPU=cpu_usage, IP=self.ip, TIME=clock)110        return cpu_usage111    def read_cpu_card(self, number):112        # card cpu usage113        free, clock = self.cpu_cal_card(number, )114        cpu_usage_card = free[0]115        self.conn.addToCpuUseCard(cpu_usage_card, number, self.ip, TIME=clock)116        return cpu_usage_card117    def read_mem(self):118        # total memory usage119        free, clock = self.mem_cal()120        mem_usage = free[0]121        self.conn.addToMemTotalUse(MEM=mem_usage, IP=self.ip, TIME=clock)122        return mem_usage123    def read_mem_card(self, number):124        # card memory usage125        free, clock = self.mem_cal_card(number, )126        mem_usage = free[0]127        self.conn.addToMemUseCard(MEM=mem_usage, CARD=number, IP=self.ip, TIME=clock)128        return mem_usage129    def read_corefile(self):130        # corefile131        free, clock = self.core_file()132        corefile = free[0]133        self.conn.addToCoreFile(COUNT=int(corefile), IP=self.ip, TIME=clock)134        return int(corefile)135    def read_crash(self):136        # crash process137        free, clock = self.crash_process()138        if free:139            crash = free[0]140            if crash:141                print('crash process:', crash)142                self.conn.addToCrash(CRASH=1, IP=self.ip, TIME=clock)143                return 1144            else:145                self.conn.addToCrash(CRASH=0, IP=self.ip, TIME=clock)146                return 0147        else:148            self.conn.addToCrash(CRASH=0, IP=self.ip, TIME=clock)149            return 0150    def alarm_act(self):151        # active alarm152        COUNT, clock = self.active_alarm_check()153        self.conn.addToAlarm(COUNT=COUNT, IP=self.ip, TIME=clock)154        return COUNT155    def ont_action_check(self):156        # ont alarm157        COUNT, clock = self.ont_alarm_check()158        self.conn.addToOntAlarm(COUNT=COUNT, IP=self.ip, TIME=clock)159        return COUNT160    def port_rate_check(self):161        # ont alarm162        rate, clock = self.real_rate_check()163        for key in rate.keys():164            self.conn.addToPortRate(PORT_NAME=key,FIXED_RATE=rate[key][0],ASSURED_RATE=rate[key][1],EXCESS_RATE=rate[key][2], IP=self.ip, TIME=clock)165        return rate166    def ont_online_number(self):167        number, clock = self.ont_online()168        self.conn.addToOntOnLineNumber(COUNT=number, IP=self.ip, TIME=clock)169        return number170    # still not write into DB171    def sensors_temperature(self):172        card1, card2, clock = self.temperature()173        return card1, card2174    def sshclose(self):175        if self.ssh2:176            self.conn.close()177            self.('xxx')178            self.ssh2.close()...rebootDevicesDebug.py
Source:rebootDevicesDebug.py  
1#!/usr/bin/env python32# -*- coding: utf-8 -*-3"""4Copyright (c) 2021 Cisco and/or its affiliates.5This software is licensed to you under the terms of the Cisco Sample6Code License, Version 1.1 (the "License"). You may obtain a copy of the7License at8               https://developer.cisco.com/docs/licenses9All use of the material herein must be in accordance with the terms of10the License. All rights not expressly granted by the License are11reserved. Unless required by applicable law or agreed to separately in12writing, software distributed under the License is distributed on an "AS13IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express14or implied.15"""16import json17import meraki18from webexteamssdk import WebexTeamsAPI19import time20import tokensJason as tokens21import csv22dashboard = meraki.DashboardAPI(tokens.API_KEY, suppress_logging=True)23delay = 224# Read values from the CSV file.25def readCsvFile():26    startTime = time.perf_counter()27    with open('apSerials.csv', 'r') as file:28        read = csv.reader(file)29        csvRows = []30        for row in read:31            csvRows.append(row)32    stopTime = time.perf_counter()33    print(f"readCsvFile completed in {stopTime-startTime}")34    return (csvRows)35# Function to reboot each device by its serial number, and returning a list of36# device serial numbers and its reboot status.37def rebootDevice(deviceSerialNumbers):38    startTime = time.perf_counter()39    rebootStatus = []40    if len(deviceSerialNumbers) != 0:41        for item in deviceSerialNumbers:42            loopStart = time.perf_counter()43            try:44                reboot = dashboard.devices.rebootDevice(serial=item)45                rebootStatus.append({"serial": item, "status": reboot})46                time.sleep(delay)47            except:48                print(f"Exception error: {item} -Check SDK Logs. Continuing.")49                continue50            loopStop = time.perf_counter()51            print(f"Iteration for an AP reboot completed in {loopStop-loopStart} secs.")52    stopTime = time.perf_counter()53    print(f"rebootDevice function completed in {stopTime-startTime} secs.")54    # Return list of serial and status of the reboot trigger55    return(rebootStatus)56# get all devices in ORG and reboot them all57def getAllDevices():58    startTime = time.perf_counter()59    rebootStatus = []60    apCount = 061    model = "MR42"62    productType = "wireless"63    response = dashboard.organizations.getOrganizationDevices(tokens.ORG_ID, total_pages='all')64    getOrgDeviceTime = time.perf_counter()65    print(f"Get all devices for reboot completed in {getOrgDeviceTime-startTime} secs.")66    for device in response:67        loopStart = time.perf_counter()68        if model in device['model']:69            try:70                print(f"{device['serial']}, {device['name']}, {device['model']}")71                apCount += 172                #reboot = dashboard.devices.rebootDevice(serial=device['serial'])73                #rebootStatus.append({"serial": device['serial'], "status": reboot})74                #time.sleep(delay)75            except:76                print(f"Exception error: {device['serial']} -Check SDK Logs. Continuing.")77                continue78            loopStop = time.perf_counter()79            print(f"Iteration for an AP reboot completed in {loopStop-loopStart} secs.")80    print(apCount)81    stopTime = time.perf_counter()82    print(f"rebootDevice function completed in {stopTime-startTime} secs.")83    # Return list of serial and status of the reboot trigger84    return(rebootStatus)85# Function to create a summary status of the reboot results.86# We're counting the number of device reboot success and fails, and returning87# the final count and a list of devices that failed to trigger a reboot.88def rebootStatus(rebootResults):89    startTime = time.perf_counter()90    passCount = 091    failCount = 092    results = []93    apFailList = []94    for item in rebootResults:95        if item["status"]["success"] is True:96            passCount += 197        else:98            failCount += 199            apFailList.append(item["serial"])100    results.append({"apRebooted": passCount, "apNotRebooted": failCount, "apFailList": apFailList})101    print(results)102    stopTime = time.perf_counter()103    print(f"rebootStatus completed in {stopTime-startTime} sec.")104    return (results)105# Function to post the final status results to a WebEx Room.106def postWebex(rebootStatus):107    startTime = time.perf_counter()108    webex = WebexTeamsAPI(access_token=tokens.WEBEX_TOKEN)109    # Post the rebootStatus to the webex room110    webex.messages.create(tokens.WEBEX_ROOMID, markdown=json.dumps(rebootStatus))111    print("Reboot status results posted to Webex room.")112    stopTime = time.perf_counter()113    print(f"postWebex completed in {stopTime-startTime} secs.")114if __name__ == '__main__':115    # Get list of device serial numbers to reboot from CSV file.116    startTime = time.perf_counter()117    rows = readCsvFile()118    deviceSerialNumbers = []119    for row in range(len(rows)):120        # nested loops in case of multiple rows in CSV file121        for serial in rows[row]:122            deviceSerialNumbers.append(serial)123    stopTime = time.perf_counter()124    print(f"Device list from CSV completed in {stopTime-startTime}")125    print("Starting reboot script.  It may take several minutes to complete.")126#    rebootAP = rebootDevice(deviceSerialNumbers)127#    status = rebootStatus(rebootAP)128#    postWebex(status)...main_fedreboot.py
Source:main_fedreboot.py  
1#!/usr/bin/env python32# -*- coding: utf-8 -*-3"""4Created on Wed May 26 22:15:29 20215@author: wangyumeng6"""7import numpy as np8import random9import torch10import torch.nn.functional as F11from torch import optim12from torch.utils.data import random_split, DataLoader, Subset, TensorDataset13from torchvision import datasets, transforms14import utilis15import CNN16from CNN import ConvNet17from FedAvg import local_update, avg_parameters18torch.manual_seed(1234)19# load data20dataset1, dataset2 = random_split(datasets.FashionMNIST('data', train=True, 21                                  transform=transforms.ToTensor()), [10000, 50000])22dataset3, dataset4 = random_split(datasets.FashionMNIST('data', train=False, 23                                  transform=transforms.ToTensor()), [5000, 5000])24# reboot data25sample = utilis.get_sample_from_dataset(dataset2) 26label  = utilis.get_label_from_dataset(dataset2)27sample_loader = DataLoader(sample, shuffle=False)  28# train, valid and test data29train_loader = DataLoader(dataset1, batch_size=25, shuffle=True)30valid_loader = DataLoader(dataset3, batch_size=25, shuffle=True)31test_loader  = DataLoader(dataset4, batch_size=25, shuffle=False)32   33m = 20  34num_comm = 5035#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")36torch.set_num_threads(36)37device = torch.device("cpu")38idx_train = utilis.partition(range(10000), m)39idx_valid = utilis.partition(range(5000), m)40local_train_loader = {}41local_valid_loader = {}42for i in range(m):43    local_train_loader['client'+str(i+1)] = DataLoader(Subset(train_loader.dataset, idx_train[i]), 44                                                             batch_size=25, shuffle=True)45    local_valid_loader['client'+str(i+1)] = DataLoader(Subset(valid_loader.dataset, idx_valid[i]), 46                                                             batch_size=25, shuffle=True)47    48num_acc = []49reboot_CNN = CNN.ConvNet().to(device)50reboot_parameters = reboot_CNN.state_dict()51for i in range(num_comm):52    print("\n----------------communicate round {}----------------\n".format(i+1))53    clients_parameters = {}54    all_outputs = np.empty(shape=[0, 50000], dtype=int)55    56    for k in range(m):57        print("----------------clients {}--------------\n".format(k+1))58        client_train_loader = local_train_loader['client'+str(k+1)]59        client_valid_loader = local_valid_loader['client'+str(k+1)]60        clients_parameters['client'+str(k+1)] = local_update(client_train_loader, 61                                                client_valid_loader, reboot_parameters, device)62        local_CNN = CNN.ConvNet().to(device)63        local_CNN.load_state_dict(clients_parameters['client'+str(k+1)], strict=True)64        outputs = CNN.predict(local_CNN, device, sample_loader)65        outputs = [int(x[0]) for x in outputs]66        all_outputs = np.append(all_outputs, [outputs], axis=0)67        del local_CNN68        69    print('------------reboot estimator------------\n')70    valid_sam = random.sample(range(50000), 2500)71    train_sam = np.delete(range(50000), valid_sam)72    reboot_sample_train = sample[train_sam]73    reboot_sample_valid = sample[valid_sam]74    reboot_label_train = all_outputs[0, train_sam]75    reboot_label_valid = all_outputs[0, valid_sam]76    for j in range(1, m):77        reboot_sample_train = torch.cat((reboot_sample_train, sample[train_sam]), 0)78        reboot_sample_valid = torch.cat((reboot_sample_valid, sample[valid_sam]), 0)79        reboot_label_train = np.hstack((reboot_label_train, all_outputs[j,train_sam]))80        reboot_label_valid = np.hstack((reboot_label_valid, all_outputs[j,valid_sam]))81        reboot_train = TensorDataset(reboot_sample_train, torch.tensor(reboot_label_train))82        reboot_valid = TensorDataset(reboot_sample_valid, torch.tensor(reboot_label_valid))83        reboot_train_loader = DataLoader(reboot_train, batch_size=512, shuffle=True, num_workers=8)84        reboot_valid_loader = DataLoader(reboot_valid, batch_size=512, shuffle=True, num_workers=8)85    optim_reboot = optim.Adam(reboot_CNN.parameters())86    num_acc_reboot = []87    for epoch in range(1, 101):88        CNN.train(reboot_CNN, device, reboot_train_loader, optim_reboot, epoch)89        acc = CNN.test(reboot_CNN, device, reboot_valid_loader)90        num_acc_reboot.append(acc)91        if epoch > 10:92            gap = np.max(num_acc_reboot[-10:]) - np.min(num_acc_reboot[-10:])93            if gap <= .5:  break94    95    #acc_reboot = CNN.test(reboot_CNN, device, test_loader)96    reboot_CNN.eval()97    with torch.no_grad():98        sum_acc = 099        for data, label in test_loader:100            data, label = data.to(device), label.to(device)101            outputs = torch.argmax(reboot_CNN(data), dim=1)102            sum_acc += (outputs == label).sum()103        acc_reboot = 100. * sum_acc.float() / len(test_loader.dataset)104        num_acc.append(acc_reboot)105        print('FedReboot accuracy: {}'.format(acc_reboot))106    107    reboot_parameters = reboot_CNN.state_dict()...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
