# Copyright 2020-2022 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast # see also https://greentreesnakes.readthedocs.io/
import pprint
import typing
from copy import deepcopy
from io import StringIO
from typing import Any, Dict, Literal, Optional, Union, overload
import astunparse
AstLits = (ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict, ast.Constant)
AstLit = Union[ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict, ast.Constant]
AstExprs = (
*AstLits,
ast.Name,
ast.Expr,
ast.UnaryOp,
ast.BinOp,
ast.BoolOp,
ast.Compare,
ast.Call,
ast.Attribute,
ast.Subscript,
)
AstExpr = Union[
AstLit,
ast.Name,
ast.Expr,
ast.UnaryOp,
ast.BinOp,
ast.BoolOp,
ast.Compare,
ast.Call,
ast.Attribute,
ast.Subscript,
]
# !! WORKAROUND !!
# There is a bug with astunparse and Python 3.8.
# https://github.com/simonpercivall/astunparse/issues/43
# Until it is fixed (which may be never), here is a workaround,
# based on the workaround found in https://github.com/juanlao7/codeclose
[docs]class FixUnparser(astunparse.Unparser):
def _Constant(self, t):
if not hasattr(t, "kind"):
setattr(t, "kind", None)
super()._Constant(t)
# !! WORKAROUND !!
# This method should be called instead of astunparse.unparse
[docs]def fixedUnparse(tree):
v = StringIO()
FixUnparser(tree, file=v)
return v.getvalue()
[docs]class Expr:
_expr: AstExpr
@property
def expr(self):
return self._expr
def __init__(self, expr: AstExpr, istrue=None):
# _istrue variable is used to check the boolean nature of
# '==' and '!=' operator's results.
self._expr = expr
self._istrue = istrue
def __bool__(self) -> bool:
if self._istrue is not None:
return self._istrue
raise TypeError(
f"Cannot convert expression e1=`{str(self)}` to bool."
"Instead of `e1 and e2`, try writing `[e1, e2]`."
)
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, deepcopy(v, memo))
return result
# the type: ignore statements are needed because the type of object.__eq__
# in typeshed is overly restrictive (to catch common errors)
@overload # type: ignore
def __eq__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload # type: ignore
def __eq__(self, other: None) -> Literal[False]: ...
def __eq__(self, other: Union["Expr", str, int, float, bool, None]):
if isinstance(other, Expr):
comp = ast.Compare(
left=self._expr, ops=[ast.Eq()], comparators=[other._expr]
)
return Expr(comp, istrue=self is other)
elif other is not None:
comp = ast.Compare(
left=self._expr, ops=[ast.Eq()], comparators=[ast.Constant(value=other)]
)
return Expr(comp, istrue=False)
else:
return False
@overload
def __ge__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __ge__(self, other: None) -> Literal[False]: ...
def __ge__(self, other):
if isinstance(other, Expr):
comp = ast.Compare(
left=self._expr, ops=[ast.GtE()], comparators=[other._expr]
)
return Expr(comp)
elif other is not None:
comp = ast.Compare(
left=self._expr,
ops=[ast.GtE()],
comparators=[ast.Constant(value=other)],
)
return Expr(comp)
else:
return False
def __getattr__(self, name: str) -> "Expr":
attr = ast.Attribute(value=self._expr, attr=name)
return Expr(attr)
def __getitem__(self, key: Union[int, str, slice]) -> "Expr":
key_ast: Union[ast.Index, ast.Slice]
if isinstance(key, int):
key_ast = ast.Index(ast.Num(n=key))
elif isinstance(key, str):
key_ast = ast.Index(ast.Str(s=key))
elif isinstance(key, slice):
key_ast = ast.Slice(key.start, key.stop, key.step)
else:
raise TypeError(f"expected int, str, or slice, got {type(key)}")
subscript = ast.Subscript(value=self._expr, slice=key_ast)
return Expr(subscript)
@overload
def __gt__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __gt__(self, other: None) -> Literal[False]: ...
def __gt__(self, other):
if isinstance(other, Expr):
comp = ast.Compare(
left=self._expr, ops=[ast.Gt()], comparators=[other._expr]
)
return Expr(comp)
elif other is not None:
comp = ast.Compare(
left=self._expr, ops=[ast.Gt()], comparators=[ast.Constant(value=other)]
)
return Expr(comp)
else:
return False
@overload
def __le__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __le__(self, other: None) -> Literal[False]: ...
def __le__(self, other):
if isinstance(other, Expr):
comp = ast.Compare(
left=self._expr, ops=[ast.LtE()], comparators=[other._expr]
)
return Expr(comp)
elif other is not None:
comp = ast.Compare(
left=self._expr,
ops=[ast.LtE()],
comparators=[ast.Constant(value=other)],
)
return Expr(comp)
else:
return False
@overload
def __lt__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __lt__(self, other: None) -> Literal[False]: ...
def __lt__(self, other):
if isinstance(other, Expr):
comp = ast.Compare(
left=self._expr, ops=[ast.Lt()], comparators=[other._expr]
)
return Expr(comp)
elif other is not None:
comp = ast.Compare(
left=self._expr, ops=[ast.Lt()], comparators=[ast.Constant(value=other)]
)
return Expr(comp)
else:
return False
@overload # type: ignore
def __ne__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload # type: ignore
def __ne__(self, other: None) -> Literal[False]: ...
def __ne__(self, other):
if isinstance(other, Expr):
comp = ast.Compare(
left=self._expr, ops=[ast.NotEq()], comparators=[other._expr]
)
return Expr(comp, istrue=self is other)
elif other is not None:
comp = ast.Compare(
left=self._expr,
ops=[ast.NotEq()],
comparators=[ast.Constant(value=other)],
)
return Expr(comp, istrue=False)
else:
return False
def __str__(self) -> str:
result = fixedUnparse(self._expr).strip()
if isinstance(self._expr, (ast.UnaryOp, ast.BinOp, ast.Compare, ast.BoolOp)):
if result.startswith("(") and result.endswith(")"):
result = result[1:-1]
return result
@overload
def __add__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __add__(self, other: None) -> Literal[False]: ...
def __add__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.Add(), self._expr, other)
@overload
def __sub__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __sub__(self, other: None) -> Literal[False]: ...
def __sub__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.Sub(), self._expr, other)
@overload
def __mul__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __mul__(self, other: None) -> Literal[False]: ...
def __mul__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.Mult(), self._expr, other)
@overload
def __truediv__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __truediv__(self, other: None) -> Literal[False]: ...
def __truediv__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.Div(), self._expr, other)
@overload
def __floordiv__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __floordiv__(self, other: None) -> Literal[False]: ...
def __floordiv__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.FloorDiv(), self._expr, other)
@overload
def __mod__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __mod__(self, other: None) -> Literal[False]: ...
def __mod__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.Mod(), self._expr, other)
@overload
def __pow__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __pow__(self, other: None) -> Literal[False]: ...
def __pow__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.Pow(), self._expr, other)
@overload
def __and__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __and__(self, other: None) -> Literal[False]: ...
def __and__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.BitAnd(), self._expr, other)
@overload
def __or__(self, other: Union["Expr", str, int, float, bool]) -> "Expr": ...
@overload
def __or__(self, other: None) -> Literal[False]: ...
def __or__(self, other) -> Union["Expr", Literal[False]]:
return _make_binop(ast.BitOr(), self._expr, other)
@overload
def _make_binop(op, left: Any, other: Union[Expr, str, int, float, bool]) -> Expr: ...
@overload
def _make_binop(op, left: Any, other: None) -> Literal[False]: ...
def _make_binop(
op, left: Any, other: Union[Expr, str, int, float, bool, None]
) -> Union["Expr", Literal[False]]:
if isinstance(other, Expr):
e = ast.BinOp(left=left, op=op, right=other.expr)
return Expr(e)
elif other is not None:
e = ast.BinOp(left=left, op=op, right=ast.Constant(value=other))
return Expr(e)
else:
return False
def _make_ast_expr(arg: Union[None, Expr, int, float, str, AstExpr]) -> AstExpr:
if arg is None:
return ast.Constant(value=None)
elif isinstance(arg, Expr):
return arg.expr
elif isinstance(arg, (int, float)):
return ast.Num(n=arg)
elif isinstance(arg, str):
return ast.Str(s=arg)
else:
assert isinstance(arg, AstExprs), type(arg)
return arg
def _make_call_expr(
name: str, *args: Union[Expr, AstExpr, int, float, bool, str, None]
) -> Expr:
func_ast = ast.Name(id=name)
args_asts = [_make_ast_expr(arg) for arg in args]
call_ast = ast.Call(func=func_ast, args=args_asts, keywords=[])
return Expr(call_ast)
[docs]def string_indexer(subject: Expr) -> Expr:
return _make_call_expr("string_indexer", subject)
[docs]def collect_set(group: Expr) -> Expr:
return _make_call_expr("collect_set", group)
[docs]def count(group: Expr) -> Expr:
return _make_call_expr("count", group)
[docs]def day_of_month(subject: Expr, fmt: Optional[str] = None) -> Expr:
if fmt is None:
return _make_call_expr("day_of_month", subject)
return _make_call_expr("day_of_month", subject, fmt)
[docs]def day_of_week(subject: Expr, fmt: Optional[str] = None) -> Expr:
if fmt is None:
return _make_call_expr("day_of_week", subject)
return _make_call_expr("day_of_week", subject, fmt)
[docs]def day_of_year(subject: Expr, fmt: Optional[str] = None) -> Expr:
if fmt is None:
return _make_call_expr("day_of_year", subject)
return _make_call_expr("day_of_year", subject, fmt)
[docs]def distinct_count(group: Expr) -> Expr:
return _make_call_expr("distinct_count", group)
[docs]def hour(subject: Expr, fmt: Optional[str] = None) -> Expr:
if fmt is None:
return _make_call_expr("hour", subject)
return _make_call_expr("hour", subject, fmt)
[docs]def item(group: Expr, value: Union[int, str]) -> Expr:
return _make_call_expr("item", group, value)
[docs]def max(group: Expr) -> Expr: # pylint:disable=redefined-builtin
return _make_call_expr("max", group)
[docs]def max_gap_to_cutoff(group: Expr, cutoff: Expr) -> Expr:
return _make_call_expr("max_gap_to_cutoff", group, cutoff)
[docs]def mean(group: Expr) -> Expr:
return _make_call_expr("mean", group)
[docs]def min(group: Expr) -> Expr: # pylint:disable=redefined-builtin
return _make_call_expr("min", group)
[docs]def minute(subject: Expr, fmt: Optional[str] = None) -> Expr:
if fmt is None:
return _make_call_expr("minute", subject)
return _make_call_expr("minute", subject, fmt)
[docs]def month(subject: Expr, fmt: Optional[str] = None) -> Expr:
if fmt is None:
return _make_call_expr("month", subject)
return _make_call_expr("month", subject, fmt)
[docs]def normalized_count(group: Expr) -> Expr:
return _make_call_expr("normalized_count", group)
[docs]def normalized_sum(group: Expr) -> Expr:
return _make_call_expr("normalized_sum", group)
[docs]def recent(series: Expr, age: int) -> Expr:
return _make_call_expr("recent", series, age)
[docs]def recent_gap_to_cutoff(series: Expr, cutoff: Expr, age: int) -> Expr:
return _make_call_expr("recent_gap_to_cutoff", series, cutoff, age)
[docs]def replace(
subject: Expr,
old2new: Dict[Any, Any],
handle_unknown: str = "identity",
unknown_value=None,
) -> Expr:
old2new_str = pprint.pformat(old2new)
module_ast = ast.parse(old2new_str)
old2new_ast = typing.cast(ast.Expr, module_ast.body[0])
assert handle_unknown in ["identity", "use_encoded_value"]
return _make_call_expr(
"replace",
subject,
old2new_ast,
handle_unknown,
unknown_value,
)
[docs]def ite(
cond: Expr,
v1: Union[Expr, int, float, bool, str],
v2: Union[Expr, int, float, bool, str],
) -> Expr:
if not isinstance(v1, Expr):
v1 = Expr(ast.Constant(value=v1))
if not isinstance(v2, Expr):
v2 = Expr(ast.Constant(value=v2))
return _make_call_expr("ite", cond, v1, v2)
[docs]def identity(subject: Expr) -> Expr:
return _make_call_expr("identity", subject)
[docs]def astype(dtype, subject: Expr) -> Expr:
return _make_call_expr("astype", dtype, subject)
[docs]def hash(hash_method: str, subject: Expr) -> Expr: # pylint:disable=redefined-builtin
return _make_call_expr("hash", hash_method, subject)
[docs]def hash_mod(hash_method: str, subject: Expr, n: Expr) -> Expr:
return _make_call_expr("hash_mod", hash_method, subject, n)
[docs]def sum(group: Expr) -> Expr: # pylint:disable=redefined-builtin
return _make_call_expr("sum", group)
[docs]def trend(series: Expr) -> Expr:
return _make_call_expr("trend", series)
[docs]def variance(group: Expr) -> Expr:
return _make_call_expr("variance", group)
[docs]def window_max(series: Expr, size: int) -> Expr:
return _make_call_expr("window_max", series, size)
[docs]def window_max_trend(series: Expr, size: int) -> Expr:
return _make_call_expr("window_max_trend", series, size)
[docs]def window_mean(series: Expr, size: int) -> Expr:
return _make_call_expr("window_mean", series, size)
[docs]def window_mean_trend(series: Expr, size: int) -> Expr:
return _make_call_expr("window_mean_trend", series, size)
[docs]def window_min(series: Expr, size: int) -> Expr:
return _make_call_expr("window_min", series, size)
[docs]def window_min_trend(series: Expr, size: int) -> Expr:
return _make_call_expr("window_min_trend", series, size)
[docs]def window_variance(series: Expr, size: int) -> Expr:
return _make_call_expr("window_variance", series, size)
[docs]def window_variance_trend(series: Expr, size: int) -> Expr:
return _make_call_expr("window_variance_trend", series, size)
[docs]def first(group: Expr) -> Expr:
return _make_call_expr("first", group)
[docs]def isnan(column: Expr) -> Expr:
return _make_call_expr("isnan", column)
[docs]def isnotnan(column: Expr) -> Expr:
return _make_call_expr("isnotnan", column)
[docs]def isnull(column: Expr) -> Expr:
return _make_call_expr("isnull", column)
[docs]def isnotnull(column: Expr) -> Expr:
return _make_call_expr("isnotnull", column)
[docs]def asc(column: Union[Expr, str]) -> Expr:
return _make_call_expr("asc", column)
[docs]def desc(column: Union[Expr, str]) -> Expr:
return _make_call_expr("desc", column)
[docs]def mode(group: Expr) -> Expr:
return _make_call_expr("mode", group)
it = Expr(ast.Name(id="it"))
def _it_column(expr):
if isinstance(expr, ast.Attribute):
if _is_ast_name_it(expr.value):
return expr.attr
else:
raise ValueError(
f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported"
)
elif isinstance(expr, ast.Subscript):
if isinstance(expr.slice, ast.Constant) or (
_is_ast_name_it(expr.value) and isinstance(expr.slice, ast.Index)
):
v = getattr(expr.slice, "value", None)
if isinstance(expr.slice, ast.Constant):
return v
elif isinstance(v, ast.Constant):
return v.value
elif isinstance(v, ast.Str):
return v.s
else:
raise ValueError(
f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported"
)
else:
raise ValueError(
f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported"
)
else:
raise ValueError(
f"Illegal {fixedUnparse(expr)}. Only the access to `it` is supported"
)
def _is_ast_name_it(expr):
return isinstance(expr, ast.Name) and expr.id == "it"