Source code for lale.search.search_space_grid

# Copyright 2019 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import math
import random
import warnings
from collections import ChainMap
from typing import Any, Dict, Iterable, List, Optional, Union

from lale.helpers import (
    DUMMY_SEARCH_SPACE_GRID_PARAM_NAME,
    discriminant_name,
    make_indexed_name,
    nest_all_HPparams,
    nest_choice_all_HPparams,
    structure_type_dict,
    structure_type_list,
    structure_type_name,
    structure_type_tuple,
)
from lale.operators import PlannedOperator
from lale.search.PGO import PGO
from lale.search.schema2search_space import op_to_search_space
from lale.search.search_space import (
    SearchSpace,
    SearchSpaceArray,
    SearchSpaceConstant,
    SearchSpaceDict,
    SearchSpaceEmpty,
    SearchSpaceError,
    SearchSpaceObject,
    SearchSpaceOperator,
    SearchSpacePrimitive,
    SearchSpaceProduct,
    SearchSpaceSum,
    should_print_search_space,
)
from lale.util.Visitor import Visitor, accept

SearchSpaceGrid = Dict[str, SearchSpacePrimitive]


[docs]def search_space_grid_to_string(grid: SearchSpaceGrid) -> str: return "{" + ";".join(f"{k}->{str(v)}" for k, v in grid.items()) + "}"
[docs]def search_space_grids_to_string(grids: List[SearchSpaceGrid]) -> str: return "|".join(search_space_grid_to_string(grid) for grid in grids)
[docs]def get_search_space_grids( op: "PlannedOperator", num_grids: Optional[float] = None, pgo: Optional[PGO] = None, data_schema: Optional[Dict[str, Any]] = None, ) -> List[SearchSpaceGrid]: """Top level function: given a lale operator, returns a list of hp grids. Parameters ---------- op : The lale PlannedOperator num_grids: integer or float, optional if set to an integer => 1, it will determine how many parameter grids will be returned (at most) if set to an float between 0 and 1, it will determine what fraction should be returned note that setting it to 1 is treated as in integer. To return all results, use None pgo: Optional Profile Guided Optimization data that can be used when discretizing continuous parameters data_schema: A schema for the actual data. If provided, it is used to instantiate data dependent schema hyperamparameter specifications. """ all_parameters = op_to_search_space_grids(op, pgo=pgo, data_schema=data_schema) if should_print_search_space("true", "all", "search_space_grids", "grids"): name = op.name() if not name: name = "an operator" print( f"search space grids for {name}:\n{search_space_grids_to_string(all_parameters)}" ) if num_grids is None: return all_parameters else: if num_grids <= 0: warnings.warn( f"get_search_space_grids(num_grids={num_grids}) called with a non-positive value for lale_num_grids" ) return [] if num_grids >= 1: samples = math.ceil(num_grids) if samples >= len(all_parameters): return all_parameters else: warnings.warn( f"get_search_space_grids(num_grids={num_grids}) sampling {math.ceil(num_grids)}/{len(all_parameters)}" ) return random.sample(all_parameters, math.ceil(num_grids)) else: samples = round(len(all_parameters) * num_grids) warnings.warn( f"get_search_space_grids(num_grids={num_grids}) sampling {samples}/{len(all_parameters)}" ) return random.sample(all_parameters, samples)
[docs]def search_space_to_grids(hp: SearchSpace) -> List[SearchSpaceGrid]: return SearchSpaceToGridVisitor.run(hp)
[docs]def op_to_search_space_grids( op: PlannedOperator, pgo: Optional[PGO] = None, data_schema: Optional[Dict[str, Any]] = None, ) -> List[SearchSpaceGrid]: search_space = op_to_search_space(op, pgo=pgo, data_schema=data_schema) grids = search_space_to_grids(search_space) return grids
# lets handle the general case SearchSpaceGridInternalType = Union[List[SearchSpaceGrid], SearchSpacePrimitive]
[docs]class SearchSpaceToGridVisitor(Visitor):
[docs] @classmethod def run(cls, space: SearchSpace) -> List[SearchSpaceGrid]: visitor = cls() grids: SearchSpaceGridInternalType = accept(space, visitor) fixed_grids = cls.fixupDegenerateSearchSpaces(grids) return fixed_grids
[docs] @classmethod def fixupDegenerateSearchSpaces( cls, space: SearchSpaceGridInternalType ) -> List[SearchSpaceGrid]: if isinstance(space, SearchSpacePrimitive): return [{DUMMY_SEARCH_SPACE_GRID_PARAM_NAME: space}] else: return space
[docs] def visitSearchSpacePrimitive( self, space: SearchSpacePrimitive ) -> SearchSpacePrimitive: return space
visitSearchSpaceEnum = visitSearchSpacePrimitive visitSearchSpaceConstant = visitSearchSpaceEnum visitSearchSpaceBool = visitSearchSpaceEnum visitSearchSpaceNumber = visitSearchSpacePrimitive def _searchSpaceList( self, space: SearchSpaceArray, *, size: int ) -> List[SearchSpaceGrid]: sub_spaces = space.items(max_elts=size) param_grids: List[List[SearchSpaceGrid]] = [ nest_all_HPparams( str(index), self.fixupDegenerateSearchSpaces(accept(sub, self)) ) for index, sub in enumerate(sub_spaces) ] param_grids_product: Iterable[Iterable[SearchSpaceGrid]] = itertools.product( *param_grids ) chained_grids: List[SearchSpaceGrid] = [ dict( ChainMap( *gridline, ) ) for gridline in param_grids_product ] if space.is_tuple: st_val = structure_type_tuple else: st_val = structure_type_list discriminated_grids: List[SearchSpaceGrid] = [ {**d, structure_type_name: SearchSpaceConstant(st_val)} for d in chained_grids ] return discriminated_grids
[docs] def visitSearchSpaceArray(self, space: SearchSpaceArray) -> List[SearchSpaceGrid]: if space.minimum == space.maximum: return self._searchSpaceList(space, size=space.minimum) else: ret: List[SearchSpaceGrid] = [] for i in range(space.minimum, space.maximum + 1): ret.extend(self._searchSpaceList(space, size=i)) return ret
[docs] def visitSearchSpaceObject(self, space: SearchSpaceObject) -> List[SearchSpaceGrid]: keys = space.keys keys_len = len(keys) final_choices: List[SearchSpaceGrid] = [] for c in space.choices: assert keys_len == len(c) kvs_complex: List[List[SearchSpaceGrid]] = [] kvs_simple: SearchSpaceGrid = {} for k, v in zip(keys, c): vspace: Union[List[SearchSpaceGrid], SearchSpacePrimitive] = accept( v, self ) if isinstance(vspace, SearchSpacePrimitive): kvs_simple[k] = vspace else: nested_vspace: List[SearchSpaceGrid] = nest_all_HPparams(k, vspace) if nested_vspace: kvs_complex.append(nested_vspace) nested_space_choices: Iterable[Iterable[SearchSpaceGrid]] = ( itertools.product(*kvs_complex) ) nested_space_choices_lists: List[List[SearchSpaceGrid]] = [ list(x) for x in nested_space_choices ] nested_space_choices_filtered: List[List[SearchSpaceGrid]] = [ ll for ll in nested_space_choices_lists if ll ] if nested_space_choices_filtered: chained_grids: Iterable[SearchSpaceGrid] = [ dict(ChainMap(*nested_choice, kvs_simple)) for nested_choice in nested_space_choices_filtered ] final_choices.extend(chained_grids) else: final_choices.append(kvs_simple) return final_choices
[docs] def visitSearchSpaceSum(self, op: SearchSpaceSum) -> SearchSpaceGridInternalType: sub_spaces: List[SearchSpace] = op.sub_spaces sub_grids: Iterable[SearchSpaceGridInternalType] = [ accept(cur_space, self) for cur_space in sub_spaces ] if len(sub_spaces) == 1: return list(sub_grids)[0] else: fixed_grids: Iterable[List[SearchSpaceGrid]] = ( SearchSpaceToGridVisitor.fixupDegenerateSearchSpaces(grid) for grid in sub_grids ) final_grids: List[SearchSpaceGrid] = [] for i, grids in enumerate(fixed_grids): if not grids: grids = [{}] else: # we need to add in this nesting # in case a higher order operator directly contains # another grids = nest_choice_all_HPparams(grids) discriminated_grids: List[SearchSpaceGrid] = [ {**d, discriminant_name: SearchSpaceConstant(i)} for d in grids ] final_grids.extend(discriminated_grids) return final_grids
[docs] def visitSearchSpaceProduct( self, op: SearchSpaceProduct ) -> SearchSpaceGridInternalType: sub_spaces = op.get_indexed_spaces() param_grids: List[List[SearchSpaceGrid]] = [ nest_all_HPparams( make_indexed_name(name, index), self.fixupDegenerateSearchSpaces(accept(space, self)), ) for name, index, space in sub_spaces ] param_grids_product: Iterable[Iterable[SearchSpaceGrid]] = itertools.product( *param_grids ) chained_grids: List[SearchSpaceGrid] = [ dict(ChainMap(*gridline)) for gridline in param_grids_product ] return chained_grids
[docs] def visitSearchSpaceDict(self, op: SearchSpaceDict) -> SearchSpaceGridInternalType: sub_spaces = op.space_dict.items() param_grids: List[List[SearchSpaceGrid]] = [ nest_all_HPparams( name, self.fixupDegenerateSearchSpaces(accept(space, self)), ) for name, space in sub_spaces ] param_grids_product: Iterable[Iterable[SearchSpaceGrid]] = itertools.product( *param_grids ) chained_grids: List[SearchSpaceGrid] = [ dict(ChainMap(*gridline)) for gridline in param_grids_product ] discriminated_grids: List[SearchSpaceGrid] = [ {**d, structure_type_name: SearchSpaceConstant(structure_type_dict)} for d in chained_grids ] return discriminated_grids
[docs] def visitSearchSpaceOperator( self, op: SearchSpaceOperator ) -> SearchSpaceGridInternalType: return accept(op.sub_space, self)
[docs] def visitSearchSpaceEmpty(self, op: SearchSpaceEmpty): raise SearchSpaceError( None, "Grid based backends can't compile an empty (sub-) search space" )