piscis.convert#
Functions#
|
Convert Piscis JAX model weights to PyTorch state dict. |
|
Convert a Piscis dataset saved as a .npz file to directories of .tif and .csv files. |
Module Contents#
- piscis.convert.convert_jax_to_torch_state_dict(jax_model_name, state_dict=None, verbose=False)#
Convert Piscis JAX model weights to PyTorch state dict.
- Parameters:
- jax_model_namestr
JAX model name.
- state_dictdict
Template state dict from PyTorch model. Default is None.
- verbosebool, optional
Whether to print conversion progress. Default is False.
- Raises:
- ModuleNotFoundError
If Flax cannot be imported.
- ValueError
If there is a shape mismatch between JAX and PyTorch weights.