"""
Functions related to loading and dumping optimization results to pickle files.
"""
import pickle
from pathlib import Path
from typing import Any, List
import zstandard as zstd
[docs]
def dump_to_pickle(
data: Any,
pickle_path: Path,
zstd_compression: int | None = 1,
) -> None:
"""
Dump data (such as List[OptimizationRun]) to a pickle file, with option to compress
the data using zstandard. Compressed pickles should have *.zstd.pkl extension.
Args:
data: Data to save to a pickle file.
pickle_path: Path to file to save the data.
zstd_compression: Zstandard compression level. If None, then
no compression is used.
"""
with open(pickle_path, "wb") as pickle_handle:
if zstd_compression:
compressor = zstd.ZstdCompressor(level=zstd_compression)
with compressor.stream_writer(pickle_handle) as writer:
pickle.dump(data, writer, protocol=pickle.HIGHEST_PROTOCOL)
else:
pickle.dump(data, pickle_handle, protocol=pickle.HIGHEST_PROTOCOL)
[docs]
def load_from_pickle(pickle_path: Path) -> Any:
"""
Load data (such as List[OptimizationRun]) from a pickle file. Zstandard compression
is detected from the file extension (*.zstd.pkl).
Args:
pickle_path: Pickle file path to read from.
Returns:
Data read from the pickle.
"""
with open(pickle_path, "rb") as pickle_handle:
if pickle_path.suffixes == [".zstd", ".pkl"]:
decompressor = zstd.ZstdDecompressor()
with decompressor.stream_reader(pickle_handle) as reader:
data = pickle.load(reader)
else:
data = pickle.load(pickle_handle)
return data
[docs]
def list_all_pickles(path: Path) -> List[Path]:
"""
Given a path to either a file or directory return a list of all pickle files present there.
Args:
path: Either a path to a pickle file or path to directory containing pickle files.
Returns:
List of paths to found pickle files.
Raises:
ValueError: If the path is a file and not a pickle, or when the path is a directory and
contains no pickles.
"""
file_path_list = []
if path.is_file():
if path.suffix == ".pkl":
file_path_list.append(path)
else:
raise ValueError("Provided file path is not a pickle file.")
elif path.is_dir():
for file_path in sorted(path.iterdir()):
if file_path.is_file() and file_path.suffix == ".pkl":
file_path_list.append(file_path)
if len(file_path_list) == 0:
raise ValueError("No pickle file found in the provided directory.")
return file_path_list