Skip to content

semipy.methods.utils.pimodel_loss

Warning

This section is in construction.

    semipy.methods.utils.pimodel_loss(model, weak_x, strong_x, y, lbda, current_iter, pimodel_n_batches, warmup_length, debiased=False)
This function computes the labelled/unlabelled loss with respect to PiModel method.

Parameters

  • model - The model to train.
  • weak_x - A torch.tensor representing one batch of weakly augmented data (both labelled and unlabelled).
  • strong_x - A torch.tensor representing the same batch of 'weak_x' but with strong augmentation.
  • y - A torch.tensor representing the labels of the current batch (unlabelled are labelled as -1).
  • lbda (float) - Balancing weight for unlabelled loss.
  • current_iter (int) - Current training iteration (number of batches per epoch * current epoch + current batch).
  • pimodel_n_batches (int) - Number of batches in a 'PiModel defined epoch' (defined as all items, labelled and unlabelled, have been seen).
  • warmup_length (int) - Number of epochs for PiModel defined warmup.
  • debiased (bool) - To activate or not safe-SSL (Schmutz et al.) via debiasing. Default: False