Skip to content

semipy.methods.utils.pseudolabel_loss

Warning

This section is in construction.

    semipy.methods.utils.pseudolabel_loss(model, X, y, lbda, threshold, debiased=False)

This function computes the labelled/unlabelled loss with respect to PseudoLabel method.

Parameters

  • model - The model to train.
  • X - A torch.tensor representing one batch of data (both labelled and unlabelled).
  • y - A torch.tensor representing the labels of the current batch (unlabelled are labelled as -1).
  • lbda (float) - Balancing weight for unlabelled loss.
  • threshold (float) - Probability threshold for pseudo-labelling.
  • debiased (bool) - To activate or not safe-SSL (Schmutz et al.) via debiasing. Default: False