"""
Calculating and plotting the convergence curve.
"""
from typing import Dict, List
from matplotlib import pyplot as plt
from ..data_classes import PointList
[docs]
def convergence_curve(log: PointList) -> List[float]:
"""
For a given log return a convergence curve - the lowest value achieved so far.
Args:
log: Results log - the values of errors of the optimized function.
Returns:
y values of the convergence curve.
"""
min_so_far = float("inf")
new_log = []
for value in log:
min_so_far = min(min_so_far, value.y)
new_log.append(min_so_far)
return new_log
[docs]
def plot_convergence_curve(
data: Dict[str, List[PointList]],
savepath: str | None = None,
*,
show: bool = True,
function_name: str | None = None,
) -> None:
"""
Plot the convergence curves of a few methods using pyplot.
Args:
data: Lists of error logs of a few methods
expressed as {method name: [log]}.
savepath: Path to save the plot, optional.
show: Wheather to show the plot, default True.
function_name: Name of the optimized function, used in title.
"""
plt.clf()
for name, loglist in data.items():
ys = [convergence_curve(log) for log in loglist]
max_len = max((len(log) for log in ys))
for log in ys:
log.extend([log[-1]] * (max_len - len(log)))
averaged_y = [sum(values) / len(values) for values in zip(*ys)]
plt.plot(averaged_y, label=name)
plt.yscale("log")
plt.xlabel("evaluations")
plt.ylabel("value")
plt.legend()
if function_name:
plt.title(f"Convergence curves for function {function_name}")
else:
plt.title("Convergence curves")
if savepath:
plt.savefig(savepath)
if show:
plt.show()