Source code for lale.schema2enums

# Copyright 2019 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from .schema_utils import JsonSchema, SchemaEnum

logger = logging.getLogger(__name__)


[docs]class DiscoveredEnums: def __init__( self, enums: Optional[SchemaEnum] = None, children: Optional[Dict[str, "DiscoveredEnums"]] = None, ) -> None: self.enums = enums self.children = children def __str__(self) -> str: def val_as_str(v): if v is None: return "null" elif isinstance(v, str): return f"'{v}'" else: return str(v) en = "" if self.enums: ens = [val_as_str(v) for v in self.enums] en = ", ".join(sorted(ens)) ch = "" if self.children: chs = [f"{str(k)}->{str(v)}" for k, v in self.children.items()] ch = ",".join(chs) if en and ch: en = en + "; " return "<" + en + ch + ">"
[docs]def schemaToDiscoveredEnums(schema: JsonSchema) -> Optional[DiscoveredEnums]: """Given a schema, returns a positive enumeration set. This is very conservative, and even includes negated enum constants (since the assumption is that they may, in some contexts, be valid) """ def combineDiscoveredEnums( combine: Callable[[Iterable[SchemaEnum]], Optional[SchemaEnum]], des: Iterable[Optional[DiscoveredEnums]], ) -> Optional[DiscoveredEnums]: enums: List[SchemaEnum] = [] children: Dict[str, List[DiscoveredEnums]] = {} for de in des: if de is None: continue if de.enums is not None: enums.append(de.enums) if de.children is not None: for cn, cv in de.children.items(): if cv is None: continue if cn in children: children[cn].append(cv) else: children[cn] = [cv] combined_enums: Optional[SchemaEnum] = None if enums: combined_enums = combine(enums) if not children: if combined_enums is None: return None else: return DiscoveredEnums(enums=combined_enums) else: combined_children: Dict[str, DiscoveredEnums] = {} for ccn, ccv in children.items(): if not ccv: continue ccvc = combineDiscoveredEnums(combine, ccv) if ccvc is not None: combined_children[ccn] = ccvc return DiscoveredEnums(enums=combined_enums, children=combined_children) def joinDiscoveredEnums( des: Iterable[Optional[DiscoveredEnums]], ) -> Optional[DiscoveredEnums]: def op(args: Iterable[SchemaEnum]) -> Optional[SchemaEnum]: return set.union(*args) return combineDiscoveredEnums(op, des) if schema is True or schema is False: return None if "enum" in schema: # TODO: we should validate the enum elements according to the schema, like schema2search_space does return DiscoveredEnums(enums=set(schema["enum"])) if "type" in schema: typ = schema["type"] if typ == "object" and "properties" in schema: props = schema["properties"] pret: Dict[str, DiscoveredEnums] = {} for p, s in props.items(): pos = schemaToDiscoveredEnums(s) if pos is not None: pret[p] = pos if pret: return DiscoveredEnums(children=pret) else: return None else: return None if "not" in schema: neg = schemaToDiscoveredEnums(schema["not"]) return neg if "allOf" in schema: posl = [schemaToDiscoveredEnums(s) for s in schema["allOf"]] return joinDiscoveredEnums(posl) if "anyOf" in schema: posl = [schemaToDiscoveredEnums(s) for s in schema["anyOf"]] return joinDiscoveredEnums(posl) if "oneOf" in schema: posl = [schemaToDiscoveredEnums(s) for s in schema["oneOf"]] return joinDiscoveredEnums(posl) return None
[docs]def accumulateDiscoveredEnumsToPythonEnums( de: Optional[DiscoveredEnums], path: List[str], acc: Dict[str, enum.Enum] ) -> None: def withEnumValue(e: str) -> Tuple[str, Any]: if isinstance(e, str): return (e.replace("-", "_"), e) elif isinstance(e, (int, float, complex)): return ("num" + str(e), e) else: logger.info( f"Unknown type ({type(e)}) of enumeration constant {e}, not handling very well" ) return (str(e), e) if de is None: return if de.enums is not None: ppath, _ = withEnumValue("_".join(path)) epath = ".".join(path) vals = (withEnumValue(x) for x in de.enums if x is not None) # pyright does not currently understand this form acc[ppath] = enum.Enum(epath, vals) # type: ignore if de.children is not None: for k in de.children: accumulateDiscoveredEnumsToPythonEnums(de.children[k], [k] + path, acc)
[docs]def discoveredEnumsToPythonEnums(de: Optional[DiscoveredEnums]) -> Dict[str, enum.Enum]: acc: Dict[str, enum.Enum] = {} accumulateDiscoveredEnumsToPythonEnums(de, [], acc) return acc
[docs]def schemaToPythonEnums(schema: JsonSchema) -> Dict[str, enum.Enum]: de = schemaToDiscoveredEnums(schema) enums = discoveredEnumsToPythonEnums(de) return enums
[docs]def addDictAsFields(obj: Any, d: Dict[str, Any], force=False) -> None: if d is None: return for k, v in d.items(): if k == "": logger.warning( f"There was a top level enumeration specified, so it is not being added to {getattr(obj, '_name', '???')}" ) elif hasattr(obj, k) and not force: logger.error( f"The object {getattr(obj, '_name', '???')} already has the field {k}. This conflicts with our attempt at adding that key as an enumeration field" ) else: setattr(obj, k, v)
[docs]def addSchemaEnumsAsFields(obj: Any, schema: JsonSchema, force=False) -> None: enums = schemaToPythonEnums(schema) addDictAsFields(obj, enums, force)