piscis.convert#

Functions#

convert_jax_to_torch_state_dict(jax_model_name[, ...])

Convert Piscis JAX model weights to PyTorch state dict.

convert_dataset(→ None)

Convert a Piscis dataset saved as a .npz file to directories of .tif and .csv files.

Module Contents#

piscis.convert.convert_jax_to_torch_state_dict(jax_model_name, state_dict=None, verbose=False)#

Convert Piscis JAX model weights to PyTorch state dict.

Parameters:
jax_model_namestr

JAX model name.

state_dictdict

Template state dict from PyTorch model. Default is None.

verbosebool, optional

Whether to print conversion progress. Default is False.

Raises:
ModuleNotFoundError

If Flax cannot be imported.

ValueError

If there is a shape mismatch between JAX and PyTorch weights.

piscis.convert.convert_dataset(dataset_path: str, new_dataset_path: str) None#

Convert a Piscis dataset saved as a .npz file to directories of .tif and .csv files.

Parameters:
dataset_pathstr

Path to the .npz dataset file.

new_dataset_pathstr

Path to the directory for the converted dataset.