"""
Surrogate Objective function using FAISS for fast KNN-based regression.
"""
import faiss
import numpy as np
from ...data_classes import Point, PointList
from .surrogate_objective_function import SurrogateObjectiveFunction
[docs]
class KNNSurrogateObjectiveFunction(SurrogateObjectiveFunction):
"""
Surrogate objective function using FAISS for fast KNN-based regression.
"""
[docs]
def __init__(
self,
num_neighbors: int,
train_set: PointList | None = None,
) -> None:
"""
Class constructor.
Args:
num_neighbors: Number of closest neighbors to use in regression.
train_set: Training data for the model.
"""
self.num_neighbors = num_neighbors
self.faiss_index = None
self.y_train = None
super().__init__(
f"FastKNN{num_neighbors}",
train_set,
{"num_neighbors": num_neighbors},
)
[docs]
def train(self, train_set: PointList) -> None:
"""
Train the FAISS-based KNN Surrogate function with provided data.
Args:
train_set: Training data for the model.
"""
super().train(train_set)
x_train, y_train = self.train_set.pairs()
x_train = np.array(x_train, dtype=np.float64)
y_train = np.array(y_train, dtype=np.float64)
self.faiss_index = faiss.IndexFlatL2(x_train.shape[1])
self.faiss_index.add( # pylint: disable=no-value-for-parameter
x_train.astype(np.float32)
)
self.y_train = y_train
[docs]
def __call__(self, point: Point) -> Point:
"""
Estimate the function value at a given point using kNN regression.
Args:
point: Point to estimate.
Returns:
Estimated value of the function at the given point.
"""
super().__call__(point)
if len(self.train_set) < self.num_neighbors:
raise ValueError("Train set length is below number of neighbors.")
x_query = np.array([point.x], dtype=np.float64)
distances, indices = (
self.faiss_index.search( # pylint: disable=no-value-for-parameter
x_query.astype(np.float32),
self.num_neighbors,
)
)
distances = distances.astype(np.float64)
weights = 1 / (np.sqrt(distances) + 1e-8) # avoid division by zero
y_pred = np.sum(self.y_train[indices] * weights, axis=1)[0] / weights.sum()
return Point(
x=point.x,
y=float(y_pred),
is_evaluated=False,
)