piscis.convert
==============

.. py:module:: piscis.convert


Functions
---------

.. autoapisummary::

   piscis.convert.convert_jax_to_torch_state_dict
   piscis.convert.convert_dataset


Module Contents
---------------

.. py:function:: convert_jax_to_torch_state_dict(jax_model_name, state_dict=None, verbose=False)

   
   Convert Piscis JAX model weights to PyTorch state dict.


   :Parameters:

       **jax_model_name** : str
           JAX model name.

       **state_dict** : dict
           Template state dict from PyTorch model. Default is None.

       **verbose** : bool, optional
           Whether to print conversion progress. Default is False.







   :Raises:

       ModuleNotFoundError
           If Flax cannot be imported.

       ValueError
           If there is a shape mismatch between JAX and PyTorch weights.







   ..
       !! processed by numpydoc !!

.. py:function:: convert_dataset(dataset_path: str, new_dataset_path: str) -> None

   
   Convert a Piscis dataset saved as a .npz file to directories of .tif and .csv files.


   :Parameters:

       **dataset_path** : str
           Path to the .npz dataset file.

       **new_dataset_path** : str
           Path to the directory for the converted dataset.














   ..
       !! processed by numpydoc !!

