Source code for optilab.functions.surrogate.knn_surrogate_objective_function

"""
Surrogate Objective function using FAISS for fast KNN-based regression.
"""

import faiss
import numpy as np

from ...data_classes import Point, PointList
from .surrogate_objective_function import SurrogateObjectiveFunction


[docs] class KNNSurrogateObjectiveFunction(SurrogateObjectiveFunction): """ Surrogate objective function using FAISS for fast KNN-based regression. """
[docs] def __init__( self, num_neighbors: int, train_set: PointList | None = None, ) -> None: """ Class constructor. Args: num_neighbors: Number of closest neighbors to use in regression. train_set: Training data for the model. """ self.num_neighbors = num_neighbors self.faiss_index = None self.y_train = None super().__init__( f"FastKNN{num_neighbors}", train_set, {"num_neighbors": num_neighbors}, )
[docs] def train(self, train_set: PointList) -> None: """ Train the FAISS-based KNN Surrogate function with provided data. Args: train_set: Training data for the model. """ super().train(train_set) x_train, y_train = self.train_set.pairs() x_train = np.array(x_train, dtype=np.float64) y_train = np.array(y_train, dtype=np.float64) self.faiss_index = faiss.IndexFlatL2(x_train.shape[1]) self.faiss_index.add( # pylint: disable=no-value-for-parameter x_train.astype(np.float32) ) self.y_train = y_train
[docs] def __call__(self, point: Point) -> Point: """ Estimate the function value at a given point using kNN regression. Args: point: Point to estimate. Returns: Estimated value of the function at the given point. """ super().__call__(point) if len(self.train_set) < self.num_neighbors: raise ValueError("Train set length is below number of neighbors.") x_query = np.array([point.x], dtype=np.float64) distances, indices = ( self.faiss_index.search( # pylint: disable=no-value-for-parameter x_query.astype(np.float32), self.num_neighbors, ) ) distances = distances.astype(np.float64) weights = 1 / (np.sqrt(distances) + 1e-8) # avoid division by zero y_pred = np.sum(self.y_train[indices] * weights, axis=1)[0] / weights.sum() return Point( x=point.x, y=float(y_pred), is_evaluated=False, )