Best Python code snippet using tempest_python
cluster_schema.py
Source:cluster_schema.py  
1# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance4# with the License. A copy of the License is located at5#6# http://aws.amazon.com/apache2.0/7#8# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES9# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and10# limitations under the License.11#12# This module contains all the classes representing the Schema of the configuration file.13# These classes are created by following marshmallow syntax.14#15import copy16import hashlib17import logging18import re19from urllib.request import urlopen20import yaml21from marshmallow import ValidationError, fields, post_load, pre_dump, pre_load, validate, validates, validates_schema22from yaml import YAMLError23from pcluster.aws.aws_api import AWSApi24from pcluster.aws.common import AWSClientError25from pcluster.config.cluster_config import (26    AdditionalPackages,27    AmiSearchFilters,28    AwsBatchClusterConfig,29    AwsBatchComputeResource,30    AwsBatchQueue,31    AwsBatchQueueNetworking,32    AwsBatchScheduling,33    AwsBatchSettings,34    CapacityType,35    CloudWatchDashboards,36    CloudWatchLogs,37    ClusterDevSettings,38    ClusterIam,39    ComputeSettings,40    CustomAction,41    CustomActions,42    Dashboards,43    Dcv,44    DirectoryService,45    Dns,46    Efa,47    EphemeralVolume,48    ExistingFsxOntap,49    ExistingFsxOpenZfs,50    FlexibleInstanceType,51    HeadNode,52    HeadNodeImage,53    HeadNodeNetworking,54    Iam,55    Image,56    Imds,57    IntelSoftware,58    LocalStorage,59    Logs,60    Monitoring,61    PlacementGroup,62    Proxy,63    QueueImage,64    QueueUpdateStrategy,65    Raid,66    Roles,67    RootVolume,68    S3Access,69    SchedulerPluginCloudFormationInfrastructure,70    SchedulerPluginClusterConfig,71    SchedulerPluginClusterInfrastructure,72    SchedulerPluginClusterSharedArtifact,73    SchedulerPluginComputeResource,74    SchedulerPluginComputeResourceConstraints,75    SchedulerPluginDefinition,76    SchedulerPluginEvent,77    SchedulerPluginEvents,78    SchedulerPluginExecuteCommand,79    SchedulerPluginFile,80    SchedulerPluginLogs,81    SchedulerPluginMonitoring,82    SchedulerPluginPluginResources,83    SchedulerPluginQueue,84    SchedulerPluginQueueConstraints,85    SchedulerPluginQueueNetworking,86    SchedulerPluginRequirements,87    SchedulerPluginScheduling,88    SchedulerPluginSettings,89    SchedulerPluginSupportedDistros,90    SchedulerPluginUser,91    SharedEbs,92    SharedEfs,93    SharedFsxLustre,94    SlurmClusterConfig,95    SlurmComputeResource,96    SlurmFlexibleComputeResource,97    SlurmQueue,98    SlurmQueueNetworking,99    SlurmScheduling,100    SlurmSettings,101    Ssh,102    SudoerConfiguration,103    Timeouts,104)105from pcluster.config.update_policy import UpdatePolicy106from pcluster.constants import (107    DELETION_POLICIES,108    DELETION_POLICIES_WITH_SNAPSHOT,109    FSX_LUSTRE,110    FSX_ONTAP,111    FSX_OPENZFS,112    FSX_VOLUME_ID_REGEX,113    LUSTRE,114    ONTAP,115    OPENZFS,116    SCHEDULER_PLUGIN_MAX_NUMBER_OF_USERS,117    SUPPORTED_OSES,118)119from pcluster.models.s3_bucket import parse_bucket_url120from pcluster.schemas.common_schema import (121    AdditionalIamPolicySchema,122    BaseDevSettingsSchema,123    BaseSchema,124    TagSchema,125    get_field_validator,126    validate_no_reserved_tag,127)128from pcluster.validators.cluster_validators import EFS_MESSAGES, FSX_MESSAGES129# pylint: disable=C0302130LOGGER = logging.getLogger(__name__)131# ---------------------- Storage ---------------------- #132class HeadNodeRootVolumeSchema(BaseSchema):133    """Represent the RootVolume schema for the Head node."""134    volume_type = fields.Str(135        validate=get_field_validator("volume_type"),136        metadata={137            "update_policy": UpdatePolicy(138                UpdatePolicy.UNSUPPORTED, action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"]139            )140        },141    )142    iops = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})143    size = fields.Int(144        metadata={145            "update_policy": UpdatePolicy(146                UpdatePolicy.UNSUPPORTED,147                fail_reason=UpdatePolicy.FAIL_REASONS["ebs_volume_resize"],148                action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"],149            )150        }151    )152    throughput = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})153    encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})154    delete_on_termination = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})155    @post_load156    def make_resource(self, data, **kwargs):157        """Generate resource."""158        return RootVolume(**data)159class QueueRootVolumeSchema(BaseSchema):160    """Represent the RootVolume schema for the queue."""161    size = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})162    encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})163    volume_type = fields.Str(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})164    iops = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})165    throughput = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})166    @post_load167    def make_resource(self, data, **kwargs):168        """Generate resource."""169        return RootVolume(**data)170class RaidSchema(BaseSchema):171    """Represent the schema of the parameters specific to Raid. It is a child of EBS schema."""172    raid_type = fields.Int(173        required=True,174        data_key="Type",175        validate=validate.OneOf([0, 1]),176        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},177    )178    number_of_volumes = fields.Int(179        validate=validate.Range(min=2, max=5), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}180    )181    @post_load182    def make_resource(self, data, **kwargs):183        """Generate resource."""184        return Raid(**data)185class EbsSettingsSchema(BaseSchema):186    """Represent the schema of EBS."""187    volume_type = fields.Str(188        validate=get_field_validator("volume_type"),189        metadata={190            "update_policy": UpdatePolicy(191                UpdatePolicy.UNSUPPORTED, action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"]192            )193        },194    )195    iops = fields.Int(metadata={"update_policy": UpdatePolicy.SUPPORTED})196    size = fields.Int(197        metadata={198            "update_policy": UpdatePolicy(199                UpdatePolicy.UNSUPPORTED,200                fail_reason=UpdatePolicy.FAIL_REASONS["ebs_volume_resize"],201                action_needed=UpdatePolicy.ACTIONS_NEEDED["ebs_volume_update"],202            )203        }204    )205    kms_key_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})206    throughput = fields.Int(metadata={"update_policy": UpdatePolicy.SUPPORTED})207    encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})208    snapshot_id = fields.Str(209        validate=validate.Regexp(r"^snap-[0-9a-z]{8}$|^snap-[0-9a-z]{17}$"),210        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},211    )212    volume_id = fields.Str(213        validate=validate.Regexp(r"^vol-[0-9a-z]{8}$|^vol-[0-9a-z]{17}$"),214        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},215    )216    raid = fields.Nested(RaidSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})217    deletion_policy = fields.Str(218        validate=validate.OneOf(DELETION_POLICIES_WITH_SNAPSHOT), metadata={"update_policy": UpdatePolicy.SUPPORTED}219    )220class HeadNodeEphemeralVolumeSchema(BaseSchema):221    """Represent the schema of ephemeral volume.It is a child of storage schema."""222    mount_dir = fields.Str(223        validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}224    )225    @post_load226    def make_resource(self, data, **kwargs):227        """Generate resource."""228        return EphemeralVolume(**data)229class QueueEphemeralVolumeSchema(BaseSchema):230    """Represent the schema of ephemeral volume.It is a child of storage schema."""231    mount_dir = fields.Str(232        validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}233    )234    @post_load235    def make_resource(self, data, **kwargs):236        """Generate resource."""237        return EphemeralVolume(**data)238class HeadNodeStorageSchema(BaseSchema):239    """Represent the schema of storage attached to a node."""240    root_volume = fields.Nested(HeadNodeRootVolumeSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})241    ephemeral_volume = fields.Nested(242        HeadNodeEphemeralVolumeSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}243    )244    @post_load245    def make_resource(self, data, **kwargs):246        """Generate resource."""247        return LocalStorage(**data)248class QueueStorageSchema(BaseSchema):249    """Represent the schema of storage attached to a node."""250    root_volume = fields.Nested(QueueRootVolumeSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})251    ephemeral_volume = fields.Nested(252        QueueEphemeralVolumeSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}253    )254    @post_load255    def make_resource(self, data, **kwargs):256        """Generate resource."""257        return LocalStorage(**data)258class EfsSettingsSchema(BaseSchema):259    """Represent the EFS schema."""260    encrypted = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})261    kms_key_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})262    performance_mode = fields.Str(263        validate=validate.OneOf(["generalPurpose", "maxIO"]), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}264    )265    throughput_mode = fields.Str(266        validate=validate.OneOf(["provisioned", "bursting"]), metadata={"update_policy": UpdatePolicy.SUPPORTED}267    )268    provisioned_throughput = fields.Int(269        validate=validate.Range(min=1, max=1024), metadata={"update_policy": UpdatePolicy.SUPPORTED}270    )271    file_system_id = fields.Str(272        validate=validate.Regexp(r"^fs-[0-9a-z]{8}$|^fs-[0-9a-z]{17}$"),273        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},274    )275    deletion_policy = fields.Str(276        validate=validate.OneOf(DELETION_POLICIES), metadata={"update_policy": UpdatePolicy.SUPPORTED}277    )278    @validates_schema279    def validate_file_system_id_ignored_parameters(self, data, **kwargs):280        """Return errors for parameters in the Efs config section that would be ignored."""281        # If file_system_id is specified, all parameters are ignored.282        messages = []283        if data.get("file_system_id") is not None:284            for key in data:285                if key is not None and key != "file_system_id":286                    messages.append(EFS_MESSAGES["errors"]["ignored_param_with_efs_fs_id"].format(efs_param=key))287            if messages:288                raise ValidationError(message=messages)289    @validates_schema290    def validate_existence_of_mode_throughput(self, data, **kwargs):291        """Validate the conditional existence requirement between throughput_mode and provisioned_throughput."""292        if kwargs.get("partial"):293            # If the schema is to be loaded partially, do not check existence constrain.294            return295        throughput_mode = data.get("throughput_mode")296        provisioned_throughput = data.get("provisioned_throughput")297        if throughput_mode != "provisioned" and provisioned_throughput:298            raise ValidationError(299                message="When specifying provisioned throughput, the throughput mode must be set to provisioned",300                field_name="ThroughputMode",301            )302        if throughput_mode == "provisioned" and not provisioned_throughput:303            raise ValidationError(304                message="When specifying throughput mode to provisioned,"305                " the provisioned throughput option must be specified",306                field_name="ProvisionedThroughput",307            )308class FsxLustreSettingsSchema(BaseSchema):309    """Represent the FSX schema."""310    storage_capacity = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})311    deployment_type = fields.Str(312        validate=validate.OneOf(["SCRATCH_1", "SCRATCH_2", "PERSISTENT_1", "PERSISTENT_2"]),313        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},314    )315    imported_file_chunk_size = fields.Int(316        validate=validate.Range(min=1, max=512000, error="has a minimum size of 1 MiB, and max size of 512,000 MiB"),317        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},318    )319    export_path = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})320    import_path = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})321    weekly_maintenance_start_time = fields.Str(322        validate=validate.Regexp(r"^[1-7]:([01]\d|2[0-3]):([0-5]\d)$"),323        metadata={"update_policy": UpdatePolicy.SUPPORTED},324    )325    automatic_backup_retention_days = fields.Int(326        validate=validate.Range(min=0, max=35), metadata={"update_policy": UpdatePolicy.SUPPORTED}327    )328    copy_tags_to_backups = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})329    daily_automatic_backup_start_time = fields.Str(330        validate=validate.Regexp(r"^([01]\d|2[0-3]):([0-5]\d)$"), metadata={"update_policy": UpdatePolicy.SUPPORTED}331    )332    per_unit_storage_throughput = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})333    backup_id = fields.Str(334        validate=validate.Regexp("^(backup-[0-9a-f]{8,})$"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}335    )336    kms_key_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})337    file_system_id = fields.Str(338        validate=validate.Regexp(r"^fs-[0-9a-z]{17}$"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}339    )340    auto_import_policy = fields.Str(341        validate=validate.OneOf(["NEW", "NEW_CHANGED", "NEW_CHANGED_DELETED"]),342        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},343    )344    drive_cache_type = fields.Str(345        validate=validate.OneOf(["READ"]), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}346    )347    data_compression_type = fields.Str(348        validate=validate.OneOf(["LZ4"]), metadata={"update_policy": UpdatePolicy.SUPPORTED}349    )350    fsx_storage_type = fields.Str(351        data_key="StorageType",352        validate=validate.OneOf(["HDD", "SSD"]),353        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},354    )355    deletion_policy = fields.Str(356        validate=validate.OneOf(DELETION_POLICIES), metadata={"update_policy": UpdatePolicy.SUPPORTED}357    )358    @validates_schema359    def validate_file_system_id_ignored_parameters(self, data, **kwargs):360        """Return errors for parameters in the FSx config section that would be ignored."""361        # If file_system_id is specified, all parameters are ignored.362        messages = []363        if data.get("file_system_id") is not None:364            for key in data:365                if key is not None and key != "file_system_id":366                    messages.append(FSX_MESSAGES["errors"]["ignored_param_with_fsx_fs_id"].format(fsx_param=key))367            if messages:368                raise ValidationError(message=messages)369    @validates_schema370    def validate_backup_id_unsupported_parameters(self, data, **kwargs):371        """Return errors for parameters in the FSx config section that would be ignored."""372        # If file_system_id is specified, all parameters are ignored.373        messages = []374        if data.get("backup_id") is not None:375            unsupported_config_param_names = [376                "deployment_type",377                "per_unit_storage_throughput",378                "storage_capacity",379                "import_path",380                "export_path",381                "imported_file_chunk_size",382                "kms_key_id",383            ]384            for key in data:385                if key in unsupported_config_param_names:386                    messages.append(FSX_MESSAGES["errors"]["unsupported_backup_param"].format(name=key))387            if messages:388                raise ValidationError(message=messages)389class FsxOpenZfsSettingsSchema(BaseSchema):390    """Represent the FSX OpenZFS schema."""391    volume_id = fields.Str(392        required=True,393        validate=validate.Regexp(FSX_VOLUME_ID_REGEX),394        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},395    )396class FsxOntapSettingsSchema(BaseSchema):397    """Represent the FSX Ontap schema."""398    volume_id = fields.Str(399        required=True,400        validate=validate.Regexp(FSX_VOLUME_ID_REGEX),401        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},402    )403class SharedStorageSchema(BaseSchema):404    """Represent the generic SharedStorage schema."""405    mount_dir = fields.Str(406        required=True, validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}407    )408    name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})409    storage_type = fields.Str(410        required=True,411        validate=validate.OneOf(["Ebs", FSX_LUSTRE, FSX_OPENZFS, FSX_ONTAP, "Efs"]),412        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},413    )414    ebs_settings = fields.Nested(EbsSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})415    efs_settings = fields.Nested(EfsSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})416    fsx_lustre_settings = fields.Nested(FsxLustreSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})417    fsx_open_zfs_settings = fields.Nested(418        FsxOpenZfsSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}419    )420    fsx_ontap_settings = fields.Nested(FsxOntapSettingsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})421    @validates_schema422    def no_coexist_storage_settings(self, data, **kwargs):423        """Validate that *_settings for different storage types do not co-exist."""424        if self.fields_coexist(425            data,426            ["ebs_settings", "efs_settings", "fsx_lustre_settings", "fsx_open_zfs_settings", "fsx_ontap_settings"],427            **kwargs,428        ):429            raise ValidationError("Multiple *Settings sections cannot be specified in the SharedStorage items.")430    @validates_schema431    def right_storage_settings(self, data, **kwargs):432        """Validate that *_settings param is associated to the right storage type."""433        for storage_type, settings in [434            ("Ebs", "ebs_settings"),435            ("Efs", "efs_settings"),436            (FSX_LUSTRE, "fsx_lustre_settings"),437            (FSX_OPENZFS, "fsx_open_zfs_settings"),438            (FSX_ONTAP, "fsx_ontap_settings"),439        ]:440            # Verify the settings section is associated to the right storage type441            if data.get(settings, None) and storage_type != data.get("storage_type"):442                raise ValidationError(443                    "SharedStorage > *Settings section is not appropriate to the "444                    f"StorageType {data.get('storage_type')}."445                )446    @post_load447    def make_resource(self, data, **kwargs):448        """Generate the right type of shared storage according to the child type (EBS vs EFS vs FsxLustre)."""449        storage_type = data.get("storage_type")450        shared_volume_attributes = {"mount_dir": data.get("mount_dir"), "name": data.get("name")}451        settings = (452            data.get("efs_settings", None)453            or data.get("ebs_settings", None)454            or data.get("fsx_lustre_settings", None)455            or data.get("fsx_open_zfs_settings", None)456            or data.get("fsx_ontap_settings", None)457        )458        if settings:459            shared_volume_attributes.update(**settings)460        if storage_type == "Efs":461            return SharedEfs(**shared_volume_attributes)462        elif storage_type == "Ebs":463            return SharedEbs(**shared_volume_attributes)464        elif storage_type == FSX_LUSTRE:465            return SharedFsxLustre(**shared_volume_attributes)466        elif storage_type == FSX_OPENZFS:467            return ExistingFsxOpenZfs(**shared_volume_attributes)468        elif storage_type == FSX_ONTAP:469            return ExistingFsxOntap(**shared_volume_attributes)470        return None471    @pre_dump472    def restore_child(self, data, **kwargs):473        """Restore back the child in the schema."""474        adapted_data = copy.deepcopy(data)475        # Move SharedXxx as a child to be automatically managed by marshmallow, see post_load action476        if adapted_data.shared_storage_type == "efs":477            storage_type = "efs"478        elif adapted_data.shared_storage_type == "fsx":479            mapping = {LUSTRE: "fsx_lustre", OPENZFS: "fsx_open_zfs", ONTAP: "fsx_ontap"}480            storage_type = mapping.get(adapted_data.file_system_type)481        else:  # "raid", "ebs"482            storage_type = "ebs"483        setattr(adapted_data, f"{storage_type}_settings", copy.copy(adapted_data))484        # Restore storage type attribute485        if adapted_data.shared_storage_type == "fsx":486            mapping = {LUSTRE: FSX_LUSTRE, OPENZFS: FSX_OPENZFS, ONTAP: FSX_ONTAP}487            adapted_data.storage_type = mapping.get(adapted_data.file_system_type)488        else:489            adapted_data.storage_type = storage_type.capitalize()490        return adapted_data491    @validates("mount_dir")492    def shared_dir_validator(self, value):493        """Validate that user is not specifying /NONE or NONE as shared_dir for any filesystem."""494        # FIXME: pcluster2 doesn't allow "^/?NONE$" mount dir to avoid an ambiguity in cookbook.495        #  We should change cookbook to solve the ambiguity and allow "^/?NONE$" for mount dir496        #  Cookbook location to be modified:497        #  https://github.com/aws/aws-parallelcluster-cookbook/blob/develop/recipes/head_node_base_config.rb#L51498        if re.match("^/?NONE$", value):499            raise ValidationError(f"{value} cannot be used as a shared directory")500# ---------------------- Networking ---------------------- #501class HeadNodeProxySchema(BaseSchema):502    """Represent the schema of proxy for the Head node."""503    http_proxy_address = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})504    @post_load505    def make_resource(self, data, **kwargs):506        """Generate resource."""507        return Proxy(**data)508class QueueProxySchema(BaseSchema):509    """Represent the schema of proxy for a queue."""510    http_proxy_address = fields.Str(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})511    @post_load512    def make_resource(self, data, **kwargs):513        """Generate resource."""514        return Proxy(**data)515class BaseNetworkingSchema(BaseSchema):516    """Represent the schema of common networking parameters used by head and compute nodes."""517    additional_security_groups = fields.List(518        fields.Str(validate=get_field_validator("security_group_id")),519        metadata={"update_policy": UpdatePolicy.SUPPORTED},520    )521    security_groups = fields.List(522        fields.Str(validate=get_field_validator("security_group_id")),523        metadata={"update_policy": UpdatePolicy.SUPPORTED},524    )525    @validates_schema526    def no_coexist_security_groups(self, data, **kwargs):527        """Validate that security_groups and additional_security_groups do not co-exist."""528        if self.fields_coexist(data, ["security_groups", "additional_security_groups"], **kwargs):529            raise ValidationError("SecurityGroups and AdditionalSecurityGroups can not be configured together.")530class HeadNodeNetworkingSchema(BaseNetworkingSchema):531    """Represent the schema of the Networking, child of the HeadNode."""532    subnet_id = fields.Str(533        required=True, validate=get_field_validator("subnet_id"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}534    )535    elastic_ip = fields.Raw(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})536    proxy = fields.Nested(HeadNodeProxySchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})537    @post_load538    def make_resource(self, data, **kwargs):539        """Generate resource."""540        return HeadNodeNetworking(**data)541class PlacementGroupSchema(BaseSchema):542    """Represent the schema of placement group."""543    enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})544    id = fields.Str(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})545    @post_load546    def make_resource(self, data, **kwargs):547        """Generate resource."""548        return PlacementGroup(**data)549class QueueNetworkingSchema(BaseNetworkingSchema):550    """Represent the schema of the Networking, child of Queue."""551    subnet_ids = fields.List(552        fields.Str(validate=get_field_validator("subnet_id")),553        required=True,554        validate=validate.Length(equal=1),555        metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},556    )557    assign_public_ip = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})558class SlurmQueueNetworkingSchema(QueueNetworkingSchema):559    """Represent the schema of the Networking, child of slurm Queue."""560    placement_group = fields.Nested(561        PlacementGroupSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}562    )563    proxy = fields.Nested(QueueProxySchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})564    @post_load565    def make_resource(self, data, **kwargs):566        """Generate resource."""567        return SlurmQueueNetworking(**data)568class AwsBatchQueueNetworkingSchema(QueueNetworkingSchema):569    """Represent the schema of the Networking, child of aws batch Queue."""570    @post_load571    def make_resource(self, data, **kwargs):572        """Generate resource."""573        return AwsBatchQueueNetworking(**data)574class SchedulerPluginQueueNetworkingSchema(SlurmQueueNetworkingSchema):575    """Represent the schema of the Networking, child of Scheduler Plugin Queue."""576    @post_load577    def make_resource(self, data, **kwargs):578        """Generate resource."""579        return SchedulerPluginQueueNetworking(**data)580class SshSchema(BaseSchema):581    """Represent the schema of the SSH."""582    key_name = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})583    allowed_ips = fields.Str(validate=get_field_validator("cidr"), metadata={"update_policy": UpdatePolicy.SUPPORTED})584    @post_load585    def make_resource(self, data, **kwargs):586        """Generate resource."""587        return Ssh(**data)588class DcvSchema(BaseSchema):589    """Represent the schema of DCV."""590    enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})591    port = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})592    allowed_ips = fields.Str(validate=get_field_validator("cidr"), metadata={"update_policy": UpdatePolicy.SUPPORTED})593    @post_load594    def make_resource(self, data, **kwargs):595        """Generate resource."""596        return Dcv(**data)597class EfaSchema(BaseSchema):598    """Represent the schema of EFA for a Compute Resource."""599    enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})600    gdr_support = fields.Bool(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})601    @post_load602    def make_resource(self, data, **kwargs):603        """Generate resource."""604        return Efa(**data)605# ---------------------- Monitoring ---------------------- #606class CloudWatchLogsSchema(BaseSchema):607    """Represent the schema of the CloudWatchLogs section."""608    enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})609    retention_in_days = fields.Int(610        validate=validate.OneOf([1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, 3653]),611        metadata={"update_policy": UpdatePolicy.SUPPORTED},612    )613    deletion_policy = fields.Str(614        validate=validate.OneOf(DELETION_POLICIES), metadata={"update_policy": UpdatePolicy.SUPPORTED}615    )616    @post_load617    def make_resource(self, data, **kwargs):618        """Generate resource."""619        return CloudWatchLogs(**data)620class CloudWatchDashboardsSchema(BaseSchema):621    """Represent the schema of the CloudWatchDashboards section."""622    enabled = fields.Bool(metadata={"update_policy": UpdatePolicy.SUPPORTED})623    @post_load624    def make_resource(self, data, **kwargs):625        """Generate resource."""626        return CloudWatchDashboards(**data)627class LogsSchema(BaseSchema):628    """Represent the schema of the Logs section."""629    cloud_watch = fields.Nested(CloudWatchLogsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})630    @post_load631    def make_resource(self, data, **kwargs):632        """Generate resource."""633        return Logs(**data)634class DashboardsSchema(BaseSchema):635    """Represent the schema of the Dashboards section."""636    cloud_watch = fields.Nested(CloudWatchDashboardsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})637    @post_load638    def make_resource(self, data, **kwargs):639        """Generate resource."""640        return Dashboards(**data)641class MonitoringSchema(BaseSchema):642    """Represent the schema of the Monitoring section."""643    detailed_monitoring = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})644    logs = fields.Nested(LogsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})645    dashboards = fields.Nested(DashboardsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})646    @post_load647    def make_resource(self, data, **kwargs):648        """Generate resource."""649        return Monitoring(**data)650# ---------------------- Others ---------------------- #651class RolesSchema(BaseSchema):652    """Represent the schema of roles."""653    lambda_functions_role = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})654    @post_load655    def make_resource(self, data, **kwargs):656        """Generate resource."""657        return Roles(**data)658class S3AccessSchema(BaseSchema):659    """Represent the schema of S3 access."""660    bucket_name = fields.Str(661        required=True,662        metadata={"update_policy": UpdatePolicy.SUPPORTED},663        validate=validate.Regexp(r"^[\*a-z0-9\-\.]+$"),664    )665    key_name = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})666    enable_write_access = fields.Bool(metadata={"update_policy": UpdatePolicy.SUPPORTED})667    @post_load668    def make_resource(self, data, **kwargs):669        """Generate resource."""670        return S3Access(**data)671class ClusterIamSchema(BaseSchema):672    """Represent the schema of IAM for Cluster."""673    roles = fields.Nested(RolesSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})674    permissions_boundary = fields.Str(675        metadata={"update_policy": UpdatePolicy.SUPPORTED}, validate=validate.Regexp("^arn:.*:policy/")676    )677    @post_load678    def make_resource(self, data, **kwargs):679        """Generate resource."""680        return ClusterIam(**data)681class IamSchema(BaseSchema):682    """Common schema of IAM for HeadNode and Queue."""683    instance_role = fields.Str(684        metadata={"update_policy": UpdatePolicy.SUPPORTED}, validate=validate.Regexp("^arn:.*:role/")685    )686    s3_access = fields.Nested(687        S3AccessSchema, many=True, metadata={"update_policy": UpdatePolicy.SUPPORTED, "update_key": "BucketName"}688    )689    additional_iam_policies = fields.Nested(690        AdditionalIamPolicySchema, many=True, metadata={"update_policy": UpdatePolicy.SUPPORTED, "update_key": "Policy"}691    )692    @validates_schema693    def no_coexist_role_policies(self, data, **kwargs):694        """Validate that instance_role, instance_profile or additional_iam_policies do not co-exist."""695        if self.fields_coexist(data, ["instance_role", "instance_profile", "additional_iam_policies"], **kwargs):696            raise ValidationError(697                "InstanceProfile, InstanceRole or AdditionalIamPolicies can not be configured together."698            )699    @validates_schema700    def no_coexist_s3_access(self, data, **kwargs):701        """Validate that instance_role, instance_profile or additional_iam_policies do not co-exist."""702        if self.fields_coexist(data, ["instance_role", "s3_access"], **kwargs):703            raise ValidationError("S3Access can not be configured when InstanceRole is set.")704        if self.fields_coexist(data, ["instance_profile", "s3_access"], **kwargs):705            raise ValidationError("S3Access can not be configured when InstanceProfile is set.")706    @post_load707    def make_resource(self, data, **kwargs):708        """Generate resource."""709        return Iam(**data)710class HeadNodeIamSchema(IamSchema):711    """Represent the schema of IAM for HeadNode."""712    instance_profile = fields.Str(713        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp("^arn:.*:instance-profile/")714    )715class QueueIamSchema(IamSchema):716    """Represent the schema of IAM for Queue."""717    instance_profile = fields.Str(718        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP},719        validate=validate.Regexp("^arn:.*:instance-profile/"),720    )721class ImdsSchema(BaseSchema):722    """Represent the schema of IMDS for HeadNode."""723    secured = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})724    @post_load725    def make_resource(self, data, **kwargs):726        """Generate resource."""727        return Imds(**data)728class IntelSoftwareSchema(BaseSchema):729    """Represent the schema of additional packages."""730    intel_hpc_platform = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})731    @post_load732    def make_resource(self, data, **kwargs):733        """Generate resource."""734        return IntelSoftware(**data)735class AdditionalPackagesSchema(BaseSchema):736    """Represent the schema of additional packages."""737    intel_software = fields.Nested(IntelSoftwareSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})738    @post_load739    def make_resource(self, data, **kwargs):740        """Generate resource."""741        return AdditionalPackages(**data)742class AmiSearchFiltersSchema(BaseSchema):743    """Represent the schema of the AmiSearchFilters section."""744    tags = fields.Nested(745        TagSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Key"}746    )747    owner = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})748    @post_load()749    def make_resource(self, data, **kwargs):750        """Generate resource."""751        return AmiSearchFilters(**data)752class TimeoutsSchema(BaseSchema):753    """Represent the schema of the Timeouts section."""754    head_node_bootstrap_timeout = fields.Int(755        validate=validate.Range(min=1), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}756    )757    compute_node_bootstrap_timeout = fields.Int(758        validate=validate.Range(min=1), metadata={"update_policy": UpdatePolicy.SUPPORTED}759    )760    @post_load()761    def make_resource(self, data, **kwargs):762        """Generate resource."""763        return Timeouts(**data)764class ClusterDevSettingsSchema(BaseDevSettingsSchema):765    """Represent the schema of Dev Setting."""766    cluster_template = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})767    ami_search_filters = fields.Nested(AmiSearchFiltersSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})768    instance_types_data = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})769    timeouts = fields.Nested(TimeoutsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})770    @post_load771    def make_resource(self, data, **kwargs):772        """Generate resource."""773        return ClusterDevSettings(**data)774# ---------------------- Node and Cluster Schema ---------------------- #775class ImageSchema(BaseSchema):776    """Represent the schema of the Image."""777    os = fields.Str(778        required=True, validate=validate.OneOf(SUPPORTED_OSES), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}779    )780    custom_ami = fields.Str(781        validate=validate.Regexp(r"^ami-[0-9a-z]{8}$|^ami-[0-9a-z]{17}$"),782        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},783    )784    @post_load785    def make_resource(self, data, **kwargs):786        """Generate resource."""787        return Image(**data)788class HeadNodeImageSchema(BaseSchema):789    """Represent the schema of the HeadNode Image."""790    custom_ami = fields.Str(791        validate=validate.Regexp(r"^ami-[0-9a-z]{8}$|^ami-[0-9a-z]{17}$"),792        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},793    )794    @post_load795    def make_resource(self, data, **kwargs):796        """Generate resource."""797        return HeadNodeImage(**data)798class QueueImageSchema(BaseSchema):799    """Represent the schema of the Queue Image."""800    custom_ami = fields.Str(801        validate=validate.Regexp(r"^ami-[0-9a-z]{8}$|^ami-[0-9a-z]{17}$"),802        metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},803    )804    @post_load805    def make_resource(self, data, **kwargs):806        """Generate resource."""807        return QueueImage(**data)808class HeadNodeCustomActionSchema(BaseSchema):809    """Represent the schema of the custom action."""810    script = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})811    args = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})812    @post_load813    def make_resource(self, data, **kwargs):814        """Generate resource."""815        return CustomAction(**data)816class HeadNodeCustomActionsSchema(BaseSchema):817    """Represent the schema for all available custom actions."""818    on_node_start = fields.Nested(HeadNodeCustomActionSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})819    on_node_configured = fields.Nested(HeadNodeCustomActionSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})820    @post_load821    def make_resource(self, data, **kwargs):822        """Generate resource."""823        return CustomActions(**data)824class QueueCustomActionSchema(BaseSchema):825    """Represent the schema of the custom action."""826    script = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})827    args = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})828    @post_load829    def make_resource(self, data, **kwargs):830        """Generate resource."""831        return CustomAction(**data)832class QueueCustomActionsSchema(BaseSchema):833    """Represent the schema for all available custom actions."""834    on_node_start = fields.Nested(835        QueueCustomActionSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}836    )837    on_node_configured = fields.Nested(838        QueueCustomActionSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}839    )840    @post_load841    def make_resource(self, data, **kwargs):842        """Generate resource."""843        return CustomActions(**data)844class InstanceTypeSchema(BaseSchema):845    """Schema of a compute resource that supports a pool of instance types."""846    instance_type = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})847    @post_load848    def make_resource(self, data, **kwargs):849        """Generate resource."""850        return FlexibleInstanceType(**data)851class HeadNodeSchema(BaseSchema):852    """Represent the schema of the HeadNode."""853    instance_type = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})854    disable_simultaneous_multithreading = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})855    networking = fields.Nested(856        HeadNodeNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}857    )858    ssh = fields.Nested(SshSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})859    local_storage = fields.Nested(HeadNodeStorageSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})860    dcv = fields.Nested(DcvSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})861    custom_actions = fields.Nested(HeadNodeCustomActionsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})862    iam = fields.Nested(HeadNodeIamSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})863    imds = fields.Nested(ImdsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})864    image = fields.Nested(HeadNodeImageSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})865    @post_load()866    def make_resource(self, data, **kwargs):867        """Generate resource."""868        return HeadNode(**data)869class _ComputeResourceSchema(BaseSchema):870    """Represent the schema of the ComputeResource."""871    name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})872class SlurmComputeResourceSchema(_ComputeResourceSchema):873    """Represent the schema of the Slurm ComputeResource."""874    instance_type = fields.Str(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})875    instance_type_list = fields.Nested(876        InstanceTypeSchema,877        many=True,878        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP_ON_REMOVE, "update_key": "InstanceType"},879    )880    max_count = fields.Int(validate=validate.Range(min=1), metadata={"update_policy": UpdatePolicy.MAX_COUNT})881    min_count = fields.Int(validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})882    spot_price = fields.Float(883        validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}884    )885    efa = fields.Nested(EfaSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})886    disable_simultaneous_multithreading = fields.Bool(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})887    schedulable_memory = fields.Int(metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})888    @validates_schema889    def no_coexist_instance_type_flexibility(self, data, **kwargs):890        """Validate that 'instance_type' and 'instance_type_list' do not co-exist."""891        if self.fields_coexist(892            data,893            ["instance_type", "instance_type_list"],894            one_required=True,895            **kwargs,896        ):897            raise ValidationError("A Compute Resource needs to specify either InstanceType or InstanceTypeList.")898    @post_load899    def make_resource(self, data, **kwargs):900        """Generate resource."""901        if data.get("instance_type_list"):902            return SlurmFlexibleComputeResource(**data)903        return SlurmComputeResource(**data)904class AwsBatchComputeResourceSchema(_ComputeResourceSchema):905    """Represent the schema of the Batch ComputeResource."""906    instance_types = fields.List(907        fields.Str(), required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}908    )909    max_vcpus = fields.Int(910        data_key="MaxvCpus",911        validate=validate.Range(min=1),912        metadata={"update_policy": UpdatePolicy.AWSBATCH_CE_MAX_RESIZE},913    )914    min_vcpus = fields.Int(915        data_key="MinvCpus", validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.SUPPORTED}916    )917    desired_vcpus = fields.Int(918        data_key="DesiredvCpus", validate=validate.Range(min=0), metadata={"update_policy": UpdatePolicy.IGNORED}919    )920    spot_bid_percentage = fields.Int(921        validate=validate.Range(min=0, max=100, min_inclusive=False), metadata={"update_policy": UpdatePolicy.SUPPORTED}922    )923    @post_load924    def make_resource(self, data, **kwargs):925        """Generate resource."""926        return AwsBatchComputeResource(**data)927class SchedulerPluginComputeResourceSchema(SlurmComputeResourceSchema):928    """Represent the schema of the Scheduler Plugin ComputeResource."""929    custom_settings = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})930    @post_load931    def make_resource(self, data, **kwargs):932        """Generate resource."""933        return SchedulerPluginComputeResource(**data)934class ComputeSettingsSchema(BaseSchema):935    """Represent the schema of the compute_settings schedulers queues."""936    local_storage = fields.Nested(QueueStorageSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})937    @post_load()938    def make_resource(self, data, **kwargs):939        """Generate resource."""940        return ComputeSettings(**data)941class BaseQueueSchema(BaseSchema):942    """Represent the schema of the attributes in common between all the schedulers queues."""943    name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})944    capacity_type = fields.Str(945        validate=validate.OneOf([event.value for event in CapacityType]),946        metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},947    )948class _CommonQueueSchema(BaseQueueSchema):949    """Represent the schema of common part between Slurm and Scheduler Plugin Queue."""950    compute_settings = fields.Nested(951        ComputeSettingsSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}952    )953    custom_actions = fields.Nested(954        QueueCustomActionsSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}955    )956    iam = fields.Nested(QueueIamSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})957    image = fields.Nested(QueueImageSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY})958class SlurmQueueSchema(_CommonQueueSchema):959    """Represent the schema of a Slurm Queue."""960    allocation_strategy = fields.Str(961        validate=validate.OneOf(["lowest-price", "capacity-optimized"]),962        metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},963    )964    compute_resources = fields.Nested(965        SlurmComputeResourceSchema,966        many=True,967        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP_ON_REMOVE, "update_key": "Name"},968    )969    networking = fields.Nested(970        SlurmQueueNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}971    )972    @post_load973    def make_resource(self, data, **kwargs):974        """Generate resource."""975        return SlurmQueue(**data)976class AwsBatchQueueSchema(BaseQueueSchema):977    """Represent the schema of a Batch Queue."""978    compute_resources = fields.Nested(979        AwsBatchComputeResourceSchema,980        many=True,981        validate=validate.Length(equal=1),982        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},983    )984    networking = fields.Nested(985        AwsBatchQueueNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}986    )987    @post_load988    def make_resource(self, data, **kwargs):989        """Generate resource."""990        return AwsBatchQueue(**data)991class SchedulerPluginQueueSchema(_CommonQueueSchema):992    """Represent the schema of a Scheduler Plugin Queue."""993    compute_resources = fields.Nested(994        SchedulerPluginComputeResourceSchema,995        many=True,996        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},997    )998    networking = fields.Nested(999        SchedulerPluginQueueNetworkingSchema, required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1000    )1001    custom_settings = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1002    @post_load1003    def make_resource(self, data, **kwargs):1004        """Generate resource."""1005        return SchedulerPluginQueue(**data)1006class DnsSchema(BaseSchema):1007    """Represent the schema of Dns Settings."""1008    disable_managed_dns = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1009    hosted_zone_id = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1010    use_ec2_hostnames = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1011    @post_load1012    def make_resource(self, data, **kwargs):1013        """Generate resource."""1014        return Dns(**data)1015class SlurmSettingsSchema(BaseSchema):1016    """Represent the schema of the Scheduling Settings."""1017    scaledown_idletime = fields.Int(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1018    dns = fields.Nested(DnsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1019    queue_update_strategy = fields.Str(1020        validate=validate.OneOf([strategy.value for strategy in QueueUpdateStrategy]),1021        metadata={"update_policy": UpdatePolicy.IGNORED},1022    )1023    enable_memory_based_scheduling = fields.Bool(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1024    @post_load1025    def make_resource(self, data, **kwargs):1026        """Generate resource."""1027        return SlurmSettings(**data)1028class AwsBatchSettingsSchema(BaseSchema):1029    """Represent the schema of the AwsBatch Scheduling Settings."""1030    @post_load1031    def make_resource(self, data, **kwargs):1032        """Generate resource."""1033        return AwsBatchSettings(**data)1034class SchedulerPluginSupportedDistrosSchema(BaseSchema):1035    """Represent the schema for SupportedDistros in a Scheduler Plugin."""1036    x86 = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1037    arm64 = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1038    @post_load1039    def make_resource(self, data, **kwargs):1040        """Generate resource."""1041        return SchedulerPluginSupportedDistros(**data)1042class SchedulerPluginQueueConstraintsSchema(BaseSchema):1043    """Represent the schema for QueueConstraints in a Scheduler Plugin."""1044    max_count = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1045    @post_load1046    def make_resource(self, data, **kwargs):1047        """Generate resource."""1048        return SchedulerPluginQueueConstraints(**data)1049class SchedulerPluginComputeResourceConstraintsSchema(BaseSchema):1050    """Represent the schema for ComputeResourceConstraints in a Scheduler Plugin."""1051    max_count = fields.Int(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1052    @post_load1053    def make_resource(self, data, **kwargs):1054        """Generate resource."""1055        return SchedulerPluginComputeResourceConstraints(**data)1056class SchedulerPluginRequirementsSchema(BaseSchema):1057    """Represent the schema for Requirements in a Scheduler Plugin."""1058    supported_distros = fields.Nested(1059        SchedulerPluginSupportedDistrosSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1060    )1061    supported_regions = fields.List(fields.Str(), metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1062    queue_constraints = fields.Nested(1063        SchedulerPluginQueueConstraintsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1064    )1065    compute_resource_constraints = fields.Nested(1066        SchedulerPluginComputeResourceConstraintsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1067    )1068    requires_sudo_privileges = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1069    supports_cluster_update = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1070    supported_parallel_cluster_versions = fields.Str(1071        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1072        validate=validate.Regexp(1073            r"^((>|<|>=|<=)?[0-9]+\.[0-9]+\.[0-9]+([a-z][0-9]+)?,\s*)*(>|<|>=|<=)?[0-9]+\.[0-9]+\.[0-9]+([a-z][0-9]+)?$"1074        ),1075    )1076    @post_load1077    def make_resource(self, data, **kwargs):1078        """Generate resource."""1079        return SchedulerPluginRequirements(**data)1080class SchedulerPluginCloudFormationClusterInfrastructureSchema(BaseSchema):1081    """Represent the CloudFormation section of the Scheduler Plugin ClusterInfrastructure schema."""1082    template = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1083    s3_bucket_owner = fields.Str(1084        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^\d{12}$")1085    )1086    checksum = fields.Str(1087        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[A-Fa-f0-9]{64}$")1088    )1089    @post_load1090    def make_resource(self, data, **kwargs):1091        """Generate resource."""1092        return SchedulerPluginCloudFormationInfrastructure(**data)1093class SchedulerPluginClusterInfrastructureSchema(BaseSchema):1094    """Represent the schema for ClusterInfrastructure schema in a Scheduler Plugin."""1095    cloud_formation = fields.Nested(1096        SchedulerPluginCloudFormationClusterInfrastructureSchema,1097        required=True,1098        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1099    )1100    @post_load1101    def make_resource(self, data, **kwargs):1102        """Generate resource."""1103        return SchedulerPluginClusterInfrastructure(**data)1104class SchedulerPluginClusterSharedArtifactSchema(BaseSchema):1105    """Represent the schema for Cluster Shared Artifact in a Scheduler Plugin."""1106    source = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1107    s3_bucket_owner = fields.Str(1108        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^\d{12}$")1109    )1110    checksum = fields.Str(1111        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[A-Fa-f0-9]{64}$")1112    )1113    @post_load1114    def make_resource(self, data, **kwargs):1115        """Generate resource."""1116        return SchedulerPluginClusterSharedArtifact(**data)1117class SchedulerPluginResourcesSchema(BaseSchema):1118    """Represent the schema for Plugin Resouces in a Scheduler Plugin."""1119    cluster_shared_artifacts = fields.Nested(1120        SchedulerPluginClusterSharedArtifactSchema,1121        many=True,1122        required=True,1123        metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Source"},1124    )1125    @post_load1126    def make_resource(self, data, **kwargs):1127        """Generate resource."""1128        return SchedulerPluginPluginResources(**data)1129class SchedulerPluginExecuteCommandSchema(BaseSchema):1130    """Represent the schema for ExecuteCommand in a Scheduler Plugin."""1131    command = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1132    @post_load1133    def make_resource(self, data, **kwargs):1134        """Generate resource."""1135        return SchedulerPluginExecuteCommand(**data)1136class SchedulerPluginEventSchema(BaseSchema):1137    """Represent the schema for Event in a Scheduler Plugin."""1138    execute_command = fields.Nested(1139        SchedulerPluginExecuteCommandSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1140    )1141    @post_load1142    def make_resource(self, data, **kwargs):1143        """Generate resource."""1144        return SchedulerPluginEvent(**data)1145class SchedulerPluginEventsSchema(BaseSchema):1146    """Represent the schema for Events in a Scheduler Plugin."""1147    head_init = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1148    head_configure = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1149    head_finalize = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1150    compute_init = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1151    compute_configure = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1152    compute_finalize = fields.Nested(SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1153    head_cluster_update = fields.Nested(1154        SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1155    )1156    head_compute_fleet_update = fields.Nested(1157        SchedulerPluginEventSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1158    )1159    @post_load1160    def make_resource(self, data, **kwargs):1161        """Generate resource."""1162        return SchedulerPluginEvents(**data)1163class SchedulerPluginFileSchema(BaseSchema):1164    """Represent the schema of the Scheduler Plugin."""1165    file_path = fields.Str(1166        required=True, validate=get_field_validator("file_path"), metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1167    )1168    timestamp_format = fields.Str(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1169    node_type = fields.Str(1170        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.OneOf(["HEAD", "COMPUTE", "ALL"])1171    )1172    log_stream_name = fields.Str(1173        required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[^:*]*$")1174    )1175    @post_load1176    def make_resource(self, data, **kwargs):1177        """Generate resource."""1178        return SchedulerPluginFile(**data)1179class SchedulerPluginLogsSchema(BaseSchema):1180    """Represent the schema of the Scheduler Plugin Logs."""1181    files = fields.Nested(1182        SchedulerPluginFileSchema,1183        required=True,1184        many=True,1185        metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "FilePath"},1186    )1187    @post_load1188    def make_resource(self, data, **kwargs):1189        """Generate resource."""1190        return SchedulerPluginLogs(**data)1191class SchedulerPluginMonitoringSchema(BaseSchema):1192    """Represent the schema of the Scheduler plugin Monitoring."""1193    logs = fields.Nested(SchedulerPluginLogsSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1194    @post_load1195    def make_resource(self, data, **kwargs):1196        """Generate resource."""1197        return SchedulerPluginMonitoring(**data)1198class SudoerConfigurationSchema(BaseSchema):1199    """Represent the SudoerConfiguration for scheduler plugin SystemUsers declared in the SchedulerDefinition."""1200    commands = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1201    run_as = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1202    @post_load1203    def make_resource(self, data, **kwargs):1204        """Generate resource."""1205        return SudoerConfiguration(**data)1206class SchedulerPluginUserSchema(BaseSchema):1207    """Represent the schema of the Scheduler Plugin."""1208    name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1209    enable_imds = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1210    sudoer_configuration = fields.Nested(1211        SudoerConfigurationSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Name"}1212    )1213    @post_load1214    def make_resource(self, data, **kwargs):1215        """Generate resource."""1216        return SchedulerPluginUser(**data)1217class SchedulerPluginDefinitionSchema(BaseSchema):1218    """Represent the schema of the Scheduler Plugin SchedulerDefinition."""1219    plugin_interface_version = fields.Str(1220        required=True,1221        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1222        validate=validate.Regexp(r"^[0-9]+\.[0-9]+$"),1223    )1224    metadata = fields.Dict(metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, required=True)1225    requirements = fields.Nested(1226        SchedulerPluginRequirementsSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1227    )1228    cluster_infrastructure = fields.Nested(1229        SchedulerPluginClusterInfrastructureSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1230    )1231    plugin_resources = fields.Nested(1232        SchedulerPluginResourcesSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1233    )1234    events = fields.Nested(1235        SchedulerPluginEventsSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1236    )1237    monitoring = fields.Nested(SchedulerPluginMonitoringSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1238    system_users = fields.Nested(1239        SchedulerPluginUserSchema,1240        many=True,1241        validate=validate.Length(max=SCHEDULER_PLUGIN_MAX_NUMBER_OF_USERS),1242        metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Name"},1243    )1244    tags = fields.Nested(1245        TagSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Key"}1246    )1247    @post_load1248    def make_resource(self, data, **kwargs):1249        """Generate resource."""1250        return SchedulerPluginDefinition(**data)1251    @validates("metadata")1252    def validate_metadata(self, value):1253        """Validate metadata contains fieds 'name' and 'version'."""1254        for key in ["Name", "Version"]:1255            if key not in value.keys():1256                raise ValidationError(f"{key} is required for scheduler plugin Metadata.")1257class SchedulerPluginSettingsSchema(BaseSchema):1258    """Represent the schema of the Scheduling Settings."""1259    scheduler_definition = fields.Nested(1260        SchedulerPluginDefinitionSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED}1261    )1262    grant_sudo_privileges = fields.Bool(metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1263    custom_settings = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1264    scheduler_definition_s3_bucket_owner = fields.Str(1265        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^\d{12}$")1266    )1267    scheduler_definition_checksum = fields.Str(1268        metadata={"update_policy": UpdatePolicy.UNSUPPORTED}, validate=validate.Regexp(r"^[A-Fa-f0-9]{64}$")1269    )1270    def _verify_checksum(self, file_content, original_definition, expected_checksum):1271        if expected_checksum:1272            actual_checksum = hashlib.sha256(file_content.encode()).hexdigest()1273            if actual_checksum != expected_checksum:1274                raise ValidationError(1275                    f"Error when validating SchedulerDefinition '{original_definition}': "1276                    f"checksum ({actual_checksum}) does not match expected one ({expected_checksum})"1277                )1278    def _fetch_scheduler_definition_from_s3(self, original_scheduler_definition, s3_bucket_owner):1279        try:1280            bucket_parsing_result = parse_bucket_url(original_scheduler_definition)1281            result = AWSApi.instance().s3.get_object(1282                bucket_name=bucket_parsing_result["bucket_name"],1283                key=bucket_parsing_result["object_key"],1284                expected_bucket_owner=s3_bucket_owner,1285            )1286            scheduler_definition = result["Body"].read().decode("utf-8")1287            return scheduler_definition1288        except AWSClientError as e:1289            error_message = (1290                f"Error while downloading scheduler definition from {original_scheduler_definition}: {str(e)}"1291            )1292            if s3_bucket_owner and e.error_code == "AccessDenied":1293                error_message = (1294                    f"{error_message}. This can be due to bucket owner not matching the expected "1295                    f"one '{s3_bucket_owner}'"1296                )1297            raise ValidationError(error_message) from e1298        except Exception as e:1299            raise ValidationError(1300                f"Error while downloading scheduler definition from {original_scheduler_definition}: {str(e)}"1301            ) from e1302    def _fetch_scheduler_definition_from_https(self, original_scheduler_definition):1303        try:1304            with urlopen(original_scheduler_definition) as f:  # nosec nosemgrep1305                scheduler_definition = f.read().decode("utf-8")1306                return scheduler_definition1307        except Exception:1308            error_message = (1309                f"Error while downloading scheduler definition from {original_scheduler_definition}: "1310                "The provided URL is invalid or unavailable."1311            )1312            raise ValidationError(error_message)1313    def _validate_scheduler_definition_url(self, original_scheduler_definition, s3_bucket_owner):1314        """Validate SchedulerDefinition url is valid."""1315        if not original_scheduler_definition.startswith("s3") and not original_scheduler_definition.startswith("https"):1316            raise ValidationError(1317                f"Error while downloading scheduler definition from {original_scheduler_definition}: The provided value"1318                " for SchedulerDefinition is invalid. You can specify this as an S3 URL, HTTPS URL or as an inline "1319                "YAML object."1320            )1321        if original_scheduler_definition.startswith("https") and s3_bucket_owner:1322            raise ValidationError(1323                f"Error while downloading scheduler definition from {original_scheduler_definition}: "1324                "SchedulerDefinitionS3BucketOwner can only be specified when SchedulerDefinition is S3 URL."1325            )1326    def _fetch_scheduler_definition_from_url(1327        self, original_scheduler_definition, s3_bucket_owner, scheduler_definition_checksum, data1328    ):1329        LOGGER.info("Downloading scheduler plugin definition from %s", original_scheduler_definition)1330        if original_scheduler_definition.startswith("s3"):1331            scheduler_definition = self._fetch_scheduler_definition_from_s3(1332                original_scheduler_definition, s3_bucket_owner1333            )1334        elif original_scheduler_definition.startswith("https"):1335            scheduler_definition = self._fetch_scheduler_definition_from_https(original_scheduler_definition)1336        self._verify_checksum(scheduler_definition, original_scheduler_definition, scheduler_definition_checksum)1337        LOGGER.info("Using the following scheduler plugin definition:\n%s", scheduler_definition)1338        try:1339            data["SchedulerDefinition"] = yaml.safe_load(scheduler_definition)1340        except YAMLError as e:1341            raise ValidationError(1342                f"The retrieved SchedulerDefinition ({original_scheduler_definition}) is not a valid YAML."1343            ) from e1344    @post_load1345    def make_resource(self, data, **kwargs):1346        """Generate resource."""1347        return SchedulerPluginSettings(**data)1348    @pre_load1349    def fetch_scheduler_definition(self, data, **kwargs):1350        """Fetch scheduler definition if it is s3 or https url."""1351        original_scheduler_definition = data["SchedulerDefinition"]1352        s3_bucket_owner = data.get("SchedulerDefinitionS3BucketOwner", None)1353        scheduler_definition_checksum = data.get("SchedulerDefinitionChecksum", None)1354        if isinstance(original_scheduler_definition, str):1355            self._validate_scheduler_definition_url(original_scheduler_definition, s3_bucket_owner)1356            self._fetch_scheduler_definition_from_url(1357                original_scheduler_definition, s3_bucket_owner, scheduler_definition_checksum, data1358            )1359        elif s3_bucket_owner or scheduler_definition_checksum:1360            raise ValidationError(1361                "SchedulerDefinitionS3BucketOwner or SchedulerDefinitionChecksum can only specified when "1362                "SchedulerDefinition is a URL."1363            )1364        return data1365class SchedulingSchema(BaseSchema):1366    """Represent the schema of the Scheduling."""1367    scheduler = fields.Str(1368        required=True,1369        validate=validate.OneOf(["slurm", "awsbatch", "plugin"]),1370        metadata={"update_policy": UpdatePolicy.UNSUPPORTED},1371    )1372    # Slurm schema1373    slurm_settings = fields.Nested(SlurmSettingsSchema, metadata={"update_policy": UpdatePolicy.IGNORED})1374    slurm_queues = fields.Nested(1375        SlurmQueueSchema,1376        many=True,1377        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP_ON_REMOVE, "update_key": "Name"},1378    )1379    # Awsbatch schema:1380    aws_batch_queues = fields.Nested(1381        AwsBatchQueueSchema,1382        many=True,1383        validate=validate.Length(equal=1),1384        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},1385    )1386    aws_batch_settings = fields.Nested(1387        AwsBatchSettingsSchema, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1388    )1389    # Scheduler Plugin1390    scheduler_settings = fields.Nested(1391        SchedulerPluginSettingsSchema, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1392    )1393    scheduler_queues = fields.Nested(1394        SchedulerPluginQueueSchema,1395        many=True,1396        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP, "update_key": "Name"},1397    )1398    @validates_schema1399    def no_coexist_schedulers(self, data, **kwargs):1400        """Validate that *_settings and *_queues for different schedulers do not co-exist."""1401        scheduler = data.get("scheduler")1402        if self.fields_coexist(data, ["aws_batch_settings", "slurm_settings", "scheduler_settings"], **kwargs):1403            raise ValidationError("Multiple *Settings sections cannot be specified in the Scheduling section.")1404        if self.fields_coexist(1405            data, ["aws_batch_queues", "slurm_queues", "scheduler_queues"], one_required=True, **kwargs1406        ):1407            if scheduler == "awsbatch":1408                scheduler_prefix = "AwsBatch"1409            elif scheduler == "plugin":1410                scheduler_prefix = "Scheduler"1411            else:1412                scheduler_prefix = scheduler.capitalize()1413            raise ValidationError(f"{scheduler_prefix}Queues section must be specified in the Scheduling section.")1414    @validates_schema1415    def right_scheduler_schema(self, data, **kwargs):1416        """Validate that *_settings field is associated to the right scheduler."""1417        for scheduler, settings, queues in [1418            ("awsbatch", "aws_batch_settings", "aws_batch_queues"),1419            ("slurm", "slurm_settings", "slurm_queues"),1420            ("plugin", "scheduler_settings", "scheduler_queues"),1421        ]:1422            # Verify the settings section is associated to the right scheduler type1423            configured_scheduler = data.get("scheduler")1424            if settings in data and scheduler != configured_scheduler:1425                raise ValidationError(1426                    f"Scheduling > *Settings section is not appropriate to the Scheduler: {configured_scheduler}."1427                )1428            if queues in data and scheduler != configured_scheduler:1429                raise ValidationError(1430                    f"Scheduling > *Queues section is not appropriate to the Scheduler: {configured_scheduler}."1431                )1432    @validates_schema1433    def same_subnet_in_different_queues(self, data, **kwargs):1434        """Validate subnet_ids configured in different queues are the same."""1435        if "slurm_queues" in data or "scheduler_queues" in data:1436            queues = "slurm_queues" if "slurm_queues" in data else "scheduler_queues"1437            def _queue_has_subnet_ids(queue):1438                return queue.networking and queue.networking.subnet_ids1439            subnet_ids = {tuple(set(q.networking.subnet_ids)) for q in data[queues] if _queue_has_subnet_ids(q)}1440            if len(subnet_ids) > 1:1441                raise ValidationError("The SubnetIds used for all of the queues should be the same.")1442    @post_load1443    def make_resource(self, data, **kwargs):1444        """Generate the right type of scheduling according to the child type (Slurm vs AwsBatch vs Custom)."""1445        scheduler = data.get("scheduler")1446        if scheduler == "slurm":1447            return SlurmScheduling(queues=data.get("slurm_queues"), settings=data.get("slurm_settings", None))1448        if scheduler == "plugin":1449            return SchedulerPluginScheduling(1450                queues=data.get("scheduler_queues"), settings=data.get("scheduler_settings", None)1451            )1452        if scheduler == "awsbatch":1453            return AwsBatchScheduling(1454                queues=data.get("aws_batch_queues"), settings=data.get("aws_batch_settings", None)1455            )1456        return None1457    @pre_dump1458    def restore_child(self, data, **kwargs):1459        """Restore back the child in the schema, see post_load action."""1460        adapted_data = copy.deepcopy(data)1461        if adapted_data.scheduler == "awsbatch":1462            scheduler_prefix = "aws_batch"1463        elif adapted_data.scheduler == "plugin":1464            scheduler_prefix = "scheduler"1465        else:1466            scheduler_prefix = adapted_data.scheduler1467        setattr(adapted_data, f"{scheduler_prefix}_queues", copy.copy(getattr(adapted_data, "queues", None)))1468        setattr(adapted_data, f"{scheduler_prefix}_settings", copy.copy(getattr(adapted_data, "settings", None)))1469        return adapted_data1470class DirectoryServiceSchema(BaseSchema):1471    """Represent the schema of the DirectoryService."""1472    domain_name = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1473    domain_addr = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1474    password_secret_arn = fields.Str(1475        required=True,1476        validate=validate.Regexp(r"^arn:.*:secret"),1477        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP},1478    )1479    domain_read_only_user = fields.Str(required=True, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1480    ldap_tls_ca_cert = fields.Str(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1481    ldap_tls_req_cert = fields.Str(1482        validate=validate.OneOf(["never", "allow", "try", "demand", "hard"]),1483        metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP},1484    )1485    ldap_access_filter = fields.Str(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1486    generate_ssh_keys_for_users = fields.Bool(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1487    additional_sssd_configs = fields.Dict(metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP})1488    @post_load1489    def make_resource(self, data, **kwargs):1490        """Generate resource."""1491        return DirectoryService(**data)1492class ClusterSchema(BaseSchema):1493    """Represent the schema of the Cluster."""1494    image = fields.Nested(ImageSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1495    head_node = fields.Nested(HeadNodeSchema, required=True, metadata={"update_policy": UpdatePolicy.SUPPORTED})1496    scheduling = fields.Nested(SchedulingSchema, required=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1497    shared_storage = fields.Nested(1498        SharedStorageSchema,1499        many=True,1500        metadata={1501            "update_policy": UpdatePolicy(1502                UpdatePolicy.UNSUPPORTED, fail_reason=UpdatePolicy.FAIL_REASONS["shared_storage_change"]1503            ),1504            "update_key": "Name",1505        },1506    )1507    monitoring = fields.Nested(MonitoringSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})1508    additional_packages = fields.Nested(AdditionalPackagesSchema, metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1509    tags = fields.Nested(1510        TagSchema, many=True, metadata={"update_policy": UpdatePolicy.UNSUPPORTED, "update_key": "Key"}1511    )1512    iam = fields.Nested(ClusterIamSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})1513    directory_service = fields.Nested(1514        DirectoryServiceSchema, metadata={"update_policy": UpdatePolicy.COMPUTE_FLEET_STOP}1515    )1516    config_region = fields.Str(data_key="Region", metadata={"update_policy": UpdatePolicy.UNSUPPORTED})1517    custom_s3_bucket = fields.Str(metadata={"update_policy": UpdatePolicy.READ_ONLY_RESOURCE_BUCKET})1518    additional_resources = fields.Str(metadata={"update_policy": UpdatePolicy.SUPPORTED})1519    dev_settings = fields.Nested(ClusterDevSettingsSchema, metadata={"update_policy": UpdatePolicy.SUPPORTED})1520    def __init__(self, cluster_name: str):1521        super().__init__()1522        self.cluster_name = cluster_name1523    @validates("tags")1524    def validate_tags(self, tags):1525        """Validate tags."""1526        validate_no_reserved_tag(tags)1527    @validates_schema1528    def no_settings_for_batch(self, data, **kwargs):1529        """Ensure IntelSoftware and DirectoryService section is not included when AWS Batch is the scheduler."""1530        scheduling = data.get("scheduling")1531        if scheduling and scheduling.scheduler == "awsbatch":1532            error_message = "The use of the {} configuration is not supported when using awsbatch as the scheduler."1533            additional_packages = data.get("additional_packages")1534            if additional_packages and additional_packages.intel_software.intel_hpc_platform:1535                raise ValidationError(error_message.format("IntelSoftware"))1536            if data.get("directory_service"):1537                raise ValidationError(error_message.format("DirectoryService"))1538    @post_load(pass_original=True)1539    def make_resource(self, data, original_data, **kwargs):1540        """Generate cluster according to the scheduler. Save original configuration."""1541        scheduler = data.get("scheduling").scheduler1542        if scheduler == "slurm":1543            cluster = SlurmClusterConfig(cluster_name=self.cluster_name, **data)1544        elif scheduler == "awsbatch":1545            cluster = AwsBatchClusterConfig(cluster_name=self.cluster_name, **data)1546        elif scheduler == "plugin":1547            cluster = SchedulerPluginClusterConfig(cluster_name=self.cluster_name, **data)1548        else:1549            raise ValidationError(f"Unsupported scheduler {scheduler}.")1550        cluster.source_config = original_data...maddpg.py
Source:maddpg.py  
1import torch2import numpy as np3import copy4import itertools5from offpolicy.utils.util import huber_loss, mse_loss, to_torch6from offpolicy.utils.popart import PopArt7from offpolicy.algorithms.base.trainer import Trainer8class MADDPG(Trainer):9    def __init__(self, args, num_agents, policies, policy_mapping_fn, device=None, actor_update_interval=1):10        """11        Trainer class for MADDPG. See parent class for more information.12        :param actor_update_interval: (int) number of critic updates to perform between every update to the actor.13        """14        self.args = args15        self.use_popart = self.args.use_popart16        self.use_value_active_masks = self.args.use_value_active_masks17        self.use_per = self.args.use_per18        self.per_eps = self.args.per_eps19        self.use_huber_loss = self.args.use_huber_loss20        self.huber_delta = self.args.huber_delta21        self.tpdv = dict(dtype=torch.float32, device=device)22        self.num_agents = num_agents23        self.policies = policies24        self.policy_mapping_fn = policy_mapping_fn25        self.policy_ids = sorted(list(self.policies.keys()))26        self.policy_agents = {policy_id: sorted(27            [agent_id for agent_id in range(self.num_agents) if self.policy_mapping_fn(agent_id) == policy_id]) for policy_id in28            self.policies.keys()}29        if self.use_popart:30            self.value_normalizer = {policy_id: PopArt(1) for policy_id in self.policies.keys()}31        self.num_updates = {p_id : 0 for p_id in self.policy_ids}32        self.use_same_share_obs = self.args.use_same_share_obs33        self.actor_update_interval = actor_update_interval34    # @profile35    def get_update_info(self, update_policy_id, obs_batch, act_batch, nobs_batch, navail_act_batch):36        """37        Form centralized observation and action info for current and next timestep.38        :param update_policy_id: (str) id of policy being updated.39        :param obs_batch: (np.ndarray) batch of observation sequences sampled from buffer.40        :param act_batch: (np.ndarray) batch of action sequences sampled from buffer.41        :param avail_act_batch: (np.ndarray) batch of available action sequences sampled from buffer. None if environment does not limit actions.42        :return cent_act: (list) list of action sequences corresponding to each agent.43        :return replace_ind_start: (int) index of act_sequences from which to replace actions for actor update.44        :return cent_nact: (np.ndarray) batch of centralize next step actions.45        """46        cent_act = []47        cent_nact = []48        replace_ind_start = None49        # iterate through policies to get the target acts and other centralized info50        ind = 051        for p_id in self.policy_ids:52            batch_size = obs_batch[p_id].shape[1]53            policy = self.policies[p_id]54            if p_id == update_policy_id:55                replace_ind_start = ind56            num_pol_agents = len(self.policy_agents[p_id])57            cent_act.append(list(act_batch[p_id]))58            combined_nobs_batch = np.concatenate(nobs_batch[p_id], axis=0)59            if navail_act_batch[p_id] is not None:60                combined_navail_act_batch = np.concatenate(navail_act_batch[p_id], axis=0)61            else:62                combined_navail_act_batch = None63            # use target actor to get next step actions64            with torch.no_grad():65                pol_nact, _ = policy.get_actions(combined_nobs_batch, combined_navail_act_batch, use_target=True)66                ind_agent_nacts = pol_nact.cpu().split(split_size=batch_size, dim=0)67            # cat to form the centralized next step actions68            cent_nact.append(torch.cat(ind_agent_nacts, dim=-1))69            ind += num_pol_agents70        cent_act = list(itertools.chain.from_iterable(cent_act))71        cent_nact = np.concatenate(cent_nact, axis=-1)72        return cent_act, replace_ind_start, cent_nact73    def train_policy_on_batch(self, update_policy_id, batch):74        """See parent class."""75        if self.use_same_share_obs:76            return self.shared_train_policy_on_batch(update_policy_id, batch)77        else:78            return self.cent_train_policy_on_batch(update_policy_id, batch)79    def shared_train_policy_on_batch(self, update_policy_id, batch):80        """Training function when all agents share the same centralized observation. See train_policy_on_batch."""81        obs_batch, cent_obs_batch, \82        act_batch, rew_batch, \83        nobs_batch, cent_nobs_batch, \84        dones_batch, dones_env_batch, valid_transition_batch,\85        avail_act_batch, navail_act_batch, \86        importance_weights, idxes = batch87        train_info = {}88        update_actor = self.num_updates[update_policy_id] % self.actor_update_interval == 089        cent_act, replace_ind_start, cent_nact = self.get_update_info(update_policy_id, obs_batch, act_batch, nobs_batch, navail_act_batch)90        cent_obs = cent_obs_batch[update_policy_id]91        cent_nobs = cent_nobs_batch[update_policy_id]92        rewards = rew_batch[update_policy_id][0]93        dones_env = dones_env_batch[update_policy_id]94        update_policy = self.policies[update_policy_id]95        batch_size = cent_obs.shape[0]96        # critic update97        with torch.no_grad():98            next_step_Qs = update_policy.target_critic(cent_nobs, cent_nact)99        next_step_Q = torch.cat(next_step_Qs, dim=-1)100        # take min to prevent overestimation bias101        next_step_Q, _ = torch.min(next_step_Q, dim=-1, keepdim=True)102        rewards = to_torch(rewards).to(**self.tpdv).view(-1, 1)103        dones_env = to_torch(dones_env).to(**self.tpdv).view(-1, 1)104        if self.use_popart:105            target_Qs = rewards + self.args.gamma * (1 - dones_env) * self.value_normalizer[p_id].denormalize(next_step_Q)106            target_Qs = self.value_normalizer[p_id](target_Qs)107        else:108            target_Qs = rewards + self.args.gamma * (1 - dones_env) * next_step_Q109        predicted_Qs = update_policy.critic(cent_obs, np.concatenate(cent_act, axis=-1))110        update_policy.critic_optimizer.zero_grad()111        # detach the targets112        errors = [target_Qs.detach() - predicted_Q for predicted_Q in predicted_Qs]113        if self.use_per:114            importance_weights = to_torch(importance_weights).to(**self.tpdv)115            if self.use_huber_loss:116                critic_loss = [huber_loss(error, self.huber_delta).flatten() for error in errors]117            else:118                critic_loss = [mse_loss(error).flatten() for error in errors]119            # weight each loss element by their importance sample weight120            critic_loss = [(loss * importance_weights).mean() for loss in critic_loss]121            critic_loss = torch.stack(critic_loss).sum(dim=0)122            # new priorities are TD error123            new_priorities = np.stack([error.abs().cpu().detach().numpy().flatten() for error in errors]).mean(axis=0) + self.per_eps124        else:125            if self.use_huber_loss:126                critic_loss = [huber_loss(error, self.huber_delta).mean() for error in errors]127            else:128                critic_loss = [mse_loss(error).mean() for error in errors]129            critic_loss = torch.stack(critic_loss).sum(dim=0)130            new_priorities = None131        critic_loss.backward()132        critic_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.critic.parameters(),133                                                                 self.args.max_grad_norm)134        update_policy.critic_optimizer.step()135        train_info['critic_loss'] = critic_loss136        train_info['critic_grad_norm'] = critic_grad_norm137        138        if update_actor:139            # actor update140            # need to zero the critic gradient and the actor gradient since the gradients first flow through critic before getting to actor during backprop141            # freeze Q-networks142            for p in update_policy.critic.parameters():143                p.requires_grad = False144            num_update_agents = len(self.policy_agents[update_policy_id])145            mask_temp = []146            for p_id in self.policy_ids:147                if isinstance(self.policies[p_id].act_dim, np.ndarray):148                    # multidiscrete case149                    sum_act_dim = int(sum(self.policies[p_id].act_dim))150                else:151                    sum_act_dim = self.policies[p_id].act_dim152                for _ in self.policy_agents[p_id]:153                    mask_temp.append(np.zeros(sum_act_dim, dtype=np.float32))154            masks = []155            valid_trans_mask = []156            # need to iterate through agents, but only formulate masks at each step157            for i in range(num_update_agents):158                curr_mask_temp = copy.deepcopy(mask_temp)159                # set the mask to 1 at locations where the action should come from the actor output160                if isinstance(update_policy.act_dim, np.ndarray):161                    # multidiscrete case162                    sum_act_dim = int(sum(update_policy.act_dim))163                else:164                    sum_act_dim = update_policy.act_dim165                curr_mask_temp[replace_ind_start + i] = np.ones(sum_act_dim, dtype=np.float32)166                curr_mask_vec = np.concatenate(curr_mask_temp)167                # expand this mask into the proper size168                curr_mask = np.tile(curr_mask_vec, (batch_size, 1))169                masks.append(curr_mask)170                # agent valid transitions171                agent_valid_trans_batch = to_torch(valid_transition_batch[update_policy_id][i]).to(**self.tpdv)172                valid_trans_mask.append(agent_valid_trans_batch)173            # cat to form into tensors174            mask = to_torch(np.concatenate(masks)).to(**self.tpdv)175            valid_trans_mask = torch.cat(valid_trans_mask, dim=0)176            pol_agents_obs_batch = np.concatenate(obs_batch[update_policy_id], axis=0)177            if avail_act_batch[update_policy_id] is not None:178                pol_agents_avail_act_batch = np.concatenate(avail_act_batch[update_policy_id], axis=0)179            else:180                pol_agents_avail_act_batch = None181            # get all actions from actor182            pol_acts, _ = update_policy.get_actions(pol_agents_obs_batch, pol_agents_avail_act_batch, use_gumbel=True)183            # separate into individual agent batches184            agent_actor_batches = pol_acts.split(split_size=batch_size, dim=0)185            cent_act = list(map(lambda arr: to_torch(arr).to(**self.tpdv), cent_act))186            actor_cent_acts = copy.deepcopy(cent_act)187            for i in range(num_update_agents):188                actor_cent_acts[replace_ind_start + i] = agent_actor_batches[i]189            actor_cent_acts = torch.cat(actor_cent_acts, dim=-1).repeat((num_update_agents, 1))190            # convert buffer acts to torch, formulate centralized buffer action and repeat as done above191            buffer_cent_acts = torch.cat(cent_act, dim=-1).repeat(num_update_agents, 1)192            # also repeat cent obs193            stacked_cent_obs = np.tile(cent_obs, (num_update_agents, 1))194            # combine the buffer cent acts with actor cent acts and pass into buffer195            actor_update_cent_acts = mask * actor_cent_acts + (1 - mask) * buffer_cent_acts196            actor_Qs = update_policy.critic(stacked_cent_obs, actor_update_cent_acts)197            # use only the first Q output for actor loss198            actor_Qs = actor_Qs[0]199            actor_Qs = actor_Qs * valid_trans_mask200            actor_loss = -(actor_Qs).sum() / (valid_trans_mask).sum()201            update_policy.critic_optimizer.zero_grad()202            update_policy.actor_optimizer.zero_grad()203            actor_loss.backward()204            actor_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.actor.parameters(),205                                                                    self.args.max_grad_norm)206            update_policy.actor_optimizer.step()207            for p in update_policy.critic.parameters():208                p.requires_grad = True209            train_info['actor_loss'] = actor_loss210            train_info['actor_grad_norm'] = actor_grad_norm211            train_info['update_actor'] = update_actor212        return train_info, new_priorities, idxes213    def cent_train_policy_on_batch(self, update_policy_id, batch):214        """Training function when each agent has its own centralized observation. See train_policy_on_batch."""215        obs_batch, cent_obs_batch, \216        act_batch, rew_batch, \217        nobs_batch, cent_nobs_batch, \218        dones_batch, dones_env_batch, valid_transition_batch,\219        avail_act_batch, navail_act_batch, \220        importance_weights, idxes = batch221        train_info = {}222        update_actor = self.num_updates[update_policy_id] % self.actor_update_interval == 0223        cent_act, replace_ind_start, cent_nact = self.get_update_info(224            update_policy_id, obs_batch, act_batch, nobs_batch, navail_act_batch)225        cent_obs = cent_obs_batch[update_policy_id]226        cent_nobs = cent_nobs_batch[update_policy_id]227        rewards = rew_batch[update_policy_id][0]228        dones_env = dones_env_batch[update_policy_id]229        dones = dones_batch[update_policy_id]230        valid_trans = valid_transition_batch[update_policy_id]231        update_policy = self.policies[update_policy_id]232        batch_size = obs_batch[update_policy_id].shape[1]233        num_update_agents = len(self.policy_agents[update_policy_id])234        all_agent_cent_obs = np.concatenate(cent_obs, axis=0)235        all_agent_cent_nobs = np.concatenate(cent_nobs, axis=0)236        # since this is the same for each agent, just repeat when stacking237        cent_act_buffer = np.concatenate(cent_act, axis=-1)238        all_agent_cent_act_buffer = np.tile(cent_act_buffer, (num_update_agents, 1))239        all_agent_cent_nact = np.tile(cent_nact, (num_update_agents, 1))240        all_env_dones = np.tile(dones_env, (num_update_agents, 1))241        all_agent_rewards = np.tile(rewards, (num_update_agents, 1))242        # critic update243        update_policy.critic_optimizer.zero_grad()244        all_agent_rewards = to_torch(all_agent_rewards).to(**self.tpdv).reshape(-1, 1)245        all_env_dones = to_torch(all_env_dones).to(**self.tpdv).reshape(-1, 1)246        all_agent_valid_trans = to_torch(valid_trans).to(**self.tpdv).reshape(-1, 1)247        # critic update248        with torch.no_grad():249            next_step_Q = update_policy.target_critic(all_agent_cent_nobs, all_agent_cent_nact).reshape(-1, 1)250        if self.use_popart:251            target_Qs = all_agent_rewards + self.args.gamma * (1 - all_env_dones) * \252                self.value_normalizer[p_id].denormalize(next_step_Q)253            target_Qs = self.value_normalizer[p_id](target_Qs)254        else:255            target_Qs = all_agent_rewards + self.args.gamma * (1 - all_env_dones) * next_step_Q256        predicted_Qs = update_policy.critic(all_agent_cent_obs, all_agent_cent_act_buffer).reshape(-1, 1)257        error = target_Qs.detach() - predicted_Qs258        if self.use_per:259            agent_importance_weights = np.tile(importance_weights, num_update_agents)260            agent_importance_weights = to_torch(agent_importance_weights).to(**self.tpdv)261            if self.use_huber_loss:262                critic_loss = huber_loss(error, self.huber_delta).flatten()263            else:264                critic_loss = mse_loss(error).flatten()265            # weight each loss element by their importance sample weight266            critic_loss = critic_loss * agent_importance_weights267            if self.use_value_active_masks:268                critic_loss = (critic_loss.view(-1, 1) * (all_agent_valid_trans)).sum() / (all_agent_valid_trans).sum()269            else:270                critic_loss = critic_loss.mean()271            # new priorities are TD error272            agent_new_priorities = error.abs().cpu().detach().numpy().flatten()273            new_priorities = np.mean(np.split(agent_new_priorities, num_update_agents), axis=0) + self.per_eps274        else:275            if self.use_huber_loss:276                critic_loss = huber_loss(error, self.huber_delta)277            else:278                critic_loss = mse_loss(error)279            if self.use_value_active_masks:280                critic_loss = (critic_loss * (all_agent_valid_trans)).sum() / (all_agent_valid_trans).sum()281            else:282                critic_loss = critic_loss.mean()283            new_priorities = None284        critic_loss.backward()285        critic_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.critic.parameters(),286                                                                 self.args.max_grad_norm)287        update_policy.critic_optimizer.step()288        train_info['critic_loss'] = critic_loss289        train_info['critic_grad_norm'] = critic_grad_norm290        # actor update291        if update_actor:292            for p in update_policy.critic.parameters():293                p.requires_grad = False294            num_update_agents = len(self.policy_agents[update_policy_id])295            mask_temp = []296            for p_id in self.policy_ids:297                if isinstance(self.policies[p_id].act_dim, np.ndarray):298                    # multidiscrete case299                    sum_act_dim = int(sum(self.policies[p_id].act_dim))300                else:301                    sum_act_dim = self.policies[p_id].act_dim302                for _ in self.policy_agents[p_id]:303                    mask_temp.append(np.zeros(sum_act_dim, dtype=np.float32))304            masks = []305            valid_trans_mask = []306            # need to iterate through agents, but only formulate masks at each step307            for i in range(num_update_agents):308                curr_mask_temp = copy.deepcopy(mask_temp)309                # set the mask to 1 at locations where the action should come from the actor output310                if isinstance(update_policy.act_dim, np.ndarray):311                    # multidiscrete case312                    sum_act_dim = int(sum(update_policy.act_dim))313                else:314                    sum_act_dim = update_policy.act_dim315                curr_mask_temp[replace_ind_start + i] = np.ones(sum_act_dim, dtype=np.float32)316                curr_mask_vec = np.concatenate(curr_mask_temp)317                # expand this mask into the proper size318                curr_mask = np.tile(curr_mask_vec, (batch_size, 1))319                masks.append(curr_mask)320                # agent valid transitions321                agent_valid_trans_batch = to_torch(valid_transition_batch[update_policy_id][i]).to(**self.tpdv)322                valid_trans_mask.append(agent_valid_trans_batch)323            # cat to form into tensors324            mask = to_torch(np.concatenate(masks)).to(**self.tpdv)325            valid_trans_mask = torch.cat(valid_trans_mask, dim=0)326            pol_agents_obs_batch = np.concatenate(obs_batch[update_policy_id], axis=0)327            if avail_act_batch[update_policy_id] is not None:328                pol_agents_avail_act_batch = np.concatenate(avail_act_batch[update_policy_id], axis=0)329            else:330                pol_agents_avail_act_batch = None331            # get all actions from actor332            pol_acts, _ = update_policy.get_actions(pol_agents_obs_batch, pol_agents_avail_act_batch, use_gumbel=True)333            # separate into individual agent batches334            agent_actor_batches = pol_acts.split(split_size=batch_size, dim=0)335            # cat along final dim to formulate centralized action and stack copies of the batch336            cent_act = list(map(lambda arr: to_torch(arr).to(**self.tpdv), cent_act))337            actor_cent_acts = copy.deepcopy(cent_act)338            for i in range(num_update_agents):339                actor_cent_acts[replace_ind_start + i] = agent_actor_batches[i]340            actor_cent_acts = torch.cat(actor_cent_acts, dim=-1).repeat((num_update_agents, 1))341            # combine the buffer cent acts with actor cent acts and pass into buffer342            actor_update_cent_acts = mask * actor_cent_acts + (1 - mask) * to_torch(all_agent_cent_act_buffer).to(**self.tpdv)343            actor_Qs = update_policy.critic(all_agent_cent_obs, actor_update_cent_acts)344            # actor_loss = -actor_Qs.mean()345            actor_Qs = actor_Qs * valid_trans_mask346            actor_loss = -(actor_Qs).sum() / (valid_trans_mask).sum()347            update_policy.critic_optimizer.zero_grad()348            update_policy.actor_optimizer.zero_grad()349            actor_loss.backward()350            actor_grad_norm = torch.nn.utils.clip_grad_norm_(update_policy.actor.parameters(),351                                                                    self.args.max_grad_norm)352            update_policy.actor_optimizer.step()353            for p in update_policy.critic.parameters():354                p.requires_grad = True355            train_info['actor_loss'] = actor_loss356            train_info['actor_grad_norm'] = actor_grad_norm357        return train_info, new_priorities, idxes358    def prep_training(self):359        """See parent class."""360        for policy in self.policies.values():361            policy.actor.train()362            policy.critic.train()363            policy.target_actor.train()364            policy.target_critic.train()365    def prep_rollout(self):366        """See parent class."""367        for policy in self.policies.values():368            policy.actor.eval()369            policy.critic.eval()370            policy.target_actor.eval()...fasttext.py
Source:fasttext.py  
1import sys2# Python 2/3 compatibility3if sys.version_info.major==3:4    xrange=range5sys.path.insert(0,'./util')6from py2py3 import *7import numpy as np8import tensorflow as tf9sys.path.insert(0,'./model')10from __init__ import *11class fastText(base_net):12    def __init__(self, hyper_params):13        '''14        >>> Construct a fastText model15        >>> hyper_params: dict, a dictionary containing all hyper parameters16            >>> batch_size: int, batch size17            >>> sequence_length: int, maximum sentence length18            >>> class_num: int, number of categories19            >>> vocab_size: int, vocabulary size20            >>> embedding_dim: int, dimension of word embeddings21            >>> update_policy: dict, update policy22            >>> embedding_matrix: optional, numpy.array, initial embedding matrix23            >>> embedding_trainable: optional, bool, whether or not the embedding is trainable, default is true24        '''25        self.batch_size=hyper_params['batch_size']26        self.sequence_length=hyper_params['sequence_length']27        self.class_num=hyper_params['class_num']28        self.vocab_size=hyper_params['vocab_size']29        self.embedding_dim=hyper_params['embedding_dim']30        self.update_policy=hyper_params['update_policy']31        self.embedding_trainable=hyper_params['embedding_trainable'] if 'embedding_trainable' in hyper_params else True32        self.grad_clip_norm=hyper_params['grad_clip_norm'] if 'grad_clip_norm' in hyper_params else 1.033        self.name='fast_text model' if not 'name' in hyper_params else hyper_params['name']34        self.sess=None35        with tf.variable_scope('fastText'):36            if not 'embedding_matrix' in hyper_params:37                print('Word embeddings are initialized from scrach')38                self.embedding_matrix=tf.get_variable('embedding_matrix', shape=[self.vocab_size, self.embedding_dim],39                    initializer=tf.random_uniform_initializer(-1.0,1.0), dtype=tf.float32)40            else:41                print('Pre-trained word embeddings are imported')42                embedding_value=tf.constant(hyper_params['embedding_matrix'], dtype=tf.float32)43                self.embedding_matrix=tf.get_variable('embedding_matrix', initializer=embedding_value, dtype=tf.float32)44        self.inputs=tf.placeholder(tf.int32,shape=[self.batch_size, self.sequence_length])45        self.masks=tf.placeholder(tf.int32,shape=[self.batch_size, self.sequence_length])46        self.labels=tf.placeholder(tf.int32,shape=[self.batch_size,])47        self.embedding_output=tf.nn.embedding_lookup(self.embedding_matrix,self.inputs)             # of shape [self.batch_size, self.sequence_length, self.embedding_dim]48        embedding_sum=tf.reduce_sum(self.embedding_output,axis=1)                                   # of shape [self.batch_size, self.embedding_dim]49        mask_sum=tf.reduce_sum(self.masks,axis=1)                                                   # of shape [self.batch_size,]50        mask_sum=tf.expand_dims(mask_sum,axis=-1)                                                   # of shape [self.batch_size, 1]51        self.sentence_embedding=tf.div(embedding_sum, tf.cast(mask_sum, dtype=tf.float32))                                     # broadcasting mask_sum, the embedding of padded token must be zero52        # Construct softmax classifier53        with tf.variable_scope('fastText'):54            W=tf.get_variable(name='w',shape=[self.embedding_dim,self.class_num],55                initializer=tf.truncated_normal_initializer(stddev=0.5))56            b=tf.get_variable(name='b',shape=[self.class_num,],57                initializer=tf.truncated_normal_initializer(stddev=0.05))58        output=tf.add(tf.matmul(self.sentence_embedding,W),b)                                       # of shape [self.batch_size, self.class_num], unnormalized output probability distribution59        # Outputs60        self.probability=tf.nn.softmax(output)                                                      # of shape [self.batch_size, self.class_num], normalized output probability distribution61        self.prediction=tf.argmax(self.probability,axis=1)                                          # of shape [self.batch_size], discrete prediction62        self.loss=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output,labels=self.labels))  # loss value63        # Construct Optimizer64        if self.update_policy['name'].lower() in ['sgd', 'stochastic gradient descent']:65            learning_rate=self.update_policy['learning_rate']66            momentum=0.0 if not 'momentum' in self.update_policy else self.update_policy['momentum']67            self.optimizer=tf.train.MomentumOptimizer(learning_rate, momentum)68        elif self.update_policy['name'].lower() in ['adagrad',]:69            learning_rate=self.update_policy['learning_rate']70            initial_accumulator_value=0.1 if not 'initial_accumulator_value' in self.update_policy \71                else self.update_policy['initial_accumulator_value']72            self.optimizer=tf.train.AdagradOptimizer(learning_rate, initial_accumulator_value)73        elif self.update_policy['name'].lower() in ['adadelta']:74            learning_rate=self.update_policy['learning_rate']75            rho=0.95 if not 'rho' in self.update_policy else self.update_policy['rho']76            epsilon=1e-8 if not 'epsilon' in self.update_policy else self.update_policy['epsilon']77            self.optimizer=tf.train.AdadeltaOptimizer(learning_rate, rho, epsilon)78        elif self.update_policy['name'].lower() in ['rms', 'rmsprop']:79            learning_rate=self.update_policy['learning_rate']80            decay=0.9 if not 'decay' in self.update_policy else self.update_policy['decay']81            momentum=0.0 if not 'momentum' in self.update_policy else self.update_policy['momentum']82            epsilon=1e-10 if not 'epsilon' in self.update_policy else self.update_policy['epsilon']83            self.optimizer=tf.train.RMSPropOptimizer(learning_rate, decay, momentum, epsilon)84        elif self.update_policy['name'].lower() in ['adam']:85            learning_rate=self.update_policy['learning_rate']86            beta1=0.9 if not 'beta1' in self.update_policy else self.update_policy['beta1']87            beta2=0.999 if not 'beta2' in self.update_policy else self.update_policy['beta2']88            epsilon=1e-8 if not 'epsilon' in self.update_policy else self.update_policy['epsilon']89            self.optimizer=tf.train.AdamOptimizer(learning_rate, beta1, beta2, epsilon)90        else:91            raise ValueError('Unrecognized Optimizer Category: %s'%self.update_policy['name'])92        print('gradient clip is applied, max = %.2f'%self.grad_clip_norm)93        gradients=self.optimizer.compute_gradients(self.loss)94        clipped_gradients=[(tf.clip_by_value(grad,-self.grad_clip_norm,self.grad_clip_norm),var) for grad,var in gradients]95        self.update=self.optimizer.apply_gradients(clipped_gradients)96    def train_validate_test_init(self):97        '''98        >>> Initialize the training validation and test phrase99        '''100        self.sess=tf.Session()101        self.sess.run(tf.global_variables_initializer())102    def train(self,inputs,masks,labels):103        '''104        >>> Training phrase105        '''106        train_dict={self.inputs:inputs,self.masks:masks,self.labels:labels}107        self.sess.run(self.update,feed_dict=train_dict)108        prediction_this_batch, loss_this_batch=self.sess.run([self.prediction,self.loss],feed_dict=train_dict)109        return prediction_this_batch, loss_this_batch110    def validate(self,inputs,masks,labels):111        '''112        >>> Validation phrase113        '''114        validate_dict={self.inputs:inputs,self.masks:masks,self.labels:labels}115        prediction_this_batch, loss_this_batch=self.sess.run([self.prediction,self.loss],feed_dict=validate_dict)116        return prediction_this_batch, loss_this_batch117    def test(self,inputs,masks,fine_tune=False):118        '''119        >>> Test phrase120        '''121        test_dict={self.inputs:inputs,self.masks:masks}122        if fine_tune==False:123            prediction_this_batch,=self.sess.run([self.prediction,],feed_dict=test_dict)124        else:125            prediction_this_batch,=self.sess.run([self.probability,],feed_dict=test_dict)126        return prediction_this_batch127    def do_summarization(self, file_list, folder2store, data_generator, n_top=5):128        '''129        >>> Not implemented130        '''131        raise NotImplementedError('Current model is fasttext, where "do_summarization" function is not implemented')132    def dump_params(self,file2dump):133        '''134        >>> Save the parameters135        >>> file2dump: str, file to store the parameters136        '''137        saver=tf.train.Saver()138        saved_path=saver.save(self.sess, file2dump)139        print('parameters are saved in file %s'%saved_path)140    def load_params(self,file2load, loaded_params=[]):141        '''142        >>> Load the parameters143        >>> file2load: str, file to load the parameters144        '''145        param2load=[]146        for var in tf.global_variables():147            if not var in loaded_params:148                param2load.append(var)149        saver=tf.train.Saver(param2load)150        saver.restore(self.sess, file2load)151        print('parameters are imported from file %s'%file2load)152    def train_validate_test_end(self):153        '''154        >>> End current training validation and test phrase155        '''156        self.sess.close()...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
