Skip to content

semipy.tools.SSLTrainer

Warning

This section is in construction.

    class semipy.tools.SSLTrainer(config: Union[str, dict], model=None, rank: Optional[int] = None, world_size: Optional[int] = None) -> None

This class is a trainer for SSL without using PyTorch Lightning. It creates an environment for training a model with SSL by downloading the data, convert it into an SSL dataset if needed, download a model, a scheduler, an optimizer, create dataloader, etc. It's a simple way to discover SemiPy for newcomers.

Parameters

  • config (str or dict) - A path to a configuration file and a dictionary containing the needed (and wanted) parameters.
  • model- A model to train. It can be a custom one. Default: None.
  • rank (int, optional) - Rank of the current process within num_replicas. Default: None
  • world_size (int, optional) - In distributed mode, total number of GPUs. Default: None

    Methods

        fit()
    

    Launches the training of the model with the right parameters and model.