Best Python code snippet using robotframework
checks.py
Source:checks.py  
1"""2Decorator for checking input/output arguments of functions.3"""4__all__ = [5    "check_values",6    "check_units",7    "check_relativistic",8    "CheckBase",9    "CheckUnits",10    "CheckValues",11]12import collections13import functools14import inspect15import numpy as np16import warnings17from astropy import units as u18from astropy.constants import c19from functools import reduce20from typing import Any, Dict, List, Tuple, Union21from plasmapy.utils.decorators.helpers import preserve_signature22from plasmapy.utils.exceptions import (23    PlasmaPyWarning,24    RelativityError,25    RelativityWarning,26)27try:28    from astropy.units.equivalencies import Equivalency29except ImportError:30    # TODO: remove once we have dependency Astropy >= 3.2.131    # astropy defined the Equivalency class in v3.2.132    class Equivalency:33        pass34class CheckBase:35    """36    Base class for 'Check' decorator classes.37    Parameters38    ----------39    checks_on_return40        specified checks on the return of the wrapped function41    **checks42        specified checks on the input arguments of the wrapped function43    """44    def __init__(self, checks_on_return=None, **checks):45        self._checks = checks46        if checks_on_return is not None:47            self._checks["checks_on_return"] = checks_on_return48    @property49    def checks(self):50        """51        Requested checks on the decorated function's input arguments52        and/or return.53        """54        return self._checks55class CheckValues(CheckBase):56    """57    A decorator class to 'check' -- limit/control -- the values of input and return58    arguments to a function or method.59    Parameters60    ----------61    checks_on_return: Dict[str, bool]62        Specifications for value checks on the return of the function being wrapped.63        (see `check values`_ for valid specifications)64    **checks: Dict[str, Dict[str, bool]]65        Specifications for value checks on the input arguments of the function66        being wrapped.  Each keyword argument in `checks` is the name of a function67        argument to be checked and the keyword value contains the value check68        specifications.69        .. _`check values`:70        The value check specifications are defined within a dictionary containing71        the keys defined below.  If the dictionary is empty or omitting keys,72        then the default value will be assumed for the missing keys.73        ================ ======= ================================================74        Key              Type    Description75        ================ ======= ================================================76        can_be_negative  `bool`  [DEFAULT `True`] values can be negative77        can_be_complex   `bool`  [DEFAULT `False`] values can be complex numbers78        can_be_inf       `bool`  [DEFAULT `True`] values can be :data:`~numpy.inf`79        can_be_nan       `bool`  [DEFAULT `True`] values can be :data:`~numpy.nan`80        none_shall_pass  `bool`  [DEFAULT `False`] values can be a python `None`81        ================ ======= ================================================82    Notes83    -----84    * Checking of function arguments `*args` and `**kwargs` is not supported.85    Examples86    --------87    .. code-block:: python88        from plasmapy.utils.decorators.checks import CheckValues89        @CheckValues(arg1={'can_be_negative': False, 'can_be_nan': False},90                     arg2={'can_be_inf': False},91                     checks_on_return={'none_shall_pass': True)92        def foo(arg1, arg2):93            return None94        # on a method95        class Foo:96            @CheckValues(arg1={'can_be_negative': False, 'can_be_nan': False},97                         arg2={'can_be_inf': False},98                         checks_on_return={'none_shall_pass': True)99            def bar(self, arg1, arg2):100                return None101    """102    #: Default values for the possible 'check' keys.103    # To add a new check to the class, the following needs to be done:104    #   1. Add a key & default value to the `__check_defaults` dictionary105    #   2. Add a corresponding if-statement to method `_check_value`106    #107    __check_defaults = {108        "can_be_negative": True,109        "can_be_complex": False,110        "can_be_inf": True,111        "can_be_nan": True,112        "none_shall_pass": False,113    }114    def __init__(115        self, checks_on_return: Dict[str, bool] = None, **checks: Dict[str, bool]116    ):117        super().__init__(checks_on_return=checks_on_return, **checks)118    def __call__(self, f):119        """120        Parameters121        ----------122        f123            Function to be wrapped124        Returns125        -------126        function127            wrapped function of `f`128        """129        self.f = f130        wrapped_sign = inspect.signature(f)131        @preserve_signature132        @functools.wraps(f)133        def wrapper(*args, **kwargs):134            # map args and kwargs to function parameters135            bound_args = wrapped_sign.bind(*args, **kwargs)136            bound_args.apply_defaults()137            # get checks138            checks = self._get_value_checks(bound_args)139            # check input arguments140            for arg_name in checks:141                # skip check of output/return142                if arg_name == "checks_on_return":143                    continue144                # check argument145                self._check_value(146                    bound_args.arguments[arg_name], arg_name, checks[arg_name]147                )148            # call function149            _return = f(**bound_args.arguments)150            # check function return151            if "checks_on_return" in checks:152                self._check_value(153                    _return, "checks_on_return", checks["checks_on_return"]154                )155            return _return156        return wrapper157    def _get_value_checks(158        self, bound_args: inspect.BoundArguments159    ) -> Dict[str, Dict[str, bool]]:160        """161        Review :attr:`checks` and function bound arguments to build a complete 'checks'162        dictionary.  If a check key is omitted from the argument checks, then a default163        value is assumed (see `check values`_).164        Parameters165        ----------166        bound_args: :class:`inspect.BoundArguments`167            arguments passed into the function being wrapped168            .. code-block:: python169                bound_args = inspect.signature(f).bind(*args, **kwargs)170        Returns171        -------172        Dict[str, Dict[str, bool]]173            A complete 'checks' dictionary for checking function input arguments174            and return.175        """176        # initialize validation dictionary177        out_checks = {}178        # Iterate through function bound arguments + return and build `out_checks:179        #180        # artificially add "return" to parameters181        things_to_check = bound_args.signature.parameters.copy()182        things_to_check["checks_on_return"] = inspect.Parameter(183            "checks_on_return",184            inspect.Parameter.POSITIONAL_ONLY,185            annotation=bound_args.signature.return_annotation,186        )187        for param in things_to_check.values():188            # variable arguments are NOT checked189            # e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments190            #      *args and **kwargs will NOT be checked191            #192            if param.kind in (193                inspect.Parameter.VAR_KEYWORD,194                inspect.Parameter.VAR_POSITIONAL,195            ):196                continue197            # grab the checks dictionary for the desired parameter198            try:199                param_in_checks = self.checks[param.name]200            except KeyError:201                # checks for parameter not specified202                continue203            # build `out_checks`204            # read checks and/or apply defaults values205            out_checks[param.name] = {}206            for v_name, v_default in self.__check_defaults.items():207                try:208                    out_checks[param.name][v_name] = param_in_checks.get(209                        v_name, v_default210                    )211                except AttributeError:212                    # for the case that checks are defined for an argument,213                    # but is NOT a dictionary214                    # (e.g. CheckValues(x=u.cm) ... this scenario could happen215                    # during subclassing)216                    out_checks[param.name][v_name] = v_default217        # Does `self.checks` indicate arguments not used by f?218        missing_params = [219            param for param in set(self.checks.keys()) - set(out_checks.keys())220        ]221        if len(missing_params) > 0:222            params_str = ", ".join(missing_params)223            warnings.warn(224                PlasmaPyWarning(225                    f"Expected to value check parameters {params_str} but they "226                    f"are missing from the call to {self.f.__name__}"227                )228            )229        return out_checks230    def _check_value(self, arg, arg_name: str, arg_checks: Dict[str, bool]):231        """232        Perform checks `arg_checks` on function argument `arg`.233        Parameters234        ----------235        arg236            The argument to be checked237        arg_name: str238            The name of the argument to be checked239        arg_checks: Dict[str, bool]240            The requested checks for the argument241        Raises242        ------243        ValueError244            raised if a check fails245        """246        if arg_name == "checks_on_return":247            valueerror_msg = f"The return value "248        else:249            valueerror_msg = f"The argument '{arg_name}' "250        valueerror_msg += f"to function {self.f.__name__}() can not contain"251        # check values252        # * 'none_shall_pass' always needs to be checked first253        ckeys = list(self.__check_defaults.keys())254        ckeys.remove("none_shall_pass")255        ckeys = ("none_shall_pass",) + tuple(ckeys)256        for ckey in ckeys:257            if ckey == "none_shall_pass":258                if arg is None and arg_checks[ckey]:259                    break260                elif arg is None:261                    raise ValueError(f"{valueerror_msg} Nones.")262            elif ckey == "can_be_negative":263                if not arg_checks[ckey]:264                    # Allow NaNs through without raising a warning265                    with np.errstate(invalid="ignore"):266                        isneg = np.any(arg < 0)267                    if isneg:268                        raise ValueError(f"{valueerror_msg} negative numbers.")269            elif ckey == "can_be_complex":270                if not arg_checks[ckey] and np.any(np.iscomplexobj(arg)):271                    raise ValueError(f"{valueerror_msg} complex numbers.")272            elif ckey == "can_be_inf":273                if not arg_checks[ckey] and np.any(np.isinf(arg)):274                    raise ValueError(f"{valueerror_msg} infs.")275            elif ckey == "can_be_nan":276                if not arg_checks["can_be_nan"] and np.any(np.isnan(arg)):277                    raise ValueError(f"{valueerror_msg} NaNs.")278class CheckUnits(CheckBase):279    """280    A decorator class to 'check' -- limit/control -- the units of input and return281    arguments to a function or method.282    Parameters283    ----------284    checks_on_return: list of astropy :mod:`~astropy.units` or dict of unit specifications285        Specifications for unit checks on the return of the function being wrapped.286        (see `check units`_ for valid specifications)287    **checks: list of astropy :mod:`~astropy.units` or dict of unit specifications288        Specifications for unit checks on the input arguments of the function289        being wrapped.  Each keyword argument in `checks` is the name of a function290        argument to be checked and the keyword value contains the unit check291        specifications.292        .. _`check units`:293        Unit checks can be defined by passing one of the astropy294        :mod:`~astropy.units`, a list of astropy units, or a dictionary containing295        the keys defined below.  Units can also be defined with function296        annotations, but must be consistent with decorator `**checks` arguments if297        used concurrently. If a key is omitted, then the default value will be assumed.298        ====================== ======= ================================================299        Key                    Type    Description300        ====================== ======= ================================================301        units                          list of desired astropy :mod:`~astropy.units`302        equivalencies                  | [DEFAULT `None`] A list of equivalent pairs to303                                         try if304                                       | the units are not directly convertible.305                                       | (see :mod:`~astropy.units.equivalencies`,306                                         and/or `astropy equivalencies`_)307        pass_equivalent_units  `bool`  | [DEFAULT `False`] allow equivalent units308                                       | to pass309        ====================== ======= ================================================310    Notes311    -----312    * Checking of function arguments `*args` and `**kwargs` is not supported.313    * Decorator does NOT perform any unit conversions.314    * If it is desired that `None` values do not raise errors or warnings, then315      include `None` in the list of units or as a default value for the function316      argument.317    * If units are not specified in `checks`, then the decorator will attempt318      to identify desired units by examining the function annotations.319    Examples320    --------321    Define units with decorator parameters::322        import astropy.units as u323        from plasmapy.utils.decorators import CheckUnits324        @CheckUnits(arg1={'units': u.cm},325                    arg2=u.cm,326                    checks_on_return=[u.cm, u.km])327        def foo(arg1, arg2):328            return arg1 + arg2329        # or on a method330        class Foo:331            @CheckUnits(arg1={'units': u.cm},332                        arg2=u.cm,333                        checks_on_return=[u.cm, u.km])334            def bar(self, arg1, arg2):335                return arg1 + arg2336    Define units with function annotations::337        import astropy.units as u338        from plasmapy.utils.decorators import CheckUnits339        @CheckUnits()340        def foo(arg1: u.cm, arg2: u.cm) -> u.cm:341            return arg1 + arg2342        # or on a method343        class Foo:344            @CheckUnits()345            def bar(self, arg1: u.cm, arg2: u.cm) -> u.cm:346                return arg1 + arg2347    Allow `None` values to pass, on input and output::348        import astropy.units as u349        from plasmapy.utils.decorators import CheckUnits350        @CheckUnits(checks_on_return=[u.cm, None])351        def foo(arg1: u.cm = None):352            return arg1353    Allow return values to have equivalent units::354        import astropy.units as u355        from plasmapy.utils.decorators import CheckUnits356        @CheckUnits(arg1={'units': u.cm},357                    checks_on_return={'units': u.km,358                                      'pass_equivalent_units': True})359        def foo(arg1):360            return arg1361    Allow equivalent units to pass with specified equivalencies::362        import astropy.units as u363        from plasmapy.utils.decorators import CheckUnits364        @CheckUnits(arg1={'units': u.K,365                          'equivalencies': u.temperature_energy(),366                          'pass_equivalent_units': True})367        def foo(arg1):368            return arg1369    .. _astropy equivalencies:370        https://docs.astropy.org/en/stable/units/equivalencies.html371    """372    #: Default values for the possible 'check' keys.373    # To add a new check the the class, the following needs to be done:374    #   1. Add a key & default value to the `__check_defaults` dictionary375    #   2. Add a corresponding conditioning statement to `_get_unit_checks`376    #   3. Add a corresponding behavior to `_check_unit`377    #378    __check_defaults = {379        "units": None,380        "equivalencies": None,381        "pass_equivalent_units": False,382        "none_shall_pass": False,383    }384    def __init__(385        self,386        checks_on_return: Union[u.Unit, List[u.Unit], Dict[str, Any]] = None,387        **checks: Union[u.Unit, List[u.Unit], Dict[str, Any]],388    ):389        super().__init__(checks_on_return=checks_on_return, **checks)390    def __call__(self, f):391        """392        Parameters393        ----------394        f395            Function to be wrapped396        Returns397        -------398        function399            wrapped function of `f`400        """401        self.f = f402        wrapped_sign = inspect.signature(f)403        @preserve_signature404        @functools.wraps(f)405        def wrapper(*args, **kwargs):406            # combine args and kwargs into dictionary407            bound_args = wrapped_sign.bind(*args, **kwargs)408            bound_args.apply_defaults()409            # get checks410            checks = self._get_unit_checks(bound_args)411            # check (input) argument units412            for arg_name in checks:413                # skip check of output/return414                if arg_name == "checks_on_return":415                    continue416                # check argument417                self._check_unit(418                    bound_args.arguments[arg_name], arg_name, checks[arg_name]419                )420            # call function421            _return = f(**bound_args.arguments)422            # check output423            if "checks_on_return" in checks:424                self._check_unit(425                    _return, "checks_on_return", checks["checks_on_return"]426                )427            return _return428        return wrapper429    def _get_unit_checks(430        self, bound_args: inspect.BoundArguments431    ) -> Dict[str, Dict[str, Any]]:432        """433        Review :attr:`checks` and function bound arguments to build a complete 'checks'434        dictionary.  If a check key is omitted from the argument checks, then a default435        value is assumed (see `check units`_)436        Parameters437        ----------438        bound_args: :class:`inspect.BoundArguments`439            arguments passed into the function being wrapped440            .. code-block:: python441                bound_args = inspect.signature(f).bind(*args, **kwargs)442        Returns443        -------444        Dict[str, Dict[str, Any]]445            A complete 'checks' dictionary for checking function input arguments446            and return.447        """448        # initialize validation dictionary449        out_checks = {}450        # Iterate through function bound arguments + return and build `out_checks`:451        #452        # artificially add "return" to parameters453        things_to_check = bound_args.signature.parameters.copy()454        things_to_check["checks_on_return"] = inspect.Parameter(455            "checks_on_return",456            inspect.Parameter.POSITIONAL_ONLY,457            annotation=bound_args.signature.return_annotation,458        )459        for param in things_to_check.values():460            # variable arguments are NOT checked461            # e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments462            #      *args and **kwargs will NOT be checked463            #464            if param.kind in (465                inspect.Parameter.VAR_KEYWORD,466                inspect.Parameter.VAR_POSITIONAL,467            ):468                continue469            # grab the checks dictionary for the desired parameter470            try:471                param_checks = self.checks[param.name]472            except KeyError:473                param_checks = None474            # -- Determine target units `_units` --475            # target units can be defined in one of three ways (in476            # preferential order):477            #   1. direct keyword pass-through478            #      i.e. CheckUnits(x=u.cm)479            #           CheckUnits(x=[u.cm, u.s])480            #   2. keyword pass-through via dictionary definition481            #      i.e. CheckUnits(x={'units': u.cm})482            #           CheckUnits(x={'units': [u.cm, u.s]})483            #   3. function annotations484            #485            # * if option (3) is used simultaneously with option (1) or (2), then486            #   checks defined by (3) must be consistent with checks from (1) or (2)487            #   to avoid raising an error.488            # * if None is included in the units list, then None values are allowed489            #490            _none_shall_pass = False491            _units = None492            _units_are_from_anno = False493            if param_checks is not None:494                # checks for argument were defined with decorator495                try:496                    _units = param_checks["units"]497                except TypeError:498                    # if checks is NOT None and is NOT a dictionary, then assume499                    # only units were specified500                    #   e.g. CheckUnits(x=u.cm)501                    #502                    _units = param_checks503                except KeyError:504                    # if checks does NOT have 'units' but is still a dictionary,505                    # then other check conditions may have been specified and the506                    # user is relying on function annotations to define desired507                    # units508                    _units = None509            # If no units have been specified by decorator checks, then look for510            # function annotations.511            #512            # Reconcile units specified by decorator checks and function annotations513            _units_anno = None514            if param.annotation is not inspect.Parameter.empty:515                # unit annotations defined516                _units_anno = param.annotation517            if _units is None and _units_anno is None and param_checks is None:518                # no checks specified and no unit annotations defined519                continue520            elif _units is None and _units_anno is None:521                # checks specified, but NO unit checks522                msg = f"No astropy.units specified for "523                if param.name == "checks_on_return":524                    msg += f"return value "525                else:526                    msg += f"argument {param.name} "527                msg += f"of function {self.f.__name__}()."528                raise ValueError(msg)529            elif _units is None:530                _units = _units_anno531                _units_are_from_anno = True532                _units_anno = None533            # Ensure `_units` is an iterable534            if not isinstance(_units, collections.abc.Iterable):535                _units = [_units]536            if not isinstance(_units_anno, collections.abc.Iterable):537                _units_anno = [_units_anno]538            # Is None allowed?539            if None in _units or param.default is None:540                _none_shall_pass = True541            # Remove Nones542            if None in _units:543                _units = [t for t in _units if t is not None]544            if None in _units_anno:545                _units_anno = [t for t in _units_anno if t is not None]546            # ensure all _units are astropy.units.Unit or physical types &547            # define 'units' for unit checks &548            # define 'none_shall_pass' check549            _units = self._condition_target_units(550                _units, from_annotations=_units_are_from_anno551            )552            _units_anno = self._condition_target_units(553                _units_anno, from_annotations=True554            )555            if not all(_u in _units for _u in _units_anno):556                raise ValueError(557                    f"For argument '{param.name}', "558                    f"annotation units ({_units_anno}) are not included in the units "559                    f"specified by decorator arguments ({_units}).  Use either "560                    f"decorator arguments or function annotations to defined unit "561                    f"types, or make sure annotation specifications match decorator "562                    f"argument specifications."563                )564            if len(_units) == 0 and len(_units_anno) == 0 and param_checks is None:565                # annotations did not specify units566                continue567            elif len(_units) == 0 and len(_units_anno) == 0:568                # checks specified, but NO unit checks569                msg = f"No astropy.units specified for "570                if param.name == "checks_on_return":571                    msg += f"return value "572                else:573                    msg += f"argument {param.name} "574                msg += f"of function {self.f.__name__}()."575                raise ValueError(msg)576            out_checks[param.name] = {577                "units": _units,578                "none_shall_pass": _none_shall_pass,579            }580            # -- Determine target equivalencies --581            # Unit equivalences can be defined by:582            # 1. keyword pass-through via dictionary definition583            #    e.g. CheckUnits(x={'units': u.C,584            #                       'equivalencies': u.temperature})585            #586            # initialize equivalencies587            try:588                _equivs = param_checks["equivalencies"]589            except (KeyError, TypeError):590                _equivs = self.__check_defaults["equivalencies"]591            # ensure equivalences are properly formatted592            if _equivs is None or _equivs == [None]:593                _equivs = None594            elif isinstance(_equivs, Equivalency):595                pass596            elif isinstance(_equivs, (list, tuple)):597                # flatten list to non-list elements598                if isinstance(_equivs, tuple):599                    _equivs = [_equivs]600                else:601                    _equivs = self._flatten_equivalencies_list(_equivs)602                # ensure passed equivalencies list is structured properly603                #   [(), ...]604                #   or [Equivalency(), ...]605                #606                # * All equivalencies must be a list of 2, 3, or 4 element tuples607                #   structured like...608                #     (from_unit, to_unit, forward_func, backward_func)609                #610                if all(isinstance(el, Equivalency) for el in _equivs):611                    _equivs = reduce(lambda x, y: x + y, _equivs)612                else:613                    _equivs = self._normalize_equivalencies(_equivs)614            out_checks[param.name]["equivalencies"] = _equivs615            # -- Determine if equivalent units pass --616            try:617                peu = param_checks.get(618                    "pass_equivalent_units",619                    self.__check_defaults["pass_equivalent_units"],620                )621            except (AttributeError, TypeError):622                peu = self.__check_defaults["pass_equivalent_units"]623            out_checks[param.name]["pass_equivalent_units"] = peu624        # Does `self.checks` indicate arguments not used by f?625        missing_params = [626            param for param in set(self.checks.keys()) - set(out_checks.keys())627        ]628        if len(missing_params) > 0:629            params_str = ", ".join(missing_params)630            warnings.warn(631                PlasmaPyWarning(632                    f"Expected to unit check parameters {params_str} but they "633                    f"are missing from the call to {self.f.__name__}"634                )635            )636        return out_checks637    def _check_unit(self, arg, arg_name: str, arg_checks: Dict[str, Any]):638        """639        Perform unit checks `arg_checks` on function argument `arg`.640        Parameters641        ----------642        arg643            The argument to be checked644        arg_name: str645            The name of the argument to be checked646        arg_checks: Dict[str, Any]647            The requested checks for the argument648        Raises649        ------650        ValueError651            If `arg` is `None` when `arg_checks['none_shall_pass']=False`652        TypeError653            If `arg` does not have `units`654        :class:`astropy.units.UnitTypeError`655            If the units of `arg` do not satisfy conditions of `arg_checks`656        """657        arg, unit, equiv, err = self._check_unit_core(arg, arg_name, arg_checks)658        if err is not None:659            raise err660    def _check_unit_core(661        self, arg, arg_name: str, arg_checks: Dict[str, Any]662    ) -> Tuple[663        Union[None, u.Quantity],664        Union[None, u.Unit],665        Union[None, List[Any]],666        Union[None, Exception],667    ]:668        """669        Determines if `arg` passes unit checks `arg_checks` and if the units of670        `arg` is equivalent to any units specified in `arg_checks`.671        Parameters672        ----------673        arg674            The argument to be checked675        arg_name: str676            The name of the argument to be checked677        arg_checks: Dict[str, Any]678            The requested checks for the argument679        Returns680        -------681        (`arg`, `unit`, `equivalencies`, `error`)682            * `arg` is the original input argument `arg` or `None` if unit683              checks fail684            * `unit` is the identified astropy :mod:`~astropy.units` that `arg`685              can be converted to or `None` if none exist686            * `equivalencies` is the astropy :mod:`~astropy.units.equivalencies`687              used for the unit conversion or `None`688            * `error` is the `Exception` associated with the failed unit checks689              or `None` for successful unit checks690        """691        # initialize str for error messages692        if arg_name == "checks_on_return":693            err_msg = f"The return value "694        else:695            err_msg = f"The argument '{arg_name}' "696        err_msg += f"to function {self.f.__name__}()"697        # initialize ValueError message698        valueerror_msg = f"{err_msg} can not contain"699        # initialize TypeError message700        typeerror_msg = f"{err_msg} should be an astropy Quantity with "701        if len(arg_checks["units"]) == 1:702            typeerror_msg += f"the following unit: {arg_checks['units'][0]}"703        else:704            typeerror_msg += "one of the following units: "705            for unit in arg_checks["units"]:706                typeerror_msg += str(unit)707                if unit != arg_checks["units"][-1]:708                    typeerror_msg += ", "709        if arg_checks["none_shall_pass"]:710            typeerror_msg += "or None "711        # pass Nones if allowed712        if arg is None:713            if arg_checks["none_shall_pass"]:714                return arg, None, None, None715            else:716                return None, None, None, ValueError(f"{valueerror_msg} Nones")717        # check units718        in_acceptable_units = []719        equiv = arg_checks["equivalencies"]720        for unit in arg_checks["units"]:721            try:722                in_acceptable_units.append(723                    arg.unit.is_equivalent(unit, equivalencies=equiv)724                )725            except AttributeError:726                if hasattr(arg, "unit"):727                    err_specifier = (728                        "a 'unit' attribute without an 'is_equivalent' method"729                    )730                else:731                    err_specifier = "no 'unit' attribute"732                msg = (733                    f"{err_msg} has {err_specifier}. "734                    f"Use an astropy Quantity instead."735                )736                return None, None, None, TypeError(msg)737        # How many acceptable units?738        nacceptable = np.count_nonzero(in_acceptable_units)739        unit = None740        equiv = None741        err = None742        if nacceptable == 0:743            # NO equivalent units744            arg = None745            err = u.UnitTypeError(typeerror_msg)746        else:747            # is there an exact match?748            units_arr = np.array(arg_checks["units"])749            units_equal_mask = np.equal(units_arr, arg.unit)750            units_mask = np.logical_and(units_equal_mask, in_acceptable_units)751            if np.count_nonzero(units_mask) == 1:752                # matched exactly to a desired unit753                unit = units_arr[units_mask][0]754                equiv = arg_checks["equivalencies"]755            elif nacceptable == 1:756                # there is a match to 1 equivalent unit757                unit = units_arr[in_acceptable_units][0]758                equiv = arg_checks["equivalencies"]759                if not arg_checks["pass_equivalent_units"]:760                    err = u.UnitTypeError(typeerror_msg)761            elif arg_checks["pass_equivalent_units"]:762                # there is a match to more than one equivalent units763                pass764            else:765                # there is a match to more than 1 equivalent units766                arg = None767                err = u.UnitTypeError(typeerror_msg)768        return arg, unit, equiv, err769    @staticmethod770    def _condition_target_units(targets: List, from_annotations: bool = False):771        """772        From a list of target units (either as a string or astropy773        :class:`~astropy.units.Unit` objects), return a list of conditioned774        :class:`~astropy.units.Unit` objects.775        Parameters776        ----------777        targets: list of target units778            list of units (either as a string or :class:`~astropy.units.Unit`)779            to be conditioned into astropy :class:`~astropy.units.Unit` objects780        from_annotations: bool781            (Default `False`) Indicates if `targets` originated from function/method782            annotations versus decorator input arguments.783        Returns784        -------785        list:786            list of `targets` converted into astropy787            :class:`~astropy.units.Unit` objects788        Raises789        ------790        TypeError791            If `target` is not a valid type for :class:`~astropy.units.Unit` when792            `from_annotations == True`,793        ValueError794            If a `target` is a valid unit type but not a valid value for795            :class:`~astropy.units.Unit`.796        """797        # Note: this method does not allow for astropy physical types. This is798        #       done because we expect all use cases of CheckUnits to define the799        #       exact units desired.800        #801        allowed_units = []802        for target in targets:803            try:804                target_unit = u.Unit(target)805                allowed_units.append(target_unit)806            except TypeError as err:807                # not a unit type808                if not from_annotations:809                    raise err810                continue811        return allowed_units812    @staticmethod813    def _normalize_equivalencies(equivalencies):814        """815        Normalizes equivalencies to ensure each is in a 4-tuple form::816            (from_unit, to_unit, forward_func, backward_func)817        `forward_func` maps `from_unit` into `to_unit` and `backward_func` does818        the reverse.819        Parameters820        ----------821        equivalencies: list of equivalent pairs822            list of astropy :mod:`~astropy.units.equivalencies` to be normalized823        Raises824        ------825        ValueError826            if an equivalency can not be interpreted827        Notes828        -----829        * the code here was copied and modified from830          :func:`astropy.units.core._normalize_equivalencies` from AstroPy831          version 3.2.3832        * this will work on both the old style list equivalencies (pre AstroPy v3.2.1)833          and the modern equivalencies defined with the834          :class:`~astropy.units.equivalencies.Equivalency` class835        """836        if equivalencies is None:837            return []838        normalized = []839        for i, equiv in enumerate(equivalencies):840            if len(equiv) == 2:841                from_unit, to_unit = equiv842                a = b = lambda x: x843            elif len(equiv) == 3:844                from_unit, to_unit, a = equiv845                b = a846            elif len(equiv) == 4:847                from_unit, to_unit, a, b = equiv848            else:849                raise ValueError(f"Invalid equivalence entry {i}: {equiv!r}")850            if not (851                from_unit is u.Unit(from_unit)852                and (to_unit is None or to_unit is u.Unit(to_unit))853                and callable(a)854                and callable(b)855            ):856                raise ValueError(f"Invalid equivalence entry {i}: {equiv!r}")857            normalized.append((from_unit, to_unit, a, b))858        return normalized859    def _flatten_equivalencies_list(self, elist):860        """861        Given a list of equivalencies, flatten out any sub-element lists862        Parameters863        ----------864        elist: list865            list of astropy :mod:`~astropy.units.equivalencies` to be flattened866        Returns867        -------868        list869            a flattened list of astropy :mod:`~astropy.units.equivalencies`870        """871        new_list = []872        for el in elist:873            if not isinstance(el, list):874                new_list.append(el)875            else:876                new_list.extend(self._flatten_equivalencies_list(el))877        return new_list878def check_units(879    func=None, checks_on_return: Dict[str, Any] = None, **checks: Dict[str, Any]880):881    """882    A decorator to 'check' -- limit/control -- the units of input and return883    arguments to a function or method.884    Parameters885    ----------886    func:887        The function to be decorated888    checks_on_return: list of astropy :mod:`~astropy.units` or dict of unit specifications889        Specifications for unit checks on the return of the function being wrapped.890        (see `check units`_ for valid specifications)891    **checks: list of astropy :mod:`~astropy.units` or dict of unit specifications892        Specifications for unit checks on the input arguments of the function893        being wrapped.  Each keyword argument in `checks` is the name of a function894        argument to be checked and the keyword value contains the unit check895        specifications.896        .. _`check units`:897        Unit checks can be defined by passing one of the astropy898        :mod:`~astropy.units`, a list of astropy units, or a dictionary containing899        the keys defined below.  Units can also be defined with function900        annotations, but must be consistent with decorator `**checks` arguments if901        used concurrently. If a key is omitted, then the default value will be assumed.902        ====================== ======= ================================================903        Key                    Type    Description904        ====================== ======= ================================================905        units                          list of desired astropy :mod:`~astropy.units`906        equivalencies                  | [DEFAULT `None`] A list of equivalent pairs to907                                         try if908                                       | the units are not directly convertible.909                                       | (see :mod:`~astropy.units.equivalencies`,910                                         and/or `astropy equivalencies`_)911        pass_equivalent_units  `bool`  | [DEFAULT `False`] allow equivalent units912                                       | to pass913        ====================== ======= ================================================914    Notes915    -----916    * Checking of function arguments `*args` and `**kwargs` is not supported.917    * Decorator does NOT perform any unit conversions, look to918      :func:`~plasmapy.utils.decorators.validate_quantities` if that functionality is919      desired.920    * If it is desired that `None` values do not raise errors or warnings, then921      include `None` in the list of units or as a default value for the function922      argument.923    * If units are not specified in `checks`, then the decorator will attempt924      to identify desired units by examining the function annotations.925    * Full functionality is defined by the class :class:`CheckUnits`.926    Examples927    --------928    Define units with decorator parameters::929        import astropy.units as u930        from plasmapy.utils.decorators import check_units931        @check_units(arg1={'units': u.cm},932                     arg2=u.cm,933                     checks_on_return=[u.cm, u.km])934        def foo(arg1, arg2):935            return arg1 + arg2936        # or on a method937        class Foo:938            @check_units(arg1={'units': u.cm},939                         arg2=u.cm,940                         checks_on_return=[u.cm, u.km])941            def bar(self, arg1, arg2):942                return arg1 + arg2943    Define units with function annotations::944        import astropy.units as u945        from plasmapy.utils.decorators import check_units946        @check_units947        def foo(arg1: u.cm, arg2: u.cm) -> u.cm:948            return arg1 + arg2949        # or on a method950        class Foo:951            @check_units952            def bar(self, arg1: u.cm, arg2: u.cm) -> u.cm:953                return arg1 + arg2954    Allow `None` values to pass::955        import astropy.units as u956        from plasmapy.utils.decorators import check_units957        @check_units(checks_on_return=[u.cm, None])958        def foo(arg1: u.cm = None):959            return arg1960    Allow return values to have equivalent units::961        import astropy.units as u962        from plasmapy.utils.decorators import check_units963        @check_units(arg1={'units': u.cm},964                     checks_on_return={'units': u.km,965                                       'pass_equivalent_units': True})966        def foo(arg1):967            return arg1968    Allow equivalent units to pass with specified equivalencies::969        import astropy.units as u970        from plasmapy.utils.decorators import check_units971        @check_units(arg1={'units': u.K,972                           'equivalencies': u.temperature(),973                           'pass_equivalent_units': True})974        def foo(arg1):975            return arg1976    .. _astropy equivalencies:977        https://docs.astropy.org/en/stable/units/equivalencies.html978    """979    if checks_on_return is not None:980        checks["checks_on_return"] = checks_on_return981    if func is not None:982        # `check_units` called as a function983        return CheckUnits(**checks)(func)984    else:985        # `check_units` called as a decorator "sugar-syntax"986        return CheckUnits(**checks)987def check_values(988    func=None, checks_on_return: Dict[str, bool] = None, **checks: Dict[str, bool]989):990    """991    A decorator to 'check' -- limit/control -- the values of input and return992    arguments to a function or method.993    Parameters994    ----------995    func:996        The function to be decorated997    checks_on_return: Dict[str, bool]998        Specifications for value checks on the return of the function being wrapped.999        (see `check values`_ for valid specifications)1000    **checks: Dict[str, Dict[str, bool]]1001        Specifications for value checks on the input arguments of the function1002        being wrapped.  Each keyword argument in `checks` is the name of a function1003        argument to be checked and the keyword value contains the value check1004        specifications.1005        .. _`check values`:1006        The value check specifications are defined within a dictionary containing1007        the keys defined below.  If the dictionary is empty or omitting keys,1008        then the default value will be assumed for the missing keys.1009        ================ ======= ================================================1010        Key              Type    Description1011        ================ ======= ================================================1012        can_be_negative  `bool`  [DEFAULT `True`] values can be negative1013        can_be_complex   `bool`  [DEFAULT `False`] values can be complex numbers1014        can_be_inf       `bool`  [DEFAULT `True`] values can be :data:`~numpy.inf`1015        can_be_nan       `bool`  [DEFAULT `True`] values can be :data:`~numpy.nan`1016        none_shall_pass  `bool`  [DEFAULT `False`] values can be a python `None`1017        ================ ======= ================================================1018    Notes1019    -----1020    * Checking of function arguments `*args` and `**kwargs` is not supported.1021    * Full functionality is defined by the class :class:`CheckValues`.1022    Examples1023    --------1024    .. code-block:: python1025        from plasmapy.utils.decorators import check_values1026        @check_values(arg1={'can_be_negative': False, 'can_be_nan': False},1027                      arg2={'can_be_inf': False},1028                      checks_on_return={'none_shall_pass': True)1029        def foo(arg1, arg2):1030            return None1031        # on a method1032        class Foo:1033            @check_values(arg1={'can_be_negative': False, 'can_be_nan': False},1034                          arg2={'can_be_inf': False},1035                          checks_on_return={'none_shall_pass': True)1036            def bar(self, arg1, arg2):1037                return None1038    """1039    if checks_on_return is not None:1040        checks["checks_on_return"] = checks_on_return1041    if func is not None:1042        # `check_values` called as a function1043        return CheckValues(**checks)(func)1044    else:1045        # `check_values` called as a decorator "sugar-syntax"1046        return CheckValues(**checks)1047def check_relativistic(func=None, betafrac=0.05):1048    r"""1049    Warns or raises an exception when the output of the decorated1050    function is greater than `betafrac` times the speed of light.1051    Parameters1052    ----------1053    func : `function`, optional1054        The function to decorate.1055    betafrac : float, optional1056        The minimum fraction of the speed of light that will raise a1057        `~plasmapy.utils.RelativityWarning`. Defaults to 5%.1058    Returns1059    -------1060    function1061        Decorated function.1062    Raises1063    ------1064    TypeError1065        If `V` is not a `~astropy.units.Quantity`.1066    ~astropy.units.UnitConversionError1067        If `V` is not in units of velocity.1068    ValueError1069        If `V` contains any `~numpy.nan` values.1070    ~plasmapy.utils.exceptions.RelativityError1071        If `V` is greater than or equal to the speed of light.1072    Warns1073    -----1074    : `~plasmapy.utils.exceptions.RelativityWarning`1075        If `V` is greater than or equal to `betafrac` times the speed of light,1076        but less than the speed of light.1077    Examples1078    --------1079    >>> from astropy import units as u1080    >>> @check_relativistic1081    ... def speed():1082    ...     return 1 * u.m / u.s1083    Passing in a custom `betafrac`:1084    >>> @check_relativistic(betafrac=0.01)1085    ... def speed():1086    ...     return 1 * u.m / u.s1087    """1088    def decorator(f):1089        @preserve_signature1090        @functools.wraps(f)1091        def wrapper(*args, **kwargs):1092            return_ = f(*args, **kwargs)1093            _check_relativistic(return_, f.__name__, betafrac=betafrac)1094            return return_1095        return wrapper1096    if func:1097        return decorator(func)1098    return decorator1099def _check_relativistic(V, funcname, betafrac=0.05):1100    r"""1101    Warn or raise error for relativistic or superrelativistic1102    velocities.1103    Parameters1104    ----------1105    V : ~astropy.units.Quantity1106        A velocity.1107    funcname : str1108        The name of the original function to be printed in the error1109        messages.1110    betafrac : float, optional1111        The minimum fraction of the speed of light that will generate1112        a warning. Defaults to 5%.1113    Raises1114    ------1115    TypeError1116        If `V` is not a `~astropy.units.Quantity`.1117    ~astropy.units.UnitConversionError1118        If `V` is not in units of velocity.1119    ValueError1120        If `V` contains any `~numpy.nan` values.1121    RelativityError1122        If `V` is greater than or equal to the speed of light.1123    Warns1124    -----1125    ~plasmapy.utils.RelativityWarning1126        If `V` is greater than or equal to the specified fraction of the1127        speed of light.1128    Examples1129    --------1130    >>> from astropy import units as u1131    >>> _check_relativistic(1*u.m/u.s, 'function_calling_this')1132    """1133    # TODO: Replace `funcname` with func.__name__?1134    errmsg = "V must be a Quantity with units of velocity in _check_relativistic"1135    if not isinstance(V, u.Quantity):1136        raise TypeError(errmsg)1137    try:1138        V_over_c = (V / c).to_value(u.dimensionless_unscaled)1139    except Exception:1140        raise u.UnitConversionError(errmsg)1141    beta = np.max(np.abs((V_over_c)))1142    if beta == np.inf:1143        raise RelativityError(f"{funcname} is yielding an infinite velocity.")1144    elif beta >= 1:1145        raise RelativityError(1146            f"{funcname} is yielding a velocity that is {str(round(beta, 3))} "1147            f"times the speed of light."1148        )1149    elif beta >= betafrac:1150        warnings.warn(1151            f"{funcname} is yielding a velocity that is "1152            f"{str(round(beta * 100, 3))}% of the speed of "1153            f"light. Relativistic effects may be important.",1154            RelativityWarning,...test_checks.py
Source:test_checks.py  
1"""2Tests for 'check` decorators (i.e. decorators that only check objects but do not3change them).4"""5import inspect6import numpy as np7import pytest8from astropy import units as u9from astropy.constants import c10from types import LambdaType11from typing import Any, Dict12from unittest import mock13from plasmapy.utils.decorators.checks import (14    _check_relativistic,15    check_relativistic,16    check_units,17    check_values,18    CheckBase,19    CheckUnits,20    CheckValues,21)22from plasmapy.utils.exceptions import (23    PlasmaPyWarning,24    RelativityError,25    RelativityWarning,26)27# ----------------------------------------------------------------------------------------28# Test Decorator class `CheckBase`29# ----------------------------------------------------------------------------------------30class TestCheckBase:31    """32    Test for decorator class :class:`~plasmapy.utils.decorators.checks.CheckBase`.33    """34    def test_for_members(self):35        assert hasattr(CheckUnits, "checks")36    def test_checks(self):37        _cases = [38            {"input": (None, {"x": 1, "y": 2}), "output": {"x": 1, "y": 2}},39            {40                "input": (6, {"x": 1, "y": 2}),41                "output": {"x": 1, "y": 2, "checks_on_return": 6},42            },43        ]44        for case in _cases:45            cb = CheckBase(checks_on_return=case["input"][0], **case["input"][1])46            assert cb.checks == case["output"]47# ----------------------------------------------------------------------------------------48# Test Decorator class `CheckValues` and decorator `check_values`49# ----------------------------------------------------------------------------------------50class TestCheckUnits:51    """52    Tests for decorator :func:`~plasmapy.utils.decorators.checks.check_units` and53    decorator class :class:`~plasmapy.utils.decorators.checks.CheckUnits`.54    """55    check_defaults = CheckUnits._CheckUnits__check_defaults  # type: Dict[str, Any]56    @staticmethod57    def foo_no_anno(x, y):58        return x + y59    @staticmethod60    def foo_partial_anno(x: u.Quantity, y: u.cm) -> u.Quantity:61        return x.value + y.value62    @staticmethod63    def foo_return_anno(x, y) -> u.um:64        return x.value + y.value65    @staticmethod66    def foo_stars(x: u.Quantity, *args, y=3 * u.cm, **kwargs):67        return x.value + y.value68    @staticmethod69    def foo_with_none(x: u.Quantity, y: u.cm = None):70        return x.value + y.value71    def test_inheritance(self):72        assert issubclass(CheckUnits, CheckBase)73    def test_cu_default_check_values(self):74        """Test the default check dictionary for CheckUnits."""75        cu = CheckUnits()76        assert hasattr(cu, "_CheckUnits__check_defaults")77        assert isinstance(cu._CheckUnits__check_defaults, dict)78        _defaults = [79            ("units", None),80            ("equivalencies", None),81            ("pass_equivalent_units", False),82            ("none_shall_pass", False),83        ]84        for key, val in _defaults:85            assert cu._CheckUnits__check_defaults[key] == val86    def test_cu_method__flatten_equivalencies_list(self):87        assert hasattr(CheckUnits, "_flatten_equivalencies_list")88        cu = CheckUnits()89        pairs = [([1, 2, 4], [1, 2, 4]), ([1, 2, (3, 4), [5, 6]], [1, 2, (3, 4), 5, 6])]90        for pair in pairs:91            assert cu._flatten_equivalencies_list(pair[0]) == pair[1]92    def test_cu_method__condition_target_units(self):93        """Test method `CheckUnits._condition_target_units`."""94        assert hasattr(CheckUnits, "_condition_target_units")95        cu = CheckUnits()96        targets = ["cm", u.km, u.Quantity, float]97        conditioned_targets = [u.cm, u.km]98        with pytest.raises(TypeError):99            cu._condition_target_units(targets)100        assert (101            cu._condition_target_units(targets, from_annotations=True)102            == conditioned_targets103        )104        with pytest.raises(ValueError):105            cu._condition_target_units(["five"])106    def test_cu_method__normalize_equivalencies(self):107        """Test method `CheckUnits._normalize_equivalencies`."""108        assert hasattr(CheckUnits, "_normalize_equivalencies")109        cu = CheckUnits()110        assert cu._normalize_equivalencies(None) == []111        # 2 element equivalency112        norme = cu._normalize_equivalencies([(u.cm, u.cm)])113        assert len(norme) == 1114        assert isinstance(norme[0], tuple)115        assert len(norme[0]) == 4116        assert norme[0][0] == norme[0][1]117        assert norme[0][2] == norme[0][3]118        assert isinstance(norme[0][2], LambdaType)119        assert norme[0][1] == norme[0][2](norme[0][0])120        assert norme[0][0] == norme[0][3](norme[0][1])121        # 3 element equivalency122        norme = cu._normalize_equivalencies([(u.cm, u.cm, lambda x: x)])123        assert len(norme) == 1124        assert isinstance(norme[0], tuple)125        assert len(norme[0]) == 4126        assert norme[0][0] == norme[0][1]127        assert norme[0][2] == norme[0][3]128        assert isinstance(norme[0][2], LambdaType)129        assert norme[0][1] == norme[0][2](norme[0][0])130        assert norme[0][0] == norme[0][3](norme[0][1])131        # 3 element equivalency132        norme = cu._normalize_equivalencies(133            [(u.K, u.deg_C, lambda x: x - 273.15, lambda x: x + 273.15)]134        )135        assert len(norme) == 1136        assert isinstance(norme[0], tuple)137        assert len(norme[0]) == 4138        assert norme[0][0] == u.K139        assert norme[0][1] == u.deg_C140        assert isinstance(norme[0][2], LambdaType)141        assert isinstance(norme[0][3], LambdaType)142        for val in [-20.0, 50.0, 195.0]:143            assert norme[0][2](val) == (lambda x: x - 273.15)(val)144            assert norme[0][3](val) == (lambda x: x + 273.15)(val)145        # not a 2, 3, or 4-tuple146        with pytest.raises(ValueError):147            cu._normalize_equivalencies([(u.cm,)])148        # input is not a astropy.unit.Unit149        with pytest.raises(ValueError):150            cu._normalize_equivalencies([("cm", u.cm)])151    def test_cu_method__get_unit_checks(self):152        """153        Test functionality/behavior of the method `_get_unit_checks` on `CheckUnits`.154        This method reviews the decorator `checks` arguments and wrapped function155        annotations to build a complete checks dictionary.156        """157        # methods must exist158        assert hasattr(CheckUnits, "_get_unit_checks")159        # setup default checks160        default_checks = {161            **self.check_defaults.copy(),162            "units": [self.check_defaults["units"]],163        }164        # setup test cases165        # 'setup' = arguments for `_get_unit_checks`166        # 'output' = expected return from `_get_unit_checks`167        # 'raises' = if `_get_unit_checks` raises an Exception168        # 'warns' = if `_get_unit_checks` issues a warning169        #170        equivs = [171            # list of astropy Equivalency objects172            [u.temperature_energy(), u.temperature()],173            # list of equivalencies (pre astropy v3.2.1 style)174            list(u.temperature()),175        ]176        _cases = [177            {178                "descr": "x units are defined via decorator kwarg of CheckUnits\n"179                "y units are defined via decorator annotations, additional\n"180                "  checks thru CheckUnits kwarg",181                "setup": {182                    "function": self.foo_partial_anno,183                    "args": (2 * u.cm, 3 * u.cm),184                    "kwargs": {},185                    "checks": {"x": {"units": [u.cm], "equivalencies": equivs[0][0]}},186                },187                "output": {188                    "x": {"units": [u.cm], "equivalencies": equivs[0][0]},189                    "y": {"units": [u.cm]},190                },191            },192            {193                "descr": "x units are defined via decorator kwarg of CheckUnits\n"194                "y units are defined via function annotations, additional\n"195                "  checks thru CheckUnits kwarg",196                "setup": {197                    "function": self.foo_partial_anno,198                    "args": (2 * u.cm, 3 * u.cm),199                    "kwargs": {},200                    "checks": {201                        "x": {"units": [u.cm], "equivalencies": equivs[0]},202                        "y": {"pass_equivalent_units": False},203                    },204                },205                "output": {206                    "x": {207                        "units": [u.cm],208                        "equivalencies": equivs[0][0] + equivs[0][1],209                    },210                    "y": {"units": [u.cm], "pass_equivalent_units": False},211                },212            },213            {214                "descr": "equivalencies are a list instead of astropy Equivalency objects",215                "setup": {216                    "function": self.foo_no_anno,217                    "args": (2 * u.K, 3 * u.K),218                    "kwargs": {},219                    "checks": {220                        "x": {"units": [u.K], "equivalencies": equivs[1][0]},221                        "y": {"units": [u.K], "equivalencies": equivs[1]},222                    },223                },224                "output": {225                    "x": {"units": [u.K], "equivalencies": [equivs[1][0]]},226                    "y": {"units": [u.K], "equivalencies": equivs[1]},227                },228            },229            {230                "descr": "number of checked arguments exceed number of function arguments",231                "setup": {232                    "function": self.foo_partial_anno,233                    "args": (2 * u.cm, 3 * u.cm),234                    "kwargs": {},235                    "checks": {236                        "x": {"units": [u.cm]},237                        "y": {"units": [u.cm]},238                        "z": {"units": [u.cm]},239                    },240                },241                "warns": PlasmaPyWarning,242                "output": {"x": {"units": [u.cm]}, "y": {"units": [u.cm]}},243            },244            {245                "descr": "arguments passed via *args and **kwargs are ignored",246                "setup": {247                    "function": self.foo_stars,248                    "args": (2 * u.cm, "hello"),249                    "kwargs": {"z": None},250                    "checks": {251                        "x": {"units": [u.cm]},252                        "y": {"units": [u.cm]},253                        "z": {"units": [u.cm]},254                    },255                },256                "warns": PlasmaPyWarning,257                "output": {"x": {"units": [u.cm]}, "y": {"units": [u.cm]}},258            },259            {260                "descr": "arguments can be None values",261                "setup": {262                    "function": self.foo_with_none,263                    "args": (2 * u.cm, 3 * u.cm),264                    "kwargs": {},265                    "checks": {"x": {"units": [u.cm, None]}},266                },267                "output": {268                    "x": {"units": [u.cm], "none_shall_pass": True},269                    "y": {"units": [u.cm], "none_shall_pass": True},270                },271            },272            {273                "descr": "checks and annotations do not specify units",274                "setup": {275                    "function": self.foo_no_anno,276                    "args": (2 * u.cm, 3 * u.cm),277                    "kwargs": {},278                    "checks": {"x": {"pass_equivalent_units": True}},279                },280                "raises": ValueError,281            },282            {283                "descr": "units are directly assigned to the check kwarg",284                "setup": {285                    "function": self.foo_partial_anno,286                    "args": (2 * u.cm, 3 * u.cm),287                    "kwargs": {},288                    "checks": {"x": u.cm},289                },290                "output": {"x": {"units": [u.cm]}, "y": {"units": [u.cm]}},291            },292            {293                "descr": "return units are assigned via checks",294                "setup": {295                    "function": self.foo_no_anno,296                    "args": (2 * u.km, 3 * u.km),297                    "kwargs": {},298                    "checks": {"checks_on_return": u.km},299                },300                "output": {"checks_on_return": {"units": [u.km]}},301            },302            {303                "descr": "return units are assigned via annotations",304                "setup": {305                    "function": self.foo_return_anno,306                    "args": (2 * u.cm, 3 * u.cm),307                    "kwargs": {},308                    "checks": {},309                },310                "output": {"checks_on_return": {"units": [u.um]}},311            },312            {313                "descr": "return units are assigned via annotations and checks arg, but"314                "are not consistent",315                "setup": {316                    "function": self.foo_return_anno,317                    "args": (2 * u.cm, 3 * u.cm),318                    "kwargs": {},319                    "checks": {"checks_on_return": {"units": u.km}},320                },321                "raises": ValueError,322            },323            {324                "descr": "return units are not specified but other checks are",325                "setup": {326                    "function": self.foo_no_anno,327                    "args": (2 * u.cm, 3 * u.cm),328                    "kwargs": {},329                    "checks": {"checks_on_return": {"pass_equivalent_units": True}},330                },331                "raises": ValueError,332            },333            {334                "descr": "no parameter checks for x are defined, but a non-unit annotation"335                "is used",336                "setup": {337                    "function": self.foo_partial_anno,338                    "args": (2 * u.cm, 3 * u.cm),339                    "kwargs": {},340                    "checks": {},341                },342                "output": {"y": {"units": [u.cm]}},343            },344            {345                "descr": "parameter checks defined for x but unit checks calculated from"346                "function annotations. Function annotations do NOT define "347                "a proper unit type.",348                "setup": {349                    "function": self.foo_partial_anno,350                    "args": (2 * u.cm, 3 * u.cm),351                    "kwargs": {},352                    "checks": {"x": {"pass_equivalent_units": True}},353                },354                "raises": ValueError,355            },356            {357                "descr": "parameter checks defined for return argument but unit checks "358                "calculated from function annotations. Function annotations do "359                "NOT define a proper unit type.",360                "setup": {361                    "function": self.foo_partial_anno,362                    "args": (2 * u.cm, 3 * u.cm),363                    "kwargs": {},364                    "checks": {"checks_on_return": {"pass_equivalent_units": True}},365                },366                "raises": ValueError,367            },368        ]369        # perform tests370        for ii, case in enumerate(_cases):371            sig = inspect.signature(case["setup"]["function"])372            bound_args = sig.bind(*case["setup"]["args"], **case["setup"]["kwargs"])373            cu = CheckUnits(**case["setup"]["checks"])374            cu.f = case["setup"]["function"]375            if "warns" in case:376                with pytest.warns(case["warns"]):377                    checks = cu._get_unit_checks(bound_args)378            elif "raises" in case:379                with pytest.raises(case["raises"]):380                    cu._get_unit_checks(bound_args)381                continue382            else:383                checks = cu._get_unit_checks(bound_args)384            # only expected argument checks exist385            assert sorted(checks.keys()) == sorted(case["output"].keys())386            # if check key-value not specified then default is assumed387            for arg_name in case["output"].keys():388                arg_checks = checks[arg_name]389                for key in default_checks.keys():390                    if key in case["output"][arg_name]:391                        val = case["output"][arg_name][key]392                    else:393                        val = default_checks[key]394                    assert arg_checks[key] == val395    def test_cu_method__check_unit(self):396        """397        Test functionality/behavior of the methods `_check_unit` and `_check_unit_core`398        on `CheckUnits`.  These methods do the actual checking of the argument units399        and should be called by `CheckUnits.__call__()`.400        """401        # methods must exist402        assert hasattr(CheckUnits, "_check_unit")403        assert hasattr(CheckUnits, "_check_unit_core")404        # setup default checks405        check = {**self.check_defaults, "units": [u.cm]}406        # check = self.check_defaults.copy()407        # check['units'] = [u.cm]408        # check['equivalencies'] = [None]409        # make a class w/ improper units410        class MyQuantity:411            unit = None412        # setup test cases413        # 'input' = arguments for `_check_unit_core` and `_check_unit`414        # 'output' = expected return from `_check_unit_core`415        #416        # add cases for 'units' checks417        _cases = [418            # argument does not have units419            {"input": (5.0, "arg", check), "output": (None, None, None, TypeError)},420            # argument does match desired units421            # * set arg_name = 'checks_on_return' to cover if-else statement422            #   in initializing error string423            {424                "input": (5.0 * u.kg, "checks_on_return", check),425                "output": (None, None, None, u.UnitTypeError),426            },427            # argument has equivalent but not matching unit428            {429                "input": (5.0 * u.km, "arg", check),430                "output": (5.0 * u.km, u.cm, None, u.UnitTypeError),431            },432            # argument is equivalent to many specified units but exactly matches one433            {434                "input": (5.0 * u.km, "arg", {**check, "units": [u.cm, u.km]}),435                "output": (5.0 * u.km, u.km, None, None),436            },437            # argument is equivalent to many specified units and438            # does NOT exactly match one439            {440                "input": (5.0 * u.m, "arg", {**check, "units": [u.cm, u.km]}),441                "output": (None, None, None, u.UnitTypeError),442            },443            # argument has attr unit but unit does not have is_equivalent444            {445                "input": (MyQuantity, "arg", check),446                "output": (None, None, None, TypeError),447            },448        ]449        # add cases for 'none_shall_pass' checks450        _cases.extend(451            [452                # argument is None and none_shall_pass = False453                {454                    "input": (None, "arg", {**check, "none_shall_pass": False}),455                    "output": (None, None, None, ValueError),456                },457                # argument is None and none_shall_pass = True458                {459                    "input": (None, "arg", {**check, "none_shall_pass": True}),460                    "output": (None, None, None, None),461                },462            ]463        )464        # add cases for 'pass_equivalent_units' checks465        _cases.extend(466            [467                # argument is equivalent to 1 to unit,468                # does NOT exactly match the unit,469                # and 'pass_equivalent_units' = True and argument470                {471                    "input": (472                        5.0 * u.km,473                        "arg",474                        {**check, "pass_equivalent_units": True},475                    ),476                    "output": (5.0 * u.km, u.cm, None, None),477                },478                # argument is equivalent to more than 1 unit,479                # does NOT exactly match any unit,480                # and 'pass_equivalent_units' = True and argument481                {482                    "input": (483                        5.0 * u.km,484                        "arg",485                        {**check, "units": [u.cm, u.m], "pass_equivalent_units": True},486                    ),487                    "output": (5.0 * u.km, None, None, None),488                },489            ]490        )491        # setup wrapped function492        cu = CheckUnits()493        cu.f = self.foo_no_anno494        # perform tests495        for ii, case in enumerate(_cases):496            arg, arg_name, arg_checks = case["input"]497            _results = cu._check_unit_core(arg, arg_name, arg_checks)498            assert _results[0:3] == case["output"][0:3]499            if _results[3] is None:500                assert _results[3] is case["output"][3]501                assert cu._check_unit(arg, arg_name, arg_checks) is None502            else:503                assert isinstance(_results[3], case["output"][3])504                with pytest.raises(case["output"][3]):505                    cu._check_unit(arg, arg_name, arg_checks)506    def test_cu_called_as_decorator(self):507        """508        Test behavior of `CheckUnits.__call__` (i.e. used as a decorator).509        """510        # setup test cases511        # 'setup' = arguments for `CheckUnits` and wrapped function512        # 'output' = expected return from wrapped function513        # 'raises' = if an Exception is expected to be raised514        # 'warns' = if a warning is expected to be issued515        #516        _cases = [517            # clean execution518            {519                "setup": {520                    "function": self.foo_no_anno,521                    "args": (2 * u.cm, 3 * u.cm),522                    "kwargs": {},523                    "checks": {"x": u.cm, "y": u.cm, "checks_on_return": u.cm},524                },525                "output": 5 * u.cm,526            },527            # argument fails checks528            {529                "setup": {530                    "function": self.foo_no_anno,531                    "args": (2 * u.cm, 3 * u.cm),532                    "kwargs": {},533                    "checks": {"x": u.g, "y": u.cm, "checks_on_return": u.cm},534                },535                "raises": u.UnitTypeError,536            },537            # return fails checks538            {539                "setup": {540                    "function": self.foo_no_anno,541                    "args": (2 * u.cm, 3 * u.cm),542                    "kwargs": {},543                    "checks": {"x": u.cm, "y": u.cm, "checks_on_return": u.km},544                },545                "raises": u.UnitTypeError,546            },547        ]548        # test549        for case in _cases:550            wfoo = CheckUnits(**case["setup"]["checks"])(case["setup"]["function"])551            args = case["setup"]["args"]552            kwargs = case["setup"]["kwargs"]553            if "raises" in case:554                with pytest.raises(case["raises"]):555                    wfoo(*args, **kwargs)556            else:557                assert wfoo(*args, **kwargs) == case["output"]558        # test on class method559        class Foo:560            @CheckUnits()561            def __init__(self, y: u.cm):562                self.y = y563            @CheckUnits(x=u.cm)564            def bar(self, x) -> u.cm:565                return x + self.y566        foo = Foo(10.0 * u.cm)567        assert foo.bar(-3 * u.cm) == 7 * u.cm568    def test_cu_preserves_signature(self):569        """Test `CheckValues` preserves signature of wrapped function."""570        # I'd like to directly test the @preserve_signature is used (??)571        wfoo = CheckUnits()(self.foo_no_anno)572        assert hasattr(wfoo, "__signature__")573        assert wfoo.__signature__ == inspect.signature(self.foo_no_anno)574    @mock.patch(575        CheckUnits.__module__ + "." + CheckUnits.__qualname__,576        side_effect=CheckUnits,577        autospec=True,578    )579    def test_decorator_func_def(self, mock_cu_class):580        """581        Test that :func:`~plasmapy.utils.decorators.checks.check_units` is582        properly defined.583        """584        # create mock function (mock_foo) from function to mock (self.foo_no_anno)585        mock_foo = mock.Mock(586            side_effect=self.foo_no_anno, name="mock_foo", autospec=True587        )588        mock_foo.__name__ = "mock_foo"589        mock_foo.__signature__ = inspect.signature(self.foo_no_anno)590        # setup test cases591        # 'setup' = arguments for `check_units` and wrapped function592        # 'output' = expected return from wrapped function593        # 'raises' = a raised Exception is expected594        # 'warns' = an issued warning is expected595        #596        _cases = [597            # only argument checks598            {599                "setup": {600                    "args": (2 * u.cm, 3 * u.cm),601                    "kwargs": {},602                    "checks": {"x": u.cm, "y": u.cm},603                },604                "output": 5 * u.cm,605            },606            # argument and return checks607            {608                "setup": {609                    "args": (2 * u.cm, 3 * u.cm),610                    "kwargs": {},611                    "checks": {"x": u.cm, "checks_on_return": u.cm},612                },613                "output": 5 * u.cm,614            },615        ]616        for case in _cases:617            for ii in range(2):618                # decorate619                if ii == 0:620                    # functional decorator call621                    wfoo = check_units(mock_foo, **case["setup"]["checks"])622                elif ii == 1:623                    # sugar decorator call624                    #625                    #  @check_units(x=check)626                    #      def foo(x):627                    #          return x628                    #629                    wfoo = check_units(**case["setup"]["checks"])(mock_foo)630                else:631                    continue632                # test633                args = case["setup"]["args"]634                kwargs = case["setup"]["kwargs"]635                assert wfoo(*args, **kwargs) == case["output"]636                assert mock_cu_class.called637                assert mock_foo.called638                assert mock_cu_class.call_args[0] == ()639                assert sorted(mock_cu_class.call_args[1].keys()) == sorted(640                    case["setup"]["checks"].keys()641                )642                for arg_name, checks in case["setup"]["checks"].items():643                    assert mock_cu_class.call_args[1][arg_name] == checks644                # reset645                mock_cu_class.reset_mock()646                mock_foo.reset_mock()647# ----------------------------------------------------------------------------------------648# Test Decorator class `CheckValues` and decorator `check_values`649# ----------------------------------------------------------------------------------------650class TestCheckValues:651    """652    Tests for decorator :func:`~plasmapy.utils.decorators.checks.check_values` and653    decorator class :class:`~plasmapy.utils.decorators.checks.CheckValues`.654    """655    check_defaults = CheckValues._CheckValues__check_defaults  # type: Dict[str, bool]656    @staticmethod657    def foo(x, y):658        return x + y659    @staticmethod660    def foo_stars(x, *args, y=3, **kwargs):661        return x + y662    def test_inheritance(self):663        assert issubclass(CheckValues, CheckBase)664    def test_cv_default_check_values(self):665        """Test the default check dictionary for CheckValues"""666        cv = CheckValues()667        assert hasattr(cv, "_CheckValues__check_defaults")668        assert isinstance(cv._CheckValues__check_defaults, dict)669        _defaults = [670            ("can_be_negative", True),671            ("can_be_complex", False),672            ("can_be_inf", True),673            ("can_be_nan", True),674            ("none_shall_pass", False),675        ]676        for key, val in _defaults:677            assert cv._CheckValues__check_defaults[key] == val678    def test_cv_method__get_value_checks(self):679        """680        Test functionality/behavior of the method `_get_value_checks` on `CheckValues`.681        This method reviews the decorator `checks` arguments to build a complete682        checks dictionary.683        """684        # methods must exist685        assert hasattr(CheckValues, "_get_value_checks")686        # setup default checks687        default_checks = self.check_defaults.copy()688        # setup test cases689        # 'setup' = arguments for `_get_value_checks`690        # 'output' = expected return from `_get_value_checks`691        # 'raises' = if `_get_value_checks` raises an Exception692        # 'warns' = if `_get_value_checks` issues a warning693        #694        _cases = [695            # define some checks696            {697                "setup": {698                    "function": self.foo,699                    "args": (2, 3),700                    "kwargs": {},701                    "checks": {702                        "x": {703                            "can_be_negative": False,704                            "can_be_complex": True,705                            "can_be_inf": False,706                        },707                        "checks_on_return": {708                            "can_be_nan": False,709                            "none_shall_pass": True,710                        },711                    },712                },713                "output": {714                    "x": {715                        "can_be_negative": False,716                        "can_be_complex": True,717                        "can_be_inf": False,718                    },719                    "checks_on_return": {"can_be_nan": False, "none_shall_pass": True},720                },721            },722            # arguments passed via *args and **kwargs are ignored723            {724                "setup": {725                    "function": self.foo_stars,726                    "args": (2, "hello"),727                    "kwargs": {"z": None},728                    "checks": {729                        "x": {"can_be_negative": False},730                        "y": {"can_be_inf": False},731                        "z": {"none_shall_pass": True},732                    },733                },734                "output": {"x": {"can_be_negative": False}, "y": {"can_be_inf": False}},735                "warns": PlasmaPyWarning,736            },737            # check argument is not a dictionary (default is assumed)738            {739                "setup": {740                    "function": self.foo,741                    "args": (2, 3),742                    "kwargs": {},743                    "checks": {"x": u.cm},744                },745                "output": {"x": {}},746            },747        ]748        # perform tests749        for case in _cases:750            sig = inspect.signature(case["setup"]["function"])751            args = case["setup"]["args"]752            kwargs = case["setup"]["kwargs"]753            bound_args = sig.bind(*args, **kwargs)754            cv = CheckValues(**case["setup"]["checks"])755            cv.f = case["setup"]["function"]756            if "warns" in case:757                with pytest.warns(case["warns"]):758                    checks = cv._get_value_checks(bound_args)759            elif "raises" in case:760                with pytest.raises(case["raises"]):761                    cv._get_value_checks(bound_args)762                continue763            else:764                checks = cv._get_value_checks(bound_args)765            # only expected keys exist766            assert sorted(checks.keys()) == sorted(case["output"].keys())767            # if check key-value not specified then default is assumed768            for arg_name in case["output"].keys():769                arg_checks = checks[arg_name]770                for key in default_checks.keys():771                    if key in case["output"][arg_name]:772                        val = case["output"][arg_name][key]773                    else:774                        val = default_checks[key]775                    assert arg_checks[key] == val776    def test_cv_method__check_value(self):777        """778        Test functionality/behavior of the `_check_value` method on `CheckValues`.779        This method does the actual checking of the argument values and should be780        called by `CheckValues.__call__()`.781        """782        # setup wrapped function783        cv = CheckValues()784        wfoo = cv(self.foo)785        # methods must exist786        assert hasattr(cv, "_check_value")787        # setup default checks788        default_checks = self.check_defaults.copy()789        # setup test cases790        # 'setup' = arguments for `CheckUnits` and wrapped function791        # 'raises' = if an Exception is expected to be raised792        # 'warns' = if a warning is expected to be issued793        #794        _cases = [795            # tests for check 'can_be_negative'796            {797                "input": {798                    "args": [799                        -5,800                        -5.0,801                        np.array([-1, 2]),802                        np.array([-3.0, 2.0]),803                        -3 * u.cm,804                        np.array([-4.0, 3.0]) * u.kg,805                    ],806                    "arg_name": "arg",807                    "checks": {**default_checks, "can_be_negative": False},808                },809                "raises": ValueError,810            },811            {812                "input": {813                    "args": [814                        -5,815                        -5.0,816                        np.array([-1, 2]),817                        np.array([-3.0, 2.0]),818                        -3 * u.cm,819                        np.array([-4.0, 3.0]) * u.kg,820                    ],821                    "arg_name": "arg",822                    "checks": {**default_checks, "can_be_negative": True},823                }824            },825            # tests for check 'can_be_complex'826            {827                "input": {828                    "args": [829                        complex(5),830                        complex(2, 3),831                        np.complex(3.0),832                        complex(4.0, 2.0) * u.cm,833                        np.array([complex(4, 5), complex(1)]) * u.kg,834                    ],835                    "arg_name": "checks_on_return",836                    "checks": {**default_checks, "can_be_complex": False},837                },838                "raises": ValueError,839            },840            {841                "input": {842                    "args": [843                        complex(5),844                        complex(2, 3),845                        np.complex(3.0),846                        complex(4.0, 2.0) * u.cm,847                        np.array([complex(4, 5), complex(1)]) * u.kg,848                    ],849                    "arg_name": "checks_on_return",850                    "checks": {**default_checks, "can_be_complex": True},851                }852            },853            # tests for check 'can_be_inf'854            {855                "input": {856                    "args": [857                        np.inf,858                        np.inf * u.cm,859                        np.array([1.0, 2.0, np.inf, 10.0]),860                        np.array([1.0, 2.0, np.inf, np.inf]) * u.kg,861                    ],862                    "arg_name": "arg",863                    "checks": {**default_checks, "can_be_inf": False},864                },865                "raises": ValueError,866            },867            {868                "input": {869                    "args": [870                        np.inf,871                        np.inf * u.cm,872                        np.array([1.0, 2.0, np.inf, 10.0]),873                        np.array([1.0, 2.0, np.inf, np.inf]) * u.kg,874                    ],875                    "arg_name": "arg",876                    "checks": {**default_checks, "can_be_inf": True},877                }878            },879            # tests for check 'can_be_nan'880            {881                "input": {882                    "args": [883                        np.nan,884                        np.nan * u.cm,885                        np.array([1.0, 2.0, np.nan, 10.0]),886                        np.array([1.0, 2.0, np.nan, np.nan]) * u.kg,887                    ],888                    "arg_name": "arg",889                    "checks": {**default_checks, "can_be_nan": False},890                },891                "raises": ValueError,892            },893            {894                "input": {895                    "args": [896                        np.nan,897                        np.nan * u.cm,898                        np.array([1.0, 2.0, np.nan, 10.0]),899                        np.array([1.0, 2.0, np.nan, np.nan]) * u.kg,900                    ],901                    "arg_name": "arg",902                    "checks": {**default_checks, "can_be_nan": True},903                }904            },905            # tests for check 'none_shall_pass'906            {907                "input": {908                    "args": [None],909                    "arg_name": "arg",910                    "checks": {**default_checks, "none_shall_pass": False},911                },912                "raises": ValueError,913            },914            {915                "input": {916                    "args": [None],917                    "arg_name": "arg",918                    "checks": {**default_checks, "none_shall_pass": True},919                }920            },921        ]922        # test923        for case in _cases:924            arg_name = case["input"]["arg_name"]925            checks = case["input"]["checks"]926            for arg in case["input"]["args"]:927                if "raises" in case:928                    with pytest.raises(case["raises"]):929                        cv._check_value(arg, arg_name, checks)930                elif "warns" in case:931                    with pytest.warns(case["warns"]):932                        cv._check_value(arg, arg_name, checks)933                else:934                    assert cv._check_value(arg, arg_name, checks) is None935    def test_cv_called_as_decorator(self):936        """937        Test behavior of `CheckValues.__call__` (i.e. used as a decorator).938        """939        # setup test cases940        # 'setup' = arguments for `CheckUnits` and wrapped function941        # 'output' = expected return from wrapped function942        # 'raises' = if an Exception is expected to be raised943        # 'warns' = if a warning is expected to be issued944        #945        _cases = [946            # clean execution947            {948                "setup": {949                    "function": self.foo,950                    "args": (2, -3),951                    "kwargs": {},952                    "checks": {953                        "x": {"can_be_negative": True},954                        "y": {"can_be_negative": True},955                        "checks_on_return": {"can_be_negative": True},956                    },957                },958                "output": -1,959            },960            # argument fails checks961            {962                "setup": {963                    "function": self.foo,964                    "args": (2, -3),965                    "kwargs": {},966                    "checks": {967                        "x": {"can_be_negative": True},968                        "y": {"can_be_negative": False},969                        "checks_on_return": {"can_be_negative": True},970                    },971                },972                "raises": ValueError,973            },974            # return fails checks975            {976                "setup": {977                    "function": self.foo,978                    "args": (2, -3),979                    "kwargs": {},980                    "checks": {981                        "x": {"can_be_negative": True},982                        "y": {"can_be_negative": True},983                        "checks_on_return": {"can_be_negative": False},984                    },985                },986                "raises": ValueError,987            },988        ]989        # test on function990        for case in _cases:991            wfoo = CheckValues(**case["setup"]["checks"])(case["setup"]["function"])992            args = case["setup"]["args"]993            kwargs = case["setup"]["kwargs"]994            if "raises" in case:995                with pytest.raises(case["raises"]):996                    wfoo(*args, **kwargs)997            else:998                assert wfoo(*args, **kwargs) == case["output"]999        # test on class method1000        class Foo:1001            @CheckValues(y={"can_be_negative": True})1002            def __init__(self, y):1003                self.y = y1004            @CheckValues(1005                x={"can_be_negative": True}, checks_on_return={"can_be_negative": False}1006            )1007            def bar(self, x):1008                return x + self.y1009        foo = Foo(-5)1010        assert foo.bar(6) == 11011        with pytest.raises(ValueError):1012            foo.bar(1)1013    def test_cv_preserves_signature(self):1014        """Test CheckValues preserves signature of wrapped function."""1015        # I'd like to directly test the @preserve_signature is used (??)1016        wfoo = CheckValues()(self.foo)1017        assert hasattr(wfoo, "__signature__")1018        assert wfoo.__signature__ == inspect.signature(self.foo)1019    @mock.patch(1020        CheckValues.__module__ + "." + CheckValues.__qualname__,1021        side_effect=CheckValues,1022        autospec=True,1023    )1024    def test_decorator_func_def(self, mock_cv_class):1025        """1026        Test that :func:`~plasmapy.utils.decorators.checks.check_values` is1027        properly defined.1028        """1029        # create mock function (mock_foo) from function to mock (self.foo)1030        mock_foo = mock.Mock(side_effect=self.foo, name="mock_foo", autospec=True)1031        mock_foo.__name__ = "mock_foo"1032        mock_foo.__signature__ = inspect.signature(self.foo)1033        # setup test cases1034        # 'setup' = arguments for `check_units` and wrapped function1035        # 'output' = expected return from wrapped function1036        # 'raises' = a raised Exception is expected1037        # 'warns' = an issued warning is expected1038        #1039        _cases = [1040            # only argument checks1041            {1042                "setup": {1043                    "args": (-4, 3),1044                    "kwargs": {},1045                    "checks": {1046                        "x": {"can_be_negative": True},1047                        "y": {"can_be_nan": False},1048                    },1049                },1050                "output": -1,1051            },1052            # argument and return checks1053            {1054                "setup": {1055                    "args": (-4, 3),1056                    "kwargs": {},1057                    "checks": {1058                        "x": {"can_be_negative": True},1059                        "checks_on_return": {"can_be_negative": True},1060                    },1061                },1062                "output": -1,1063            },1064        ]1065        for case in _cases:1066            for ii in range(2):1067                # decorate1068                if ii == 0:1069                    # functional decorator call1070                    wfoo = check_values(mock_foo, **case["setup"]["checks"])1071                elif ii == 1:1072                    # sugar decorator call1073                    #1074                    #  @check_values(x=check)1075                    #      def foo(x):1076                    #          return x1077                    #1078                    wfoo = check_values(**case["setup"]["checks"])(mock_foo)1079                else:1080                    continue1081                # test1082                args = case["setup"]["args"]1083                kwargs = case["setup"]["kwargs"]1084                assert wfoo(*args, **kwargs) == case["output"]1085                assert mock_cv_class.called1086                assert mock_foo.called1087                assert mock_cv_class.call_args[0] == ()1088                assert sorted(mock_cv_class.call_args[1].keys()) == sorted(1089                    case["setup"]["checks"].keys()1090                )1091                for arg_name, checks in case["setup"]["checks"].items():1092                    assert mock_cv_class.call_args[1][arg_name] == checks1093                # reset1094                mock_cv_class.reset_mock()1095                mock_foo.reset_mock()1096# ----------------------------------------------------------------------------------------1097# Test Decorator `check_relativistic` (& function `_check_relativistic`1098# ----------------------------------------------------------------------------------------1099# (speed, betafrac)1100non_relativistic_speed_examples = [1101    (0 * u.m / u.s, 0.1),1102    (0.0099999 * c, 0.1),1103    (-0.009 * c, 0.1),1104    (5 * u.AA / u.Gyr, 0.1),1105]1106# (speed, betafrac, error)1107relativistic_error_examples = [1108    (u.m / u.s, 0.1, TypeError),1109    (51513.35, 0.1, TypeError),1110    (5 * u.m, 0.1, u.UnitConversionError),1111    (1.0 * c, 0.1, RelativityError),1112    (1.1 * c, 0.1, RelativityError),1113    (np.inf * u.cm / u.s, 0.1, RelativityError),1114    (-1.0 * c, 0.1, RelativityError),1115    (-1.1 * c, 0.1, RelativityError),1116    (-np.inf * u.cm / u.s, 0.1, RelativityError),1117]1118# (speed, betafrac, warning)1119relativistic_warning_examples = [1120    (0.11 * c, 0.1),1121    (-0.11 * c, 0.1),1122    (2997924581 * u.cm / u.s, 0.1),1123    (0.02 * c, 0.01),1124]1125# Tests for _check_relativistic1126@pytest.mark.parametrize("speed, betafrac", non_relativistic_speed_examples)1127def test__check_relativisitc_valid(speed, betafrac):1128    _check_relativistic(speed, "f", betafrac=betafrac)1129@pytest.mark.parametrize("speed, betafrac, error", relativistic_error_examples)1130def test__check_relativistic_errors(speed, betafrac, error):1131    with pytest.raises(error):1132        _check_relativistic(speed, "f", betafrac=betafrac)1133@pytest.mark.parametrize("speed, betafrac", relativistic_warning_examples)1134def test__check_relativistic_warnings(speed, betafrac):1135    with pytest.warns(RelativityWarning):1136        _check_relativistic(speed, "f", betafrac=betafrac)1137# Tests for check_relativistic decorator1138@pytest.mark.parametrize("speed, betafrac", non_relativistic_speed_examples)1139def test_check_relativistic_decorator(speed, betafrac):1140    @check_relativistic(betafrac=betafrac)1141    def speed_func():1142        return speed1143    speed_func()1144@pytest.mark.parametrize("speed", [item[0] for item in non_relativistic_speed_examples])1145def test_check_relativistic_decorator_no_args(speed):1146    @check_relativistic1147    def speed_func():1148        return speed1149    speed_func()1150@pytest.mark.parametrize("speed", [item[0] for item in non_relativistic_speed_examples])1151def test_check_relativistic_decorator_no_args_parentheses(speed):1152    @check_relativistic()1153    def speed_func():1154        return speed1155    speed_func()1156@pytest.mark.parametrize("speed, betafrac, error", relativistic_error_examples)1157def test_check_relativistic_decorator_errors(speed, betafrac, error):1158    @check_relativistic(betafrac=betafrac)1159    def speed_func():1160        return speed1161    with pytest.raises(error):...test_decorators.py
Source:test_decorators.py  
...330        with pytest.raises(expected_exception):331            decorated_function(particle)332    else:333        decorated_function(particle)334def test_none_shall_pass():335    """Tests the `none_shall_pass` keyword argument in is_particle.336    If `none_shall_pass=True`, then an annotated argument should allow337    `None` to be passed through to the decorated function."""338    @particle_input(none_shall_pass=True)339    def func_none_shall_pass(particle: Particle) -> Optional[Particle]:340        return particle341    @particle_input(none_shall_pass=True)342    def func_none_shall_pass_with_tuple(343        particles: (Particle, Particle)344    ) -> (Optional[Particle], Optional[Particle]):345        return particles346    @particle_input(none_shall_pass=True)347    def func_none_shall_pass_with_list(particles: [Particle]) -> [Optional[Particle]]:348        return particles349    assert func_none_shall_pass(None) is None, (350        "The none_shall_pass keyword in the particle_input decorator is set "351        "to True, but is not passing through None."352    )353    assert func_none_shall_pass_with_tuple((None, None)) == (None, None), (354        "The none_shall_pass keyword in the particle_input decorator is set "355        "to True, but is not passing through None."356    )357    assert func_none_shall_pass_with_list((None, None)) == (None, None), (358        "The none_shall_pass keyword in the particle_input decorator is set "359        "to True, but is not passing through None."360    )361def test_none_shall_not_pass():362    """Tests the `none_shall_pass` keyword argument in is_particle.363    If `none_shall_pass=False`, then particle_input should raise a...validators.py
Source:validators.py  
1"""2Various decorators to validate input/output arguments to functions.3"""4__all__ = ["validate_quantities", "ValidateQuantities"]5import astropy.units as u6import functools7import inspect8import warnings9from typing import Any, Dict10from plasmapy.utils.decorators.checks import CheckUnits, CheckValues11from plasmapy.utils.decorators.helpers import preserve_signature12class ValidateQuantities(CheckUnits, CheckValues):13    """14    A decorator class to 'validate' -- control and convert -- the units and values15    of input and return arguments to a function or method.  Arguments are expected to16    be astropy :class:`~astropy.units.quantity.Quantity` objects.17    Parameters18    ----------19    validations_on_return: dictionary of validation specifications20        Specifications for unit and value validations on the return of the21        function being wrapped. (see `quantity validations`_ for valid22        specifications.23    **validations: dictionary of validation specifications24        Specifications for unit and value validations on the input arguments of the25        function being wrapped.  Each keyword argument in `validations` is the26        name of a function argument to be validated and the keyword value contains27        the unit and value validation specifications.28        .. _`quantity validations`:29        Unit and value validations can be defined by passing one of the astropy30        :mod:`~astropy.units`, a list of astropy units, or a dictionary containing31        the keys defined below.  Units can also be defined with function annotations,32        but must be consistent with decorator `**validations` arguments if used33        concurrently.  If a key is omitted, then the default value will be assumed.34        ====================== ======= ================================================35        Key                    Type    Description36        ====================== ======= ================================================37        units                          list of desired astropy :mod:`~astropy.units`38        equivalencies                  | [DEFAULT `None`] A list of equivalent pairs to39                                         try if40                                       | the units are not directly convertible.41                                       | (see :mod:`~astropy.units.equivalencies`,42                                         and/or `astropy equivalencies`_)43        pass_equivalent_units  `bool`  | [DEFAULT `False`] allow equivalent units44                                       | to pass45        can_be_negative        `bool`  [DEFAULT `True`] values can be negative46        can_be_complex         `bool`  [DEFAULT `False`] values can be complex numbers47        can_be_inf             `bool`  [DEFAULT `True`] values can be :data:`~numpy.inf`48        can_be_nan             `bool`  [DEFAULT `True`] values can be :data:`~numpy.nan`49        none_shall_pass        `bool`  [DEFAULT `False`] values can be a python `None`50        ====================== ======= ================================================51    Notes52    -----53    * Validation of function arguments `*args` and `**kwargs` is not supported.54    * `None` values will pass when `None` is included in the list of specified units,55      is set as a default value for the function argument, or `none_shall_pass` is56      set to `True`.  If `none_shall_pass` is doubly/triply defined through the57      mentioned options, then they all must be consistent with each other.58    * If units are not specified in `validations`, then the decorator will attempt59      to identify desired units by examining the function annotations.60    Examples61    --------62    Define unit and value validations with decorator parameters::63        import astropy.units as u64        from plasmapy.utils.decorators import ValidateQuantities65        @ValidateQuantities(mass={'units': u.g,66                                  'can_be_negative': False},67                            vel=u.cm / u.s,68                            validations_on_return=[u.g * u.cm / u.s, u.kg * u.m / u.s])69        def foo(mass, vel):70            return mass * vel71        # on a method72        class Foo:73            @ValidateQuantities(mass={'units': u.g,74                                      'can_be_negative': False},75                                vel=u.cm / u.s,76                                validations_on_return=[u.g * u.cm / u.s,77                                                       u.kg * u.m / u.s])78            def bar(self, mass, vel):79                return mass * vel80    Define units with function annotations::81        import astropy.units as u82        from plasmapy.utils.decorators import ValidateQuantities83        @ValidateQuantities(mass={'can_be_negative': False})84        def foo(mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s:85            return mass * vel86        # on a method87        class Foo:88            @ValidateQuantities(mass={'can_be_negative': False})89            def bar(self, mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s:90                return mass * vel91    Allow `None` values to pass::92        import astropy.units as u93        from plasmapy.utils.decorators import ValidateQuantities94        @ValidateQuantities(checks_on_return=[u.cm, None])95        def foo(arg1: u.cm = None):96            return arg197    Allow return values to have equivalent units::98        import astropy.units as u99        from plasmapy.utils.decorators import ValidateQuantities100        @ValidateQuantities(arg1={'units': u.cm},101                            checks_on_return={'units': u.km,102                                              'pass_equivalent_units': True})103        def foo(arg1):104            return arg1105    Allow equivalent units to pass with specified equivalencies::106        import astropy.units as u107        from plasmapy.utils.decorators import ValidateQuantities108        @ValidateQuantities(arg1={'units': u.K,109                                  'equivalencies': u.temperature(),110                                  'pass_equivalent_units': True})111        def foo(arg1):112            return arg1113    .. _astropy equivalencies:114        https://docs.astropy.org/en/stable/units/equivalencies.html115    """116    def __init__(self, validations_on_return=None, **validations: Dict[str, Any]):117        if "checks_on_return" in validations:118            raise TypeError(119                f"keyword argument 'checks_on_return' is not allowed, "120                f"use 'validations_on_return' to set validations "121                f"on the return variable"122            )123        self._validations = validations124        checks = validations.copy()125        if validations_on_return is not None:126            self._validations["validations_on_return"] = validations_on_return127            checks["checks_on_return"] = validations_on_return128        super().__init__(**checks)129    def __call__(self, f):130        """131        Parameters132        ----------133        f134            Function to be wrapped135        Returns136        -------137        function138            wrapped function of `f`139        """140        self.f = f141        wrapped_sign = inspect.signature(f)142        @preserve_signature143        @functools.wraps(f)144        def wrapper(*args, **kwargs):145            # combine args and kwargs into dictionary146            bound_args = wrapped_sign.bind(*args, **kwargs)147            bound_args.apply_defaults()148            # get conditioned validations149            validations = self._get_validations(bound_args)150            # validate (input) argument units and values151            for arg_name in validations:152                # skip check of output/return153                if arg_name == "validations_on_return":154                    continue155                # validate argument & update for conversion156                arg = self._validate_quantity(157                    bound_args.arguments[arg_name], arg_name, validations[arg_name]158                )159                bound_args.arguments[arg_name] = arg160            # call function161            _return = f(**bound_args.arguments)162            # validate output163            if "validations_on_return" in validations:164                _return = self._validate_quantity(165                    _return,166                    "validations_on_return",167                    validations["validations_on_return"],168                )169            return _return170        return wrapper171    def _get_validations(172        self, bound_args: inspect.BoundArguments173    ) -> Dict[str, Dict[str, Any]]:174        """175        Review :attr:`validations` and function bound arguments to build a complete176        'validations' dictionary.  If a validation key is omitted from the argument177        validations, then a default value is assumed (see `quantity validations`_).178        Parameters179        ----------180        bound_args: :class:`inspect.BoundArguments`181            arguments passed into the function being wrapped182            .. code-block:: python183                bound_args = inspect.signature(f).bind(*args, **kwargs)184        Returns185        -------186        Dict[str, Dict[str, Any]]187            A complete 'validations' dictionary for validating function input arguments188            and return.189        """190        unit_checks = self._get_unit_checks(bound_args)191        value_checks = self._get_value_checks(bound_args)192        # combine all validations193        # * `unit_checks` will encompass all argument "checks" defined either by194        #   function annotations or **validations.195        # * `value_checks` may miss some arguments if **validations only defines196        #   unit validations or some validations come from function annotations197        validations = unit_checks.copy()198        for arg_name in validations:199            # augment 'none_shall_pass' (if needed)200            try:201                # if 'none_shall_pass' was in the original passed-in validations,202                # then override the value determined by CheckUnits203                _none_shall_pass = self.validations[arg_name]["none_shall_pass"]204                # if validations[arg_name]['none_shall_pass'] != _none_shall_pass:205                if (206                    _none_shall_pass is False207                    and validations[arg_name]["none_shall_pass"] is True208                ):209                    raise ValueError(210                        f"Validation 'none_shall_pass' for argument '{arg_name}' is "211                        f"inconsistent between function annotations "212                        f"({validations[arg_name]['none_shall_pass']}) and decorator "213                        f"argument ({_none_shall_pass})."214                    )215                validations[arg_name]["none_shall_pass"] = _none_shall_pass216            except (KeyError, TypeError):217                # 'none_shall_pass' was not in the original passed-in validations, so218                # rely on the value determined by CheckUnits219                pass220            finally:221                try:222                    del value_checks[arg_name]["none_shall_pass"]223                except KeyError:224                    dvc = self._CheckValues__check_defaults.copy()225                    del dvc["none_shall_pass"]226                    value_checks[arg_name] = dvc227            # update the validations dictionary228            validations[arg_name].update(value_checks[arg_name])229        if "checks_on_return" in validations:230            validations["validations_on_return"] = validations.pop("checks_on_return")231        return validations232    def _validate_quantity(self, arg, arg_name: str, arg_validations: Dict[str, Any]):233        """234        Perform validations `arg_validations` on function argument `arg`235        named `arg_name`.236        Parameters237        ----------238        arg239            The argument to be validated.240        arg_name: str241            The name of the argument to be validated242        arg_validations: Dict[str, Any]243            The requested validations for the argument244        Raises245        ------246        TypeError247            if argument is not an AstroPy :class:`~astropy.units.Quantity`248            or not convertible to a :class:`~astropy.units.Quantity`249        ValueError250            if validations fail251        """252        # rename to work with "check" methods253        if arg_name == "validations_on_return":254            arg_name = "checks_on_return"255        # initialize str for error message256        if arg_name == "checks_on_return":257            err_msg = f"The return value  "258        else:259            err_msg = f"The argument '{arg_name}' "260        err_msg += f"to function {self.f.__name__}()"261        # initialize TypeError message262        typeerror_msg = (263            f"{err_msg} should be an astropy Quantity with units"264            f" equivalent to one of ["265        )266        for ii, unit in enumerate(arg_validations["units"]):267            typeerror_msg += f"{unit}"268            if ii != len(arg_validations["units"]) - 1:269                typeerror_msg += f", "270        typeerror_msg += f"]"271        # add units to arg if possible272        # * a None value will be taken care of by `_check_unit_core`273        #274        if arg is None or hasattr(arg, "unit"):275            pass276        elif len(arg_validations["units"]) != 1:277            raise TypeError(typeerror_msg)278        else:279            try:280                arg = arg * arg_validations["units"][0]281            except (TypeError, ValueError):282                raise TypeError(typeerror_msg)283            else:284                warnings.warn(285                    u.UnitsWarning(286                        f"{err_msg} has no specified units. Assuming units of "287                        f"{arg_validations['units'][0]}. To silence this warning, "288                        f"explicitly pass in an astropy Quantity "289                        f"(e.g. 5. * astropy.units.cm) "290                        f"(see http://docs.astropy.org/en/stable/units/)"291                    )292                )293        # check units294        arg, unit, equiv, err = self._check_unit_core(arg, arg_name, arg_validations)295        # convert quantity296        if (297            arg is not None298            and unit is not None299            and not arg_validations["pass_equivalent_units"]300        ):301            arg = arg.to(unit, equivalencies=equiv)302        elif err is not None:303            raise err304        # check value305        self._check_value(arg, arg_name, arg_validations)306        return arg307    @property308    def validations(self):309        """310        Requested validations on the decorated function's input arguments and311        return variable.312        """313        return self._validations314def validate_quantities(func=None, validations_on_return=None, **validations):315    """316    A decorator to 'validate' -- control and convert -- the units and values317    of input and return arguments to a function or method.  Arguments are expected to318    be astropy :class:`~astropy.units.quantity.Quantity` objects.319    Parameters320    ----------321    func:322        The function to be decorated323    validations_on_return: dictionary of validation specifications324        Specifications for unit and value validations on the return of the325        function being wrapped. (see `quantity validations`_ for valid326        specifications.327    **validations: dictionary of validation specifications328        Specifications for unit and value validations on the input arguments of the329        function being wrapped.  Each keyword argument in `validations` is the330        name of a function argument to be validated and the keyword value contains331        the unit and value validation specifications.332        .. _`quantity validations`:333        Unit and value validations can be defined by passing one of the astropy334        :mod:`~astropy.units`, a list of astropy units, or a dictionary containing335        the keys defined below.  Units can also be defined with function annotations,336        but must be consistent with decorator `**validations` arguments if used337        concurrently.  If a key is omitted, then the default value will be assumed.338        ====================== ======= ================================================339        Key                    Type    Description340        ====================== ======= ================================================341        units                          list of desired astropy :mod:`~astropy.units`342        equivalencies                  | [DEFAULT `None`] A list of equivalent pairs to343                                         try if344                                       | the units are not directly convertible.345                                       | (see :mod:`~astropy.units.equivalencies`,346                                         and/or `astropy equivalencies`_)347        pass_equivalent_units  `bool`  | [DEFAULT `False`] allow equivalent units348                                       | to pass349        can_be_negative        `bool`  [DEFAULT `True`] values can be negative350        can_be_complex         `bool`  [DEFAULT `False`] values can be complex numbers351        can_be_inf             `bool`  [DEFAULT `True`] values can be :data:`~numpy.inf`352        can_be_nan             `bool`  [DEFAULT `True`] values can be :data:`~numpy.nan`353        none_shall_pass        `bool`  [DEFAULT `False`] values can be a python `None`354        ====================== ======= ================================================355    Notes356    -----357    * Validation of function arguments `*args` and `**kwargs` is not supported.358    * `None` values will pass when `None` is included in the list of specified units,359      is set as a default value for the function argument, or `none_shall_pass` is360      set to `True`.  If `none_shall_pass` is doubly/triply defined through the361      mentioned options, then they all must be consistent with each other.362    * If units are not specified in `validations`, then the decorator will attempt363      to identify desired units by examining the function annotations.364    * Full functionality is defined by the class :class:`ValidateQuantities`.365    Examples366    --------367    Define unit and value validations with decorator parameters::368        import astropy.units as u369        from plasmapy.utils.decorators import validate_quantities370        @validate_quantities(mass={'units': u.g,371                                   'can_be_negative': False},372                             vel=u.cm / u.s,373                             validations_on_return=[u.g * u.cm / u.s, u.kg * u.m / u.s])374        def foo(mass, vel):375            return mass * vel376        # on a method377        class Foo:378            @validate_quantities(mass={'units': u.g,379                                       'can_be_negative': False},380                                 vel=u.cm / u.s,381                                 validations_on_return=[u.g * u.cm / u.s,382                                                        u.kg * u.m / u.s])383            def bar(self, mass, vel):384                return mass * vel385    Define units with function annotations::386        import astropy.units as u387        from plasmapy.utils.decorators import validate_quantities388        @validate_quantities(mass={'can_be_negative': False})389        def foo(mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s:390            return mass * vel391        # rely only on annotations392        @validate_quantities393        def foo(x: u.cm, time: u.s) -> u.cm / u.s:394            return x / time395        # on a method396        class Foo:397            @validate_quantities(mass={'can_be_negative': False})398            def bar(self, mass: u.g, vel: u.cm / u.s) -> u.g * u.cm / u.s:399                return mass * vel400    Allow `None` values to pass::401        import astropy.units as u402        from plasmapy.utils.decorators import validate_quantities403        @validate_quantities(arg2={'none_shall_pass': True},404                             checks_on_return=[u.cm, None])405        def foo(arg1: u.cm = None, arg2: u.cm):406            return None407    Allow return values to have equivalent units::408        import astropy.units as u409        from plasmapy.utils.decorators import validate_quantities410        @validate_quantities(arg1={'units': u.cm},411                             checks_on_return={'units': u.km,412                                               'pass_equivalent_units': True})413        def foo(arg1):414            return arg1415    Allow equivalent units to pass with specified equivalencies::416        import astropy.units as u417        from plasmapy.utils.decorators import validate_quantities418        @validate_quantities(arg1={'units': u.K,419                                   'equivalencies': u.temperature(),420                                   'pass_equivalent_units': True})421        def foo(arg1):422            return arg1423    .. _astropy equivalencies:424        https://docs.astropy.org/en/stable/units/equivalencies.html425    """426    if validations_on_return is not None:427        validations["validations_on_return"] = validations_on_return428    if func is not None:429        # `validate_quantities` called as a function430        return ValidateQuantities(**validations)(func)431    else:432        # `validate_quantities` called as a decorator "sugar-syntax"...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
