Source code for lale.expressions

# Copyright 2020-2022 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast  # see also https://greentreesnakes.readthedocs.io/
import pprint
import typing
from copy import deepcopy
from io import StringIO
from typing import Any, Dict, Literal, Optional, Union, overload

import astunparse

AstLits = (ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict, ast.Constant)
AstLit = Union[ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict, ast.Constant]
AstExprs = (
    *AstLits,
    ast.Name,
    ast.Expr,
    ast.UnaryOp,
    ast.BinOp,
    ast.BoolOp,
    ast.Compare,
    ast.Call,
    ast.Attribute,
    ast.Subscript,
)
AstExpr = Union[
    AstLit,
    ast.Name,
    ast.Expr,
    ast.UnaryOp,
    ast.BinOp,
    ast.BoolOp,
    ast.Compare,
    ast.Call,
    ast.Attribute,
    ast.Subscript,
]


# !! WORKAROUND !!
# There is a bug with astunparse and Python 3.8.
# https://github.com/simonpercivall/astunparse/issues/43
# Until it is fixed (which may be never), here is a workaround,
# based on the workaround found in https://github.com/juanlao7/codeclose
[docs]class FixUnparser(astunparse.Unparser): def _Constant(self, t): if not hasattr(t, "kind"): setattr(t, "kind", None) super()._Constant(t)
# !! WORKAROUND !! # This method should be called instead of astunparse.unparse
[docs]def fixedUnparse(tree): v = StringIO() FixUnparser(tree, file=v) return v.getvalue()
[docs]class Expr: _expr: AstExpr @property def expr(self): return self._expr def __init__(self, expr: AstExpr, istrue=None): # _istrue variable is used to check the boolean nature of # '==' and '!=' operator's results. self._expr = expr self._istrue = istrue def __bool__(self) -> bool: if self._istrue is not None: return self._istrue raise TypeError( f"Cannot convert expression e1=`{str(self)}` to bool." "Instead of `e1 and e2`, try writing `[e1, e2]`." ) def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): setattr(result, k, deepcopy(v, memo)) return result # the type: ignore statements are needed because the type of object.__eq__ # in typeshed is overly restrictive (to catch common errors) @overload # type: ignore def __eq__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload # type: ignore def __eq__(self, other: None) -> Literal[False]: ... def __eq__(self, other: Union["Expr", str, int, float, bool, None]): if isinstance(other, Expr): comp = ast.Compare( left=self._expr, ops=[ast.Eq()], comparators=[other._expr] ) return Expr(comp, istrue=self is other) elif other is not None: comp = ast.Compare( left=self._expr, ops=[ast.Eq()], comparators=[ast.Constant(value=other)] ) return Expr(comp, istrue=False) else: return False @overload def __ge__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __ge__(self, other: None) -> Literal[False]: ... def __ge__(self, other): if isinstance(other, Expr): comp = ast.Compare( left=self._expr, ops=[ast.GtE()], comparators=[other._expr] ) return Expr(comp) elif other is not None: comp = ast.Compare( left=self._expr, ops=[ast.GtE()], comparators=[ast.Constant(value=other)], ) return Expr(comp) else: return False def __getattr__(self, name: str) -> "Expr": attr = ast.Attribute(value=self._expr, attr=name) return Expr(attr) def __getitem__(self, key: Union[int, str, slice]) -> "Expr": key_ast: Union[ast.Index, ast.Slice] if isinstance(key, int): key_ast = ast.Index(ast.Num(n=key)) elif isinstance(key, str): key_ast = ast.Index(ast.Str(s=key)) elif isinstance(key, slice): key_ast = ast.Slice(key.start, key.stop, key.step) else: raise TypeError(f"expected int, str, or slice, got {type(key)}") subscript = ast.Subscript(value=self._expr, slice=key_ast) return Expr(subscript) @overload def __gt__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __gt__(self, other: None) -> Literal[False]: ... def __gt__(self, other): if isinstance(other, Expr): comp = ast.Compare( left=self._expr, ops=[ast.Gt()], comparators=[other._expr] ) return Expr(comp) elif other is not None: comp = ast.Compare( left=self._expr, ops=[ast.Gt()], comparators=[ast.Constant(value=other)] ) return Expr(comp) else: return False @overload def __le__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __le__(self, other: None) -> Literal[False]: ... def __le__(self, other): if isinstance(other, Expr): comp = ast.Compare( left=self._expr, ops=[ast.LtE()], comparators=[other._expr] ) return Expr(comp) elif other is not None: comp = ast.Compare( left=self._expr, ops=[ast.LtE()], comparators=[ast.Constant(value=other)], ) return Expr(comp) else: return False @overload def __lt__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __lt__(self, other: None) -> Literal[False]: ... def __lt__(self, other): if isinstance(other, Expr): comp = ast.Compare( left=self._expr, ops=[ast.Lt()], comparators=[other._expr] ) return Expr(comp) elif other is not None: comp = ast.Compare( left=self._expr, ops=[ast.Lt()], comparators=[ast.Constant(value=other)] ) return Expr(comp) else: return False @overload # type: ignore def __ne__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload # type: ignore def __ne__(self, other: None) -> Literal[False]: ... def __ne__(self, other): if isinstance(other, Expr): comp = ast.Compare( left=self._expr, ops=[ast.NotEq()], comparators=[other._expr] ) return Expr(comp, istrue=self is other) elif other is not None: comp = ast.Compare( left=self._expr, ops=[ast.NotEq()], comparators=[ast.Constant(value=other)], ) return Expr(comp, istrue=False) else: return False def __str__(self) -> str: result = fixedUnparse(self._expr).strip() if isinstance(self._expr, (ast.UnaryOp, ast.BinOp, ast.Compare, ast.BoolOp)): if result.startswith("(") and result.endswith(")"): result = result[1:-1] return result @overload def __add__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __add__(self, other: None) -> Literal[False]: ... def __add__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.Add(), self._expr, other) @overload def __sub__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __sub__(self, other: None) -> Literal[False]: ... def __sub__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.Sub(), self._expr, other) @overload def __mul__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __mul__(self, other: None) -> Literal[False]: ... def __mul__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.Mult(), self._expr, other) @overload def __truediv__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __truediv__(self, other: None) -> Literal[False]: ... def __truediv__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.Div(), self._expr, other) @overload def __floordiv__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __floordiv__(self, other: None) -> Literal[False]: ... def __floordiv__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.FloorDiv(), self._expr, other) @overload def __mod__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __mod__(self, other: None) -> Literal[False]: ... def __mod__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.Mod(), self._expr, other) @overload def __pow__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __pow__(self, other: None) -> Literal[False]: ... def __pow__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.Pow(), self._expr, other) @overload def __and__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __and__(self, other: None) -> Literal[False]: ... def __and__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.BitAnd(), self._expr, other) @overload def __or__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ... @overload def __or__(self, other: None) -> Literal[False]: ... def __or__(self, other) -> Union["Expr", Literal[False]]: return _make_binop(ast.BitOr(), self._expr, other)
@overload def _make_binop(op, left: Any, other: Union[Expr, str, int, float, bool]) -> Expr: ... @overload def _make_binop(op, left: Any, other: None) -> Literal[False]: ... def _make_binop( op, left: Any, other: Union[Expr, str, int, float, bool, None] ) -> Union["Expr", Literal[False]]: if isinstance(other, Expr): e = ast.BinOp(left=left, op=op, right=other.expr) return Expr(e) elif other is not None: e = ast.BinOp(left=left, op=op, right=ast.Constant(value=other)) return Expr(e) else: return False def _make_ast_expr(arg: Union[None, Expr, int, float, str, AstExpr]) -> AstExpr: if arg is None: return ast.Constant(value=None) elif isinstance(arg, Expr): return arg.expr elif isinstance(arg, (int, float)): return ast.Num(n=arg) elif isinstance(arg, str): return ast.Str(s=arg) else: assert isinstance(arg, AstExprs), type(arg) return arg def _make_call_expr( name: str, *args: Union[Expr, AstExpr, int, float, bool, str, None] ) -> Expr: func_ast = ast.Name(id=name) args_asts = [_make_ast_expr(arg) for arg in args] call_ast = ast.Call(func=func_ast, args=args_asts, keywords=[]) return Expr(call_ast)
[docs]def string_indexer(subject: Expr) -> Expr: return _make_call_expr("string_indexer", subject)
[docs]def collect_set(group: Expr) -> Expr: return _make_call_expr("collect_set", group)
[docs]def count(group: Expr) -> Expr: return _make_call_expr("count", group)
[docs]def day_of_month(subject: Expr, fmt: Optional[str] = None) -> Expr: if fmt is None: return _make_call_expr("day_of_month", subject) return _make_call_expr("day_of_month", subject, fmt)
[docs]def day_of_week(subject: Expr, fmt: Optional[str] = None) -> Expr: if fmt is None: return _make_call_expr("day_of_week", subject) return _make_call_expr("day_of_week", subject, fmt)
[docs]def day_of_year(subject: Expr, fmt: Optional[str] = None) -> Expr: if fmt is None: return _make_call_expr("day_of_year", subject) return _make_call_expr("day_of_year", subject, fmt)
[docs]def distinct_count(group: Expr) -> Expr: return _make_call_expr("distinct_count", group)
[docs]def hour(subject: Expr, fmt: Optional[str] = None) -> Expr: if fmt is None: return _make_call_expr("hour", subject) return _make_call_expr("hour", subject, fmt)
[docs]def item(group: Expr, value: Union[int, str]) -> Expr: return _make_call_expr("item", group, value)
[docs]def max(group: Expr) -> Expr: # pylint:disable=redefined-builtin return _make_call_expr("max", group)
[docs]def max_gap_to_cutoff(group: Expr, cutoff: Expr) -> Expr: return _make_call_expr("max_gap_to_cutoff", group, cutoff)
[docs]def mean(group: Expr) -> Expr: return _make_call_expr("mean", group)
[docs]def min(group: Expr) -> Expr: # pylint:disable=redefined-builtin return _make_call_expr("min", group)
[docs]def minute(subject: Expr, fmt: Optional[str] = None) -> Expr: if fmt is None: return _make_call_expr("minute", subject) return _make_call_expr("minute", subject, fmt)
[docs]def month(subject: Expr, fmt: Optional[str] = None) -> Expr: if fmt is None: return _make_call_expr("month", subject) return _make_call_expr("month", subject, fmt)
[docs]def normalized_count(group: Expr) -> Expr: return _make_call_expr("normalized_count", group)
[docs]def normalized_sum(group: Expr) -> Expr: return _make_call_expr("normalized_sum", group)
[docs]def recent(series: Expr, age: int) -> Expr: return _make_call_expr("recent", series, age)
[docs]def recent_gap_to_cutoff(series: Expr, cutoff: Expr, age: int) -> Expr: return _make_call_expr("recent_gap_to_cutoff", series, cutoff, age)
[docs]def replace( subject: Expr, old2new: Dict[Any, Any], handle_unknown: str = "identity", unknown_value=None, ) -> Expr: old2new_str = pprint.pformat(old2new) module_ast = ast.parse(old2new_str) old2new_ast = typing.cast(ast.Expr, module_ast.body[0]) assert handle_unknown in ["identity", "use_encoded_value"] return _make_call_expr( "replace", subject, old2new_ast, handle_unknown, unknown_value, )
[docs]def ite( cond: Expr, v1: Union[Expr, int, float, bool, str], v2: Union[Expr, int, float, bool, str], ) -> Expr: if not isinstance(v1, Expr): v1 = Expr(ast.Constant(value=v1)) if not isinstance(v2, Expr): v2 = Expr(ast.Constant(value=v2)) return _make_call_expr("ite", cond, v1, v2)
[docs]def identity(subject: Expr) -> Expr: return _make_call_expr("identity", subject)
[docs]def astype(dtype, subject: Expr) -> Expr: return _make_call_expr("astype", dtype, subject)
[docs]def hash(hash_method: str, subject: Expr) -> Expr: # pylint:disable=redefined-builtin return _make_call_expr("hash", hash_method, subject)
[docs]def hash_mod(hash_method: str, subject: Expr, n: Expr) -> Expr: return _make_call_expr("hash_mod", hash_method, subject, n)
[docs]def sum(group: Expr) -> Expr: # pylint:disable=redefined-builtin return _make_call_expr("sum", group)
[docs]def trend(series: Expr) -> Expr: return _make_call_expr("trend", series)
[docs]def variance(group: Expr) -> Expr: return _make_call_expr("variance", group)
[docs]def window_max(series: Expr, size: int) -> Expr: return _make_call_expr("window_max", series, size)
[docs]def window_max_trend(series: Expr, size: int) -> Expr: return _make_call_expr("window_max_trend", series, size)
[docs]def window_mean(series: Expr, size: int) -> Expr: return _make_call_expr("window_mean", series, size)
[docs]def window_mean_trend(series: Expr, size: int) -> Expr: return _make_call_expr("window_mean_trend", series, size)
[docs]def window_min(series: Expr, size: int) -> Expr: return _make_call_expr("window_min", series, size)
[docs]def window_min_trend(series: Expr, size: int) -> Expr: return _make_call_expr("window_min_trend", series, size)
[docs]def window_variance(series: Expr, size: int) -> Expr: return _make_call_expr("window_variance", series, size)
[docs]def window_variance_trend(series: Expr, size: int) -> Expr: return _make_call_expr("window_variance_trend", series, size)
[docs]def first(group: Expr) -> Expr: return _make_call_expr("first", group)
[docs]def isnan(column: Expr) -> Expr: return _make_call_expr("isnan", column)
[docs]def isnotnan(column: Expr) -> Expr: return _make_call_expr("isnotnan", column)
[docs]def isnull(column: Expr) -> Expr: return _make_call_expr("isnull", column)
[docs]def isnotnull(column: Expr) -> Expr: return _make_call_expr("isnotnull", column)
[docs]def asc(column: Union[Expr, str]) -> Expr: return _make_call_expr("asc", column)
[docs]def desc(column: Union[Expr, str]) -> Expr: return _make_call_expr("desc", column)
[docs]def median(group: Expr) -> Expr: return _make_call_expr("median", group)
[docs]def mode(group: Expr) -> Expr: return _make_call_expr("mode", group)
it = Expr(ast.Name(id="it")) def _it_column(expr): if isinstance(expr, ast.Attribute): if _is_ast_name_it(expr.value): return expr.attr else: raise ValueError( f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported" ) elif isinstance(expr, ast.Subscript): if isinstance(expr.slice, ast.Constant) or ( _is_ast_name_it(expr.value) and isinstance(expr.slice, ast.Index) ): v = getattr(expr.slice, "value", None) if isinstance(expr.slice, ast.Constant): return v elif isinstance(v, ast.Constant): return v.value elif isinstance(v, ast.Str): return v.s else: raise ValueError( f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported" ) else: raise ValueError( f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported" ) else: raise ValueError( f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported" ) def _is_ast_name_it(expr): return isinstance(expr, ast.Name) and expr.id == "it"