Training#

This guide on training Piscis follows the code from the train_piscis.ipynb notebook.

Step 1: Import Required Libraries

First, import the necessary Piscis modules for training.

from piscis.downloads import download_dataset
from piscis.training import train_model

Step 2: Download the Piscis Dataset

Download the dataset required for this example. Here, we use the dataset labeled 20251212. The download_dataset function downloads the specific dataset from our Hugging Face Dataset Repository.

download_dataset('20251212', '')

Step 3: Train Piscis

Train a new Piscis model using the train_model function. The parameters shown below are exactly the same as what we used to train the 20251212 model.

train_model(
    model_name='new_model',
    dataset_path='20251212',
    initial_model_name=None,
    adjustment='standardize',
    input_size=(256, 256),
    random_seed=0,
    batch_size=4,
    num_workers=0,
    learning_rate=0.1,
    weight_decay=1e-5,
    epochs=500,
    warmup_fraction=0.04,
    decay_fraction=0.4,
    decay_transitions=10,
    decay_factor=0.5,
    l2_loss_weight=0.1,
    dilation_iterations=1,
    max_distance=3.0,
    temperature=0.05,
    epsilon=1e-7,
    checkpoint_every=10,
    device='cuda'
)

See the API reference for the train_model function for more information on each training parameter.

Once training is complete, new_model will be saved to the .piscis/models folder in the user’s home directory, which is then accessible by the Piscis class for inference.

Custom Datasets#

In addition to our preformatted datasets, you can create your own custom datasets using the generate_dataset function.

For users of NimbusImage who would like to convert exported annotations into a custom dataset, see the generate_dataset.ipynb notebook for an example.

Fine-tuning#

Instead of training a new model from scratch, you may want to consider fine-tuning a pre-trained model such as 20251212.

Piscis allows you to initialize training with the weights of an existing model via the initial_model_name parameter of the train_model function.