lale.util.numpy_to_torch_dataset module

class lale.util.numpy_to_torch_dataset.NumpyTorchDataset(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Pytorch Dataset subclass that takes a numpy array and an optional label array.

X and y are the dataset and labels respectively.

Parameters
  • X (numpy array) – Two dimensional dataset of input features.

  • y (numpy array) – Labels

get_data()[source]
lale.util.numpy_to_torch_dataset.numpy_collate_fn(batch)[source]