Source code for optilab.utils.pickle_utils

"""
Functions related to loading and dumping optimization results to pickle files.
"""

import pickle
from pathlib import Path
from typing import Any, List

import zstandard as zstd


[docs] def dump_to_pickle( data: Any, pickle_path: Path, zstd_compression: int | None = 1, ) -> None: """ Dump data (such as List[OptimizationRun]) to a pickle file, with option to compress the data using zstandard. Compressed pickles should have *.zstd.pkl extension. Args: data: Data to save to a pickle file. pickle_path: Path to file to save the data. zstd_compression: Zstandard compression level. If None, then no compression is used. """ with open(pickle_path, "wb") as pickle_handle: if zstd_compression: compressor = zstd.ZstdCompressor(level=zstd_compression) with compressor.stream_writer(pickle_handle) as writer: pickle.dump(data, writer, protocol=pickle.HIGHEST_PROTOCOL) else: pickle.dump(data, pickle_handle, protocol=pickle.HIGHEST_PROTOCOL)
[docs] def load_from_pickle(pickle_path: Path) -> Any: """ Load data (such as List[OptimizationRun]) from a pickle file. Zstandard compression is detected from the file extension (*.zstd.pkl). Args: pickle_path: Pickle file path to read from. Returns: Data read from the pickle. """ with open(pickle_path, "rb") as pickle_handle: if pickle_path.suffixes == [".zstd", ".pkl"]: decompressor = zstd.ZstdDecompressor() with decompressor.stream_reader(pickle_handle) as reader: data = pickle.load(reader) else: data = pickle.load(pickle_handle) return data
[docs] def list_all_pickles(path: Path) -> List[Path]: """ Given a path to either a file or directory return a list of all pickle files present there. Args: path: Either a path to a pickle file or path to directory containing pickle files. Returns: List of paths to found pickle files. Raises: ValueError: If the path is a file and not a pickle, or when the path is a directory and contains no pickles. """ file_path_list = [] if path.is_file(): if path.suffix == ".pkl": file_path_list.append(path) else: raise ValueError("Provided file path is not a pickle file.") elif path.is_dir(): for file_path in sorted(path.iterdir()): if file_path.is_file() and file_path.suffix == ".pkl": file_path_list.append(file_path) if len(file_path_list) == 0: raise ValueError("No pickle file found in the provided directory.") return file_path_list