Source code for hpt.utils.load_yaml

"""Utils specific to the tuners module.

Author: @sgpjesus
"""
import logging
from inspect import signature
from numbers import Number
from typing import Union
from pathlib import Path

import yaml
from schema import And
from schema import Optional as Optional_
from schema import Or, Schema

from .classpath import import_object


[docs] class YamlValidator: def __init__(self): self.obs_class = None
[docs] def assert_class_exists(self, path: str) -> bool: """Checks if a given module and Class exists in the current python environment. Saves class in global object to assert arguments. Parameters ---------- path : str Classpath to the Class to be checked. Returns AttributeError If Class does not exist within module. ValueError If classpath is malformed. ------- bool True if check passes. Raises ------ ModuleNotFoundError If module does not exist. AttributeError If Class does not exist within module. ValueError If classpath is malformed. """ try: self.obs_class = import_object(path) except ModuleNotFoundError as e: module_name = path[: path.rindex(".")] raise ModuleNotFoundError( f"Provided module for model in YAML file does not exist. Check " f"for errors in '{module_name}'." ) from e except AttributeError as e: class_name = path[path.rindex(".") + 1 :] raise AttributeError( f"Provided class for model in YAML file is not contained in " f"module. Check for errors in '{class_name}'." ) from e except ValueError as e: raise ValueError( f"Provided classpath is malformed. You should specify the " f"module and class name. Check for errors in '{path}'." ) from e return True
[docs] def assert_argument_exists(self, argument: str) -> bool: """Checks if a given argument for the Model Class to be isntantiated is expected by the Class signature. Uses global object of model to check signature. Parameters ---------- argument : str Argument to be checked. Returns ------- bool True if check passes. Raises ------ TypeError If argument is not expected in Class. """ signature_params = list(signature(self.obs_class).parameters) if "kwargs" in signature_params or argument in signature_params: return True else: raise TypeError( f"Unexpected argument for model '{self.obs_class}':" f" '{argument}'" )
validator = YamlValidator() # Schema for YAML file. HYPERPARAMETER_SPACE_SCHEMA = Schema( { str: { # Model name given by the user "classpath": And(str, validator.assert_class_exists), "kwargs": { Optional_(And(str, validator.assert_argument_exists)): Or( Or(str, Number), And(list, lambda x: len(x) >= 1), And( { "type": Or( "int", "float", ), "range": And( list, lambda values: len(values) == 2, lambda values: all( isinstance(x, Number) for x in values ), ), Optional_("log", default=False): Or(True, False), } ), ) }, }, } )
[docs] def load_hyperparameter_space(path_or_dict: Union[str, Path, dict]) -> dict: """Loads the hyperparameter space encoded as a YAML in the given path. If given a dict, space is already loaded and this function will return the same object. Parameters ---------- path_or_dict : Union[str, dict] Either the path to the YAML file, or a dictionary containing a hyperparameter space following the expected structure. Returns ------- The loaded hyperparameter space. """ # Read hyperparameter space from the YAML file (if given) if isinstance(path_or_dict, (str, Path)): path_obj = Path(path_or_dict).resolve() logging.debug( f"Loading hyperparameter space from the following YAML file: " f"{path_obj}" ) try: with open(str(path_obj), "r") as f_in: hyperparameter_space = yaml.safe_load(f_in) except yaml.YAMLError as err: raise ValueError(f"{err}. Did you pass a valid YAML file ?") from err # Else, assume the given dictionary describes a hyperparameter space elif isinstance(path_or_dict, dict): hyperparameter_space = path_or_dict else: raise ValueError( "Invalid value for `learner_hyperparams`. " "Must be either a path to a YAML file or a dict following the same structure." ) # validate the configuration file HYPERPARAMETER_SPACE_SCHEMA.validate(hyperparameter_space) return hyperparameter_space