Source code for lale.lib.rasl.task_graphs

# Copyright 2021, 2022 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import functools
import itertools
import logging
import pathlib
import sys
import tempfile
import time
from abc import ABC, abstractmethod
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
    cast,
)

import graphviz
import numpy as np
import pandas as pd
import sklearn.model_selection
import sklearn.tree

import lale.helpers
import lale.json_operator
import lale.pretty_print
from lale.datasets import pandas2spark
from lale.operators import (
    TrainableIndividualOp,
    TrainablePipeline,
    TrainedIndividualOp,
    TrainedPipeline,
)

from .metrics import MetricMonoid, MetricMonoidFactory
from .monoid import Monoid, MonoidFactory

if lale.helpers.spark_installed:
    from pyspark.sql.dataframe import DataFrame as SparkDataFrame

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

_BatchStatus = enum.Enum("BatchStatus", "RESIDENT SPILLED")

_TaskStatus = enum.Enum("_TaskStatus", "FRESH READY WAITING DONE")

_Operation = enum.Enum(
    "_Operation", "SCAN SPLIT TRANSFORM PREDICT FIT PARTIAL_FIT TO_MONOID COMBINE"
)

_DUMMY_INPUT_STEP = -1
_DUMMY_SCORE_STEP = sys.maxsize
_ALL_FOLDS = "*"
_ALL_BATCHES = -1


[docs]def is_pretrained(op: TrainableIndividualOp) -> bool: """Is the operator frozen-trained or does it lack a fit method?""" return isinstance(op, TrainedIndividualOp) and ( op.is_frozen_trained() or not hasattr(op.impl, "fit") )
[docs]def is_incremental(op: TrainableIndividualOp) -> bool: """Does the operator have a partial_fit method or is it pre-trained?""" return op.has_method("partial_fit") or is_pretrained(op)
[docs]def is_associative(op: TrainableIndividualOp) -> bool: """Is the operator pre-trained or does it implement MonoidFactory?""" return is_pretrained(op) or isinstance(op.impl, MonoidFactory)
def _batch_id(fold: str, idx: int) -> str: return fold + ("*" if idx == _ALL_BATCHES else str(idx)) def _get_fold(batch_id: str) -> str: return batch_id[0] def _get_idx(batch_id: str) -> int: return _ALL_BATCHES if batch_id[1] == "*" else int(batch_id[1:]) class _Batch: def __init__(self, X, y, task: Optional["_ApplyTask"]): self.X = X self.y = y self.task = task if isinstance(X, pd.DataFrame) and isinstance(y, pd.Series): space_X = int(cast(pd.DataFrame, X).memory_usage().sum()) space_y = cast(pd.Series, y).memory_usage() self.space = space_X + space_y else: self.space = 1 # place-holder value for Spark def spill(self, spill_dir: pathlib.Path) -> None: name_X = spill_dir / f"X_{self}.pkl" name_y = spill_dir / f"y_{self}.pkl" if isinstance(self.X, pd.DataFrame): cast(pd.DataFrame, self.X).to_pickle(name_X) elif isinstance(self.X, np.ndarray): np.save(name_X, self.X, allow_pickle=True) else: raise ValueError( f"""Spilling of {type(self.X)} is not supported. Supported types are: pandas DataFrame, numpy ndarray.""" ) if isinstance(self.y, pd.Series): cast(pd.Series, self.y).to_pickle(name_y) elif isinstance(self.y, np.ndarray): np.save(name_y, self.y, allow_pickle=True) else: raise ValueError( f"""Spilling of {type(self.y)} is not supported. Supported types are: pandas DataFrame, pandas Series, and numpy ndarray.""" ) self.X, self.y = name_X, name_y def load_spilled(self) -> None: assert isinstance(self.X, pathlib.Path) and isinstance(self.y, pathlib.Path) # we know these are pickles written by us, so we can trust them try: data_X = pd.read_pickle(self.X) # nosec B301 except FileNotFoundError: data_X = np.load(f"{self.X}" + ".npy", allow_pickle=True) try: data_y = pd.read_pickle(self.y) # nosec B301 except FileNotFoundError: data_y = np.load(f"{self.y}" + ".npy", allow_pickle=True) self.X, self.y = data_X, data_y def delete_if_spilled(self) -> None: if isinstance(self.X, pathlib.Path) and isinstance(self.y, pathlib.Path): self.X.unlink() self.y.unlink() def __str__(self) -> str: assert self.task is not None assert len(self.task.batch_ids) == 1 and not self.task.has_all_batches() batch_id = self.task.batch_ids[0] return f"{self.task.step_id}_{batch_id}_{self.task.held_out}" @property def Xy(self) -> Tuple[Any, Any]: assert self.status == _BatchStatus.RESIDENT return self.X, self.y @property def status(self) -> _BatchStatus: if isinstance(self.X, pathlib.Path) and isinstance(self.y, pathlib.Path): return _BatchStatus.SPILLED return _BatchStatus.RESIDENT _MemoKey = Tuple[Type["_Task"], int, Tuple[str, ...], Optional[str]] class _Task: preds: List["_Task"] succs: List["_Task"] def __init__( self, step_id: int, batch_ids: Tuple[str, ...], held_out: Optional[str] ): assert len(batch_ids) >= 1 self.step_id = step_id self.batch_ids = batch_ids self.held_out = held_out self.status = _TaskStatus.FRESH self.preds = [] self.succs = [] self.deletable_output = True @abstractmethod def get_operation( self, pipeline: TrainablePipeline[TrainableIndividualOp] ) -> _Operation: pass def add_pred(self, pred): if pred not in self.preds: self.preds.append(pred) pred.succs.append(self) def has_all_batches(self) -> bool: return any(b[1] == "*" for b in self.batch_ids) def can_be_ready(self, end_of_scanned_batches) -> bool: if any(p.status is not _TaskStatus.DONE for p in self.preds): return False if end_of_scanned_batches: return True return not self.has_all_batches() def expand_batches(self, up_to) -> Tuple[str, ...]: if self.has_all_batches(): result = tuple( itertools.chain.from_iterable( ( (_batch_id(_get_fold(b), i) for i in range(up_to)) if b[1] == "*" else [b] ) for b in self.batch_ids ) ) else: result = self.batch_ids return result def memo_key(self) -> _MemoKey: return type(self), self.step_id, self.batch_ids, self.held_out class _TrainTask(_Task): monoid: Optional[Monoid] trained: Optional[TrainedIndividualOp] def __init__(self, step_id: int, batch_ids: Tuple[str, ...], held_out: str): super().__init__(step_id, batch_ids, held_out) self.monoid = None self.trained = None def get_operation( self, pipeline: TrainablePipeline[TrainableIndividualOp] ) -> _Operation: step = pipeline.steps_list()[self.step_id] if is_pretrained(step): return _Operation.FIT if is_associative(step): if len(self.batch_ids) == 1 and not self.has_all_batches(): return _Operation.TO_MONOID return _Operation.COMBINE if is_incremental(step): return _Operation.PARTIAL_FIT return _Operation.FIT def get_trained( self, pipeline: TrainablePipeline[TrainableIndividualOp] ) -> TrainedIndividualOp: if self.trained is None: assert self.monoid is not None trainable = pipeline.steps_list()[self.step_id] self.trained = trainable.convert_to_trained() hyperparams = trainable.impl._hyperparams self.trained._impl = trainable._impl_class()(**hyperparams) if trainable.has_method("_set_fit_attributes"): self.trained._impl._set_fit_attributes(self.monoid) elif trainable.has_method("from_monoid"): self.trained._impl.from_monoid(self.monoid) else: assert False, self.trained return self.trained class _ApplyTask(_Task): batch: Optional[_Batch] splits: Optional[List[Tuple[List[int], List[int]]]] def __init__(self, step_id: int, batch_ids: Tuple[str, ...], held_out: str): super().__init__(step_id, batch_ids, held_out) assert len(batch_ids) == 1 and not self.has_all_batches() self.batch = None self.splits = None # for cross validation with scan tasks def get_operation(self, pipeline: TrainablePipeline) -> _Operation: if self.step_id == _DUMMY_INPUT_STEP: return _Operation.SCAN if len(self.preds) == 0 else _Operation.SPLIT step = pipeline.steps_list()[self.step_id] return _Operation.TRANSFORM if step.is_transformer() else _Operation.PREDICT class _MetricTask(_Task): mscore: Optional[MetricMonoid] def __init__(self, step_id: int, batch_ids: Tuple[str, ...], held_out: str): super().__init__(step_id, batch_ids, held_out) self.mmonoid = None def get_operation(self, pipeline: TrainablePipeline) -> _Operation: if len(self.batch_ids) == 1 and not self.has_all_batches(): return _Operation.TO_MONOID return _Operation.COMBINE def _task_type_prio(task: _Task) -> int: if isinstance(task, _TrainTask): return 0 if isinstance(task, _ApplyTask): return 1 assert isinstance(task, _MetricTask), type(task) return 2
[docs]class Prio(ABC): """Abstract base class for scheduling priority in task graphs.""" arity: int
[docs] def bottom(self) -> Any: # tuple of "inf" means all others are more important return self.arity * (float("inf"),)
[docs] def batch_priority(self, batch: _Batch) -> Any: # prefer to keep resident if lower assert batch.task is not None return min( ( self.task_priority(s) for s in batch.task.succs if s.status in [_TaskStatus.READY, _TaskStatus.WAITING] ), default=self.bottom(), )
[docs] @abstractmethod def task_priority(self, task: _Task) -> Any: # prefer to do first if lower pass
[docs]class PrioStep(Prio): """Execute tasks from earlier steps first, like nested-loop algorithm.""" arity = 6
[docs] def task_priority(self, task: _Task) -> Any: if task.has_all_batches(): max_batch_idx = sys.maxsize else: max_batch_idx = max(_get_idx(b) for b in task.batch_ids) result = ( task.status.value, task.step_id, max_batch_idx, len(task.batch_ids), task.batch_ids, _task_type_prio(task), ) assert len(result) == self.arity return result
[docs]class PrioBatch(Prio): """Execute tasks from earlier batches first.""" arity = 6
[docs] def task_priority(self, task: _Task) -> Any: if task.has_all_batches(): max_batch_idx = sys.maxsize else: max_batch_idx = max(_get_idx(b) for b in task.batch_ids) result = ( task.status.value, max_batch_idx, len(task.batch_ids), task.batch_ids, task.step_id, _task_type_prio(task), ) assert len(result) == self.arity return result
[docs]class PrioResourceAware(Prio): """Execute tasks with less non-resident data first.""" arity = 5
[docs] def task_priority(self, task: _Task) -> Any: non_res = sum( p.batch.space for p in task.preds if isinstance(p, _ApplyTask) and p.batch is not None if p.batch.status != _BatchStatus.RESIDENT ) result = ( task.status.value, non_res, task.batch_ids, task.step_id, _task_type_prio(task), ) assert len(result) == self.arity return result
def _step_id_to_string( step_id: int, pipeline: TrainablePipeline, cls2label: Optional[Dict[str, str]] = None, ) -> str: if step_id == _DUMMY_INPUT_STEP: return "INP" if step_id == _DUMMY_SCORE_STEP: return "SCR" step = pipeline.steps_list()[step_id] cls = step.class_name() return cls2label[cls] if cls2label and cls in cls2label else step.name() def _task_to_string( task: _Task, pipeline: TrainablePipeline, cls2label: Optional[Dict[str, str]] = None, sep: str = "\n", trace_id: Optional[int] = None, ) -> str: trace_id_s = "" if trace_id is None else f"{trace_id} " operation_s = task.get_operation(pipeline).name.lower() step_s = _step_id_to_string(task.step_id, pipeline, cls2label) batches_s = ",".join(task.batch_ids) held_out_s = "" if task.held_out is None else f"\\\\{task.held_out}" return f"{trace_id_s}{operation_s}{sep}{step_s}({batches_s}){held_out_s}" # TODO: Maybe we can address this another way? # pylint: disable=E1101 class _RunStats: _values: Dict[str, float] def __init__(self): object.__setattr__( self, "_values", { "spill_count": 0, "load_count": 0, "spill_space": 0, "load_space": 0, "min_resident": 0, "max_resident": 0, "train_count": 0, "apply_count": 0, "metric_count": 0, "train_time": 0, "apply_time": 0, "metric_time": 0, "critical_count": 0, "critical_time": 0, }, ) def __getattr__(self, name: str) -> float: if name in self._values: return self._values[name] raise AttributeError(f"'{name}' not in {self._values.keys()}") def __setattr__(self, name: str, value: float) -> None: if name in self._values: self._values[name] = value else: raise AttributeError(f"'{name}' not in {self._values.keys()}") def __repr__(self) -> str: return lale.pretty_print.json_to_string(self._values) class _TraceRecord: task: _Task time: float def __init__(self, task: _Task, task_time: float): self.task = task self.time = task_time if isinstance(task, _ApplyTask) and task.batch is not None: self.space = task.batch.space else: self.space = 0 # TODO: size for train tasks and metrics tasks class _TaskGraph: step_ids: Dict[TrainableIndividualOp, int] step_id_preds: Dict[int, List[int]] fresh_tasks: List[_Task] all_tasks: Dict[_MemoKey, _Task] tasks_with_all_batches: List[_Task] def __init__( self, pipeline: TrainablePipeline[TrainableIndividualOp], folds: List[str], partial_transform: Union[bool, str], same_fold: bool, ): self.pipeline = pipeline self.folds = folds self.partial_transform = partial_transform self.same_fold = same_fold self.step_ids = {step: i for i, step in enumerate(pipeline.steps_list())} self.step_id_preds = { self.step_ids[s]: ( [_DUMMY_INPUT_STEP] if len(pipeline._preds[s]) == 0 else [self.step_ids[p] for p in pipeline._preds[s]] ) for s in pipeline.steps_list() } self.fresh_tasks = [] self.all_tasks = {} self.tasks_with_all_batches = [] def __enter__(self) -> "_TaskGraph": return self def __exit__(self, exc_value, exc_type, traceback) -> None: for task in self.all_tasks.values(): # preds form a garbage collection cycle with succs task.preds.clear() task.succs.clear() # tasks form a garbage collection cycle with batches if isinstance(task, _ApplyTask) and task.batch is not None: task.batch.task = None task.batch = None self.all_tasks.clear() def extract_scores(self, scoring: MetricMonoidFactory) -> List[float]: def extract_score(held_out: str) -> float: batch_ids = (_batch_id(held_out, _ALL_BATCHES),) task = self.all_tasks[(_MetricTask, _DUMMY_SCORE_STEP, batch_ids, held_out)] assert isinstance(task, _MetricTask) and task.mmonoid is not None return scoring.from_monoid(task.mmonoid) scores = [extract_score(held_out) for held_out in self.folds] return scores def extract_trained_pipeline( self, held_out: Optional[str], up_to: int ) -> TrainedPipeline: if up_to == _ALL_BATCHES: batch_ids = _batch_ids_except(self.folds, held_out) else: assert len(self.folds) == 1 and held_out is None batch_ids = tuple(_batch_id(self.folds[0], i) for i in range(up_to)) def extract_trained_step(step_id: int) -> TrainedIndividualOp: task = cast( _TrainTask, self.all_tasks[(_TrainTask, step_id, batch_ids, held_out)] ) return task.get_trained(self.pipeline) step_map = { old_step: extract_trained_step(step_id) for step_id, old_step in enumerate(self.pipeline.steps_list()) } trained_edges = [(step_map[x], step_map[y]) for x, y in self.pipeline.edges()] result = TrainedPipeline( list(step_map.values()), trained_edges, ordered=True, _lale_trained=True ) return result def find_or_create( self, task_class: Type["_Task"], step_id: int, batch_ids: Tuple[str, ...], held_out: Optional[str], ) -> _Task: memo_key = task_class, step_id, batch_ids, held_out if memo_key not in self.all_tasks: task = task_class(step_id, batch_ids, held_out) self.all_tasks[memo_key] = task self.fresh_tasks.append(task) if task.has_all_batches(): self.tasks_with_all_batches.append(task) return self.all_tasks[memo_key] def visualize( self, prio: Prio, call_depth: int, trace: Optional[List[_TraceRecord]] ) -> None: cls2label = lale.json_operator._get_cls2label(call_depth + 1) dot = graphviz.Digraph() dot.attr("graph", rankdir="LR", nodesep="0.1") dot.attr("node", fontsize="11", margin="0.03,0.03", shape="box", height="0.1") next_task = min(self.all_tasks.values(), key=prio.task_priority) task_key2trace_id: Dict[_MemoKey, int] = {} if trace is not None: task_key2trace_id = {r.task.memo_key(): i for i, r in enumerate(trace)} for task in self.all_tasks.values(): if task.status is _TaskStatus.FRESH: color = "white" elif task.status is _TaskStatus.READY: color = "lightgreen" if task is next_task else "yellow" elif task.status is _TaskStatus.WAITING: color = "coral" else: assert task.status is _TaskStatus.DONE color = "lightgray" # https://www.graphviz.org/doc/info/shapes.html if isinstance(task, _TrainTask): style = "filled,rounded" elif isinstance(task, _ApplyTask): style = "filled" elif isinstance(task, _MetricTask): style = "filled,diagonals" else: assert False, type(task) trace_id = task_key2trace_id.get(task.memo_key(), None) task_s = _task_to_string(task, self.pipeline, cls2label, trace_id=trace_id) dot.node(task_s, style=style, fillcolor=color) for task in self.all_tasks.values(): trace_id = task_key2trace_id.get(task.memo_key(), None) task_s = _task_to_string(task, self.pipeline, cls2label, trace_id=trace_id) for succ in task.succs: succ_id = task_key2trace_id.get(succ.memo_key(), None) succ_s = _task_to_string( succ, self.pipeline, cls2label, trace_id=succ_id ) dot.edge(task_s, succ_s) import IPython.display IPython.display.display(dot) def _batch_ids_except(folds: List[str], held_out: Optional[str]) -> Tuple[str, ...]: return tuple(_batch_id(f, _ALL_BATCHES) for f in folds if f != held_out) def _create_initial_tasks( tg: _TaskGraph, need_metrics: bool, keep_estimator: bool ) -> None: held_out: Optional[str] _ = tg.find_or_create( _ApplyTask, _DUMMY_INPUT_STEP, (_batch_id(tg.folds[0] if len(tg.folds) == 1 else _ALL_FOLDS, 0),), None, ) if need_metrics: for held_out in tg.folds: task = tg.find_or_create( _MetricTask, _DUMMY_SCORE_STEP, (_batch_id(held_out, _ALL_BATCHES),), None if len(tg.folds) == 1 else held_out, ) task.deletable_output = False if keep_estimator: for step_id in tg.step_ids.values(): held_outs = cast( List[Optional[str]], [None] if len(tg.folds) == 1 else tg.folds ) for held_out in held_outs: task = tg.find_or_create( _TrainTask, step_id, _batch_ids_except(tg.folds, held_out), held_out, ) assert isinstance(task, _TrainTask) task.deletable_output = False trainable = tg.pipeline.steps_list()[task.step_id] if is_pretrained(trainable): task.trained = cast(TrainedIndividualOp, trainable) task.status = _TaskStatus.DONE def _backward_chain_tasks( tg: _TaskGraph, n_batches_scanned: int, end_of_scanned_batches: bool ) -> None: def apply_pred_ho(task, pred_batch_id, pred_step_id): assert isinstance(task, _TrainTask), type(task) if len(tg.folds) == 1 or pred_step_id == _DUMMY_INPUT_STEP: result = None elif tg.same_fold: result = task.held_out else: result = _get_fold(pred_batch_id) return result def train_pred_ho(task, pred_batch_ids): assert isinstance(task, _TrainTask), type(task) if len(pred_batch_ids) == 1 and ( tg.step_id_preds[task.step_id] == [_DUMMY_INPUT_STEP] or not tg.same_fold ): result = None else: result = task.held_out return result pred_batch_ids: Tuple[str, ...] while len(tg.fresh_tasks) > 0: task = tg.fresh_tasks.pop() if isinstance(task, _TrainTask): step = tg.pipeline.steps_list()[task.step_id] if is_pretrained(step): pass elif len(task.batch_ids) == 1 and not task.has_all_batches(): for pred_step_id in tg.step_id_preds[task.step_id]: task.add_pred( tg.find_or_create( _ApplyTask, pred_step_id, task.batch_ids, apply_pred_ho(task, task.batch_ids[0], pred_step_id), ) ) else: if is_associative(step): if tg.partial_transform in ["score", True]: if task.has_all_batches(): if n_batches_scanned > 0: expanded_batch_ids = task.expand_batches( n_batches_scanned ) last_combine_task = tg.find_or_create( _TrainTask, task.step_id, expanded_batch_ids, train_pred_ho(task, expanded_batch_ids), ) last_combine_task.deletable_output = False if end_of_scanned_batches: task.add_pred(last_combine_task) else: if len(task.batch_ids) > 1: pred_batch_ids = task.batch_ids[:-1] task.add_pred( tg.find_or_create( _TrainTask, task.step_id, pred_batch_ids, train_pred_ho(task, pred_batch_ids), ) ) pred_batch_ids = task.batch_ids[-1:] task.add_pred( tg.find_or_create( _TrainTask, task.step_id, pred_batch_ids, train_pred_ho(task, pred_batch_ids), ) ) else: for batch_id in task.expand_batches(n_batches_scanned): pred_batch_ids = (batch_id,) task.add_pred( tg.find_or_create( _TrainTask, task.step_id, pred_batch_ids, train_pred_ho(task, pred_batch_ids), ) ) elif is_incremental(step): if task.has_all_batches(): if n_batches_scanned > 0: expanded_batch_ids = task.expand_batches(n_batches_scanned) last_partial_fit_task = tg.find_or_create( _TrainTask, task.step_id, expanded_batch_ids, train_pred_ho(task, expanded_batch_ids), ) last_partial_fit_task.deletable_output = False if end_of_scanned_batches: task.add_pred(last_partial_fit_task) else: if len(task.batch_ids) > 1: pred_batch_ids = task.batch_ids[:-1] task.add_pred( tg.find_or_create( _TrainTask, task.step_id, pred_batch_ids, train_pred_ho(task, pred_batch_ids), ) ) pred_batch_id = task.batch_ids[-1] for pred_step_id in tg.step_id_preds[task.step_id]: task.add_pred( tg.find_or_create( _ApplyTask, pred_step_id, (pred_batch_id,), apply_pred_ho(task, pred_batch_id, pred_step_id), ) ) else: for pred_step_id in tg.step_id_preds[task.step_id]: for pred_batch_id in task.expand_batches(n_batches_scanned): task.add_pred( tg.find_or_create( _ApplyTask, pred_step_id, (pred_batch_id,), apply_pred_ho(task, pred_batch_id, pred_step_id), ) ) elif isinstance(task, _ApplyTask): assert len(task.batch_ids) == 1 and not task.has_all_batches() if task.step_id == _DUMMY_INPUT_STEP: assert task.held_out is None, task.held_out batch_id = task.batch_ids[0] if len(tg.folds) > 1 and _get_fold(batch_id) != _ALL_FOLDS: task.add_pred( tg.find_or_create( _ApplyTask, task.step_id, (_batch_id(_ALL_FOLDS, _get_idx(batch_id)),), None, ) ) else: if ( tg.partial_transform is True or tg.partial_transform == "score" and all(isinstance(s, _MetricTask) for s in task.succs) ): fit_upto = _get_idx(task.batch_ids[0]) if end_of_scanned_batches and fit_upto == n_batches_scanned - 1: pred_batch_ids = _batch_ids_except(tg.folds, task.held_out) else: pred_batch_ids = tuple( _batch_id(fold, idx) for fold in tg.folds if fold != task.held_out for idx in range(fit_upto + 1) ) else: pred_batch_ids = _batch_ids_except(tg.folds, task.held_out) task.add_pred( tg.find_or_create( _TrainTask, task.step_id, pred_batch_ids, task.held_out, ) ) for pred_step_id in tg.step_id_preds[task.step_id]: if len(tg.folds) == 1 or pred_step_id == _DUMMY_INPUT_STEP: pred_held_out = None else: pred_held_out = task.held_out task.add_pred( tg.find_or_create( _ApplyTask, pred_step_id, task.batch_ids, pred_held_out ) ) elif isinstance(task, _MetricTask): if len(task.batch_ids) == 1 and not task.has_all_batches(): task.add_pred( tg.find_or_create( _ApplyTask, _DUMMY_INPUT_STEP, task.batch_ids, None ) ) sink = tg.pipeline.get_last() assert sink is not None task.add_pred( tg.find_or_create( _ApplyTask, tg.step_ids[sink], task.batch_ids, task.held_out ) ) else: for batch_id in task.expand_batches(n_batches_scanned): task.add_pred( tg.find_or_create( _MetricTask, task.step_id, (batch_id,), task.held_out ) ) else: assert False, type(task) if task.status is not _TaskStatus.DONE: if task.can_be_ready(end_of_scanned_batches): task.status = _TaskStatus.READY else: task.status = _TaskStatus.WAITING def _create_tasks( pipeline: TrainablePipeline[TrainableIndividualOp], folds: List[str], need_metrics: bool, keep_estimator: bool, partial_transform: Union[bool, str], same_fold: bool, ) -> _TaskGraph: tg = _TaskGraph(pipeline, folds, partial_transform, same_fold) _create_initial_tasks(tg, need_metrics, keep_estimator) _backward_chain_tasks(tg, 0, False) return tg def _analyze_run_trace(stats: _RunStats, trace: List[_TraceRecord]) -> _RunStats: memo_key2critical_count: Dict[_MemoKey, int] = {} memo_key2critical_time: Dict[_MemoKey, float] = {} for record in trace: if isinstance(record.task, _TrainTask): stats.train_count += 1 stats.train_time += record.time elif isinstance(record.task, _ApplyTask): stats.apply_count += 1 stats.apply_time += record.time elif isinstance(record.task, _MetricTask): stats.metric_count += 1 stats.metric_time += record.time else: assert False, type(record.task) critical_count = 1 + max( ( memo_key2critical_count[p.memo_key()] for p in record.task.preds if p in memo_key2critical_count ), default=0, ) stats.critical_count = max(critical_count, stats.critical_count) memo_key2critical_count[record.task.memo_key()] = critical_count critical_time = record.time + max( ( memo_key2critical_time[p.memo_key()] for p in record.task.preds if p in memo_key2critical_time ), default=0, ) stats.critical_time = max(critical_time, stats.critical_time) memo_key2critical_time[record.task.memo_key()] = critical_time return stats class _BatchCache: spill_dir: Optional[tempfile.TemporaryDirectory] spill_path: Optional[pathlib.Path] def __init__( self, tasks: Dict[_MemoKey, _Task], max_resident: Optional[int], prio: Prio, verbose: int, ): self.tasks = tasks self.max_resident = sys.maxsize if max_resident is None else max_resident self.prio = prio self.spill_dir = None self.spill_path = None self.verbose = verbose self.stats = _RunStats() self.stats.max_resident = self.max_resident def __enter__(self) -> "_BatchCache": if self.max_resident < sys.maxsize: self.spill_dir = tempfile.TemporaryDirectory() self.spill_path = pathlib.Path(self.spill_dir.name) return self def __exit__(self, exc_value, exc_type, traceback) -> None: if self.spill_dir is not None: self.spill_dir.cleanup() def _get_apply_preds(self, task: _Task) -> List[_ApplyTask]: result = [t for t in task.preds if isinstance(t, _ApplyTask)] assert all(t.batch is not None for t in result) return result def estimate_space(self, task: _ApplyTask) -> int: other_tasks_with_similar_output = ( t for t in self.tasks.values() if t is not task and isinstance(t, _ApplyTask) if t.step_id == task.step_id and t.batch is not None ) try: surrogate = next(other_tasks_with_similar_output) assert isinstance(surrogate, _ApplyTask) and surrogate.batch is not None return surrogate.batch.space except StopIteration: # the iterator was empty if task.step_id == _DUMMY_INPUT_STEP: return 1 # safe to underestimate on first batch scanned apply_preds = self._get_apply_preds(task) return sum(cast(_Batch, t.batch).space for t in apply_preds) def ensure_space(self, amount_needed: int, no_spill_set: Set[_Batch]) -> None: no_spill_space = sum(b.space for b in no_spill_set) min_resident = amount_needed + no_spill_space self.stats.min_resident = max(self.stats.min_resident, min_resident) resident_batches = [ t.batch for t in self.tasks.values() if isinstance(t, _ApplyTask) and t.batch is not None if t.batch.status == _BatchStatus.RESIDENT ] resident_batches.sort(key=self.prio.batch_priority) resident_batches_space = sum(b.space for b in resident_batches) while resident_batches_space + amount_needed > self.max_resident: if len(resident_batches) == 0: logger.warning( f"ensure_space() failed, amount_needed {amount_needed}, no_spill_space {no_spill_space}, min_resident {min_resident}, max_resident {self.max_resident}" ) break batch = resident_batches.pop() assert batch.status == _BatchStatus.RESIDENT and batch.task is not None if batch in no_spill_set: logger.warning(f"aborted spill of batch {batch}") else: assert self.spill_path is not None, self.max_resident batch.spill(self.spill_path) self.stats.spill_count += 1 self.stats.spill_space += batch.space if self.verbose >= 2: print(f"spill {batch.X} {batch.y}") resident_batches_space -= batch.space def load_input_batches(self, task: _Task) -> None: apply_preds = self._get_apply_preds(task) no_spill_set = cast(Set[_Batch], set(t.batch for t in apply_preds)) for pred in apply_preds: assert pred.batch is not None if pred.batch.status == _BatchStatus.SPILLED: self.ensure_space(pred.batch.space, no_spill_set) if self.verbose >= 2: print(f"load {pred.batch.X} {pred.batch.y}") pred.batch.load_spilled() self.stats.load_count += 1 self.stats.load_space += pred.batch.space for pred in apply_preds: assert pred.batch is not None assert pred.batch.status == _BatchStatus.RESIDENT def _run_tasks_inner( tg: _TaskGraph, batches_train: Iterable[Tuple[Any, Any]], batches_valid: Optional[List[Tuple[Any, Any]]], scoring: Optional[MetricMonoidFactory], cv, unique_class_labels: List[Union[str, int, float]], cache: _BatchCache, prio: Prio, verbose: int, progress_callback: Optional[Callable[[float, float, int, bool], None]], call_depth: int, ) -> None: for task in tg.all_tasks.values(): assert task.status is not _TaskStatus.FRESH n_batches_scanned = 0 end_of_scanned_batches = False ready_keys = {k for k, t in tg.all_tasks.items() if t.status is _TaskStatus.READY} def find_task( task_class: Type["_Task"], task_list: List[_Task] ) -> Union[_Task, List[_Task]]: task_list = [t for t in task_list if isinstance(t, task_class)] if len(task_list) == 1: return task_list[0] else: return task_list def try_to_delete_output(task: _Task) -> None: if task.deletable_output: if all(s.status is _TaskStatus.DONE for s in task.succs): if isinstance(task, _ApplyTask): if task.batch is not None: task.batch.delete_if_spilled() task.batch = None elif isinstance(task, _TrainTask): task.monoid = None if batches_valid is None: task.trained = None elif isinstance(task, _MetricTask): task.mmonoid = None else: assert False, type(task) def mark_done(task: _Task) -> None: try_to_delete_output(task) if task.status is _TaskStatus.DONE: return if task.status is _TaskStatus.READY: ready_keys.remove(task.memo_key()) task.status = _TaskStatus.DONE for succ in task.succs: if succ.status is _TaskStatus.WAITING: if succ.can_be_ready(end_of_scanned_batches): succ.status = _TaskStatus.READY ready_keys.add(succ.memo_key()) for pred in task.preds: if all(s.status is _TaskStatus.DONE for s in pred.succs): mark_done(pred) if isinstance(task, _TrainTask): if task.get_operation(tg.pipeline) is _Operation.TO_MONOID: if task.monoid is not None and task.monoid.is_absorbing: def is_moot(task2): # same modulo batch_ids type1, step1, _, hold1 = task.memo_key() type2, step2, _, hold2 = task2.memo_key() return type1 == type2 and step1 == step2 and hold1 == hold2 task_monoid = task.monoid # prevent accidental None assignment for task2 in tg.all_tasks.values(): if task2.status is not _TaskStatus.DONE and is_moot(task2): assert isinstance(task2, _TrainTask) task2.monoid = task_monoid mark_done(task2) trace: Optional[List[_TraceRecord]] = [] if verbose >= 2 else None batches_iterator = iter(batches_train) while len(ready_keys) > 0: task = tg.all_tasks[ min(ready_keys, key=lambda k: prio.task_priority(tg.all_tasks[k])) ] if verbose >= 3: tg.visualize(prio, call_depth + 1, trace) print(_task_to_string(task, tg.pipeline, sep=" ")) operation = task.get_operation(tg.pipeline) start_time = time.time() if verbose >= 2 else float("nan") if operation is _Operation.SCAN: assert not end_of_scanned_batches assert isinstance(task, _ApplyTask) assert len(task.batch_ids) == 1 and len(task.preds) == 0 cache.ensure_space(cache.estimate_space(task), set()) try: X, y = next(batches_iterator) task.batch = _Batch(X, y, task) n_batches_scanned += 1 _ = tg.find_or_create( _ApplyTask, _DUMMY_INPUT_STEP, (_batch_id(_get_fold(task.batch_ids[0]), n_batches_scanned),), None, ) except StopIteration: end_of_scanned_batches = True assert n_batches_scanned >= 1 for task_with_ab in tg.tasks_with_all_batches: if task_with_ab.status is _TaskStatus.WAITING: task_with_ab.status = _TaskStatus.FRESH tg.fresh_tasks.append(task_with_ab) else: assert task_with_ab.status is _TaskStatus.DONE _backward_chain_tasks(tg, n_batches_scanned, end_of_scanned_batches) ready_keys = { k for k, t in tg.all_tasks.items() if t.status is _TaskStatus.READY } elif operation is _Operation.SPLIT: assert isinstance(task, _ApplyTask) assert len(task.batch_ids) == 1 and len(task.preds) == 1 batch_id = task.batch_ids[0] scan_pred = cast(_ApplyTask, task.preds[0]) cache.load_input_batches(task) assert scan_pred.batch is not None cache.ensure_space(cache.estimate_space(task), {scan_pred.batch}) input_X, input_y = scan_pred.batch.Xy is_sparky = lale.helpers.spark_installed and isinstance( input_X, SparkDataFrame ) if is_sparky: # TODO: use Spark native split instead input_X, input_y = input_X.toPandas(), input_y.toPandas().squeeze() if scan_pred.splits is None: scan_pred.splits = list(cv.split(input_X, input_y)) train, test = scan_pred.splits[ord(_get_fold(batch_id)) - ord("d")] dummy_estimator = sklearn.tree.DecisionTreeClassifier() output_X, output_y = lale.helpers.split_with_schemas( dummy_estimator, input_X, input_y, test, train ) if is_sparky: # TODO: use Spark native split instead output_X, output_y = pandas2spark(output_X), pandas2spark(output_y) task.batch = _Batch(output_X, output_y, task) elif operation in [_Operation.TRANSFORM, _Operation.PREDICT]: assert isinstance(task, _ApplyTask) assert len(task.batch_ids) == 1 train_pred = cast(_TrainTask, find_task(_TrainTask, task.preds)) trained = train_pred.get_trained(tg.pipeline) apply_preds = [t for t in task.preds if isinstance(t, _ApplyTask)] cache.load_input_batches(task) if len(apply_preds) == 1: assert apply_preds[0].batch is not None input_X, input_y = apply_preds[0].batch.Xy else: assert not any(pred.batch is None for pred in apply_preds) input_X = [cast(_Batch, pred.batch).X for pred in apply_preds] # The assumption is that input_y is not changed by the preds, so we can # use it from any one of them. input_y = cast(_Batch, apply_preds[0].batch).y no_spill_set = cast(Set[_Batch], set(t.batch for t in apply_preds)) cache.ensure_space(cache.estimate_space(task), no_spill_set) if operation is _Operation.TRANSFORM: if trained.has_method("transform_X_y"): output_X, output_y = trained.transform_X_y(input_X, input_y) else: output_X, output_y = trained.transform(input_X), input_y task.batch = _Batch(output_X, output_y, task) else: y_pred = trained.predict(input_X) if isinstance(y_pred, np.ndarray): y_pred = pd.Series( y_pred, cast(pd.Series, input_y).index, cast(pd.Series, input_y).dtype, "y_pred", ) task.batch = _Batch(input_X, y_pred, task) elif operation is _Operation.FIT: assert isinstance(task, _TrainTask) assert all(isinstance(p, _ApplyTask) for p in task.preds) apply_preds = [cast(_ApplyTask, p) for p in task.preds] assert not any(p.batch is None for p in apply_preds) trainable = tg.pipeline.steps_list()[task.step_id] if is_pretrained(trainable): assert len(task.preds) == 0 if task.trained is None: task.trained = cast(TrainedIndividualOp, trainable) else: cache.load_input_batches(task) if len(task.preds) == 1: input_X, input_y = cast(_Batch, apply_preds[0].batch).Xy else: assert not is_incremental(trainable) list_X = [cast(_Batch, p.batch).X for p in apply_preds] list_y = [cast(_Batch, p.batch).y for p in apply_preds] if all(isinstance(X, pd.DataFrame) for X in list_X): input_X = pd.concat(list_X) input_y = pd.concat(list_y) elif lale.helpers.spark_installed and all( isinstance(X, SparkDataFrame) for X in list_X ): input_X = functools.reduce(lambda a, b: a.union(b), list_X) # type: ignore input_y = functools.reduce(lambda a, b: a.union(b), list_y) # type: ignore elif all(isinstance(X, np.ndarray) for X in list_X): input_X = np.concatenate(list_X) input_y = np.concatenate(list_y) else: raise ValueError( f"""Input of {type(list_X[0])} is not supported for fit on a non-incremental operator. Supported types are: pandas DataFrame, numpy ndarray, and spark DataFrame.""" ) task.trained = trainable.fit(input_X, input_y) elif operation is _Operation.PARTIAL_FIT: assert isinstance(task, _TrainTask) if task.has_all_batches(): assert len(task.preds) == 1, ( _task_to_string(task, tg.pipeline, sep=" "), len(task.preds), ) train_pred = cast(_TrainTask, task.preds[0]) task.trained = train_pred.get_trained(tg.pipeline) else: assert len(task.preds) in [1, 2] if len(task.preds) == 1: trainee = tg.pipeline.steps_list()[task.step_id] else: train_pred = cast(_TrainTask, find_task(_TrainTask, task.preds)) trainee = train_pred.get_trained(tg.pipeline) apply_pred = cast(_ApplyTask, find_task(_ApplyTask, task.preds)) assert apply_pred.batch is not None cache.load_input_batches(task) input_X, input_y = apply_pred.batch.Xy if trainee.is_supervised() and trainee.is_classifier(): task.trained = trainee.partial_fit( input_X, input_y, classes=unique_class_labels ) else: task.trained = trainee.partial_fit(input_X, input_y) elif operation is _Operation.TO_MONOID: assert len(task.batch_ids) == 1 assert all(isinstance(p, _ApplyTask) for p in task.preds) assert all(cast(_ApplyTask, p).batch is not None for p in task.preds) cache.load_input_batches(task) if isinstance(task, _TrainTask): assert len(task.preds) == 1 trainable = tg.pipeline.steps_list()[task.step_id] input_X, input_y = task.preds[0].batch.Xy # type: ignore task.monoid = trainable.impl.to_monoid((input_X, input_y)) elif isinstance(task, _MetricTask): assert len(task.preds) == 2 assert task.preds[0].step_id == _DUMMY_INPUT_STEP assert scoring is not None X, y_true = task.preds[0].batch.Xy # type: ignore y_pred = task.preds[1].batch.y # type: ignore task.mmonoid = scoring.to_monoid((y_true, y_pred, X)) if progress_callback is not None: if batches_valid is None or len(batches_valid) == 0: score_valid = float("nan") else: partially_trained = tg.extract_trained_pipeline( None, n_batches_scanned ) score_valid = scoring.score_estimator_batched( partially_trained, batches_valid ) progress_callback( scoring.from_monoid(task.mmonoid), score_valid, n_batches_scanned, end_of_scanned_batches, ) else: assert False, type(task) elif operation is _Operation.COMBINE: cache.load_input_batches(task) if isinstance(task, _TrainTask): assert all(isinstance(p, _TrainTask) for p in task.preds) trainable = tg.pipeline.steps_list()[task.step_id] monoids = (cast(_TrainTask, p).monoid for p in task.preds) task.monoid = functools.reduce(lambda a, b: a.combine(b), monoids) # type: ignore elif isinstance(task, _MetricTask): scores = (cast(_MetricTask, p).mmonoid for p in task.preds) task.mmonoid = functools.reduce(lambda a, b: a.combine(b), scores) # type: ignore else: assert False, type(task) else: assert False, operation if verbose >= 2: finish_time = time.time() assert trace is not None trace.append(_TraceRecord(task, finish_time - start_time)) mark_done(task) if verbose >= 2: tg.visualize(prio, call_depth + 1, trace) assert trace is not None print(_analyze_run_trace(cache.stats, trace)) def _run_tasks( tg: _TaskGraph, batches_train: Iterable[Tuple[Any, Any]], batches_valid: Optional[List[Tuple[Any, Any]]], scoring: Optional[MetricMonoidFactory], cv, unique_class_labels: List[Union[str, int, float]], max_resident: Optional[int], prio: Prio, verbose: int, progress_callback: Optional[Callable[[float, float, int, bool], None]], call_depth: int, ) -> None: if scoring is None and progress_callback is not None: logger.warning("progress_callback only gets called if scoring is not None") with _BatchCache(tg.all_tasks, max_resident, prio, verbose) as cache: _run_tasks_inner( tg, batches_train, batches_valid, scoring, cv, unique_class_labels, cache, prio, verbose, progress_callback, call_depth + 1, )
[docs]def fit_with_batches( pipeline: TrainablePipeline[TrainableIndividualOp], batches_train: Iterable[Tuple[Any, Any]], batches_valid: Optional[List[Tuple[Any, Any]]], scoring: Optional[MetricMonoidFactory], unique_class_labels: List[Union[str, int, float]], max_resident: Optional[int], prio: Prio, partial_transform: Union[bool, str], verbose: int, progress_callback: Optional[Callable[[float, float, int, bool], None]], ) -> TrainedPipeline[TrainedIndividualOp]: """Replacement for the `fit` method on a pipeline (early interface, subject to change).""" assert partial_transform in [False, "score", True] need_metrics = scoring is not None folds = ["d"] with _create_tasks( pipeline, folds, need_metrics, True, partial_transform, False ) as tg: _run_tasks( tg, batches_train, batches_valid, scoring, None, unique_class_labels, max_resident, prio, verbose, progress_callback, call_depth=2, ) trained_pipeline = tg.extract_trained_pipeline(None, _ALL_BATCHES) return trained_pipeline
[docs]def cross_val_score( pipeline: TrainablePipeline[TrainableIndividualOp], batches: Iterable[Tuple[Any, Any]], scoring: MetricMonoidFactory, cv, unique_class_labels: List[Union[str, int, float]], max_resident: Optional[int], prio: Prio, same_fold: bool, verbose: int, ) -> List[float]: """Replacement for sklearn's `cross_val_score`_ function (early interface, subject to change). .. _`cross_val_score`: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html """ cv = sklearn.model_selection.check_cv(cv) folds = [chr(ord("d") + i) for i in range(cv.get_n_splits())] with _create_tasks(pipeline, folds, True, False, False, same_fold) as tg: _run_tasks( tg, batches, None, scoring, cv, unique_class_labels, max_resident, prio, verbose, None, call_depth=2, ) scores = tg.extract_scores(scoring) return scores
[docs]def cross_validate( pipeline: TrainablePipeline[TrainableIndividualOp], batches: Iterable[Tuple[Any, Any]], scoring: MetricMonoidFactory, cv, unique_class_labels: List[Union[str, int, float]], max_resident: Optional[int], prio: Prio, same_fold: bool, return_estimator: bool, verbose: int, ) -> Dict[str, Union[List[float], List[TrainedPipeline]]]: """Replacement for sklearn's `cross_validate`_ function (early interface, subject to change). .. _`cross_validate`: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html """ cv = sklearn.model_selection.check_cv(cv) folds = [chr(ord("d") + i) for i in range(cv.get_n_splits())] with _create_tasks(pipeline, folds, True, return_estimator, False, same_fold) as tg: _run_tasks( tg, batches, None, scoring, cv, unique_class_labels, max_resident, prio, verbose, None, call_depth=2, ) result: Dict[str, Union[List[float], List[TrainedPipeline]]] = {} result["test_score"] = tg.extract_scores(scoring) if return_estimator: result["estimator"] = [ tg.extract_trained_pipeline(held_out, _ALL_BATCHES) for held_out in tg.folds ] return result