How to use update_policy method in tempest

Best Python code snippet using tempest_python

cluster_schema.py

Source:cluster_schema.py Github

copy

Full Screen

1# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance4# with the License. A copy of the License is located at5#6# http://aws.amazon.com/apache2.0/7#8# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES9# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and10# limitations under the License.11#12# This module contains all the classes representing the Schema of the configuration file.13# These classes are created by following marshmallow syntax.14#15import copy16import hashlib17import logging18import re19from urllib.request import urlopen20import yaml21from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load, validate, validates, validates_schema22from yaml import YAMLError23from pcluster.aws.aws_api import AWSApi24from pcluster.aws.common import AWSClientError25from pcluster.config.cluster_config import (26 AdditionalPackages,27 AmiSearchFilters,28 AwsBatchClusterConfig,29 AwsBatchComputeResource,30 AwsBatchQueue,31 AwsBatchQueueNetworking,32 AwsBatchScheduling,33 AwsBatchSettings,34 CapacityType,35 CloudWatchDashboards,36 CloudWatchLogs,37 ClusterDevSettings,38 ClusterIam,39 ComputeSettings,40 CustomAction,41 CustomActions,42 Dashboards,43 Dcv,44 DirectoryService,45 Dns,46 Efa,47 EphemeralVolume,48 ExistingFsxOntap,49 ExistingFsxOpenZfs,50 FlexibleInstanceType,51 HeadNode,52 HeadNodeImage,53 HeadNodeNetworking,54 Iam,55 Image,56 Imds,57 IntelSoftware,58 LocalStorage,59 Logs,60 Monitoring,61 PlacementGroup,62 Proxy,63 QueueImage,64 QueueUpdateStrategy,65 Raid,66 Roles,67 RootVolume,68 S3Access,69 SchedulerPluginCloudFormationInfrastructure,70 SchedulerPluginClusterConfig,71 SchedulerPluginClusterInfrastructure,72 SchedulerPluginClusterSharedArtifact,73 SchedulerPluginComputeResource,74 SchedulerPluginComputeResourceConstraints,75 SchedulerPluginDefinition,76 SchedulerPluginEvent,77 SchedulerPluginEvents,78 SchedulerPluginExecuteCommand,79 SchedulerPluginFile,80 SchedulerPluginLogs,81 SchedulerPluginMonitoring,82 SchedulerPluginPluginResources,83 SchedulerPluginQueue,84 SchedulerPluginQueueConstraints,85 SchedulerPluginQueueNetworking,86 SchedulerPluginRequirements,87 SchedulerPluginScheduling,88 SchedulerPluginSettings,89 SchedulerPluginSupportedDistros,90 SchedulerPluginUser,91 SharedEbs,92 SharedEfs,93 SharedFsxLustre,94 SlurmClusterConfig,95 SlurmComputeResource,96 SlurmFlexibleComputeResource,97 SlurmQueue,98 SlurmQueueNetworking,99 SlurmScheduling,100 SlurmSettings,101 Ssh,102 SudoerConfiguration,103 Timeouts,104)105from pcluster.config.update_policy import UpdatePolicy106from pcluster.constants import (107 DELETION_POLICIES,108 DELETION_POLICIES_WITH_SNAPSHOT,109 FSX_LUSTRE,110 FSX_ONTAP,111 FSX_OPENZFS,112 FSX_VOLUME_ID_REGEX,113 LUSTRE,114 ONTAP,115 OPENZFS,116 SCHEDULER_PLUGIN_MAX_NUMBER_OF_USERS,117 SUPPORTED_OSES,118)119from pcluster.models.s3_bucket import parse_bucket_url120from pcluster.schemas.common_schema import (121 AdditionalIamPolicySchema,122 BaseDevSettingsSchema,123 BaseSchema,124 TagSchema,125 get_field_validator,126 validate_no_reserved_tag,127)128from pcluster.validators.cluster_validators import EFS_MESSAGES, FSX_MESSAGES129# pylint: disable=C0302130LOGGER = logging.getLogger(__name__)131# ---------------------- Storage ---------------------- #132class HeadNodeRootVolumeSchema(BaseSchema):133 """Represent the RootVolume schema for the Head node."""134 volume_type = fields.Str(135 validate=get_field_validator("volume_type"),136 metadata={137 "update_policy": UpdatePolicy(138 UpdatePolicy.UNSUPPORTED, action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"]139 )140 },141 )142 iops = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})143 size = fields.Int(144 metadata={145 "update_policy": UpdatePolicy(146 UpdatePolicy.UNSUPPORTED,147 fail_reason=UpdatePolicy.FAIL_REASONS["ebs_volume_resize"],148 action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"],149 )150 }151 )152 throughput = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})153 encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})154 delete_on_termination = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})155 @post_load156 def make_resource(self, data, **kwargs):157 """Generate resource."""158 return RootVolume(**data)159class QueueRootVolumeSchema(BaseSchema):160 """Represent the RootVolume schema for the queue."""161 size = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})162 encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})163 volume_type = fields.Str(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})164 iops = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})165 throughput = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})166 @post_load167 def make_resource(self, data, **kwargs):168 """Generate resource."""169 return RootVolume(**data)170class RaidSchema(BaseSchema):171 """Represent the schema of the parameters specific to Raid. It is a child of EBS schema."""172 raid_type = fields.Int(173 required=True,174 data_key="Type",175 validate=validate.OneOf([0, 1]),176 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},177 )178 number_of_volumes = fields.Int(179 validate=validate.Range(min=2, max=5), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}180 )181 @post_load182 def make_resource(self, data, **kwargs):183 """Generate resource."""184 return Raid(**data)185class EbsSettingsSchema(BaseSchema):186 """Represent the schema of EBS."""187 volume_type = fields.Str(188 validate=get_field_validator("volume_type"),189 metadata={190 "update_policy": UpdatePolicy(191 UpdatePolicy.UNSUPPORTED, action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"]192 )193 },194 )195 iops = fields.Int(metadata={"update_policy": UpdatePolicy.SUPPORTED})196 size = fields.Int(197 metadata={198 "update_policy": UpdatePolicy(199 UpdatePolicy.UNSUPPORTED,200 fail_reason=UpdatePolicy.FAIL_REASONS["ebs_volume_resize"],201 action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"],202 )203 }204 )205 kms_key_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})206 throughput = fields.Int(metadata={"update_policy": UpdatePolicy.SUPPORTED})207 encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})208 snapshot_id = fields.Str(209 validate=validate.Regexp(r"^snap-[0-9a-z]{8}$|^snap-[0-9a-z]{17}$"),210 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},211 )212 volume_id = fields.Str(213 validate=validate.Regexp(r"^vol-[0-9a-z]{8}$|^vol-[0-9a-z]{17}$"),214 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},215 )216 raid = fields.Nested(RaidSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})217 deletion_policy = fields.Str(218 validate=validate.OneOf(DELETION_POLICIES_WITH_SNAPSHOT), metadata={"update_policy": UpdatePolicy.SUPPORTED}219 )220class HeadNodeEphemeralVolumeSchema(BaseSchema):221 """Represent the schema of ephemeral volume.It is a child of storage schema."""222 mount_dir = fields.Str(223 validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}224 )225 @post_load226 def make_resource(self, data, **kwargs):227 """Generate resource."""228 return EphemeralVolume(**data)229class QueueEphemeralVolumeSchema(BaseSchema):230 """Represent the schema of ephemeral volume.It is a child of storage schema."""231 mount_dir = fields.Str(232 validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}233 )234 @post_load235 def make_resource(self, data, **kwargs):236 """Generate resource."""237 return EphemeralVolume(**data)238class HeadNodeStorageSchema(BaseSchema):239 """Represent the schema of storage attached to a node."""240 root_volume = fields.Nested(HeadNodeRootVolumeSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})241 ephemeral_volume = fields.Nested(242 HeadNodeEphemeralVolumeSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}243 )244 @post_load245 def make_resource(self, data, **kwargs):246 """Generate resource."""247 return LocalStorage(**data)248class QueueStorageSchema(BaseSchema):249 """Represent the schema of storage attached to a node."""250 root_volume = fields.Nested(QueueRootVolumeSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})251 ephemeral_volume = fields.Nested(252 QueueEphemeralVolumeSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}253 )254 @post_load255 def make_resource(self, data, **kwargs):256 """Generate resource."""257 return LocalStorage(**data)258class EfsSettingsSchema(BaseSchema):259 """Represent the EFS schema."""260 encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})261 kms_key_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})262 performance_mode = fields.Str(263 validate=validate.OneOf(["generalPurpose", "maxIO"]), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}264 )265 throughput_mode = fields.Str(266 validate=validate.OneOf(["provisioned", "bursting"]), metadata={"update_policy": UpdatePolicy.SUPPORTED}267 )268 provisioned_throughput = fields.Int(269 validate=validate.Range(min=1, max=1024), metadata={"update_policy": UpdatePolicy.SUPPORTED}270 )271 file_system_id = fields.Str(272 validate=validate.Regexp(r"^fs-[0-9a-z]{8}$|^fs-[0-9a-z]{17}$"),273 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},274 )275 deletion_policy = fields.Str(276 validate=validate.OneOf(DELETION_POLICIES), metadata={"update_policy": UpdatePolicy.SUPPORTED}277 )278 @validates_schema279 def validate_file_system_id_ignored_parameters(self, data, **kwargs):280 """Return errors for parameters in the Efs config section that would be ignored."""281 # If file_system_id is specified, all parameters are ignored.282 messages = []283 if data.get("file_system_id") is not None:284 for key in data:285 if key is not None and key != "file_system_id":286 messages.append(EFS_MESSAGES["errors"]["ignored_param_with_efs_fs_id"].format(efs_param=key))287 if messages:288 raise ValidationError(message=messages)289 @validates_schema290 def validate_existence_of_mode_throughput(self, data, **kwargs):291 """Validate the conditional existence requirement between throughput_mode and provisioned_throughput."""292 if kwargs.get("partial"):293 # If the schema is to be loaded partially, do not check existence constrain.294 return295 throughput_mode = data.get("throughput_mode")296 provisioned_throughput = data.get("provisioned_throughput")297 if throughput_mode != "provisioned" and provisioned_throughput:298 raise ValidationError(299 message="When specifying provisioned throughput, the throughput mode must be set to provisioned",300 field_name="ThroughputMode",301 )302 if throughput_mode == "provisioned" and not provisioned_throughput:303 raise ValidationError(304 message="When specifying throughput mode to provisioned,"305 " the provisioned throughput option must be specified",306 field_name="ProvisionedThroughput",307 )308class FsxLustreSettingsSchema(BaseSchema):309 """Represent the FSX schema."""310 storage_capacity = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})311 deployment_type = fields.Str(312 validate=validate.OneOf(["SCRATCH_1", "SCRATCH_2", "PERSISTENT_1", "PERSISTENT_2"]),313 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},314 )315 imported_file_chunk_size = fields.Int(316 validate=validate.Range(min=1, max=512000, error="has a minimum size of 1 MiB, and max size of 512,000 MiB"),317 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},318 )319 export_path = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})320 import_path = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})321 weekly_maintenance_start_time = fields.Str(322 validate=validate.Regexp(r"^[1-7]:([01]\d|2[0-3]):([0-5]\d)$"),323 metadata={"update_policy": UpdatePolicy.SUPPORTED},324 )325 automatic_backup_retention_days = fields.Int(326 validate=validate.Range(min=0, max=35), metadata={"update_policy": UpdatePolicy.SUPPORTED}327 )328 copy_tags_to_backups = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})329 daily_automatic_backup_start_time = fields.Str(330 validate=validate.Regexp(r"^([01]\d|2[0-3]):([0-5]\d)$"), metadata={"update_policy": UpdatePolicy.SUPPORTED}331 )332 per_unit_storage_throughput = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})333 backup_id = fields.Str(334 validate=validate.Regexp("^(backup-[0-9a-f]{8,})$"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}335 )336 kms_key_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})337 file_system_id = fields.Str(338 validate=validate.Regexp(r"^fs-[0-9a-z]{17}$"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}339 )340 auto_import_policy = fields.Str(341 validate=validate.OneOf(["NEW", "NEW_CHANGED", "NEW_CHANGED_DELETED"]),342 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},343 )344 drive_cache_type = fields.Str(345 validate=validate.OneOf(["READ"]), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}346 )347 data_compression_type = fields.Str(348 validate=validate.OneOf(["LZ4"]), metadata={"update_policy": UpdatePolicy.SUPPORTED}349 )350 fsx_storage_type = fields.Str(351 data_key="StorageType",352 validate=validate.OneOf(["HDD", "SSD"]),353 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},354 )355 deletion_policy = fields.Str(356 validate=validate.OneOf(DELETION_POLICIES), metadata={"update_policy": UpdatePolicy.SUPPORTED}357 )358 @validates_schema359 def validate_file_system_id_ignored_parameters(self, data, **kwargs):360 """Return errors for parameters in the FSx config section that would be ignored."""361 # If file_system_id is specified, all parameters are ignored.362 messages = []363 if data.get("file_system_id") is not None:364 for key in data:365 if key is not None and key != "file_system_id":366 messages.append(FSX_MESSAGES["errors"]["ignored_param_with_fsx_fs_id"].format(fsx_param=key))367 if messages:368 raise ValidationError(message=messages)369 @validates_schema370 def validate_backup_id_unsupported_parameters(self, data, **kwargs):371 """Return errors for parameters in the FSx config section that would be ignored."""372 # If file_system_id is specified, all parameters are ignored.373 messages = []374 if data.get("backup_id") is not None:375 unsupported_config_param_names = [376 "deployment_type",377 "per_unit_storage_throughput",378 "storage_capacity",379 "import_path",380 "export_path",381 "imported_file_chunk_size",382 "kms_key_id",383 ]384 for key in data:385 if key in unsupported_config_param_names:386 messages.append(FSX_MESSAGES["errors"]["unsupported_backup_param"].format(name=key))387 if messages:388 raise ValidationError(message=messages)389class FsxOpenZfsSettingsSchema(BaseSchema):390 """Represent the FSX OpenZFS schema."""391 volume_id = fields.Str(392 required=True,393 validate=validate.Regexp(FSX_VOLUME_ID_REGEX),394 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},395 )396class FsxOntapSettingsSchema(BaseSchema):397 """Represent the FSX Ontap schema."""398 volume_id = fields.Str(399 required=True,400 validate=validate.Regexp(FSX_VOLUME_ID_REGEX),401 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},402 )403class SharedStorageSchema(BaseSchema):404 """Represent the generic SharedStorage schema."""405 mount_dir = fields.Str(406 required=True, validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}407 )408 name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})409 storage_type = fields.Str(410 required=True,411 validate=validate.OneOf(["Ebs", FSX_LUSTRE, FSX_OPENZFS, FSX_ONTAP, "Efs"]),412 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},413 )414 ebs_settings = fields.Nested(EbsSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})415 efs_settings = fields.Nested(EfsSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})416 fsx_lustre_settings = fields.Nested(FsxLustreSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})417 fsx_open_zfs_settings = fields.Nested(418 FsxOpenZfsSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}419 )420 fsx_ontap_settings = fields.Nested(FsxOntapSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})421 @validates_schema422 def no_coexist_storage_settings(self, data, **kwargs):423 """Validate that *_settings for different storage types do not co-exist."""424 if self.fields_coexist(425 data,426 ["ebs_settings", "efs_settings", "fsx_lustre_settings", "fsx_open_zfs_settings", "fsx_ontap_settings"],427 **kwargs,428 ):429 raise ValidationError("Multiple *Settings sections cannot be specified in the SharedStorage items.")430 @validates_schema431 def right_storage_settings(self, data, **kwargs):432 """Validate that *_settings param is associated to the right storage type."""433 for storage_type, settings in [434 ("Ebs", "ebs_settings"),435 ("Efs", "efs_settings"),436 (FSX_LUSTRE, "fsx_lustre_settings"),437 (FSX_OPENZFS, "fsx_open_zfs_settings"),438 (FSX_ONTAP, "fsx_ontap_settings"),439 ]:440 # Verify the settings section is associated to the right storage type441 if data.get(settings, None) and storage_type != data.get("storage_type"):442 raise ValidationError(443 "SharedStorage > *Settings section is not appropriate to the "444 f"StorageType {data.get('storage_type')}."445 )446 @post_load447 def make_resource(self, data, **kwargs):448 """Generate the right type of shared storage according to the child type (EBS vs EFS vs FsxLustre)."""449 storage_type = data.get("storage_type")450 shared_volume_attributes = {"mount_dir": data.get("mount_dir"), "name": data.get("name")}451 settings = (452 data.get("efs_settings", None)453 or data.get("ebs_settings", None)454 or data.get("fsx_lustre_settings", None)455 or data.get("fsx_open_zfs_settings", None)456 or data.get("fsx_ontap_settings", None)457 )458 if settings:459 shared_volume_attributes.update(**settings)460 if storage_type == "Efs":461 return SharedEfs(**shared_volume_attributes)462 elif storage_type == "Ebs":463 return SharedEbs(**shared_volume_attributes)464 elif storage_type == FSX_LUSTRE:465 return SharedFsxLustre(**shared_volume_attributes)466 elif storage_type == FSX_OPENZFS:467 return ExistingFsxOpenZfs(**shared_volume_attributes)468 elif storage_type == FSX_ONTAP:469 return ExistingFsxOntap(**shared_volume_attributes)470 return None471 @pre_dump472 def restore_child(self, data, **kwargs):473 """Restore back the child in the schema."""474 adapted_data = copy.deepcopy(data)475 # Move SharedXxx as a child to be automatically managed by marshmallow, see post_load action476 if adapted_data.shared_storage_type == "efs":477 storage_type = "efs"478 elif adapted_data.shared_storage_type == "fsx":479 mapping = {LUSTRE: "fsx_lustre", OPENZFS: "fsx_open_zfs", ONTAP: "fsx_ontap"}480 storage_type = mapping.get(adapted_data.file_system_type)481 else: # "raid", "ebs"482 storage_type = "ebs"483 setattr(adapted_data, f"{storage_type}_settings", copy.copy(adapted_data))484 # Restore storage type attribute485 if adapted_data.shared_storage_type == "fsx":486 mapping = {LUSTRE: FSX_LUSTRE, OPENZFS: FSX_OPENZFS, ONTAP: FSX_ONTAP}487 adapted_data.storage_type = mapping.get(adapted_data.file_system_type)488 else:489 adapted_data.storage_type = storage_type.capitalize()490 return adapted_data491 @validates("mount_dir")492 def shared_dir_validator(self, value):493 """Validate that user is not specifying /NONE or NONE as shared_dir for any filesystem."""494 # FIXME: pcluster2 doesn't allow "^/?NONE$" mount dir to avoid an ambiguity in cookbook.495 # We should change cookbook to solve the ambiguity and allow "^/?NONE$" for mount dir496 # Cookbook location to be modified:497 # https://github.com/aws/aws-parallelcluster-cookbook/blob/develop/recipes/head_node_base_config.rb#L51498 if re.match("^/?NONE$", value):499 raise ValidationError(f"{value} cannot be used as a shared directory")500# ---------------------- Networking ---------------------- #501class HeadNodeProxySchema(BaseSchema):502 """Represent the schema of proxy for the Head node."""503 http_proxy_address = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})504 @post_load505 def make_resource(self, data, **kwargs):506 """Generate resource."""507 return Proxy(**data)508class QueueProxySchema(BaseSchema):509 """Represent the schema of proxy for a queue."""510 http_proxy_address = fields.Str(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})511 @post_load512 def make_resource(self, data, **kwargs):513 """Generate resource."""514 return Proxy(**data)515class BaseNetworkingSchema(BaseSchema):516 """Represent the schema of common networking parameters used by head and compute nodes."""517 additional_security_groups = fields.List(518 fields.Str(validate=get_field_validator("security_group_id")),519 metadata={"update_policy": UpdatePolicy.SUPPORTED},520 )521 security_groups = fields.List(522 fields.Str(validate=get_field_validator("security_group_id")),523 metadata={"update_policy": UpdatePolicy.SUPPORTED},524 )525 @validates_schema526 def no_coexist_security_groups(self, data, **kwargs):527 """Validate that security_groups and additional_security_groups do not co-exist."""528 if self.fields_coexist(data, ["security_groups", "additional_security_groups"], **kwargs):529 raise ValidationError("SecurityGroups and AdditionalSecurityGroups can not be configured together.")530class HeadNodeNetworkingSchema(BaseNetworkingSchema):531 """Represent the schema of the Networking, child of the HeadNode."""532 subnet_id = fields.Str(533 required=True, validate=get_field_validator("subnet_id"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}534 )535 elastic_ip = fields.Raw(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})536 proxy = fields.Nested(HeadNodeProxySchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})537 @post_load538 def make_resource(self, data, **kwargs):539 """Generate resource."""540 return HeadNodeNetworking(**data)541class PlacementGroupSchema(BaseSchema):542 """Represent the schema of placement group."""543 enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})544 id = fields.Str(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})545 @post_load546 def make_resource(self, data, **kwargs):547 """Generate resource."""548 return PlacementGroup(**data)549class QueueNetworkingSchema(BaseNetworkingSchema):550 """Represent the schema of the Networking, child of Queue."""551 subnet_ids = fields.List(552 fields.Str(validate=get_field_validator("subnet_id")),553 required=True,554 validate=validate.Length(equal=1),555 metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},556 )557 assign_public_ip = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})558class SlurmQueueNetworkingSchema(QueueNetworkingSchema):559 """Represent the schema of the Networking, child of slurm Queue."""560 placement_group = fields.Nested(561 PlacementGroupSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}562 )563 proxy = fields.Nested(QueueProxySchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})564 @post_load565 def make_resource(self, data, **kwargs):566 """Generate resource."""567 return SlurmQueueNetworking(**data)568class AwsBatchQueueNetworkingSchema(QueueNetworkingSchema):569 """Represent the schema of the Networking, child of aws batch Queue."""570 @post_load571 def make_resource(self, data, **kwargs):572 """Generate resource."""573 return AwsBatchQueueNetworking(**data)574class SchedulerPluginQueueNetworkingSchema(SlurmQueueNetworkingSchema):575 """Represent the schema of the Networking, child of Scheduler Plugin Queue."""576 @post_load577 def make_resource(self, data, **kwargs):578 """Generate resource."""579 return SchedulerPluginQueueNetworking(**data)580class SshSchema(BaseSchema):581 """Represent the schema of the SSH."""582 key_name = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})583 allowed_ips = fields.Str(validate=get_field_validator("cidr"), metadata={"update_policy": UpdatePolicy.SUPPORTED})584 @post_load585 def make_resource(self, data, **kwargs):586 """Generate resource."""587 return Ssh(**data)588class DcvSchema(BaseSchema):589 """Represent the schema of DCV."""590 enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})591 port = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})592 allowed_ips = fields.Str(validate=get_field_validator("cidr"), metadata={"update_policy": UpdatePolicy.SUPPORTED})593 @post_load594 def make_resource(self, data, **kwargs):595 """Generate resource."""596 return Dcv(**data)597class EfaSchema(BaseSchema):598 """Represent the schema of EFA for a Compute Resource."""599 enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})600 gdr_support = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})601 @post_load602 def make_resource(self, data, **kwargs):603 """Generate resource."""604 return Efa(**data)605# ---------------------- Monitoring ---------------------- #606class CloudWatchLogsSchema(BaseSchema):607 """Represent the schema of the CloudWatchLogs section."""608 enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})609 retention_in_days = fields.Int(610 validate=validate.OneOf([1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, 3653]),611 metadata={"update_policy": UpdatePolicy.SUPPORTED},612 )613 deletion_policy = fields.Str(614 validate=validate.OneOf(DELETION_POLICIES), metadata={"update_policy": UpdatePolicy.SUPPORTED}615 )616 @post_load617 def make_resource(self, data, **kwargs):618 """Generate resource."""619 return CloudWatchLogs(**data)620class CloudWatchDashboardsSchema(BaseSchema):621 """Represent the schema of the CloudWatchDashboards section."""622 enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.SUPPORTED})623 @post_load624 def make_resource(self, data, **kwargs):625 """Generate resource."""626 return CloudWatchDashboards(**data)627class LogsSchema(BaseSchema):628 """Represent the schema of the Logs section."""629 cloud_watch = fields.Nested(CloudWatchLogsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})630 @post_load631 def make_resource(self, data, **kwargs):632 """Generate resource."""633 return Logs(**data)634class DashboardsSchema(BaseSchema):635 """Represent the schema of the Dashboards section."""636 cloud_watch = fields.Nested(CloudWatchDashboardsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})637 @post_load638 def make_resource(self, data, **kwargs):639 """Generate resource."""640 return Dashboards(**data)641class MonitoringSchema(BaseSchema):642 """Represent the schema of the Monitoring section."""643 detailed_monitoring = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})644 logs = fields.Nested(LogsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})645 dashboards = fields.Nested(DashboardsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})646 @post_load647 def make_resource(self, data, **kwargs):648 """Generate resource."""649 return Monitoring(**data)650# ---------------------- Others ---------------------- #651class RolesSchema(BaseSchema):652 """Represent the schema of roles."""653 lambda_functions_role = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})654 @post_load655 def make_resource(self, data, **kwargs):656 """Generate resource."""657 return Roles(**data)658class S3AccessSchema(BaseSchema):659 """Represent the schema of S3 access."""660 bucket_name = fields.Str(661 required=True,662 metadata={"update_policy": UpdatePolicy.SUPPORTED},663 validate=validate.Regexp(r"^[\*a-z0-9\-\.]+$"),664 )665 key_name = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})666 enable_write_access = fields.Bool(metadata={"update_policy": UpdatePolicy.SUPPORTED})667 @post_load668 def make_resource(self, data, **kwargs):669 """Generate resource."""670 return S3Access(**data)671class ClusterIamSchema(BaseSchema):672 """Represent the schema of IAM for Cluster."""673 roles = fields.Nested(RolesSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})674 permissions_boundary = fields.Str(675 metadata={"update_policy": UpdatePolicy.SUPPORTED}, validate=validate.Regexp("^arn:.*:policy/")676 )677 @post_load678 def make_resource(self, data, **kwargs):679 """Generate resource."""680 return ClusterIam(**data)681class IamSchema(BaseSchema):682 """Common schema of IAM for HeadNode and Queue."""683 instance_role = fields.Str(684 metadata={"update_policy": UpdatePolicy.SUPPORTED}, validate=validate.Regexp("^arn:.*:role/")685 )686 s3_access = fields.Nested(687 S3AccessSchema, many=True, metadata={"update_policy": UpdatePolicy.SUPPORTED, "update_key": "BucketName"}688 )689 additional_iam_policies = fields.Nested(690 AdditionalIamPolicySchema, many=True, metadata={"update_policy": UpdatePolicy.SUPPORTED, "update_key": "Policy"}691 )692 @validates_schema693 def no_coexist_role_policies(self, data, **kwargs):694 """Validate that instance_role, instance_profile or additional_iam_policies do not co-exist."""695 if self.fields_coexist(data, ["instance_role", "instance_profile", "additional_iam_policies"], **kwargs):696 raise ValidationError(697 "InstanceProfile, InstanceRole or AdditionalIamPolicies can not be configured together."698 )699 @validates_schema700 def no_coexist_s3_access(self, data, **kwargs):701 """Validate that instance_role, instance_profile or additional_iam_policies do not co-exist."""702 if self.fields_coexist(data, ["instance_role", "s3_access"], **kwargs):703 raise ValidationError("S3Access can not be configured when InstanceRole is set.")704 if self.fields_coexist(data, ["instance_profile", "s3_access"], **kwargs):705 raise ValidationError("S3Access can not be configured when InstanceProfile is set.")706 @post_load707 def make_resource(self, data, **kwargs):708 """Generate resource."""709 return Iam(**data)710class HeadNodeIamSchema(IamSchema):711 """Represent the schema of IAM for HeadNode."""712 instance_profile = fields.Str(713 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp("^arn:.*:instance-profile/")714 )715class QueueIamSchema(IamSchema):716 """Represent the schema of IAM for Queue."""717 instance_profile = fields.Str(718 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP},719 validate=validate.Regexp("^arn:.*:instance-profile/"),720 )721class ImdsSchema(BaseSchema):722 """Represent the schema of IMDS for HeadNode."""723 secured = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})724 @post_load725 def make_resource(self, data, **kwargs):726 """Generate resource."""727 return Imds(**data)728class IntelSoftwareSchema(BaseSchema):729 """Represent the schema of additional packages."""730 intel_hpc_platform = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})731 @post_load732 def make_resource(self, data, **kwargs):733 """Generate resource."""734 return IntelSoftware(**data)735class AdditionalPackagesSchema(BaseSchema):736 """Represent the schema of additional packages."""737 intel_software = fields.Nested(IntelSoftwareSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})738 @post_load739 def make_resource(self, data, **kwargs):740 """Generate resource."""741 return AdditionalPackages(**data)742class AmiSearchFiltersSchema(BaseSchema):743 """Represent the schema of the AmiSearchFilters section."""744 tags = fields.Nested(745 TagSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Key"}746 )747 owner = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})748 @post_load()749 def make_resource(self, data, **kwargs):750 """Generate resource."""751 return AmiSearchFilters(**data)752class TimeoutsSchema(BaseSchema):753 """Represent the schema of the Timeouts section."""754 head_node_bootstrap_timeout = fields.Int(755 validate=validate.Range(min=1), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}756 )757 compute_node_bootstrap_timeout = fields.Int(758 validate=validate.Range(min=1), metadata={"update_policy": UpdatePolicy.SUPPORTED}759 )760 @post_load()761 def make_resource(self, data, **kwargs):762 """Generate resource."""763 return Timeouts(**data)764class ClusterDevSettingsSchema(BaseDevSettingsSchema):765 """Represent the schema of Dev Setting."""766 cluster_template = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})767 ami_search_filters = fields.Nested(AmiSearchFiltersSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})768 instance_types_data = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})769 timeouts = fields.Nested(TimeoutsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})770 @post_load771 def make_resource(self, data, **kwargs):772 """Generate resource."""773 return ClusterDevSettings(**data)774# ---------------------- Node and Cluster Schema ---------------------- #775class ImageSchema(BaseSchema):776 """Represent the schema of the Image."""777 os = fields.Str(778 required=True, validate=validate.OneOf(SUPPORTED_OSES), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}779 )780 custom_ami = fields.Str(781 validate=validate.Regexp(r"^ami-[0-9a-z]{8}$|^ami-[0-9a-z]{17}$"),782 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},783 )784 @post_load785 def make_resource(self, data, **kwargs):786 """Generate resource."""787 return Image(**data)788class HeadNodeImageSchema(BaseSchema):789 """Represent the schema of the HeadNode Image."""790 custom_ami = fields.Str(791 validate=validate.Regexp(r"^ami-[0-9a-z]{8}$|^ami-[0-9a-z]{17}$"),792 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},793 )794 @post_load795 def make_resource(self, data, **kwargs):796 """Generate resource."""797 return HeadNodeImage(**data)798class QueueImageSchema(BaseSchema):799 """Represent the schema of the Queue Image."""800 custom_ami = fields.Str(801 validate=validate.Regexp(r"^ami-[0-9a-z]{8}$|^ami-[0-9a-z]{17}$"),802 metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},803 )804 @post_load805 def make_resource(self, data, **kwargs):806 """Generate resource."""807 return QueueImage(**data)808class HeadNodeCustomActionSchema(BaseSchema):809 """Represent the schema of the custom action."""810 script = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})811 args = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})812 @post_load813 def make_resource(self, data, **kwargs):814 """Generate resource."""815 return CustomAction(**data)816class HeadNodeCustomActionsSchema(BaseSchema):817 """Represent the schema for all available custom actions."""818 on_node_start = fields.Nested(HeadNodeCustomActionSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})819 on_node_configured = fields.Nested(HeadNodeCustomActionSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})820 @post_load821 def make_resource(self, data, **kwargs):822 """Generate resource."""823 return CustomActions(**data)824class QueueCustomActionSchema(BaseSchema):825 """Represent the schema of the custom action."""826 script = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})827 args = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})828 @post_load829 def make_resource(self, data, **kwargs):830 """Generate resource."""831 return CustomAction(**data)832class QueueCustomActionsSchema(BaseSchema):833 """Represent the schema for all available custom actions."""834 on_node_start = fields.Nested(835 QueueCustomActionSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}836 )837 on_node_configured = fields.Nested(838 QueueCustomActionSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}839 )840 @post_load841 def make_resource(self, data, **kwargs):842 """Generate resource."""843 return CustomActions(**data)844class InstanceTypeSchema(BaseSchema):845 """Schema of a compute resource that supports a pool of instance types."""846 instance_type = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})847 @post_load848 def make_resource(self, data, **kwargs):849 """Generate resource."""850 return FlexibleInstanceType(**data)851class HeadNodeSchema(BaseSchema):852 """Represent the schema of the HeadNode."""853 instance_type = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})854 disable_simultaneous_multithreading = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})855 networking = fields.Nested(856 HeadNodeNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}857 )858 ssh = fields.Nested(SshSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})859 local_storage = fields.Nested(HeadNodeStorageSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})860 dcv = fields.Nested(DcvSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})861 custom_actions = fields.Nested(HeadNodeCustomActionsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})862 iam = fields.Nested(HeadNodeIamSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})863 imds = fields.Nested(ImdsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})864 image = fields.Nested(HeadNodeImageSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})865 @post_load()866 def make_resource(self, data, **kwargs):867 """Generate resource."""868 return HeadNode(**data)869class _ComputeResourceSchema(BaseSchema):870 """Represent the schema of the ComputeResource."""871 name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})872class SlurmComputeResourceSchema(_ComputeResourceSchema):873 """Represent the schema of the Slurm ComputeResource."""874 instance_type = fields.Str(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})875 instance_type_list = fields.Nested(876 InstanceTypeSchema,877 many=True,878 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP_ON_REMOVE, "update_key": "InstanceType"},879 )880 max_count = fields.Int(validate=validate.Range(min=1), metadata={"update_policy": UpdatePolicy.MAX_COUNT})881 min_count = fields.Int(validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})882 spot_price = fields.Float(883 validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}884 )885 efa = fields.Nested(EfaSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})886 disable_simultaneous_multithreading = fields.Bool(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})887 schedulable_memory = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})888 @validates_schema889 def no_coexist_instance_type_flexibility(self, data, **kwargs):890 """Validate that 'instance_type' and 'instance_type_list' do not co-exist."""891 if self.fields_coexist(892 data,893 ["instance_type", "instance_type_list"],894 one_required=True,895 **kwargs,896 ):897 raise ValidationError("A Compute Resource needs to specify either InstanceType or InstanceTypeList.")898 @post_load899 def make_resource(self, data, **kwargs):900 """Generate resource."""901 if data.get("instance_type_list"):902 return SlurmFlexibleComputeResource(**data)903 return SlurmComputeResource(**data)904class AwsBatchComputeResourceSchema(_ComputeResourceSchema):905 """Represent the schema of the Batch ComputeResource."""906 instance_types = fields.List(907 fields.Str(), required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}908 )909 max_vcpus = fields.Int(910 data_key="MaxvCpus",911 validate=validate.Range(min=1),912 metadata={"update_policy": UpdatePolicy.AWSBATCH_CE_MAX_RESIZE},913 )914 min_vcpus = fields.Int(915 data_key="MinvCpus", validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.SUPPORTED}916 )917 desired_vcpus = fields.Int(918 data_key="DesiredvCpus", validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.IGNORED}919 )920 spot_bid_percentage = fields.Int(921 validate=validate.Range(min=0, max=100, min_inclusive=False), metadata={"update_policy": UpdatePolicy.SUPPORTED}922 )923 @post_load924 def make_resource(self, data, **kwargs):925 """Generate resource."""926 return AwsBatchComputeResource(**data)927class SchedulerPluginComputeResourceSchema(SlurmComputeResourceSchema):928 """Represent the schema of the Scheduler Plugin ComputeResource."""929 custom_settings = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})930 @post_load931 def make_resource(self, data, **kwargs):932 """Generate resource."""933 return SchedulerPluginComputeResource(**data)934class ComputeSettingsSchema(BaseSchema):935 """Represent the schema of the compute_settings schedulers queues."""936 local_storage = fields.Nested(QueueStorageSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})937 @post_load()938 def make_resource(self, data, **kwargs):939 """Generate resource."""940 return ComputeSettings(**data)941class BaseQueueSchema(BaseSchema):942 """Represent the schema of the attributes in common between all the schedulers queues."""943 name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})944 capacity_type = fields.Str(945 validate=validate.OneOf([event.value for event in CapacityType]),946 metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},947 )948class _CommonQueueSchema(BaseQueueSchema):949 """Represent the schema of common part between Slurm and Scheduler Plugin Queue."""950 compute_settings = fields.Nested(951 ComputeSettingsSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}952 )953 custom_actions = fields.Nested(954 QueueCustomActionsSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}955 )956 iam = fields.Nested(QueueIamSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})957 image = fields.Nested(QueueImageSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})958class SlurmQueueSchema(_CommonQueueSchema):959 """Represent the schema of a Slurm Queue."""960 allocation_strategy = fields.Str(961 validate=validate.OneOf(["lowest-price", "capacity-optimized"]),962 metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},963 )964 compute_resources = fields.Nested(965 SlurmComputeResourceSchema,966 many=True,967 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP_ON_REMOVE, "update_key": "Name"},968 )969 networking = fields.Nested(970 SlurmQueueNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}971 )972 @post_load973 def make_resource(self, data, **kwargs):974 """Generate resource."""975 return SlurmQueue(**data)976class AwsBatchQueueSchema(BaseQueueSchema):977 """Represent the schema of a Batch Queue."""978 compute_resources = fields.Nested(979 AwsBatchComputeResourceSchema,980 many=True,981 validate=validate.Length(equal=1),982 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},983 )984 networking = fields.Nested(985 AwsBatchQueueNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}986 )987 @post_load988 def make_resource(self, data, **kwargs):989 """Generate resource."""990 return AwsBatchQueue(**data)991class SchedulerPluginQueueSchema(_CommonQueueSchema):992 """Represent the schema of a Scheduler Plugin Queue."""993 compute_resources = fields.Nested(994 SchedulerPluginComputeResourceSchema,995 many=True,996 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},997 )998 networking = fields.Nested(999 SchedulerPluginQueueNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1000 )1001 custom_settings = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1002 @post_load1003 def make_resource(self, data, **kwargs):1004 """Generate resource."""1005 return SchedulerPluginQueue(**data)1006class DnsSchema(BaseSchema):1007 """Represent the schema of Dns Settings."""1008 disable_managed_dns = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1009 hosted_zone_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1010 use_ec2_hostnames = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1011 @post_load1012 def make_resource(self, data, **kwargs):1013 """Generate resource."""1014 return Dns(**data)1015class SlurmSettingsSchema(BaseSchema):1016 """Represent the schema of the Scheduling Settings."""1017 scaledown_idletime = fields.Int(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1018 dns = fields.Nested(DnsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1019 queue_update_strategy = fields.Str(1020 validate=validate.OneOf([strategy.value for strategy in QueueUpdateStrategy]),1021 metadata={"update_policy": UpdatePolicy.IGNORED},1022 )1023 enable_memory_based_scheduling = fields.Bool(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1024 @post_load1025 def make_resource(self, data, **kwargs):1026 """Generate resource."""1027 return SlurmSettings(**data)1028class AwsBatchSettingsSchema(BaseSchema):1029 """Represent the schema of the AwsBatch Scheduling Settings."""1030 @post_load1031 def make_resource(self, data, **kwargs):1032 """Generate resource."""1033 return AwsBatchSettings(**data)1034class SchedulerPluginSupportedDistrosSchema(BaseSchema):1035 """Represent the schema for SupportedDistros in a Scheduler Plugin."""1036 x86 = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1037 arm64 = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1038 @post_load1039 def make_resource(self, data, **kwargs):1040 """Generate resource."""1041 return SchedulerPluginSupportedDistros(**data)1042class SchedulerPluginQueueConstraintsSchema(BaseSchema):1043 """Represent the schema for QueueConstraints in a Scheduler Plugin."""1044 max_count = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1045 @post_load1046 def make_resource(self, data, **kwargs):1047 """Generate resource."""1048 return SchedulerPluginQueueConstraints(**data)1049class SchedulerPluginComputeResourceConstraintsSchema(BaseSchema):1050 """Represent the schema for ComputeResourceConstraints in a Scheduler Plugin."""1051 max_count = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1052 @post_load1053 def make_resource(self, data, **kwargs):1054 """Generate resource."""1055 return SchedulerPluginComputeResourceConstraints(**data)1056class SchedulerPluginRequirementsSchema(BaseSchema):1057 """Represent the schema for Requirements in a Scheduler Plugin."""1058 supported_distros = fields.Nested(1059 SchedulerPluginSupportedDistrosSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1060 )1061 supported_regions = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1062 queue_constraints = fields.Nested(1063 SchedulerPluginQueueConstraintsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1064 )1065 compute_resource_constraints = fields.Nested(1066 SchedulerPluginComputeResourceConstraintsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1067 )1068 requires_sudo_privileges = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1069 supports_cluster_update = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1070 supported_parallel_cluster_versions = fields.Str(1071 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1072 validate=validate.Regexp(1073 r"^((>|<|>=|<=)?[0-9]+\.[0-9]+\.[0-9]+([a-z][0-9]+)?,\s*)*(>|<|>=|<=)?[0-9]+\.[0-9]+\.[0-9]+([a-z][0-9]+)?$"1074 ),1075 )1076 @post_load1077 def make_resource(self, data, **kwargs):1078 """Generate resource."""1079 return SchedulerPluginRequirements(**data)1080class SchedulerPluginCloudFormationClusterInfrastructureSchema(BaseSchema):1081 """Represent the CloudFormation section of the Scheduler Plugin ClusterInfrastructure schema."""1082 template = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1083 s3_bucket_owner = fields.Str(1084 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^\d{12}$")1085 )1086 checksum = fields.Str(1087 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[A-Fa-f0-9]{64}$")1088 )1089 @post_load1090 def make_resource(self, data, **kwargs):1091 """Generate resource."""1092 return SchedulerPluginCloudFormationInfrastructure(**data)1093class SchedulerPluginClusterInfrastructureSchema(BaseSchema):1094 """Represent the schema for ClusterInfrastructure schema in a Scheduler Plugin."""1095 cloud_formation = fields.Nested(1096 SchedulerPluginCloudFormationClusterInfrastructureSchema,1097 required=True,1098 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1099 )1100 @post_load1101 def make_resource(self, data, **kwargs):1102 """Generate resource."""1103 return SchedulerPluginClusterInfrastructure(**data)1104class SchedulerPluginClusterSharedArtifactSchema(BaseSchema):1105 """Represent the schema for Cluster Shared Artifact in a Scheduler Plugin."""1106 source = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1107 s3_bucket_owner = fields.Str(1108 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^\d{12}$")1109 )1110 checksum = fields.Str(1111 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[A-Fa-f0-9]{64}$")1112 )1113 @post_load1114 def make_resource(self, data, **kwargs):1115 """Generate resource."""1116 return SchedulerPluginClusterSharedArtifact(**data)1117class SchedulerPluginResourcesSchema(BaseSchema):1118 """Represent the schema for Plugin Resouces in a Scheduler Plugin."""1119 cluster_shared_artifacts = fields.Nested(1120 SchedulerPluginClusterSharedArtifactSchema,1121 many=True,1122 required=True,1123 metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Source"},1124 )1125 @post_load1126 def make_resource(self, data, **kwargs):1127 """Generate resource."""1128 return SchedulerPluginPluginResources(**data)1129class SchedulerPluginExecuteCommandSchema(BaseSchema):1130 """Represent the schema for ExecuteCommand in a Scheduler Plugin."""1131 command = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1132 @post_load1133 def make_resource(self, data, **kwargs):1134 """Generate resource."""1135 return SchedulerPluginExecuteCommand(**data)1136class SchedulerPluginEventSchema(BaseSchema):1137 """Represent the schema for Event in a Scheduler Plugin."""1138 execute_command = fields.Nested(1139 SchedulerPluginExecuteCommandSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1140 )1141 @post_load1142 def make_resource(self, data, **kwargs):1143 """Generate resource."""1144 return SchedulerPluginEvent(**data)1145class SchedulerPluginEventsSchema(BaseSchema):1146 """Represent the schema for Events in a Scheduler Plugin."""1147 head_init = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1148 head_configure = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1149 head_finalize = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1150 compute_init = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1151 compute_configure = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1152 compute_finalize = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1153 head_cluster_update = fields.Nested(1154 SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1155 )1156 head_compute_fleet_update = fields.Nested(1157 SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1158 )1159 @post_load1160 def make_resource(self, data, **kwargs):1161 """Generate resource."""1162 return SchedulerPluginEvents(**data)1163class SchedulerPluginFileSchema(BaseSchema):1164 """Represent the schema of the Scheduler Plugin."""1165 file_path = fields.Str(1166 required=True, validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1167 )1168 timestamp_format = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1169 node_type = fields.Str(1170 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.OneOf(["HEAD", "COMPUTE", "ALL"])1171 )1172 log_stream_name = fields.Str(1173 required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[^:*]*$")1174 )1175 @post_load1176 def make_resource(self, data, **kwargs):1177 """Generate resource."""1178 return SchedulerPluginFile(**data)1179class SchedulerPluginLogsSchema(BaseSchema):1180 """Represent the schema of the Scheduler Plugin Logs."""1181 files = fields.Nested(1182 SchedulerPluginFileSchema,1183 required=True,1184 many=True,1185 metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "FilePath"},1186 )1187 @post_load1188 def make_resource(self, data, **kwargs):1189 """Generate resource."""1190 return SchedulerPluginLogs(**data)1191class SchedulerPluginMonitoringSchema(BaseSchema):1192 """Represent the schema of the Scheduler plugin Monitoring."""1193 logs = fields.Nested(SchedulerPluginLogsSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1194 @post_load1195 def make_resource(self, data, **kwargs):1196 """Generate resource."""1197 return SchedulerPluginMonitoring(**data)1198class SudoerConfigurationSchema(BaseSchema):1199 """Represent the SudoerConfiguration for scheduler plugin SystemUsers declared in the SchedulerDefinition."""1200 commands = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1201 run_as = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1202 @post_load1203 def make_resource(self, data, **kwargs):1204 """Generate resource."""1205 return SudoerConfiguration(**data)1206class SchedulerPluginUserSchema(BaseSchema):1207 """Represent the schema of the Scheduler Plugin."""1208 name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1209 enable_imds = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1210 sudoer_configuration = fields.Nested(1211 SudoerConfigurationSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Name"}1212 )1213 @post_load1214 def make_resource(self, data, **kwargs):1215 """Generate resource."""1216 return SchedulerPluginUser(**data)1217class SchedulerPluginDefinitionSchema(BaseSchema):1218 """Represent the schema of the Scheduler Plugin SchedulerDefinition."""1219 plugin_interface_version = fields.Str(1220 required=True,1221 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1222 validate=validate.Regexp(r"^[0-9]+\.[0-9]+$"),1223 )1224 metadata = fields.Dict(metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, required=True)1225 requirements = fields.Nested(1226 SchedulerPluginRequirementsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1227 )1228 cluster_infrastructure = fields.Nested(1229 SchedulerPluginClusterInfrastructureSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1230 )1231 plugin_resources = fields.Nested(1232 SchedulerPluginResourcesSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1233 )1234 events = fields.Nested(1235 SchedulerPluginEventsSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1236 )1237 monitoring = fields.Nested(SchedulerPluginMonitoringSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1238 system_users = fields.Nested(1239 SchedulerPluginUserSchema,1240 many=True,1241 validate=validate.Length(max=SCHEDULER_PLUGIN_MAX_NUMBER_OF_USERS),1242 metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Name"},1243 )1244 tags = fields.Nested(1245 TagSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Key"}1246 )1247 @post_load1248 def make_resource(self, data, **kwargs):1249 """Generate resource."""1250 return SchedulerPluginDefinition(**data)1251 @validates("metadata")1252 def validate_metadata(self, value):1253 """Validate metadata contains fieds 'name' and 'version'."""1254 for key in ["Name", "Version"]:1255 if key not in value.keys():1256 raise ValidationError(f"{key} is required for scheduler plugin Metadata.")1257class SchedulerPluginSettingsSchema(BaseSchema):1258 """Represent the schema of the Scheduling Settings."""1259 scheduler_definition = fields.Nested(1260 SchedulerPluginDefinitionSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1261 )1262 grant_sudo_privileges = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1263 custom_settings = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1264 scheduler_definition_s3_bucket_owner = fields.Str(1265 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^\d{12}$")1266 )1267 scheduler_definition_checksum = fields.Str(1268 metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[A-Fa-f0-9]{64}$")1269 )1270 def _verify_checksum(self, file_content, original_definition, expected_checksum):1271 if expected_checksum:1272 actual_checksum = hashlib.sha256(file_content.encode()).hexdigest()1273 if actual_checksum != expected_checksum:1274 raise ValidationError(1275 f"Error when validating SchedulerDefinition '{original_definition}': "1276 f"checksum ({actual_checksum}) does not match expected one ({expected_checksum})"1277 )1278 def _fetch_scheduler_definition_from_s3(self, original_scheduler_definition, s3_bucket_owner):1279 try:1280 bucket_parsing_result = parse_bucket_url(original_scheduler_definition)1281 result = AWSApi.instance().s3.get_object(1282 bucket_name=bucket_parsing_result["bucket_name"],1283 key=bucket_parsing_result["object_key"],1284 expected_bucket_owner=s3_bucket_owner,1285 )1286 scheduler_definition = result["Body"].read().decode("utf-8")1287 return scheduler_definition1288 except AWSClientError as e:1289 error_message = (1290 f"Error while downloading scheduler definition from {original_scheduler_definition}: {str(e)}"1291 )1292 if s3_bucket_owner and e.error_code == "AccessDenied":1293 error_message = (1294 f"{error_message}. This can be due to bucket owner not matching the expected "1295 f"one '{s3_bucket_owner}'"1296 )1297 raise ValidationError(error_message) from e1298 except Exception as e:1299 raise ValidationError(1300 f"Error while downloading scheduler definition from {original_scheduler_definition}: {str(e)}"1301 ) from e1302 def _fetch_scheduler_definition_from_https(self, original_scheduler_definition):1303 try:1304 with urlopen(original_scheduler_definition) as f: # nosec nosemgrep1305 scheduler_definition = f.read().decode("utf-8")1306 return scheduler_definition1307 except Exception:1308 error_message = (1309 f"Error while downloading scheduler definition from {original_scheduler_definition}: "1310 "The provided URL is invalid or unavailable."1311 )1312 raise ValidationError(error_message)1313 def _validate_scheduler_definition_url(self, original_scheduler_definition, s3_bucket_owner):1314 """Validate SchedulerDefinition url is valid."""1315 if not original_scheduler_definition.startswith("s3") and not original_scheduler_definition.startswith("https"):1316 raise ValidationError(1317 f"Error while downloading scheduler definition from {original_scheduler_definition}: The provided value"1318 " for SchedulerDefinition is invalid. You can specify this as an S3 URL, HTTPS URL or as an inline "1319 "YAML object."1320 )1321 if original_scheduler_definition.startswith("https") and s3_bucket_owner:1322 raise ValidationError(1323 f"Error while downloading scheduler definition from {original_scheduler_definition}: "1324 "SchedulerDefinitionS3BucketOwner can only be specified when SchedulerDefinition is S3 URL."1325 )1326 def _fetch_scheduler_definition_from_url(1327 self, original_scheduler_definition, s3_bucket_owner, scheduler_definition_checksum, data1328 ):1329 LOGGER.info("Downloading scheduler plugin definition from %s", original_scheduler_definition)1330 if original_scheduler_definition.startswith("s3"):1331 scheduler_definition = self._fetch_scheduler_definition_from_s3(1332 original_scheduler_definition, s3_bucket_owner1333 )1334 elif original_scheduler_definition.startswith("https"):1335 scheduler_definition = self._fetch_scheduler_definition_from_https(original_scheduler_definition)1336 self._verify_checksum(scheduler_definition, original_scheduler_definition, scheduler_definition_checksum)1337 LOGGER.info("Using the following scheduler plugin definition:\n%s", scheduler_definition)1338 try:1339 data["SchedulerDefinition"] = yaml.safe_load(scheduler_definition)1340 except YAMLError as e:1341 raise ValidationError(1342 f"The retrieved SchedulerDefinition ({original_scheduler_definition}) is not a valid YAML."1343 ) from e1344 @post_load1345 def make_resource(self, data, **kwargs):1346 """Generate resource."""1347 return SchedulerPluginSettings(**data)1348 @pre_load1349 def fetch_scheduler_definition(self, data, **kwargs):1350 """Fetch scheduler definition if it is s3 or https url."""1351 original_scheduler_definition = data["SchedulerDefinition"]1352 s3_bucket_owner = data.get("SchedulerDefinitionS3BucketOwner", None)1353 scheduler_definition_checksum = data.get("SchedulerDefinitionChecksum", None)1354 if isinstance(original_scheduler_definition, str):1355 self._validate_scheduler_definition_url(original_scheduler_definition, s3_bucket_owner)1356 self._fetch_scheduler_definition_from_url(1357 original_scheduler_definition, s3_bucket_owner, scheduler_definition_checksum, data1358 )1359 elif s3_bucket_owner or scheduler_definition_checksum:1360 raise ValidationError(1361 "SchedulerDefinitionS3BucketOwner or SchedulerDefinitionChecksum can only specified when "1362 "SchedulerDefinition is a URL."1363 )1364 return data1365class SchedulingSchema(BaseSchema):1366 """Represent the schema of the Scheduling."""1367 scheduler = fields.Str(1368 required=True,1369 validate=validate.OneOf(["slurm", "awsbatch", "plugin"]),1370 metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1371 )1372 # Slurm schema1373 slurm_settings = fields.Nested(SlurmSettingsSchema, metadata={"update_policy": UpdatePolicy.IGNORED})1374 slurm_queues = fields.Nested(1375 SlurmQueueSchema,1376 many=True,1377 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP_ON_REMOVE, "update_key": "Name"},1378 )1379 # Awsbatch schema:1380 aws_batch_queues = fields.Nested(1381 AwsBatchQueueSchema,1382 many=True,1383 validate=validate.Length(equal=1),1384 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},1385 )1386 aws_batch_settings = fields.Nested(1387 AwsBatchSettingsSchema, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1388 )1389 # Scheduler Plugin1390 scheduler_settings = fields.Nested(1391 SchedulerPluginSettingsSchema, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1392 )1393 scheduler_queues = fields.Nested(1394 SchedulerPluginQueueSchema,1395 many=True,1396 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},1397 )1398 @validates_schema1399 def no_coexist_schedulers(self, data, **kwargs):1400 """Validate that *_settings and *_queues for different schedulers do not co-exist."""1401 scheduler = data.get("scheduler")1402 if self.fields_coexist(data, ["aws_batch_settings", "slurm_settings", "scheduler_settings"], **kwargs):1403 raise ValidationError("Multiple *Settings sections cannot be specified in the Scheduling section.")1404 if self.fields_coexist(1405 data, ["aws_batch_queues", "slurm_queues", "scheduler_queues"], one_required=True, **kwargs1406 ):1407 if scheduler == "awsbatch":1408 scheduler_prefix = "AwsBatch"1409 elif scheduler == "plugin":1410 scheduler_prefix = "Scheduler"1411 else:1412 scheduler_prefix = scheduler.capitalize()1413 raise ValidationError(f"{scheduler_prefix}Queues section must be specified in the Scheduling section.")1414 @validates_schema1415 def right_scheduler_schema(self, data, **kwargs):1416 """Validate that *_settings field is associated to the right scheduler."""1417 for scheduler, settings, queues in [1418 ("awsbatch", "aws_batch_settings", "aws_batch_queues"),1419 ("slurm", "slurm_settings", "slurm_queues"),1420 ("plugin", "scheduler_settings", "scheduler_queues"),1421 ]:1422 # Verify the settings section is associated to the right scheduler type1423 configured_scheduler = data.get("scheduler")1424 if settings in data and scheduler != configured_scheduler:1425 raise ValidationError(1426 f"Scheduling > *Settings section is not appropriate to the Scheduler: {configured_scheduler}."1427 )1428 if queues in data and scheduler != configured_scheduler:1429 raise ValidationError(1430 f"Scheduling > *Queues section is not appropriate to the Scheduler: {configured_scheduler}."1431 )1432 @validates_schema1433 def same_subnet_in_different_queues(self, data, **kwargs):1434 """Validate subnet_ids configured in different queues are the same."""1435 if "slurm_queues" in data or "scheduler_queues" in data:1436 queues = "slurm_queues" if "slurm_queues" in data else "scheduler_queues"1437 def _queue_has_subnet_ids(queue):1438 return queue.networking and queue.networking.subnet_ids1439 subnet_ids = {tuple(set(q.networking.subnet_ids)) for q in data[queues] if _queue_has_subnet_ids(q)}1440 if len(subnet_ids) > 1:1441 raise ValidationError("The SubnetIds used for all of the queues should be the same.")1442 @post_load1443 def make_resource(self, data, **kwargs):1444 """Generate the right type of scheduling according to the child type (Slurm vs AwsBatch vs Custom)."""1445 scheduler = data.get("scheduler")1446 if scheduler == "slurm":1447 return SlurmScheduling(queues=data.get("slurm_queues"), settings=data.get("slurm_settings", None))1448 if scheduler == "plugin":1449 return SchedulerPluginScheduling(1450 queues=data.get("scheduler_queues"), settings=data.get("scheduler_settings", None)1451 )1452 if scheduler == "awsbatch":1453 return AwsBatchScheduling(1454 queues=data.get("aws_batch_queues"), settings=data.get("aws_batch_settings", None)1455 )1456 return None1457 @pre_dump1458 def restore_child(self, data, **kwargs):1459 """Restore back the child in the schema, see post_load action."""1460 adapted_data = copy.deepcopy(data)1461 if adapted_data.scheduler == "awsbatch":1462 scheduler_prefix = "aws_batch"1463 elif adapted_data.scheduler == "plugin":1464 scheduler_prefix = "scheduler"1465 else:1466 scheduler_prefix = adapted_data.scheduler1467 setattr(adapted_data, f"{scheduler_prefix}_queues", copy.copy(getattr(adapted_data, "queues", None)))1468 setattr(adapted_data, f"{scheduler_prefix}_settings", copy.copy(getattr(adapted_data, "settings", None)))1469 return adapted_data1470class DirectoryServiceSchema(BaseSchema):1471 """Represent the schema of the DirectoryService."""1472 domain_name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1473 domain_addr = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1474 password_secret_arn = fields.Str(1475 required=True,1476 validate=validate.Regexp(r"^arn:.*:secret"),1477 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP},1478 )1479 domain_read_only_user = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1480 ldap_tls_ca_cert = fields.Str(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1481 ldap_tls_req_cert = fields.Str(1482 validate=validate.OneOf(["never", "allow", "try", "demand", "hard"]),1483 metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP},1484 )1485 ldap_access_filter = fields.Str(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1486 generate_ssh_keys_for_users = fields.Bool(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1487 additional_sssd_configs = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1488 @post_load1489 def make_resource(self, data, **kwargs):1490 """Generate resource."""1491 return DirectoryService(**data)1492class ClusterSchema(BaseSchema):1493 """Represent the schema of the Cluster."""1494 image = fields.Nested(ImageSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1495 head_node = fields.Nested(HeadNodeSchema, required=True, metadata={"update_policy": UpdatePolicy.SUPPORTED})1496 scheduling = fields.Nested(SchedulingSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1497 shared_storage = fields.Nested(1498 SharedStorageSchema,1499 many=True,1500 metadata={1501 "update_policy": UpdatePolicy(1502 UpdatePolicy.UNSUPPORTED, fail_reason=UpdatePolicy.FAIL_REASONS["shared_storage_change"]1503 ),1504 "update_key": "Name",1505 },1506 )1507 monitoring = fields.Nested(MonitoringSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})1508 additional_packages = fields.Nested(AdditionalPackagesSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1509 tags = fields.Nested(1510 TagSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Key"}1511 )1512 iam = fields.Nested(ClusterIamSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})1513 directory_service = fields.Nested(1514 DirectoryServiceSchema, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1515 )1516 config_region = fields.Str(data_key="Region", metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1517 custom_s3_bucket = fields.Str(metadata={"update_policy": UpdatePolicy.READ_ONLY_RESOURCE_BUCKET})1518 additional_resources = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})1519 dev_settings = fields.Nested(ClusterDevSettingsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})1520 def __init__(self, cluster_name: str):1521 super().__init__()1522 self.cluster_name = cluster_name1523 @validates("tags")1524 def validate_tags(self, tags):1525 """Validate tags."""1526 validate_no_reserved_tag(tags)1527 @validates_schema1528 def no_settings_for_batch(self, data, **kwargs):1529 """Ensure IntelSoftware and DirectoryService section is not included when AWS Batch is the scheduler."""1530 scheduling = data.get("scheduling")1531 if scheduling and scheduling.scheduler == "awsbatch":1532 error_message = "The use of the {} configuration is not supported when using awsbatch as the scheduler."1533 additional_packages = data.get("additional_packages")1534 if additional_packages and additional_packages.intel_software.intel_hpc_platform:1535 raise ValidationError(error_message.format("IntelSoftware"))1536 if data.get("directory_service"):1537 raise ValidationError(error_message.format("DirectoryService"))1538 @post_load(pass_original=True)1539 def make_resource(self, data, original_data, **kwargs):1540 """Generate cluster according to the scheduler. Save original configuration."""1541 scheduler = data.get("scheduling").scheduler1542 if scheduler == "slurm":1543 cluster = SlurmClusterConfig(cluster_name=self.cluster_name, **data)1544 elif scheduler == "awsbatch":1545 cluster = AwsBatchClusterConfig(cluster_name=self.cluster_name, **data)1546 elif scheduler == "plugin":1547 cluster = SchedulerPluginClusterConfig(cluster_name=self.cluster_name, **data)1548 else:1549 raise ValidationError(f"Unsupported scheduler {scheduler}.")1550 cluster.source_config = original_data...

Full Screen

Full Screen

maddpg.py

Source:maddpg.py Github

copy

Full Screen

1import torch2import numpy as np3import copy4import itertools5from offpolicy.utils.util import huber_loss, mse_loss, to_torch6from offpolicy.utils.popart import PopArt7from offpolicy.algorithms.base.trainer import Trainer8class MADDPG(Trainer):9 def __init__(self, args, num_agents, policies, policy_mapping_fn, device=None, actor_update_interval=1):10 """11 Trainer class for MADDPG. See parent class for more information.12 :param actor_update_interval: (int) number of critic updates to perform between every update to the actor.13 """14 self.args = args15 self.use_popart = self.args.use_popart16 self.use_value_active_masks = self.args.use_value_active_masks17 self.use_per = self.args.use_per18 self.per_eps = self.args.per_eps19 self.use_huber_loss = self.args.use_huber_loss20 self.huber_delta = self.args.huber_delta21 self.tpdv = dict(dtype=torch.float32, device=device)22 self.num_agents = num_agents23 self.policies = policies24 self.policy_mapping_fn = policy_mapping_fn25 self.policy_ids = sorted(list(self.policies.keys()))26 self.policy_agents = {policy_id: sorted(27 [agent_id for agent_id in range(self.num_agents) if self.policy_mapping_fn(agent_id) == policy_id]) for policy_id in28 self.policies.keys()}29 if self.use_popart:30 self.value_normalizer = {policy_id: PopArt(1) for policy_id in self.policies.keys()}31 self.num_updates = {p_id : 0 for p_id in self.policy_ids}32 self.use_same_share_obs = self.args.use_same_share_obs33 self.actor_update_interval = actor_update_interval34 # @profile35 def get_update_info(self, update_policy_id, obs_batch, act_batch, nobs_batch, navail_act_batch):36 """37 Form centralized observation and action info for current and next timestep.38 :param update_policy_id: (str) id of policy being updated.39 :param obs_batch: (np.ndarray) batch of observation sequences sampled from buffer.40 :param act_batch: (np.ndarray) batch of action sequences sampled from buffer.41 :param avail_act_batch: (np.ndarray) batch of available action sequences sampled from buffer. None if environment does not limit actions.42 :return cent_act: (list) list of action sequences corresponding to each agent.43 :return replace_ind_start: (int) index of act_sequences from which to replace actions for actor update.44 :return cent_nact: (np.ndarray) batch of centralize next step actions.45 """46 cent_act = []47 cent_nact = []48 replace_ind_start = None49 # iterate through policies to get the target acts and other centralized info50 ind = 051 for p_id in self.policy_ids:52 batch_size = obs_batch[p_id].shape[1]53 policy = self.policies[p_id]54 if p_id == update_policy_id:55 replace_ind_start = ind56 num_pol_agents = len(self.policy_agents[p_id])57 cent_act.append(list(act_batch[p_id]))58 combined_nobs_batch = np.concatenate(nobs_batch[p_id], axis=0)59 if navail_act_batch[p_id] is not None:60 combined_navail_act_batch = np.concatenate(navail_act_batch[p_id], axis=0)61 else:62 combined_navail_act_batch = None63 # use target actor to get next step actions64 with torch.no_grad():65 pol_nact, _ = policy.get_actions(combined_nobs_batch, combined_navail_act_batch, use_target=True)66 ind_agent_nacts = pol_nact.cpu().split(split_size=batch_size, dim=0)67 # cat to form the centralized next step actions68 cent_nact.append(torch.cat(ind_agent_nacts, dim=-1))69 ind += num_pol_agents70 cent_act = list(itertools.chain.from_iterable(cent_act))71 cent_nact = np.concatenate(cent_nact, axis=-1)72 return cent_act, replace_ind_start, cent_nact73 def train_policy_on_batch(self, update_policy_id, batch):74 """See parent class."""75 if self.use_same_share_obs:76 return self.shared_train_policy_on_batch(update_policy_id, batch)77 else:78 return self.cent_train_policy_on_batch(update_policy_id, batch)79 def shared_train_policy_on_batch(self, update_policy_id, batch):80 """Training function when all agents share the same centralized observation. See train_policy_on_batch."""81 obs_batch, cent_obs_batch, \82 act_batch, rew_batch, \83 nobs_batch, cent_nobs_batch, \84 dones_batch, dones_env_batch, valid_transition_batch,\85 avail_act_batch, navail_act_batch, \86 importance_weights, idxes = batch87 train_info = {}88 update_actor = self.num_updates[update_policy_id] % self.actor_update_interval == 089 cent_act, replace_ind_start, cent_nact = self.get_update_info(update_policy_id, obs_batch, act_batch, nobs_batch, navail_act_batch)90 cent_obs = cent_obs_batch[update_policy_id]91 cent_nobs = cent_nobs_batch[update_policy_id]92 rewards = rew_batch[update_policy_id][0]93 dones_env = dones_env_batch[update_policy_id]94 update_policy = self.policies[update_policy_id]95 batch_size = cent_obs.shape[0]96 # critic update97 with torch.no_grad():98 next_step_Qs = update_policy.target_critic(cent_nobs, cent_nact)99 next_step_Q = torch.cat(next_step_Qs, dim=-1)100 # take min to prevent overestimation bias101 next_step_Q, _ = torch.min(next_step_Q, dim=-1, keepdim=True)102 rewards = to_torch(rewards).to(**self.tpdv).view(-1, 1)103 dones_env = to_torch(dones_env).to(**self.tpdv).view(-1, 1)104 if self.use_popart:105 target_Qs = rewards + self.args.gamma * (1 - dones_env) * self.value_normalizer[p_id].denormalize(next_step_Q)106 target_Qs = self.value_normalizer[p_id](target_Qs)107 else:108 target_Qs = rewards + self.args.gamma * (1 - dones_env) * next_step_Q109 predicted_Qs = update_policy.critic(cent_obs, np.concatenate(cent_act, axis=-1))110 update_policy.critic_optimizer.zero_grad()111 # detach the targets112 errors = [target_Qs.detach() - predicted_Q for predicted_Q in predicted_Qs]113 if self.use_per:114 importance_weights = to_torch(importance_weights).to(**self.tpdv)115 if self.use_huber_loss:116 critic_loss = [huber_loss(error, self.huber_delta).flatten() for error in errors]117 else:118 critic_loss = [mse_loss(error).flatten() for error in errors]119 # weight each loss element by their importance sample weight120 critic_loss = [(loss * importance_weights).mean() for loss in critic_loss]121 critic_loss = torch.stack(critic_loss).sum(dim=0)122 # new priorities are TD error123 new_priorities = np.stack([error.abs().cpu().detach().numpy().flatten() for error in errors]).mean(axis=0) + self.per_eps124 else:125 if self.use_huber_loss:126 critic_loss = [huber_loss(error, self.huber_delta).mean() for error in errors]127 else:128 critic_loss = [mse_loss(error).mean() for error in errors]129 critic_loss = torch.stack(critic_loss).sum(dim=0)130 new_priorities = None131 critic_loss.backward()132 critic_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.critic.parameters(),133 self.args.max_grad_norm)134 update_policy.critic_optimizer.step()135 train_info['critic_loss'] = critic_loss136 train_info['critic_grad_norm'] = critic_grad_norm137 138 if update_actor:139 # actor update140 # need to zero the critic gradient and the actor gradient since the gradients first flow through critic before getting to actor during backprop141 # freeze Q-networks142 for p in update_policy.critic.parameters():143 p.requires_grad = False144 num_update_agents = len(self.policy_agents[update_policy_id])145 mask_temp = []146 for p_id in self.policy_ids:147 if isinstance(self.policies[p_id].act_dim, np.ndarray):148 # multidiscrete case149 sum_act_dim = int(sum(self.policies[p_id].act_dim))150 else:151 sum_act_dim = self.policies[p_id].act_dim152 for _ in self.policy_agents[p_id]:153 mask_temp.append(np.zeros(sum_act_dim, dtype=np.float32))154 masks = []155 valid_trans_mask = []156 # need to iterate through agents, but only formulate masks at each step157 for i in range(num_update_agents):158 curr_mask_temp = copy.deepcopy(mask_temp)159 # set the mask to 1 at locations where the action should come from the actor output160 if isinstance(update_policy.act_dim, np.ndarray):161 # multidiscrete case162 sum_act_dim = int(sum(update_policy.act_dim))163 else:164 sum_act_dim = update_policy.act_dim165 curr_mask_temp[replace_ind_start + i] = np.ones(sum_act_dim, dtype=np.float32)166 curr_mask_vec = np.concatenate(curr_mask_temp)167 # expand this mask into the proper size168 curr_mask = np.tile(curr_mask_vec, (batch_size, 1))169 masks.append(curr_mask)170 # agent valid transitions171 agent_valid_trans_batch = to_torch(valid_transition_batch[update_policy_id][i]).to(**self.tpdv)172 valid_trans_mask.append(agent_valid_trans_batch)173 # cat to form into tensors174 mask = to_torch(np.concatenate(masks)).to(**self.tpdv)175 valid_trans_mask = torch.cat(valid_trans_mask, dim=0)176 pol_agents_obs_batch = np.concatenate(obs_batch[update_policy_id], axis=0)177 if avail_act_batch[update_policy_id] is not None:178 pol_agents_avail_act_batch = np.concatenate(avail_act_batch[update_policy_id], axis=0)179 else:180 pol_agents_avail_act_batch = None181 # get all actions from actor182 pol_acts, _ = update_policy.get_actions(pol_agents_obs_batch, pol_agents_avail_act_batch, use_gumbel=True)183 # separate into individual agent batches184 agent_actor_batches = pol_acts.split(split_size=batch_size, dim=0)185 cent_act = list(map(lambda arr: to_torch(arr).to(**self.tpdv), cent_act))186 actor_cent_acts = copy.deepcopy(cent_act)187 for i in range(num_update_agents):188 actor_cent_acts[replace_ind_start + i] = agent_actor_batches[i]189 actor_cent_acts = torch.cat(actor_cent_acts, dim=-1).repeat((num_update_agents, 1))190 # convert buffer acts to torch, formulate centralized buffer action and repeat as done above191 buffer_cent_acts = torch.cat(cent_act, dim=-1).repeat(num_update_agents, 1)192 # also repeat cent obs193 stacked_cent_obs = np.tile(cent_obs, (num_update_agents, 1))194 # combine the buffer cent acts with actor cent acts and pass into buffer195 actor_update_cent_acts = mask * actor_cent_acts + (1 - mask) * buffer_cent_acts196 actor_Qs = update_policy.critic(stacked_cent_obs, actor_update_cent_acts)197 # use only the first Q output for actor loss198 actor_Qs = actor_Qs[0]199 actor_Qs = actor_Qs * valid_trans_mask200 actor_loss = -(actor_Qs).sum() / (valid_trans_mask).sum()201 update_policy.critic_optimizer.zero_grad()202 update_policy.actor_optimizer.zero_grad()203 actor_loss.backward()204 actor_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.actor.parameters(),205 self.args.max_grad_norm)206 update_policy.actor_optimizer.step()207 for p in update_policy.critic.parameters():208 p.requires_grad = True209 train_info['actor_loss'] = actor_loss210 train_info['actor_grad_norm'] = actor_grad_norm211 train_info['update_actor'] = update_actor212 return train_info, new_priorities, idxes213 def cent_train_policy_on_batch(self, update_policy_id, batch):214 """Training function when each agent has its own centralized observation. See train_policy_on_batch."""215 obs_batch, cent_obs_batch, \216 act_batch, rew_batch, \217 nobs_batch, cent_nobs_batch, \218 dones_batch, dones_env_batch, valid_transition_batch,\219 avail_act_batch, navail_act_batch, \220 importance_weights, idxes = batch221 train_info = {}222 update_actor = self.num_updates[update_policy_id] % self.actor_update_interval == 0223 cent_act, replace_ind_start, cent_nact = self.get_update_info(224 update_policy_id, obs_batch, act_batch, nobs_batch, navail_act_batch)225 cent_obs = cent_obs_batch[update_policy_id]226 cent_nobs = cent_nobs_batch[update_policy_id]227 rewards = rew_batch[update_policy_id][0]228 dones_env = dones_env_batch[update_policy_id]229 dones = dones_batch[update_policy_id]230 valid_trans = valid_transition_batch[update_policy_id]231 update_policy = self.policies[update_policy_id]232 batch_size = obs_batch[update_policy_id].shape[1]233 num_update_agents = len(self.policy_agents[update_policy_id])234 all_agent_cent_obs = np.concatenate(cent_obs, axis=0)235 all_agent_cent_nobs = np.concatenate(cent_nobs, axis=0)236 # since this is the same for each agent, just repeat when stacking237 cent_act_buffer = np.concatenate(cent_act, axis=-1)238 all_agent_cent_act_buffer = np.tile(cent_act_buffer, (num_update_agents, 1))239 all_agent_cent_nact = np.tile(cent_nact, (num_update_agents, 1))240 all_env_dones = np.tile(dones_env, (num_update_agents, 1))241 all_agent_rewards = np.tile(rewards, (num_update_agents, 1))242 # critic update243 update_policy.critic_optimizer.zero_grad()244 all_agent_rewards = to_torch(all_agent_rewards).to(**self.tpdv).reshape(-1, 1)245 all_env_dones = to_torch(all_env_dones).to(**self.tpdv).reshape(-1, 1)246 all_agent_valid_trans = to_torch(valid_trans).to(**self.tpdv).reshape(-1, 1)247 # critic update248 with torch.no_grad():249 next_step_Q = update_policy.target_critic(all_agent_cent_nobs, all_agent_cent_nact).reshape(-1, 1)250 if self.use_popart:251 target_Qs = all_agent_rewards + self.args.gamma * (1 - all_env_dones) * \252 self.value_normalizer[p_id].denormalize(next_step_Q)253 target_Qs = self.value_normalizer[p_id](target_Qs)254 else:255 target_Qs = all_agent_rewards + self.args.gamma * (1 - all_env_dones) * next_step_Q256 predicted_Qs = update_policy.critic(all_agent_cent_obs, all_agent_cent_act_buffer).reshape(-1, 1)257 error = target_Qs.detach() - predicted_Qs258 if self.use_per:259 agent_importance_weights = np.tile(importance_weights, num_update_agents)260 agent_importance_weights = to_torch(agent_importance_weights).to(**self.tpdv)261 if self.use_huber_loss:262 critic_loss = huber_loss(error, self.huber_delta).flatten()263 else:264 critic_loss = mse_loss(error).flatten()265 # weight each loss element by their importance sample weight266 critic_loss = critic_loss * agent_importance_weights267 if self.use_value_active_masks:268 critic_loss = (critic_loss.view(-1, 1) * (all_agent_valid_trans)).sum() / (all_agent_valid_trans).sum()269 else:270 critic_loss = critic_loss.mean()271 # new priorities are TD error272 agent_new_priorities = error.abs().cpu().detach().numpy().flatten()273 new_priorities = np.mean(np.split(agent_new_priorities, num_update_agents), axis=0) + self.per_eps274 else:275 if self.use_huber_loss:276 critic_loss = huber_loss(error, self.huber_delta)277 else:278 critic_loss = mse_loss(error)279 if self.use_value_active_masks:280 critic_loss = (critic_loss * (all_agent_valid_trans)).sum() / (all_agent_valid_trans).sum()281 else:282 critic_loss = critic_loss.mean()283 new_priorities = None284 critic_loss.backward()285 critic_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.critic.parameters(),286 self.args.max_grad_norm)287 update_policy.critic_optimizer.step()288 train_info['critic_loss'] = critic_loss289 train_info['critic_grad_norm'] = critic_grad_norm290 # actor update291 if update_actor:292 for p in update_policy.critic.parameters():293 p.requires_grad = False294 num_update_agents = len(self.policy_agents[update_policy_id])295 mask_temp = []296 for p_id in self.policy_ids:297 if isinstance(self.policies[p_id].act_dim, np.ndarray):298 # multidiscrete case299 sum_act_dim = int(sum(self.policies[p_id].act_dim))300 else:301 sum_act_dim = self.policies[p_id].act_dim302 for _ in self.policy_agents[p_id]:303 mask_temp.append(np.zeros(sum_act_dim, dtype=np.float32))304 masks = []305 valid_trans_mask = []306 # need to iterate through agents, but only formulate masks at each step307 for i in range(num_update_agents):308 curr_mask_temp = copy.deepcopy(mask_temp)309 # set the mask to 1 at locations where the action should come from the actor output310 if isinstance(update_policy.act_dim, np.ndarray):311 # multidiscrete case312 sum_act_dim = int(sum(update_policy.act_dim))313 else:314 sum_act_dim = update_policy.act_dim315 curr_mask_temp[replace_ind_start + i] = np.ones(sum_act_dim, dtype=np.float32)316 curr_mask_vec = np.concatenate(curr_mask_temp)317 # expand this mask into the proper size318 curr_mask = np.tile(curr_mask_vec, (batch_size, 1))319 masks.append(curr_mask)320 # agent valid transitions321 agent_valid_trans_batch = to_torch(valid_transition_batch[update_policy_id][i]).to(**self.tpdv)322 valid_trans_mask.append(agent_valid_trans_batch)323 # cat to form into tensors324 mask = to_torch(np.concatenate(masks)).to(**self.tpdv)325 valid_trans_mask = torch.cat(valid_trans_mask, dim=0)326 pol_agents_obs_batch = np.concatenate(obs_batch[update_policy_id], axis=0)327 if avail_act_batch[update_policy_id] is not None:328 pol_agents_avail_act_batch = np.concatenate(avail_act_batch[update_policy_id], axis=0)329 else:330 pol_agents_avail_act_batch = None331 # get all actions from actor332 pol_acts, _ = update_policy.get_actions(pol_agents_obs_batch, pol_agents_avail_act_batch, use_gumbel=True)333 # separate into individual agent batches334 agent_actor_batches = pol_acts.split(split_size=batch_size, dim=0)335 # cat along final dim to formulate centralized action and stack copies of the batch336 cent_act = list(map(lambda arr: to_torch(arr).to(**self.tpdv), cent_act))337 actor_cent_acts = copy.deepcopy(cent_act)338 for i in range(num_update_agents):339 actor_cent_acts[replace_ind_start + i] = agent_actor_batches[i]340 actor_cent_acts = torch.cat(actor_cent_acts, dim=-1).repeat((num_update_agents, 1))341 # combine the buffer cent acts with actor cent acts and pass into buffer342 actor_update_cent_acts = mask * actor_cent_acts + (1 - mask) * to_torch(all_agent_cent_act_buffer).to(**self.tpdv)343 actor_Qs = update_policy.critic(all_agent_cent_obs, actor_update_cent_acts)344 # actor_loss = -actor_Qs.mean()345 actor_Qs = actor_Qs * valid_trans_mask346 actor_loss = -(actor_Qs).sum() / (valid_trans_mask).sum()347 update_policy.critic_optimizer.zero_grad()348 update_policy.actor_optimizer.zero_grad()349 actor_loss.backward()350 actor_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.actor.parameters(),351 self.args.max_grad_norm)352 update_policy.actor_optimizer.step()353 for p in update_policy.critic.parameters():354 p.requires_grad = True355 train_info['actor_loss'] = actor_loss356 train_info['actor_grad_norm'] = actor_grad_norm357 return train_info, new_priorities, idxes358 def prep_training(self):359 """See parent class."""360 for policy in self.policies.values():361 policy.actor.train()362 policy.critic.train()363 policy.target_actor.train()364 policy.target_critic.train()365 def prep_rollout(self):366 """See parent class."""367 for policy in self.policies.values():368 policy.actor.eval()369 policy.critic.eval()370 policy.target_actor.eval()...

Full Screen

Full Screen

fasttext.py

Source:fasttext.py Github

copy

Full Screen

1import sys2# Python 2/3 compatibility3if sys.version_info.major==3:4 xrange=range5sys.path.insert(0,'./util')6from py2py3 import *7import numpy as np8import tensorflow as tf9sys.path.insert(0,'./model')10from __init__ import *11class fastText(base_net):12 def __init__(self, hyper_params):13 '''14 >>> Construct a fastText model15 >>> hyper_params: dict, a dictionary containing all hyper parameters16 >>> batch_size: int, batch size17 >>> sequence_length: int, maximum sentence length18 >>> class_num: int, number of categories19 >>> vocab_size: int, vocabulary size20 >>> embedding_dim: int, dimension of word embeddings21 >>> update_policy: dict, update policy22 >>> embedding_matrix: optional, numpy.array, initial embedding matrix23 >>> embedding_trainable: optional, bool, whether or not the embedding is trainable, default is true24 '''25 self.batch_size=hyper_params['batch_size']26 self.sequence_length=hyper_params['sequence_length']27 self.class_num=hyper_params['class_num']28 self.vocab_size=hyper_params['vocab_size']29 self.embedding_dim=hyper_params['embedding_dim']30 self.update_policy=hyper_params['update_policy']31 self.embedding_trainable=hyper_params['embedding_trainable'] if 'embedding_trainable' in hyper_params else True32 self.grad_clip_norm=hyper_params['grad_clip_norm'] if 'grad_clip_norm' in hyper_params else 1.033 self.name='fast_text model' if not 'name' in hyper_params else hyper_params['name']34 self.sess=None35 with tf.variable_scope('fastText'):36 if not 'embedding_matrix' in hyper_params:37 print('Word embeddings are initialized from scrach')38 self.embedding_matrix=tf.get_variable('embedding_matrix', shape=[self.vocab_size, self.embedding_dim],39 initializer=tf.random_uniform_initializer(-1.0,1.0), dtype=tf.float32)40 else:41 print('Pre-trained word embeddings are imported')42 embedding_value=tf.constant(hyper_params['embedding_matrix'], dtype=tf.float32)43 self.embedding_matrix=tf.get_variable('embedding_matrix', initializer=embedding_value, dtype=tf.float32)44 self.inputs=tf.placeholder(tf.int32,shape=[self.batch_size, self.sequence_length])45 self.masks=tf.placeholder(tf.int32,shape=[self.batch_size, self.sequence_length])46 self.labels=tf.placeholder(tf.int32,shape=[self.batch_size,])47 self.embedding_output=tf.nn.embedding_lookup(self.embedding_matrix,self.inputs) # of shape [self.batch_size, self.sequence_length, self.embedding_dim]48 embedding_sum=tf.reduce_sum(self.embedding_output,axis=1) # of shape [self.batch_size, self.embedding_dim]49 mask_sum=tf.reduce_sum(self.masks,axis=1) # of shape [self.batch_size,]50 mask_sum=tf.expand_dims(mask_sum,axis=-1) # of shape [self.batch_size, 1]51 self.sentence_embedding=tf.div(embedding_sum, tf.cast(mask_sum, dtype=tf.float32)) # broadcasting mask_sum, the embedding of padded token must be zero52 # Construct softmax classifier53 with tf.variable_scope('fastText'):54 W=tf.get_variable(name='w',shape=[self.embedding_dim,self.class_num],55 initializer=tf.truncated_normal_initializer(stddev=0.5))56 b=tf.get_variable(name='b',shape=[self.class_num,],57 initializer=tf.truncated_normal_initializer(stddev=0.05))58 output=tf.add(tf.matmul(self.sentence_embedding,W),b) # of shape [self.batch_size, self.class_num], unnormalized output probability distribution59 # Outputs60 self.probability=tf.nn.softmax(output) # of shape [self.batch_size, self.class_num], normalized output probability distribution61 self.prediction=tf.argmax(self.probability,axis=1) # of shape [self.batch_size], discrete prediction62 self.loss=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output,labels=self.labels)) # loss value63 # Construct Optimizer64 if self.update_policy['name'].lower() in ['sgd', 'stochastic gradient descent']:65 learning_rate=self.update_policy['learning_rate']66 momentum=0.0 if not 'momentum' in self.update_policy else self.update_policy['momentum']67 self.optimizer=tf.train.MomentumOptimizer(learning_rate, momentum)68 elif self.update_policy['name'].lower() in ['adagrad',]:69 learning_rate=self.update_policy['learning_rate']70 initial_accumulator_value=0.1 if not 'initial_accumulator_value' in self.update_policy \71 else self.update_policy['initial_accumulator_value']72 self.optimizer=tf.train.AdagradOptimizer(learning_rate, initial_accumulator_value)73 elif self.update_policy['name'].lower() in ['adadelta']:74 learning_rate=self.update_policy['learning_rate']75 rho=0.95 if not 'rho' in self.update_policy else self.update_policy['rho']76 epsilon=1e-8 if not 'epsilon' in self.update_policy else self.update_policy['epsilon']77 self.optimizer=tf.train.AdadeltaOptimizer(learning_rate, rho, epsilon)78 elif self.update_policy['name'].lower() in ['rms', 'rmsprop']:79 learning_rate=self.update_policy['learning_rate']80 decay=0.9 if not 'decay' in self.update_policy else self.update_policy['decay']81 momentum=0.0 if not 'momentum' in self.update_policy else self.update_policy['momentum']82 epsilon=1e-10 if not 'epsilon' in self.update_policy else self.update_policy['epsilon']83 self.optimizer=tf.train.RMSPropOptimizer(learning_rate, decay, momentum, epsilon)84 elif self.update_policy['name'].lower() in ['adam']:85 learning_rate=self.update_policy['learning_rate']86 beta1=0.9 if not 'beta1' in self.update_policy else self.update_policy['beta1']87 beta2=0.999 if not 'beta2' in self.update_policy else self.update_policy['beta2']88 epsilon=1e-8 if not 'epsilon' in self.update_policy else self.update_policy['epsilon']89 self.optimizer=tf.train.AdamOptimizer(learning_rate, beta1, beta2, epsilon)90 else:91 raise ValueError('Unrecognized Optimizer Category: %s'%self.update_policy['name'])92 print('gradient clip is applied, max = %.2f'%self.grad_clip_norm)93 gradients=self.optimizer.compute_gradients(self.loss)94 clipped_gradients=[(tf.clip_by_value(grad,-self.grad_clip_norm,self.grad_clip_norm),var) for grad,var in gradients]95 self.update=self.optimizer.apply_gradients(clipped_gradients)96 def train_validate_test_init(self):97 '''98 >>> Initialize the training validation and test phrase99 '''100 self.sess=tf.Session()101 self.sess.run(tf.global_variables_initializer())102 def train(self,inputs,masks,labels):103 '''104 >>> Training phrase105 '''106 train_dict={self.inputs:inputs,self.masks:masks,self.labels:labels}107 self.sess.run(self.update,feed_dict=train_dict)108 prediction_this_batch, loss_this_batch=self.sess.run([self.prediction,self.loss],feed_dict=train_dict)109 return prediction_this_batch, loss_this_batch110 def validate(self,inputs,masks,labels):111 '''112 >>> Validation phrase113 '''114 validate_dict={self.inputs:inputs,self.masks:masks,self.labels:labels}115 prediction_this_batch, loss_this_batch=self.sess.run([self.prediction,self.loss],feed_dict=validate_dict)116 return prediction_this_batch, loss_this_batch117 def test(self,inputs,masks,fine_tune=False):118 '''119 >>> Test phrase120 '''121 test_dict={self.inputs:inputs,self.masks:masks}122 if fine_tune==False:123 prediction_this_batch,=self.sess.run([self.prediction,],feed_dict=test_dict)124 else:125 prediction_this_batch,=self.sess.run([self.probability,],feed_dict=test_dict)126 return prediction_this_batch127 def do_summarization(self, file_list, folder2store, data_generator, n_top=5):128 '''129 >>> Not implemented130 '''131 raise NotImplementedError('Current model is fasttext, where "do_summarization" function is not implemented')132 def dump_params(self,file2dump):133 '''134 >>> Save the parameters135 >>> file2dump: str, file to store the parameters136 '''137 saver=tf.train.Saver()138 saved_path=saver.save(self.sess, file2dump)139 print('parameters are saved in file %s'%saved_path)140 def load_params(self,file2load, loaded_params=[]):141 '''142 >>> Load the parameters143 >>> file2load: str, file to load the parameters144 '''145 param2load=[]146 for var in tf.global_variables():147 if not var in loaded_params:148 param2load.append(var)149 saver=tf.train.Saver(param2load)150 saver.restore(self.sess, file2load)151 print('parameters are imported from file %s'%file2load)152 def train_validate_test_end(self):153 '''154 >>> End current training validation and test phrase155 '''156 self.sess.close()...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run tempest automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful