Source code for lale.type_checking

# Copyright 2019 IBM Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Lale uses `JSON Schema`_ to check machine-learning pipelines for correct types.

In general, there are two kinds of checks. The first is an instance
check (`v: s`), which checks whether a JSON value v is valid for a
schema s. The second is a subschema_ check (`s <: t`), which checks
whether one schema s is a subchema of another schema t.

Besides regular JSON values, Lale also supports certain JSON-like
values. For example, a ``np.ndarray`` of numbers is treated like a
JSON array of arrays of numbers. Furthermore, Lale supports an 'Any'
type for which all instance and subschema checks on the left as well
as the right side succeed. This is specified using ``{'laleType': 'Any'}``.

.. _`JSON Schema`:

.. _subschema:

import functools
import inspect
from import Iterable
from typing import Any, Dict, List, Optional, Tuple, overload

import jsonschema
import jsonschema.exceptions
import jsonschema.validators
import jsonsubschema
import numpy as np
import numpy.random
import sklearn.base

import lale.datasets.data_schemas
import lale.expressions
import lale.helpers
import lale.operators

JSON_TYPE = Dict[str, Any]

def _validate_lale_type(
    validator, laleType, instance, schema
):  # pylint:disable=unused-argument
    if laleType == "Any":
    elif laleType == "callable":
        if not callable(instance):
            yield jsonschema.exceptions.ValidationError(
                f"expected {laleType}, got {type(instance)}"
    elif laleType == "operator":
        if not (
            isinstance(instance, (lale.operators.Operator, sklearn.base.BaseEstimator))
            or (
                and issubclass(instance, sklearn.base.BaseEstimator)
            yield jsonschema.exceptions.ValidationError(
                f"expected {laleType}, got {type(instance)}"
    elif laleType == "expression":
        if not isinstance(instance, lale.expressions.Expr):
            yield jsonschema.exceptions.ValidationError(
                f"expected {laleType}, got {type(instance)}"
    elif laleType == "numpy.random.RandomState":
        if not isinstance(instance, numpy.random.RandomState):
            yield jsonschema.exceptions.ValidationError(
                f"expected {laleType}, got {type(instance)}"
    elif laleType == "CrossvalGenerator":
        if not (hasattr(instance, "split") or isinstance(instance, Iterable)):
            yield jsonschema.exceptions.ValidationError(
                f"expected {laleType}, got {type(instance)}"

def _is_extended_boolean(checker, instance):
    return isinstance(instance, (bool, np.bool_))

_lale_validator = jsonschema.validators.extend(
    validators={"laleType": _validate_lale_type},
        "boolean", _is_extended_boolean

[docs]def always_validate_schema(value: Any, schema: JSON_TYPE, subsample_array: bool = True): """Validate that the value is an instance of the schema. Parameters ---------- value: JSON (int, float, str, list, dict) or JSON-like (tuple, np.ndarray, pd.DataFrame ...). Left-hand side of instance check. schema: JSON schema Right-hand side of instance check. subsample_array: bool Speed up checking by doing only partial conversion to JSON. Raises ------ jsonschema.ValidationError The value was invalid for the schema. """ json_value = lale.helpers.data_to_json(value, subsample_array) sch: Any = lale.helpers.data_to_json(schema, False) try: validator = _lale_validator(sch) validator.validate(json_value) except Exception: jsonschema.validate(json_value, sch, _lale_validator)
[docs]def validate_schema_directly( value: Any, schema: JSON_TYPE, subsample_array: bool = True ): """Validate that the value is an instance of the schema. Parameters ---------- value: JSON (int, float, str, list, dict) or JSON-like (tuple, np.ndarray, pd.DataFrame ...). Left-hand side of instance check. schema: JSON schema Right-hand side of instance check. subsample_array: bool Speed up checking by doing only partial conversion to JSON. Raises ------ jsonschema.ValidationError The value was invalid for the schema. """ from lale.settings import disable_hyperparams_schema_validation if disable_hyperparams_schema_validation: return True # if schema validation is disabled, always return as valid return always_validate_schema(value, schema, subsample_array=subsample_array)
_JSON_META_SCHEMA_URL = "" def _json_meta_schema() -> Dict[str, Any]: return jsonschema.Draft4Validator.META_SCHEMA _validator = jsonschema.Draft4Validator(_json_meta_schema())
[docs]def validate_is_schema(value: Dict[str, Any]): # only checking hyperparams schema validation flag because it is likely to be true and this call is cheap. from lale.settings import disable_hyperparams_schema_validation if disable_hyperparams_schema_validation: return if "$schema" in value: assert value["$schema"] == _JSON_META_SCHEMA_URL _validator.validate(value)
[docs]def is_schema(value) -> bool: if isinstance(value, dict): try: _validator.validate(value) except jsonschema.ValidationError: return False return True return False
def _json_replace(subject, old, new): if subject == old: return new if isinstance(subject, list): result = [_json_replace(s, old, new) for s in subject] for s, r in zip(subject, result): if s != r: return result elif isinstance(subject, tuple): result = tuple(_json_replace(s, old, new) for s in subject) for s, r in zip(subject, result): if s != r: return result elif isinstance(subject, dict): if isinstance(old, dict): is_sub_dict = True for k, v in old.items(): if k not in subject or subject[k] != v: is_sub_dict = False break if is_sub_dict: return new result = {k: _json_replace(v, old, new) for k, v in subject.items()} for k in subject: if subject[k] != result[k]: return result return subject # nothing changed so share original object (not a copy)
[docs]def is_subschema(sub_schema: JSON_TYPE, super_schema: JSON_TYPE) -> bool: """Is sub_schema a subschema of super_schema? Parameters ---------- sub_schema: JSON schema Left-hand side of subschema check. super_schema: JSON schema Right-hand side of subschema check. Returns ------- bool True if `sub_schema <: super_schema`, False otherwise. Raises ------ jsonschema.ValueError An error occured while checking the subschema relation """ new_sub = _json_replace(sub_schema, {"laleType": "Any"}, {"not": {}}) try: return jsonsubschema.isSubschema(new_sub, super_schema) except Exception as e: raise ValueError( f"unexpected internal error checking ({new_sub} <: {super_schema})" ) from e
[docs]class SubschemaError(Exception): """Raised when a subschema check (sub `<:` sup) failed.""" def __init__(self, sub, sup, sub_name="sub", sup_name="super"): self.sub = sub self.sup = sup self.sub_name = sub_name self.sup_name = sup_name def __str__(self): summary = f"Expected {self.sub_name} to be a subschema of {self.sup_name}." from lale.pretty_print import json_to_string sub = json_to_string(self.sub) sup = json_to_string(self.sup) details = f"\n{self.sub_name} = {sub}\n{self.sup_name} = {sup}" return summary + details
def _validate_subschema( sub: JSON_TYPE, sup: JSON_TYPE, sub_name="sub", sup_name="super" ): if not is_subschema(sub, sup): raise SubschemaError(sub, sup, sub_name, sup_name)
[docs]def validate_schema(lhs: Any, super_schema: JSON_TYPE): """Validate that lhs is an instance of or a subschema of super_schema. Parameters ---------- lhs: value Left-hand side of instance or subschema check. super_schema: JSON schema Right-hand side of instance or subschema check. Raises ------ jsonschema.ValidationError The lhs was an invalid value for super_schema. SubschemaError The lhs had a schema that was not a subschema of super_schema. """ from lale.settings import disable_data_schema_validation if disable_data_schema_validation: return # If schema validation is disabled, always return as valid sub_schema: Optional[JSON_TYPE] try: sub_schema = lale.datasets.data_schemas._to_schema(lhs) except ValueError: sub_schema = None if sub_schema is None: validate_schema_directly(lhs, super_schema) else: _validate_subschema(sub_schema, super_schema)
[docs]def join_schemas(*schemas: JSON_TYPE) -> JSON_TYPE: """Compute the lattice join (union type, disjunction) of the arguments. Parameters ---------- *schemas: list of JSON schemas Schemas to be joined. Returns ------- JSON schema The joined schema. """ def join_two_schemas(s_a: JSON_TYPE, s_b: JSON_TYPE) -> JSON_TYPE: if s_a is None: return s_b s_a = lale.helpers.dict_without(s_a, "description") s_b = lale.helpers.dict_without(s_b, "description") if is_subschema(s_a, s_b): return s_b if is_subschema(s_b, s_a): return s_a # we should improve the typing of the jsonsubschema API so that this ignore can be removed return jsonsubschema.joinSchemas(s_a, s_b) # type: ignore if len(schemas) == 0: return {"not": {}} result = functools.reduce(join_two_schemas, schemas) return result
[docs]def get_hyperparam_names(op: "lale.operators.IndividualOp") -> List[str]: """Names of the arguments to the constructor of the impl. Parameters ---------- op: lale.operators.IndividualOp Operator whose hyperparameters to get. Returns ------- List[str] List of hyperparameter names. """ if op.impl_class.__module__.startswith("lale"): hp_schema = op.hyperparam_schema() params = next(iter(hp_schema.get("allOf", []))).get("properties", {}) return list(params.keys()) else: c: Any = op.impl_class sig = inspect.signature(c.__init__) params = sig.parameters return list(params.keys())
[docs]def validate_method(op: "lale.operators.IndividualOp", schema_name: str): """Check whether the operator has the given method schema. Parameters ---------- op: lale.operators.IndividualOp Operator whose methods to check. schema_name: 'input_fit' or 'input_predict' or 'input_predict_proba' or 'input_transform' 'output_predict' or 'output_predict_proba' or 'output_transform' Name of schema to check. Raises ------ AssertionError The operator does not have the given schema. """ if op._impl.__module__.startswith("lale"): assert schema_name in op._schemas["properties"] else: method_name = "" if schema_name.startswith("input_"): method_name = schema_name[len("input_") :] elif schema_name.startswith("output_"): method_name = schema_name[len("output_") :] if method_name: assert hasattr(op._impl, method_name)
def _get_args_schema(fun): sig = inspect.signature(fun) result = {"type": "object", "properties": {}} required = [] additional_properties = False for name, param in sig.parameters.items(): ignored_kinds = [ inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, ] if name != "self": if param.kind in ignored_kinds: additional_properties = True else: if param.default == inspect.Parameter.empty: param_schema = {"laleType": "Any"} required.append(name) else: param_schema = {"default": param.default} result["properties"][name] = param_schema if not additional_properties: result["additionalProperties"] = False if len(required) > 0: result["required"] = required return result
[docs]def get_hyperparam_defaults(impl): result = {} if hasattr(impl, "__init__"): sig = inspect.signature(impl.__init__) for name, param in sig.parameters.items(): if param.default != inspect.Parameter.empty: result[name] = param.default return result
[docs]def get_default_schema(impl): """Creates combined schemas for a bare operator implementation class. Used when there were no explicit combined schemas provided when the operator was created. The default schema provides defaults by inspecting the signature of the ``__init__`` method, and uses 'Any' types for the inputs and outputs of other methods. Returns ------- JSON Schema Combined schema with properties for hyperparams and all applicable method inputs and outputs. """ if hasattr(impl, "__init__"): hyperparams_schema = _get_args_schema(impl.__init__) else: hyperparams_schema = {"type": "object", "properties": {}} hyperparams_schema["relevantToOptimizer"] = [] method_schemas: Dict[str, JSON_TYPE] = { "hyperparams": {"allOf": [hyperparams_schema]} } if hasattr(impl, "fit"): method_schemas["input_fit"] = _get_args_schema( for method_name in ["predict", "predict_proba", "transform"]: if hasattr(impl, method_name): method_args_schema = _get_args_schema(getattr(impl, method_name)) method_schemas["input_" + method_name] = method_args_schema method_schemas["output_" + method_name] = {"laleType": "Any"} tags = { "pre": [], "op": (["transformer"] if hasattr(impl, "transform") else []) + (["estimator"] if hasattr(impl, "predict") else []), "post": [], } result = { "$schema": "", "description": f"Schema for {type(impl)} auto-generated by lale.type_checking.get_default_schema().", "type": "object", "tags": tags, "properties": method_schemas, } return result
_data_info_keys = {"laleMaximum": "maximum", "laleNot": "not"}
[docs]def has_data_constraints(hyperparam_schema: JSON_TYPE) -> bool: def recursive_check(subject: Any) -> bool: if isinstance(subject, (list, tuple)): for v in subject: if recursive_check(v): return True elif isinstance(subject, dict): for k, v in subject.items(): if k in _data_info_keys or recursive_check(v): return True return False result = recursive_check(hyperparam_schema) return result
[docs]def replace_data_constraints( hyperparam_schema: JSON_TYPE, data_schema: JSON_TYPE ) -> JSON_TYPE: @overload def recursive_replace(subject: JSON_TYPE) -> JSON_TYPE: ... @overload def recursive_replace(subject: List) -> List: ... @overload def recursive_replace(subject: Tuple) -> Tuple: ... @overload def recursive_replace(subject: Any) -> Any: ... def recursive_replace(subject): any_changes = False if isinstance(subject, (list, tuple)): result = [] for v in subject: new_v = recursive_replace(v) result.append(new_v) any_changes = any_changes or v is not new_v if isinstance(subject, tuple): result = tuple(result) elif isinstance(subject, dict): result = {} for k, v in subject.items(): if k in _data_info_keys: new_v = lale.helpers.json_lookup("properties/" + v, data_schema) if new_v is None: new_k = k new_v = v else: new_k = _data_info_keys[k] else: new_v = recursive_replace(v) new_k = k result[new_k] = new_v any_changes = any_changes or k != new_k or v is not new_v else: return subject return result if any_changes else subject result = recursive_replace(hyperparam_schema) return result