Best Python code snippet using pandera_python
checks.py
Source:checks.py  
1"""Data validation checks."""2import inspect3import operator4import re5from collections import ChainMap, namedtuple6from functools import partial, wraps7from itertools import chain8from typing import (9    Any,10    Callable,11    Dict,12    Iterable,13    List,14    Optional,15    Type,16    TypeVar,17    Union,18    no_type_check,19)20import pandas as pd21from . import check_utils, constants, errors22from . import strategies as st23CheckResult = namedtuple(24    "CheckResult",25    ["check_output", "check_passed", "checked_object", "failure_cases"],26)27GroupbyObject = Union[28    pd.core.groupby.SeriesGroupBy, pd.core.groupby.DataFrameGroupBy29]30SeriesCheckObj = Union[pd.Series, Dict[str, pd.Series]]31DataFrameCheckObj = Union[pd.DataFrame, Dict[str, pd.DataFrame]]32def register_check_statistics(statistics_args):33    """Decorator to set statistics based on Check method."""34    def register_check_statistics_decorator(class_method):35        @wraps(class_method)36        def _wrapper(cls, *args, **kwargs):37            args = list(args)38            arg_names = inspect.getfullargspec(class_method).args[1:]39            if not arg_names:40                arg_names = statistics_args41            args_dict = {**dict(zip(arg_names, args)), **kwargs}42            check = class_method(cls, *args, **kwargs)43            check.statistics = {44                stat: args_dict.get(stat) for stat in statistics_args45            }46            check.statistics_args = statistics_args47            return check48        return _wrapper49    return register_check_statistics_decorator50_T = TypeVar("_T", bound="_CheckBase")51class _CheckMeta(type):  # pragma: no cover52    """Check metaclass."""53    REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {}  # noqa54    def __getattr__(cls, name: str) -> Any:55        """Prevent attribute errors for registered checks."""56        attr = ChainMap(cls.__dict__, cls.REGISTERED_CUSTOM_CHECKS).get(name)57        if attr is None:58            raise AttributeError(59                f"'{cls}' object has no attribute '{name}'. "60                "Make sure any custom checks have been registered "61                "using the extensions api."62            )63        return attr64    def __dir__(cls) -> Iterable[str]:65        """Allow custom checks to show up as attributes when autocompleting."""66        return chain(super().__dir__(), cls.REGISTERED_CUSTOM_CHECKS.keys())67    # pylint: disable=line-too-long68    # mypy has limited metaclass support so this doesn't pass typecheck69    # see https://mypy.readthedocs.io/en/stable/metaclasses.html#gotchas-and-limitations-of-metaclass-support70    # pylint: enable=line-too-long71    @no_type_check72    def __contains__(cls: Type[_T], item: Union[_T, str]) -> bool:73        """Allow lookups for registered checks."""74        if isinstance(item, cls):75            name = item.name76            return hasattr(cls, name)77        # assume item is str78        return hasattr(cls, item)79class _CheckBase(metaclass=_CheckMeta):80    """Check base class."""81    def __init__(82        self,83        check_fn: Union[84            Callable[[pd.Series], Union[pd.Series, bool]],85            Callable[[pd.DataFrame], Union[pd.DataFrame, pd.Series, bool]],86        ],87        groups: Optional[Union[str, List[str]]] = None,88        groupby: Optional[Union[str, List[str], Callable]] = None,89        ignore_na: bool = True,90        element_wise: bool = False,91        name: str = None,92        error: Optional[str] = None,93        raise_warning: bool = False,94        n_failure_cases: Union[int, None] = constants.N_FAILURE_CASES,95        title: Optional[str] = None,96        description: Optional[str] = None,97        **check_kwargs,98    ) -> None:99        """Apply a validation function to each element, Series, or DataFrame.100        :param check_fn: A function to check pandas data structure. For Column101            or SeriesSchema checks, if element_wise is True, this function102            should have the signature: ``Callable[[pd.Series],103            Union[pd.Series, bool]]``, where the output series is a boolean104            vector.105            If element_wise is False, this function should have the signature:106            ``Callable[[Any], bool]``, where ``Any`` is an element in the107            column.108            For DataFrameSchema checks, if element_wise=True, fn109            should have the signature: ``Callable[[pd.DataFrame],110            Union[pd.DataFrame, pd.Series, bool]]``, where the output dataframe111            or series contains booleans.112            If element_wise is True, fn is applied to each row in113            the dataframe with the signature ``Callable[[pd.Series], bool]``114            where the series input is a row in the dataframe.115        :param groups: The dict input to the `fn` callable will be constrained116            to the groups specified by `groups`.117        :param groupby: If a string or list of strings is provided, these118            columns are used to group the Column series. If a119            callable is passed, the expected signature is: ``Callable[120            [pd.DataFrame], pd.core.groupby.DataFrameGroupBy]``121            The the case of ``Column`` checks, this function has access to the122            entire dataframe, but ``Column.name`` is selected from this123            DataFrameGroupby object so that a SeriesGroupBy object is passed124            into ``check_fn``.125            Specifying the groupby argument changes the ``check_fn`` signature126            to:127            ``Callable[[Dict[Union[str, Tuple[str]], pd.Series]], Union[bool, pd.Series]]``  # noqa128            where the input is a dictionary mapping129            keys to subsets of the column/dataframe.130        :param ignore_na: If True, null values will be ignored when determining131            if a check passed or failed. For dataframes, ignores rows with any132            null value. *New in version 0.4.0*133        :param element_wise: Whether or not to apply validator in an134            element-wise fashion. If bool, assumes that all checks should be135            applied to the column element-wise. If list, should be the same136            number of elements as checks.137        :param name: optional name for the check.138        :param error: custom error message if series fails validation139            check.140        :param raise_warning: if True, raise a UserWarning and do not throw141            exception instead of raising a SchemaError for a specific check.142            This option should be used carefully in cases where a failing143            check is informational and shouldn't stop execution of the program.144        :param n_failure_cases: report the first n unique failure cases. If145            None, report all failure cases.146        :param title: A human-readable label for the check.147        :param description: An arbitrary textual description of the check.148        :param check_kwargs: key-word arguments to pass into ``check_fn``149        :example:150        >>> import pandas as pd151        >>> import pandera as pa152        >>>153        >>>154        >>> # column checks are vectorized by default155        >>> check_positive = pa.Check(lambda s: s > 0)156        >>>157        >>> # define an element-wise check158        >>> check_even = pa.Check(lambda x: x % 2 == 0, element_wise=True)159        >>>160        >>> # checks can be given human-readable metadata161        >>> check_with_metadata = pa.Check(162        ...     lambda x: True,163        ...     title="Always passes",164        ...     description="This check always passes."165        ... )166        >>>167        >>> # specify assertions across categorical variables using `groupby`,168        >>> # for example, make sure the mean measure for group "A" is always169        >>> # larger than the mean measure for group "B"170        >>> check_by_group = pa.Check(171        ...     lambda measures: measures["A"].mean() > measures["B"].mean(),172        ...     groupby=["group"],173        ... )174        >>>175        >>> # define a wide DataFrame-level check176        >>> check_dataframe = pa.Check(177        ...     lambda df: df["measure_1"] > df["measure_2"])178        >>>179        >>> measure_checks = [check_positive, check_even, check_by_group]180        >>>181        >>> schema = pa.DataFrameSchema(182        ...     columns={183        ...         "measure_1": pa.Column(int, checks=measure_checks),184        ...         "measure_2": pa.Column(int, checks=measure_checks),185        ...         "group": pa.Column(str),186        ...     },187        ...     checks=check_dataframe188        ... )189        >>>190        >>> df = pd.DataFrame({191        ...     "measure_1": [10, 12, 14, 16],192        ...     "measure_2": [2, 4, 6, 8],193        ...     "group": ["B", "B", "A", "A"]194        ... })195        >>>196        >>> schema.validate(df)[["measure_1", "measure_2", "group"]]197            measure_1  measure_2 group198        0         10          2     B199        1         12          4     B200        2         14          6     A201        3         16          8     A202        See :ref:`here<checks>` for more usage details.203        """204        if element_wise and groupby is not None:205            raise errors.SchemaInitError(206                "Cannot use groupby when element_wise=True."207            )208        self._check_fn = check_fn209        self._check_kwargs = check_kwargs210        self.element_wise = element_wise211        self.error = error212        self.name = name or getattr(213            self._check_fn, "__name__", self._check_fn.__class__.__name__214        )215        self.ignore_na = ignore_na216        self.raise_warning = raise_warning217        self.n_failure_cases = n_failure_cases218        self.title = title219        self.description = description220        if groupby is None and groups is not None:221            raise ValueError(222                "`groupby` argument needs to be provided when `groups` "223                "argument is defined"224            )225        if isinstance(groupby, str):226            groupby = [groupby]227        self.groupby = groupby228        if isinstance(groups, str):229            groups = [groups]230        self.groups = groups231        self.failure_cases = None232        self._statistics = None233    @property234    def statistics(self) -> Dict[str, Any]:235        """Get check statistics."""236        return getattr(self, "_statistics")237    @statistics.setter238    def statistics(self, statistics):239        """Set check statistics."""240        self._statistics = statistics241    @staticmethod242    def _format_groupby_input(243        groupby_obj: GroupbyObject,244        groups: Optional[List[str]],245    ) -> Union[Dict[str, Union[pd.Series, pd.DataFrame]]]:246        """Format groupby object into dict of groups to Series or DataFrame.247        :param groupby_obj: a pandas groupby object.248        :param groups: only include these groups in the output.249        :returns: dictionary mapping group names to Series or DataFrame.250        """251        if groups is None:252            return dict(list(groupby_obj))253        group_keys = set(group_key for group_key, _ in groupby_obj)254        invalid_groups = [g for g in groups if g not in group_keys]255        if invalid_groups:256            raise KeyError(257                f"groups {invalid_groups} provided in `groups` argument not a valid group "258                f"key. Valid group keys: {group_keys}"259            )260        return {261            group_key: group262            for group_key, group in groupby_obj263            if group_key in groups264        }265    def _prepare_series_input(266        self,267        df_or_series: Union[pd.Series, pd.DataFrame],268        column: Optional[str] = None,269    ) -> SeriesCheckObj:270        """Prepare input for Column check.271        :param pd.Series series: one-dimensional ndarray with axis labels272            (including time series).273        :param pd.DataFrame dataframe_context: optional dataframe to supply274            when checking a Column in a DataFrameSchema.275        :returns: a Series, or a dictionary mapping groups to Series276            to be used by `_check_fn` and `_vectorized_check`277        """278        if check_utils.is_field(df_or_series):279            return df_or_series280        elif self.groupby is None:281            return df_or_series[column]282        elif isinstance(self.groupby, list):283            return self._format_groupby_input(284                df_or_series.groupby(self.groupby)[column],285                self.groups,286            )287        elif callable(self.groupby):288            return self._format_groupby_input(289                self.groupby(df_or_series)[column],290                self.groups,291            )292        raise TypeError("Type %s not recognized for `groupby` argument.")293    def _prepare_dataframe_input(294        self, dataframe: pd.DataFrame295    ) -> DataFrameCheckObj:296        """Prepare input for DataFrameSchema check.297        :param dataframe: dataframe to validate.298        :returns: a DataFrame, or a dictionary mapping groups to pd.DataFrame299            to be used by `_check_fn` and `_vectorized_check`300        """301        if self.groupby is None:302            return dataframe303        groupby_obj = dataframe.groupby(self.groupby)304        return self._format_groupby_input(groupby_obj, self.groups)305    def __call__(306        self,307        df_or_series: Union[pd.DataFrame, pd.Series],308        column: Optional[str] = None,309    ) -> CheckResult:310        # pylint: disable=too-many-branches311        """Validate pandas DataFrame or Series.312        :param df_or_series: pandas DataFrame of Series to validate.313        :param column: for dataframe checks, apply the check function to this314            column.315        :returns: CheckResult tuple containing:316            ``check_output``: boolean scalar, ``Series`` or ``DataFrame``317            indicating which elements passed the check.318            ``check_passed``: boolean scalar that indicating whether the check319            passed overall.320            ``checked_object``: the checked object itself. Depending on the321            options provided to the ``Check``, this will be a pandas Series,322            DataFrame, or if the ``groupby`` option is specified, a323            ``Dict[str, Series]`` or ``Dict[str, DataFrame]`` where the keys324            are distinct groups.325            ``failure_cases``: subset of the check_object that failed.326        """327        # prepare check object328        if check_utils.is_field(df_or_series) or (329            column is not None and check_utils.is_table(df_or_series)330        ):331            check_obj = self._prepare_series_input(df_or_series, column)332        elif check_utils.is_table(df_or_series):333            check_obj = self._prepare_dataframe_input(df_or_series)334        else:335            raise ValueError(336                f"object of type {type(df_or_series)} not supported. Must be "337                "a Series, a dictionary of Series, or DataFrame"338            )339        # apply check function to check object340        check_fn = partial(self._check_fn, **self._check_kwargs)341        if self.element_wise:342            check_output = (343                check_obj.apply(check_fn, axis=1)  # type: ignore344                if check_utils.is_table(check_obj)345                else check_obj.map(check_fn)  # type: ignore346                if check_utils.is_field(check_obj)347                else check_fn(check_obj)348            )349        else:350            # vectorized check function case351            check_output = check_fn(check_obj)352        # failure cases only apply when the check function returns a boolean353        # series that matches the shape and index of the check_obj354        if (355            isinstance(check_obj, dict)356            or isinstance(check_output, bool)357            or not check_utils.is_supported_check_obj(check_output)358            or check_obj.shape[0] != check_output.shape[0]359            or (check_obj.index != check_output.index).all()360        ):361            failure_cases = None362        elif check_utils.is_field(check_output):363            (364                check_output,365                failure_cases,366            ) = check_utils.prepare_series_check_output(367                check_obj,368                check_output,369                ignore_na=self.ignore_na,370                n_failure_cases=self.n_failure_cases,371            )372        elif check_utils.is_table(check_output):373            (374                check_output,375                failure_cases,376            ) = check_utils.prepare_dataframe_check_output(377                check_obj,378                check_output,379                df_orig=df_or_series,380                ignore_na=self.ignore_na,381                n_failure_cases=self.n_failure_cases,382            )383        else:384            raise TypeError(385                f"output type of check_fn not recognized: {type(check_output)}"386            )387        check_passed = (388            check_output.all()389            if check_utils.is_field(check_output)390            else check_output.all(axis=None)391            if check_utils.is_table(check_output)392            else check_output393        )394        return CheckResult(395            check_output, check_passed, check_obj, failure_cases396        )397    def __eq__(self, other: object) -> bool:398        if not isinstance(other, type(self)):399            return NotImplemented400        are_check_fn_objects_equal = (401            self._get_check_fn_code() == other._get_check_fn_code()402        )403        try:404            are_strategy_fn_objects_equal = all(405                getattr(self.__dict__.get("strategy"), attr)406                == getattr(other.__dict__.get("strategy"), attr)407                for attr in ["func", "args", "keywords"]408            )409        except AttributeError:410            are_strategy_fn_objects_equal = True411        are_all_other_check_attributes_equal = {412            k: v413            for k, v in self.__dict__.items()414            if k not in ["_check_fn", "strategy"]415        } == {416            k: v417            for k, v in other.__dict__.items()418            if k not in ["_check_fn", "strategy"]419        }420        return (421            are_check_fn_objects_equal422            and are_strategy_fn_objects_equal423            and are_all_other_check_attributes_equal424        )425    def _get_check_fn_code(self):426        check_fn = self.__dict__["_check_fn"]427        try:428            code = check_fn.__code__.co_code429        except AttributeError:430            # try accessing the functools.partial wrapper431            code = check_fn.func.__code__.co_code432        return code433    def __hash__(self) -> int:434        return hash(self._get_check_fn_code())435    def __repr__(self) -> str:436        return (437            f"<Check {self.name}: {self.error}>"438            if self.error is not None439            else f"<Check {self.name}>"440        )441class Check(_CheckBase):442    """Check a pandas Series or DataFrame for certain properties."""443    @classmethod444    @st.register_check_strategy(st.eq_strategy)445    @register_check_statistics(["value"])446    def equal_to(cls, value, **kwargs) -> "Check":447        """Ensure all elements of a series equal a certain value.448        *New in version 0.4.5*449        Alias: ``eq``450        :param value: All elements of a given :class:`pandas.Series` must have451            this value452        :param kwargs: key-word arguments passed into the `Check` initializer.453        :returns: :class:`Check` object454        """455        def _equal(series: pd.Series) -> pd.Series:456            """Comparison function for check"""457            return series == value458        return cls(459            _equal,460            name=cls.equal_to.__name__,461            error=f"equal_to({value})",462            **kwargs,463        )464    eq = equal_to465    @classmethod466    @st.register_check_strategy(st.ne_strategy)467    @register_check_statistics(["value"])468    def not_equal_to(cls, value, **kwargs) -> "Check":469        """Ensure no elements of a series equals a certain value.470        *New in version 0.4.5*471        Alias: ``ne``472        :param value: This value must not occur in the checked473            :class:`pandas.Series`.474        :param kwargs: key-word arguments passed into the `Check` initializer.475        :returns: :class:`Check` object476        """477        def _not_equal(series: pd.Series) -> pd.Series:478            """Comparison function for check"""479            return series != value480        return cls(481            _not_equal,482            name=cls.not_equal_to.__name__,483            error=f"not_equal_to({value})",484            **kwargs,485        )486    ne = not_equal_to487    @classmethod488    @st.register_check_strategy(st.gt_strategy)489    @register_check_statistics(["min_value"])490    def greater_than(cls, min_value, **kwargs) -> "Check":491        """Ensure values of a series are strictly greater than a minimum value.492        *New in version 0.4.5*493        Alias: ``gt``494        :param min_value: Lower bound to be exceeded. Must be a type comparable495            to the dtype of the :class:`pandas.Series` to be validated (e.g. a496            numerical type for float or int and a datetime for datetime).497        :param kwargs: key-word arguments passed into the `Check` initializer.498        :returns: :class:`Check` object499        """500        if min_value is None:501            raise ValueError("min_value must not be None")502        def _greater_than(series: pd.Series) -> pd.Series:503            """Comparison function for check"""504            return series > min_value505        return cls(506            _greater_than,507            name=cls.greater_than.__name__,508            error=f"greater_than({min_value})",509            **kwargs,510        )511    gt = greater_than512    @classmethod513    @st.register_check_strategy(st.ge_strategy)514    @register_check_statistics(["min_value"])515    def greater_than_or_equal_to(cls, min_value, **kwargs) -> "Check":516        """Ensure all values are greater or equal a certain value.517        *New in version 0.4.5*518        Alias: ``ge``519        :param min_value: Allowed minimum value for values of a series. Must be520            a type comparable to the dtype of the :class:`pandas.Series` to be521            validated.522        :param kwargs: key-word arguments passed into the `Check` initializer.523        :returns: :class:`Check` object524        """525        if min_value is None:526            raise ValueError("min_value must not be None")527        def _greater_or_equal(series: pd.Series) -> pd.Series:528            """Comparison function for check"""529            return series >= min_value530        return cls(531            _greater_or_equal,532            name=cls.greater_than_or_equal_to.__name__,533            error=f"greater_than_or_equal_to({min_value})",534            **kwargs,535        )536    ge = greater_than_or_equal_to537    @classmethod538    @st.register_check_strategy(st.lt_strategy)539    @register_check_statistics(["max_value"])540    def less_than(cls, max_value, **kwargs) -> "Check":541        """Ensure values of a series are strictly below a maximum value.542        *New in version 0.4.5*543        Alias: ``lt``544        :param max_value: All elements of a series must be strictly smaller545            than this. Must be a type comparable to the dtype of the546            :class:`pandas.Series` to be validated.547        :param kwargs: key-word arguments passed into the `Check` initializer.548        :returns: :class:`Check` object549        """550        if max_value is None:551            raise ValueError("max_value must not be None")552        def _less_than(series: pd.Series) -> pd.Series:553            """Comparison function for check"""554            return series < max_value555        return cls(556            _less_than,557            name=cls.less_than.__name__,558            error=f"less_than({max_value})",559            **kwargs,560        )561    lt = less_than562    @classmethod563    @st.register_check_strategy(st.le_strategy)564    @register_check_statistics(["max_value"])565    def less_than_or_equal_to(cls, max_value, **kwargs) -> "Check":566        """Ensure values are less than or equal to a maximum value.567        *New in version 0.4.5*568        Alias: ``le``569        :param max_value: Upper bound not to be exceeded. Must be a type570            comparable to the dtype of the :class:`pandas.Series` to be571            validated.572        :param kwargs: key-word arguments passed into the `Check` initializer.573        :returns: :class:`Check` object574        """575        if max_value is None:576            raise ValueError("max_value must not be None")577        def _less_or_equal(series: pd.Series) -> pd.Series:578            """Comparison function for check"""579            return series <= max_value580        return cls(581            _less_or_equal,582            name=cls.less_than_or_equal_to.__name__,583            error=f"less_than_or_equal_to({max_value})",584            **kwargs,585        )586    le = less_than_or_equal_to587    @classmethod588    @st.register_check_strategy(st.in_range_strategy)589    @register_check_statistics(590        ["min_value", "max_value", "include_min", "include_max"]591    )592    def in_range(593        cls, min_value, max_value, include_min=True, include_max=True, **kwargs594    ) -> "Check":595        """Ensure all values of a series are within an interval.596        :param min_value: Left / lower endpoint of the interval.597        :param max_value: Right / upper endpoint of the interval. Must not be598            smaller than min_value.599        :param include_min: Defines whether min_value is also an allowed value600            (the default) or whether all values must be strictly greater than601            min_value.602        :param include_max: Defines whether min_value is also an allowed value603            (the default) or whether all values must be strictly smaller than604            max_value.605        :param kwargs: key-word arguments passed into the `Check` initializer.606        Both endpoints must be a type comparable to the dtype of the607        :class:`pandas.Series` to be validated.608        :returns: :class:`Check` object609        """610        if min_value is None:611            raise ValueError("min_value must not be None")612        if max_value is None:613            raise ValueError("max_value must not be None")614        if max_value < min_value or (615            min_value == max_value and (not include_min or not include_max)616        ):617            raise ValueError(618                f"The combination of min_value = {min_value} and max_value = {max_value} "619                "defines an empty interval!"620            )621        # Using functions from operator module to keep conditions out of the622        # closure623        left_op = operator.le if include_min else operator.lt624        right_op = operator.ge if include_max else operator.gt625        def _in_range(series: pd.Series) -> pd.Series:626            """Comparison function for check"""627            return left_op(min_value, series) & right_op(max_value, series)628        return cls(629            _in_range,630            name=cls.in_range.__name__,631            error=f"in_range({min_value}, {max_value})",632            **kwargs,633        )634    @classmethod635    @st.register_check_strategy(st.isin_strategy)636    @register_check_statistics(["allowed_values"])637    def isin(cls, allowed_values: Iterable, **kwargs) -> "Check":638        """Ensure only allowed values occur within a series.639        :param allowed_values: The set of allowed values. May be any iterable.640        :param kwargs: key-word arguments passed into the `Check` initializer.641        :returns: :class:`Check` object642        .. note::643            It is checked whether all elements of a :class:`pandas.Series`644            are part of the set of elements of allowed values. If allowed645            values is a string, the set of elements consists of all distinct646            characters of the string. Thus only single characters which occur647            in allowed_values at least once can meet this condition. If you648            want to check for substrings use :func:`Check.str_is_substring`.649        """650        # Turn allowed_values into a set. Not only for performance but also651        # avoid issues with a mutable argument passed by reference which may be652        # changed from outside.653        try:654            allowed_values = frozenset(allowed_values)655        except TypeError as exc:656            raise ValueError(657                f"Argument allowed_values must be iterable. Got {allowed_values}"658            ) from exc659        def _isin(series: pd.Series) -> pd.Series:660            """Comparison function for check"""661            return series.isin(allowed_values)662        return cls(663            _isin,664            name=cls.isin.__name__,665            error=f"isin({set(allowed_values)})",666            **kwargs,667        )668    @classmethod669    @st.register_check_strategy(st.notin_strategy)670    @register_check_statistics(["forbidden_values"])671    def notin(cls, forbidden_values: Iterable, **kwargs) -> "Check":672        """Ensure some defined values don't occur within a series.673        :param forbidden_values: The set of values which should not occur. May674            be any iterable.675        :param raise_warning: if True, check raises UserWarning instead of676            SchemaError on validation.677        :returns: :class:`Check` object678        .. note::679            Like :func:`Check.isin` this check operates on single characters if680            it is applied on strings. A string as paraforbidden_valuesmeter681            forbidden_values is understood as set of prohibited characters. Any682            string of length > 1 can't be in it by design.683        """684        # Turn forbidden_values into a set. Not only for performance but also685        # avoid issues with a mutable argument passed by reference which may be686        # changed from outside.687        try:688            forbidden_values = frozenset(forbidden_values)689        except TypeError as exc:690            raise ValueError(691                f"Argument forbidden_values must be iterable. Got {forbidden_values}"692            ) from exc693        def _notin(series: pd.Series) -> pd.Series:694            """Comparison function for check"""695            return ~series.isin(forbidden_values)696        return cls(697            _notin,698            name=cls.notin.__name__,699            error=f"notin({set(forbidden_values)})",700            **kwargs,701        )702    @classmethod703    @st.register_check_strategy(st.str_matches_strategy)704    @register_check_statistics(["pattern"])705    def str_matches(cls, pattern: str, **kwargs) -> "Check":706        """Ensure that string values match a regular expression.707        :param pattern: Regular expression pattern to use for matching708        :param kwargs: key-word arguments passed into the `Check` initializer.709        :returns: :class:`Check` object710        The behaviour is as of :func:`pandas.Series.str.match`.711        """712        # By compiling the regex we get the benefit of an early argument check713        try:714            regex = re.compile(pattern)715        except TypeError as exc:716            raise ValueError(717                f'pattern="{pattern}" cannot be compiled as regular expression'718            ) from exc719        def _match(series: pd.Series) -> pd.Series:720            """721            Check if all strings in the series match the regular expression.722            """723            return series.str.match(regex, na=False)724        return cls(725            _match,726            name=cls.str_matches.__name__,727            error=f"str_matches({regex})",728            **kwargs,729        )730    @classmethod731    @st.register_check_strategy(st.str_contains_strategy)732    @register_check_statistics(["pattern"])733    def str_contains(cls, pattern: str, **kwargs) -> "Check":734        """Ensure that a pattern can be found within each row.735        :param pattern: Regular expression pattern to use for searching736        :param kwargs: key-word arguments passed into the `Check` initializer.737        :returns: :class:`Check` object738        The behaviour is as of :func:`pandas.Series.str.contains`.739        """740        # By compiling the regex we get the benefit of an early argument check741        try:742            regex = re.compile(pattern)743        except TypeError as exc:744            raise ValueError(745                f'pattern="{pattern}" cannot be compiled as regular expression'746            ) from exc747        def _contains(series: pd.Series) -> pd.Series:748            """Check if a regex search is successful within each value"""749            return series.str.contains(regex, na=False)750        return cls(751            _contains,752            name=cls.str_contains.__name__,753            error=f"str_contains({regex})",754            **kwargs,755        )756    @classmethod757    @st.register_check_strategy(st.str_startswith_strategy)758    @register_check_statistics(["string"])759    def str_startswith(cls, string: str, **kwargs) -> "Check":760        """Ensure that all values start with a certain string.761        :param string: String all values should start with762        :param kwargs: key-word arguments passed into the `Check` initializer.763        :returns: :class:`Check` object764        """765        def _startswith(series: pd.Series) -> pd.Series:766            """Returns true only for strings starting with string"""767            return series.str.startswith(string, na=False)768        return cls(769            _startswith,770            name=cls.str_startswith.__name__,771            error=f"str_startswith({string})",772            **kwargs,773        )774    @classmethod775    @st.register_check_strategy(st.str_endswith_strategy)776    @register_check_statistics(["string"])777    def str_endswith(cls, string: str, **kwargs) -> "Check":778        """Ensure that all values end with a certain string.779        :param string: String all values should end with780        :param kwargs: key-word arguments passed into the `Check` initializer.781        :returns: :class:`Check` object782        """783        def _endswith(series: pd.Series) -> pd.Series:784            """Returns true only for strings ending with string"""785            return series.str.endswith(string, na=False)786        return cls(787            _endswith,788            name=cls.str_endswith.__name__,789            error=f"str_endswith({string})",790            **kwargs,791        )792    @classmethod793    @st.register_check_strategy(st.str_length_strategy)794    @register_check_statistics(["min_value", "max_value"])795    def str_length(796        cls, min_value: int = None, max_value: int = None, **kwargs797    ) -> "Check":798        """Ensure that the length of strings is within a specified range.799        :param min_value: Minimum length of strings (default: no minimum)800        :param max_value: Maximum length of strings (default: no maximum)801        :param kwargs: key-word arguments passed into the `Check` initializer.802        :returns: :class:`Check` object803        """804        if min_value is None and max_value is None:805            raise ValueError(806                "At least a minimum or a maximum need to be specified. Got "807                "None."808            )809        if max_value is None:810            def _str_length(series: pd.Series) -> pd.Series:811                """Check for the minimum string length"""812                return series.str.len() >= min_value813        elif min_value is None:814            def _str_length(series: pd.Series) -> pd.Series:815                """Check for the maximum string length"""816                return series.str.len() <= max_value817        else:818            def _str_length(series: pd.Series) -> pd.Series:819                """Check for both, minimum and maximum string length"""820                return (series.str.len() <= max_value) & (821                    series.str.len() >= min_value822                )823        return cls(824            _str_length,825            name=cls.str_length.__name__,826            error=f"str_length({min_value}, {max_value})",827            **kwargs,...test_strategies.py
Source:test_strategies.py  
1# pylint: disable=undefined-variable,redefined-outer-name,invalid-name,undefined-loop-variable,too-many-lines  # noqa2"""Unit tests for pandera data generating strategies."""3import datetime4import operator5import re6from typing import Any, Callable, Optional, Set7from unittest.mock import MagicMock8import numpy as np9import pandas as pd10import pytest11import pandera as pa12from pandera import strategies13from pandera.checks import _CheckBase, register_check_statistics14from pandera.dtypes import is_category, is_complex, is_float15from pandera.engines import pandas_engine16try:17    import hypothesis18    import hypothesis.extra.numpy as npst19    import hypothesis.strategies as st20except ImportError:21    HAS_HYPOTHESIS = False22    hypothesis = MagicMock()23    st = MagicMock()24else:25    HAS_HYPOTHESIS = True26UNSUPPORTED_DTYPE_CLS: Set[Any] = set(27    [28        pandas_engine.Interval,29        pandas_engine.Period,30        pandas_engine.Sparse,31    ]32)33SUPPORTED_DTYPES = set()34for data_type in pandas_engine.Engine.get_registered_dtypes():35    if (36        # valid hypothesis.strategies.floats <=6437        getattr(data_type, "bit_width", -1) > 6438        or is_category(data_type)39        or data_type in UNSUPPORTED_DTYPE_CLS40        or (41            pandas_engine.GEOPANDAS_INSTALLED42            and data_type == pandas_engine.Geometry43        )44    ):45        continue46    SUPPORTED_DTYPES.add(pandas_engine.Engine.dtype(data_type))47SUPPORTED_DTYPES.add(pandas_engine.Engine.dtype("datetime64[ns, UTC]"))48NUMERIC_DTYPES = [49    data_type for data_type in SUPPORTED_DTYPES if data_type.continuous50]51NULLABLE_DTYPES = [52    data_type53    for data_type in SUPPORTED_DTYPES54    if not is_complex(data_type)55    and not is_category(data_type)56    and not data_type == pandas_engine.Engine.dtype("object")57]58@pytest.mark.parametrize(59    "data_type",60    [61        pa.Category,62        pandas_engine.Interval(  # type: ignore # pylint:disable=unexpected-keyword-arg,no-value-for-parameter63            subtype=np.int6464        ),65    ],66)67def test_unsupported_pandas_dtype_strategy(data_type):68    """Test unsupported pandas dtype strategy raises error."""69    with pytest.raises(TypeError, match=r"is currently unsupported"):70        strategies.pandas_dtype_strategy(data_type)71@pytest.mark.parametrize("data_type", SUPPORTED_DTYPES)72@hypothesis.given(st.data())73@hypothesis.settings(74    suppress_health_check=[75        hypothesis.HealthCheck.too_slow,76        hypothesis.HealthCheck.data_too_large,77    ],78    max_examples=20,79)80def test_pandas_dtype_strategy(data_type, data):81    """Test that series can be constructed from pandas dtype."""82    strategy = strategies.pandas_dtype_strategy(data_type)83    example = data.draw(strategy)84    expected_type = strategies.to_numpy_dtype(data_type).type85    if isinstance(example, pd.Timestamp):86        example = example.to_numpy()87    assert example.dtype.type == expected_type88    chained_strategy = strategies.pandas_dtype_strategy(data_type, strategy)89    chained_example = data.draw(chained_strategy)90    if isinstance(chained_example, pd.Timestamp):91        chained_example = chained_example.to_numpy()92    assert chained_example.dtype.type == expected_type93@pytest.mark.parametrize("data_type", NUMERIC_DTYPES)94@hypothesis.given(st.data())95@hypothesis.settings(96    suppress_health_check=[hypothesis.HealthCheck.too_slow],97)98def test_check_strategy_continuous(data_type, data):99    """Test built-in check strategies can generate continuous data."""100    np_dtype = strategies.to_numpy_dtype(data_type)101    value = data.draw(102        npst.from_dtype(103            strategies.to_numpy_dtype(data_type),104            allow_nan=False,105            allow_infinity=False,106        )107    )108    # don't overstep bounds of representation109    hypothesis.assume(np.finfo(np_dtype).min < value < np.finfo(np_dtype).max)110    assert data.draw(strategies.ne_strategy(data_type, value=value)) != value111    assert data.draw(strategies.eq_strategy(data_type, value=value)) == value112    assert (113        data.draw(strategies.gt_strategy(data_type, min_value=value)) > value114    )115    assert (116        data.draw(strategies.ge_strategy(data_type, min_value=value)) >= value117    )118    assert (119        data.draw(strategies.lt_strategy(data_type, max_value=value)) < value120    )121    assert (122        data.draw(strategies.le_strategy(data_type, max_value=value)) <= value123    )124def value_ranges(data_type: pa.DataType):125    """Strategy to generate value range based on PandasDtype"""126    kwargs = dict(127        allow_nan=False,128        allow_infinity=False,129        exclude_min=False,130        exclude_max=False,131    )132    return (133        st.tuples(134            strategies.pandas_dtype_strategy(135                data_type, strategy=None, **kwargs136            ),137            strategies.pandas_dtype_strategy(138                data_type, strategy=None, **kwargs139            ),140        )141        .map(sorted)142        .filter(lambda x: x[0] < x[1])143    )144@pytest.mark.parametrize("data_type", NUMERIC_DTYPES)145@pytest.mark.parametrize(146    "strat_fn, arg_name, base_st_type, compare_op",147    [148        [strategies.ne_strategy, "value", "type", operator.ne],149        [strategies.eq_strategy, "value", "just", operator.eq],150        [strategies.gt_strategy, "min_value", "limit", operator.gt],151        [strategies.ge_strategy, "min_value", "limit", operator.ge],152        [strategies.lt_strategy, "max_value", "limit", operator.lt],153        [strategies.le_strategy, "max_value", "limit", operator.le],154    ],155)156@hypothesis.given(st.data())157@hypothesis.settings(158    suppress_health_check=[hypothesis.HealthCheck.too_slow],159)160def test_check_strategy_chained_continuous(161    data_type, strat_fn, arg_name, base_st_type, compare_op, data162):163    """164    Test built-in check strategies can generate continuous data building off165    of a parent strategy.166    """167    min_value, max_value = data.draw(value_ranges(data_type))168    hypothesis.assume(min_value < max_value)169    value = min_value170    base_st = strategies.pandas_dtype_strategy(171        data_type,172        min_value=min_value,173        max_value=max_value,174        allow_nan=False,175        allow_infinity=False,176    )177    if base_st_type == "type":178        assert_base_st = base_st179    elif base_st_type == "just":180        assert_base_st = st.just(value)181    elif base_st_type == "limit":182        assert_base_st = strategies.pandas_dtype_strategy(183            data_type,184            min_value=min_value,185            max_value=max_value,186            allow_nan=False,187            allow_infinity=False,188        )189    else:190        raise RuntimeError(f"base_st_type {base_st_type} not recognized")191    local_vars = locals()192    assert_value = local_vars[arg_name]193    example = data.draw(194        strat_fn(data_type, assert_base_st, **{arg_name: assert_value})195    )196    assert compare_op(example, assert_value)197@pytest.mark.parametrize("data_type", NUMERIC_DTYPES)198@pytest.mark.parametrize("chained", [True, False])199@hypothesis.given(st.data())200@hypothesis.settings(201    suppress_health_check=[hypothesis.HealthCheck.too_slow],202)203def test_in_range_strategy(data_type, chained, data):204    """Test the built-in in-range strategy can correctly generate data."""205    min_value, max_value = data.draw(value_ranges(data_type))206    hypothesis.assume(min_value < max_value)207    base_st_in_range = None208    if chained:209        if is_float(data_type):210            base_st_kwargs = {211                "exclude_min": False,212                "exclude_max": False,213            }214        else:215            base_st_kwargs = {}216        # constraining the strategy this way makes testing more efficient217        base_st_in_range = strategies.pandas_dtype_strategy(218            data_type,219            min_value=min_value,220            max_value=max_value,221            **base_st_kwargs,  # type: ignore[arg-type]222        )223    strat = strategies.in_range_strategy(224        data_type,225        base_st_in_range,226        min_value=min_value,227        max_value=max_value,228    )229    assert min_value <= data.draw(strat) <= max_value230@pytest.mark.parametrize(231    "data_type",232    [data_type for data_type in SUPPORTED_DTYPES if data_type.continuous],233)234@pytest.mark.parametrize("chained", [True, False])235@hypothesis.given(st.data())236@hypothesis.settings(237    suppress_health_check=[hypothesis.HealthCheck.too_slow],238)239def test_isin_notin_strategies(data_type, chained, data):240    """Test built-in check strategies that rely on discrete values."""241    value_st = strategies.pandas_dtype_strategy(242        data_type,243        allow_nan=False,244        allow_infinity=False,245        exclude_min=False,246        exclude_max=False,247    )248    values = [data.draw(value_st) for _ in range(10)]249    isin_base_st = None250    notin_base_st = None251    if chained:252        base_values = values + [data.draw(value_st) for _ in range(10)]253        isin_base_st = strategies.isin_strategy(254            data_type, allowed_values=base_values255        )256        notin_base_st = strategies.notin_strategy(257            data_type, forbidden_values=base_values258        )259    isin_st = strategies.isin_strategy(260        data_type, isin_base_st, allowed_values=values261    )262    notin_st = strategies.notin_strategy(263        data_type, notin_base_st, forbidden_values=values264    )265    assert data.draw(isin_st) in values266    assert data.draw(notin_st) not in values267@pytest.mark.parametrize(268    "str_strat, pattern_fn",269    [270        [271            strategies.str_matches_strategy,272            lambda patt: f"^{patt}$",273        ],274        [strategies.str_contains_strategy, None],275        [strategies.str_startswith_strategy, None],276        [strategies.str_endswith_strategy, None],277    ],278)279@pytest.mark.parametrize("chained", [True, False])280@hypothesis.given(st.data(), st.text())281def test_str_pattern_checks(282    str_strat: Callable,283    pattern_fn: Optional[Callable[..., str]],284    chained: bool,285    data,286    pattern,287) -> None:288    """Test built-in check strategies for string pattern checks."""289    try:290        re.compile(pattern)291        re_compiles = True292    except re.error:293        re_compiles = False294    hypothesis.assume(re_compiles)295    pattern = pattern if pattern_fn is None else pattern_fn(pattern)296    base_st = None297    if chained:298        try:299            base_st = str_strat(pa.String, pattern=pattern)300        except TypeError:301            base_st = str_strat(pa.String, string=pattern)302    try:303        st = str_strat(pa.String, base_st, pattern=pattern)304    except TypeError:305        st = str_strat(pa.String, base_st, string=pattern)306    example = data.draw(st)307    assert re.search(pattern, example)308@pytest.mark.parametrize("chained", [True, False])309@hypothesis.given(310    st.data(),311    (312        st.tuples(313            st.integers(min_value=0, max_value=100),314            st.integers(min_value=0, max_value=100),315        )316        .map(sorted)  # type: ignore[arg-type]317        .filter(lambda x: x[0] < x[1])  # type: ignore318    ),319)320@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow])321def test_str_length_checks(chained, data, value_range):322    """Test built-in check strategies for string length."""323    min_value, max_value = value_range324    base_st = None325    if chained:326        base_st = strategies.str_length_strategy(327            pa.String,328            min_value=max(0, min_value - 5),329            max_value=max_value + 5,330        )331    str_length_st = strategies.str_length_strategy(332        pa.String, base_st, min_value=min_value, max_value=max_value333    )334    example = data.draw(str_length_st)335    assert min_value <= len(example) <= max_value336@hypothesis.given(st.data())337def test_register_check_strategy(data) -> None:338    """Test registering check strategy on a custom check."""339    # pylint: disable=unused-argument340    def custom_eq_strategy(341        pandas_dtype: pa.DataType,342        strategy: st.SearchStrategy = None,343        *,344        value: Any,345    ):346        return st.just(value).map(strategies.to_numpy_dtype(pandas_dtype).type)347    # pylint: disable=no-member348    class CustomCheck(_CheckBase):349        """Custom check class."""350        @classmethod351        @strategies.register_check_strategy(custom_eq_strategy)352        @register_check_statistics(["value"])353        def custom_equals(cls, value, **kwargs) -> "CustomCheck":354            """Define a built-in check."""355            def _custom_equals(series: pd.Series) -> pd.Series:356                """Comparison function for check"""357                return series == value358            return cls(359                _custom_equals,360                name=cls.custom_equals.__name__,361                error=f"equal_to({value})",362                **kwargs,363            )364    check = CustomCheck.custom_equals(100)365    result = data.draw(check.strategy(pa.Int()))366    assert result == 100367def test_register_check_strategy_exception() -> None:368    """Check method needs statistics attr to register a strategy."""369    def custom_strat() -> None:370        pass371    class CustomCheck(_CheckBase):372        """Custom check class."""373        @classmethod374        @strategies.register_check_strategy(custom_strat)  # type: ignore[arg-type]375        # mypy correctly identifies the error376        def custom_check(cls, **kwargs) -> "CustomCheck":377            """Built-in check with no statistics."""378            def _custom_check(series: pd.Series) -> pd.Series:379                """Some check function."""380                return series381            return cls(382                _custom_check,383                name=cls.custom_check.__name__,384                **kwargs,385            )386    with pytest.raises(387        AttributeError,388        match="check object doesn't have a defined statistics property",389    ):390        CustomCheck.custom_check()391@hypothesis.given(st.data())392@hypothesis.settings(393    suppress_health_check=[hypothesis.HealthCheck.too_slow],394)395def test_series_strategy(data) -> None:396    """Test SeriesSchema strategy."""397    series_schema = pa.SeriesSchema(pa.Int(), pa.Check.gt(0))398    series_schema(data.draw(series_schema.strategy()))399def test_series_example() -> None:400    """Test SeriesSchema example method generate examples that pass."""401    series_schema = pa.SeriesSchema(pa.Int(), pa.Check.gt(0))402    for _ in range(10):403        series_schema(series_schema.example())404@hypothesis.given(st.data())405@hypothesis.settings(406    suppress_health_check=[hypothesis.HealthCheck.too_slow],407)408def test_column_strategy(data) -> None:409    """Test Column schema strategy."""410    column_schema = pa.Column(pa.Int(), pa.Check.gt(0), name="column")411    column_schema(data.draw(column_schema.strategy()))412def test_column_example():413    """Test Column schema example method generate examples that pass."""414    column_schema = pa.Column(pa.Int(), pa.Check.gt(0), name="column")415    for _ in range(10):416        column_schema(column_schema.example())417@pytest.mark.parametrize("data_type", SUPPORTED_DTYPES)418@pytest.mark.parametrize("size", [None, 0, 1, 3, 5])419@hypothesis.given(st.data())420@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow])421def test_dataframe_strategy(data_type, size, data):422    """Test DataFrameSchema strategy."""423    dataframe_schema = pa.DataFrameSchema(424        {f"{data_type}_col": pa.Column(data_type)}425    )426    df_sample = data.draw(dataframe_schema.strategy(size=size))427    if size == 0:428        assert df_sample.empty429    elif size is None:430        assert df_sample.empty or isinstance(431            dataframe_schema(df_sample), pd.DataFrame432        )433    else:434        assert isinstance(dataframe_schema(df_sample), pd.DataFrame)435    with pytest.raises(pa.errors.BaseStrategyOnlyError):436        strategies.dataframe_strategy(437            data_type, strategies.pandas_dtype_strategy(data_type)438        )439@hypothesis.given(st.data())440@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow])441def test_dataframe_example(data) -> None:442    """Test DataFrameSchema example method generate examples that pass."""443    schema = pa.DataFrameSchema({"column": pa.Column(int, pa.Check.gt(0))})444    df_sample = data.draw(schema.strategy(size=10))445    schema(df_sample)446@pytest.mark.parametrize("size", [3, 5, 10])447@hypothesis.given(st.data())448@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow])449def test_dataframe_unique(size, data) -> None:450    """Test that DataFrameSchemas with unique columns are actually unique."""451    schema = pa.DataFrameSchema(452        {453            "col1": pa.Column(int),454            "col2": pa.Column(float),455            "col3": pa.Column(str),456            "col4": pa.Column(int),457        },458        unique=["col1", "col2", "col3"],459    )460    df_sample = data.draw(schema.strategy(size=size))461    schema(df_sample)462@pytest.mark.parametrize(463    "regex",464    [465        "col_[0-9]{1,4}",466        "[a-zA-Z]+_foobar",467        "[a-z]+_[0-9]+_[a-z]+",468    ],469)470@hypothesis.given(st.data(), st.integers(min_value=-5, max_value=5))471@hypothesis.settings(472    suppress_health_check=[hypothesis.HealthCheck.too_slow],473)474def test_dataframe_with_regex(regex: str, data, n_regex_columns: int) -> None:475    """Test DataFrameSchema strategy with regex columns"""476    dataframe_schema = pa.DataFrameSchema({regex: pa.Column(int, regex=True)})477    if n_regex_columns < 1:478        with pytest.raises(ValueError):479            dataframe_schema.strategy(size=5, n_regex_columns=n_regex_columns)480    else:481        df = dataframe_schema(482            data.draw(483                dataframe_schema.strategy(484                    size=5, n_regex_columns=n_regex_columns485                )486            )487        )488        assert df.shape[1] == n_regex_columns489@pytest.mark.parametrize("data_type", NUMERIC_DTYPES)490@hypothesis.settings(491    suppress_health_check=[hypothesis.HealthCheck.too_slow],492)493@hypothesis.given(st.data())494def test_dataframe_checks(data_type, data):495    """Test dataframe strategy with checks defined at the dataframe level."""496    min_value, max_value = data.draw(value_ranges(data_type))497    dataframe_schema = pa.DataFrameSchema(498        {f"{data_type}_col": pa.Column(data_type) for _ in range(5)},499        checks=pa.Check.in_range(min_value, max_value),500    )501    strat = dataframe_schema.strategy(size=5)502    example = data.draw(strat)503    dataframe_schema(example)504@pytest.mark.parametrize(505    "data_type", [pa.Int(), pa.Float, pa.String, pa.DateTime]506)507@hypothesis.given(st.data())508@hypothesis.settings(509    suppress_health_check=[hypothesis.HealthCheck.too_slow],510)511def test_dataframe_strategy_with_indexes(data_type, data):512    """Test dataframe strategy with index and multiindex components."""513    dataframe_schema_index = pa.DataFrameSchema(index=pa.Index(data_type))514    dataframe_schema_multiindex = pa.DataFrameSchema(515        index=pa.MultiIndex(516            [pa.Index(data_type, name=f"index{i}") for i in range(3)]517        )518    )519    dataframe_schema_index(data.draw(dataframe_schema_index.strategy(size=10)))520    dataframe_schema_multiindex(521        data.draw(dataframe_schema_multiindex.strategy(size=10))522    )523@hypothesis.given(st.data())524@hypothesis.settings(525    suppress_health_check=[hypothesis.HealthCheck.too_slow],526)527def test_index_strategy(data) -> None:528    """Test Index schema component strategy."""529    data_type = pa.Int()530    index_schema = pa.Index(data_type, unique=True, name="index")531    strat = index_schema.strategy(size=10)532    example = data.draw(strat)533    assert (~example.duplicated()).all()534    actual_data_type = pandas_engine.Engine.dtype(example.dtype)535    assert data_type.check(actual_data_type)536    index_schema(pd.DataFrame(index=example))537def test_index_example() -> None:538    """539    Test Index schema component example method generates examples that pass.540    """541    data_type = pa.Int()542    index_schema = pa.Index(data_type, unique=True)543    for _ in range(10):544        index_schema(pd.DataFrame(index=index_schema.example()))545@hypothesis.given(st.data())546@hypothesis.settings(547    suppress_health_check=[hypothesis.HealthCheck.too_slow],548)549def test_multiindex_strategy(data) -> None:550    """Test MultiIndex schema component strategy."""551    data_type = pa.Float()552    multiindex = pa.MultiIndex(553        indexes=[554            pa.Index(data_type, unique=True, name="level_0"),555            pa.Index(data_type, nullable=True),556            pa.Index(data_type),557        ]558    )559    strat = multiindex.strategy(size=10)560    example = data.draw(strat)561    for i in range(example.nlevels):562        actual_data_type = pandas_engine.Engine.dtype(563            example.get_level_values(i).dtype564        )565        assert data_type.check(actual_data_type)566    with pytest.raises(pa.errors.BaseStrategyOnlyError):567        strategies.multiindex_strategy(568            data_type, strategies.pandas_dtype_strategy(data_type)569        )570def test_multiindex_example() -> None:571    """572    Test MultiIndex schema component example method generates examples that573    pass.574    """575    data_type = pa.Float()576    multiindex = pa.MultiIndex(577        indexes=[578            pa.Index(data_type, unique=True, name="level_0"),579            pa.Index(data_type, nullable=True),580            pa.Index(data_type),581        ]582    )583    for _ in range(10):584        example = multiindex.example()585        multiindex(pd.DataFrame(index=example))586@pytest.mark.parametrize("data_type", NULLABLE_DTYPES)587@hypothesis.given(st.data())588def test_field_element_strategy(data_type, data):589    """Test strategy for generating elements in columns/indexes."""590    strategy = strategies.field_element_strategy(data_type)591    element = data.draw(strategy)592    expected_type = strategies.to_numpy_dtype(data_type).type593    assert element.dtype.type == expected_type594    with pytest.raises(pa.errors.BaseStrategyOnlyError):595        strategies.field_element_strategy(596            data_type, strategies.pandas_dtype_strategy(data_type)597        )598@pytest.mark.parametrize("data_type", NULLABLE_DTYPES)599@pytest.mark.parametrize(600    "field_strategy",601    [strategies.index_strategy, strategies.series_strategy],602)603@pytest.mark.parametrize("nullable", [True, False])604@hypothesis.given(st.data())605@hypothesis.settings(606    suppress_health_check=[hypothesis.HealthCheck.too_slow],607)608def test_check_nullable_field_strategy(609    data_type, field_strategy, nullable, data610):611    """Test strategies for generating nullable column/index data."""612    size = 5613    strat = field_strategy(data_type, nullable=nullable, size=size)614    example = data.draw(strat)615    if nullable:616        assert example.isna().sum() >= 0617    else:618        assert example.notna().all()619@pytest.mark.parametrize("data_type", NULLABLE_DTYPES)620@pytest.mark.parametrize("nullable", [True, False])621@hypothesis.given(st.data())622@hypothesis.settings(623    suppress_health_check=[hypothesis.HealthCheck.too_slow],624)625def test_check_nullable_dataframe_strategy(data_type, nullable, data):626    """Test strategies for generating nullable DataFrame data."""627    size = 5628    # pylint: disable=no-value-for-parameter629    strat = strategies.dataframe_strategy(630        columns={"col": pa.Column(data_type, nullable=nullable, name="col")},631        size=size,632    )633    example = data.draw(strat)634    if nullable:635        assert example.isna().sum(axis=None).item() >= 0636    else:637        assert example.notna().all(axis=None)638@pytest.mark.parametrize(639    "schema, warning",640    [641        [642            pa.SeriesSchema(643                pa.Int(),644                checks=[645                    pa.Check(lambda x: x > 0, element_wise=True),646                    pa.Check(lambda x: x > -10, element_wise=True),647                ],648            ),649            "Element-wise",650        ],651        [652            pa.SeriesSchema(653                pa.Int(),654                checks=[655                    pa.Check(lambda s: s > -10000),656                    pa.Check(lambda s: s > -9999),657                ],658            ),659            "Vectorized",660        ],661    ],662)663@hypothesis.settings(664    suppress_health_check=[665        hypothesis.HealthCheck.filter_too_much,666        hypothesis.HealthCheck.too_slow,667    ],668)669@hypothesis.given(st.data())670def test_series_strategy_undefined_check_strategy(671    schema: pa.SeriesSchema, warning: str, data672) -> None:673    """Test case where series check strategy is undefined."""674    with pytest.warns(675        UserWarning, match=f"{warning} check doesn't have a defined strategy"676    ):677        strat = schema.strategy(size=5)678    example = data.draw(strat)679    schema(example)680@pytest.mark.parametrize(681    "schema, warning",682    [683        [684            pa.DataFrameSchema(685                columns={"column": pa.Column(int)},686                checks=[687                    pa.Check(lambda x: x > 0, element_wise=True),688                    pa.Check(lambda x: x > -10, element_wise=True),689                ],690            ),691            "Element-wise",692        ],693        [694            pa.DataFrameSchema(695                columns={696                    "column": pa.Column(697                        int,698                        checks=[699                            pa.Check(lambda s: s > -10000),700                            pa.Check(lambda s: s > -9999),701                        ],702                    )703                },704            ),705            "Column",706        ],707        # schema with regex column and custom undefined strategy708        [709            pa.DataFrameSchema(710                columns={711                    "[0-9]+": pa.Column(712                        int,713                        checks=[pa.Check(lambda s: True)],714                        regex=True,715                    )716                },717            ),718            "Column",719        ],720        [721            pa.DataFrameSchema(722                columns={"column": pa.Column(int)},723                checks=[724                    pa.Check(lambda s: s > -10000),725                    pa.Check(lambda s: s > -9999),726                ],727            ),728            "Dataframe",729        ],730    ],731)732@hypothesis.settings(733    suppress_health_check=[734        hypothesis.HealthCheck.filter_too_much,735        hypothesis.HealthCheck.too_slow,736    ],737)738@hypothesis.given(st.data())739def test_dataframe_strategy_undefined_check_strategy(740    schema: pa.DataFrameSchema, warning: str, data741) -> None:742    """Test case where dataframe check strategy is undefined."""743    strat = schema.strategy(size=5)744    with pytest.warns(745        UserWarning, match=f"{warning} check doesn't have a defined strategy"746    ):747        example = data.draw(strat)748    schema(example)749def test_unsatisfiable_checks():750    """Test that unsatisfiable checks raise an exception."""751    schema = pa.DataFrameSchema(752        columns={753            "col1": pa.Column(int, checks=[pa.Check.gt(0), pa.Check.lt(0)])754        }755    )756    for _ in range(5):757        with pytest.raises(hypothesis.errors.Unsatisfiable):758            schema.example(size=10)759class Schema(pa.SchemaModel):760    """Schema model for strategy testing."""761    col1: pa.typing.Series[int]762    col2: pa.typing.Series[float]763    col3: pa.typing.Series[str]764@hypothesis.given(st.data())765@hypothesis.settings(766    suppress_health_check=[hypothesis.HealthCheck.too_slow],767)768def test_schema_model_strategy(data) -> None:769    """Test that strategy can be created from a SchemaModel."""770    strat = Schema.strategy(size=10)771    sample_data = data.draw(strat)772    Schema.validate(sample_data)773@hypothesis.given(st.data())774@hypothesis.settings(775    suppress_health_check=[hypothesis.HealthCheck.too_slow],776)777def test_schema_model_strategy_df_check(data) -> None:778    """Test that schema with custom checks produce valid data."""779    class SchemaWithDFCheck(Schema):780        """Schema with a custom dataframe-level check with no strategy."""781        # pylint:disable=no-self-use782        @pa.dataframe_check783        @classmethod784        def non_empty(cls, df: pd.DataFrame) -> bool:785            """Checks that dataframe is not empty."""786            return not df.empty787    strat = SchemaWithDFCheck.strategy(size=10)788    sample_data = data.draw(strat)789    Schema.validate(sample_data)790def test_schema_model_example() -> None:791    """Test that examples can be drawn from a SchemaModel."""792    sample_data = Schema.example(size=10)793    Schema.validate(sample_data)794def test_schema_component_with_no_pdtype() -> None:795    """796    Test that SchemaDefinitionError is raised if trying to create a strategy797    where pandas_dtype property is not specified.798    """799    for schema_component_strategy in [800        strategies.column_strategy,801        strategies.index_strategy,802    ]:803        with pytest.raises(pa.errors.SchemaDefinitionError):804            schema_component_strategy(pandera_dtype=None)  # type: ignore805@pytest.mark.parametrize(806    "check_arg", [pd.Timestamp("2006-01-01"), np.datetime64("2006-01-01")]807)808@hypothesis.given(st.data())809@hypothesis.settings(810    suppress_health_check=[hypothesis.HealthCheck.too_slow],811)812def test_datetime_example(check_arg, data) -> None:813    """Test Column schema example method generate examples of814    timezone-naive datetimes that pass."""815    for checks in [816        pa.Check.le(check_arg),817        pa.Check.ge(check_arg),818        pa.Check.eq(check_arg),819        pa.Check.isin([check_arg]),820    ]:821        column_schema = pa.Column(822            "datetime", checks=checks, name="test_datetime"823        )824        column_schema(data.draw(column_schema.strategy()))825@pytest.mark.parametrize(826    "dtype",827    (828        pd.DatetimeTZDtype(tz="UTC"),829        pd.DatetimeTZDtype(tz="dateutil/US/Central"),830    ),831)832@pytest.mark.parametrize(833    "check_arg",834    [835        pd.Timestamp("2006-01-01", tz="CET"),836        pd.Timestamp("2006-01-01", tz="UTC"),837    ],838)839@hypothesis.given(st.data())840@hypothesis.settings(841    suppress_health_check=[hypothesis.HealthCheck.too_slow],842)843def test_datetime_tz_example(dtype, check_arg, data) -> None:844    """Test Column schema example method generate examples of845    timezone-aware datetimes that pass."""846    for checks in [847        pa.Check.le(check_arg),848        pa.Check.ge(check_arg),849        pa.Check.eq(check_arg),850        pa.Check.isin([check_arg]),851    ]:852        column_schema = pa.Column(853            dtype,854            checks=checks,855            name="test_datetime_tz",856        )857        column_schema(data.draw(column_schema.strategy()))858@pytest.mark.parametrize(859    "dtype",860    (pd.Timedelta,),861)862@pytest.mark.parametrize(863    "check_arg",864    [865        # nanoseconds866        pd.Timedelta(int(1e9), unit="nanoseconds"),867        np.timedelta64(int(1e9), "ns"),868        # microseconds869        pd.Timedelta(int(1e6), unit="microseconds"),870        datetime.timedelta(microseconds=int(1e6)),871        # milliseconds872        pd.Timedelta(int(1e3), unit="milliseconds"),873        np.timedelta64(int(1e3), "ms"),874        datetime.timedelta(milliseconds=int(1e3)),875        # seconds876        pd.Timedelta(1, unit="s"),877        np.timedelta64(1, "s"),878        datetime.timedelta(seconds=1),879        # minutes880        pd.Timedelta(1, unit="m"),881        np.timedelta64(1, "m"),882        datetime.timedelta(minutes=1),883        # hours884        pd.Timedelta(1, unit="h"),885        np.timedelta64(1, "h"),886        datetime.timedelta(hours=1),887        # days888        pd.Timedelta(1, unit="day"),889        np.timedelta64(1, "D"),890        datetime.timedelta(days=1),891        # weeks892        pd.Timedelta(1, unit="W"),893        np.timedelta64(1, "W"),894        datetime.timedelta(weeks=1),895    ],896)897@hypothesis.given(st.data())898@hypothesis.settings(899    suppress_health_check=[hypothesis.HealthCheck.too_slow],900)901def test_timedelta(dtype, check_arg, data):902    """903    Test Column schema example method generate examples of timedeltas904    that pass tests.905    """906    for checks in [907        pa.Check.le(check_arg),908        pa.Check.ge(check_arg),909        pa.Check.eq(check_arg),910        pa.Check.isin([check_arg]),911        pa.Check.in_range(check_arg, check_arg + check_arg),912    ]:913        column_schema = pa.Column(914            dtype,915            checks=checks,916            name="test_datetime_tz",917        )918        column_schema(data.draw(column_schema.strategy()))919@pytest.mark.parametrize("dtype", [int, float, str])920@hypothesis.given(st.data())921@hypothesis.settings(922    suppress_health_check=[hypothesis.HealthCheck.too_slow],923)924def test_empty_nullable_schema(dtype, data):925    """Test that empty nullable schema strategy draws empty examples."""926    schema = pa.DataFrameSchema({"myval": pa.Column(dtype, nullable=True)})...Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!
