Source code for optilab.functions.surrogate.surrogate_objective_function

"""
Abstract base class for surrogate objective functions.
"""

from typing import Any, Dict

from ...data_classes import Point, PointList
from ..objective_function import ObjectiveFunction


[docs] class SurrogateObjectiveFunction(ObjectiveFunction): """ Abstract base class for surrogate objective functions. """
[docs] def __init__( self, name: str, train_set: PointList | None = None, hyperparameters: Dict[str, Any] | None = None, ) -> None: """ Class constructor. The dimensionality is deduced from the training points. Args: name: Name of the surrogate function. train_set: Training data for the model. hyperparameters: Dictionary with hyperparameters of the function. """ self.is_ready = False super().__init__( name, 1, hyperparameters, ) if train_set: self.train(train_set)
[docs] def train(self, train_set: PointList) -> None: """ Train the Surrogate function with provided data. Args: train_set: Training data for the model. Raises: ValueError: If not all points are evaluated. """ if not all((train_point.is_evaluated for train_point in train_set)): raise ValueError("Not all points in the training set are evaluated!") self.is_ready = True dim_set = {point.dim() for point in train_set.points} if not len(dim_set) == 1: raise ValueError( "Provided train set has x-es with different dimensionalities." ) if 0 in dim_set: raise ValueError("0-dim x values found in train set.") self.metadata.dim = list(dim_set)[0] self.train_set = train_set
[docs] def __call__(self, point: Point) -> Point: """ Estimate the value of a single point with the surrogate function. Args: point: Point to estimate. Raises: ValueError: If dimensionality of x doesn't match self.dim. Returns: Estimated value of the function in the provided point. """ if not self.is_ready: raise NotImplementedError("The surrogate function is not trained!") super().__call__(point)