1# Copyright (c) OpenMMLab. All rights reserved.2import numpy as np3import torch4from mmpose.core import (aggregate_scale, aggregate_stage_flip,5 flip_feature_maps, get_group_preds, split_ae_outputs)6def test_split_ae_outputs():7 fake_outputs = [torch.zeros((1, 4, 2, 2))]8 heatmaps, tags = split_ae_outputs(9 fake_outputs,10 num_joints=4,11 with_heatmaps=[False],12 with_ae=[True],13 select_output_index=[0])14def test_flip_feature_maps():15 fake_outputs = [torch.zeros((1, 4, 2, 2))]16 _ = flip_feature_maps(fake_outputs, None)17 _ = flip_feature_maps(fake_outputs, flip_index=[1, 0])18def test_aggregate_stage_flip():19 fake_outputs = [torch.zeros((1, 4, 2, 2))]20 fake_flip_outputs = [torch.ones((1, 4, 2, 2))]21 output = aggregate_stage_flip(22 fake_outputs,23 fake_flip_outputs,24 index=-1,25 project2image=True,26 size_projected=(4, 4),27 align_corners=False,28 aggregate_stage='concat',29 aggregate_flip='average')30 assert isinstance(output, list)31 output = aggregate_stage_flip(32 fake_outputs,33 fake_flip_outputs,34 index=-1,35 project2image=True,36 size_projected=(4, 4),37 align_corners=False,38 aggregate_stage='average',39 aggregate_flip='average')40 assert isinstance(output, list)41 output = aggregate_stage_flip(42 fake_outputs,43 fake_flip_outputs,44 index=-1,45 project2image=True,46 size_projected=(4, 4),47 align_corners=False,48 aggregate_stage='average',49 aggregate_flip='concat')50 assert isinstance(output, list)51 output = aggregate_stage_flip(52 fake_outputs,53 fake_flip_outputs,54 index=-1,55 project2image=True,56 size_projected=(4, 4),57 align_corners=False,58 aggregate_stage='concat',59 aggregate_flip='concat')60 assert isinstance(output, list)61def test_aggregate_scale():62 fake_outputs = [torch.zeros((1, 4, 2, 2)), torch.zeros((1, 4, 2, 2))]63 output = aggregate_scale(64 fake_outputs, align_corners=False, aggregate_scale='average')65 assert isinstance(output, torch.Tensor)66 assert output.shape == fake_outputs[0].shape67 output = aggregate_scale(68 fake_outputs, align_corners=False, aggregate_scale='unsqueeze_concat')69 assert isinstance(output, torch.Tensor)70 assert len(output.shape) == len(fake_outputs[0].shape) + 171def test_get_group_preds():72 fake_grouped_joints = [np.array([[[0, 0], [1, 1]]])]73 results = get_group_preds(74 fake_grouped_joints,75 center=np.array([0, 0]),76 scale=np.array([1, 1]),77 heatmap_size=np.array([2, 2]))78 assert not results == []79 results = get_group_preds(80 fake_grouped_joints,81 center=np.array([0, 0]),82 scale=np.array([1, 1]),83 heatmap_size=np.array([2, 2]),84 use_udp=True)...

