Skip to content

semipy.methods.utils.fixmatch_loss

Warning

This section is in construction.

    semipy.methods.utils.fixmatch_loss(model, weak_x, strong_x, y, lbda, threshold, debiased=False)
This function computes the labelled/unlabelled loss with respect to FixMatch method.

Parameters

  • model - The model to train.
  • weak_x - A torch.tensor representing one batch of weakly augmented data (both labelled and unlabelled).
  • strong_x - A torch.tensor representing the same batch of 'weak_x' but with strong augmentation.
  • y - A torch.tensor representing the labels of the current batch (unlabelled are labelled as -1).
  • lbda (float) - Balancing weight for unlabelled loss.
  • threshold (float) - Probability threshold for pseudo-labelling.
  • debiased (bool) - To activate or not safe-SSL (Schmutz et al.) via debiasing. Default: False