Skip to content

semipy.methods.utils.adamatch_loss

Warning

This section is in construction.

    semipy.methods.utils.adamatch_loss(model, weak_x, strong_x, y, lbda, threshold, dist_align, debiased=False)
This function computes the labelled/unlabelled loss with respect to AdaMatch method.

Parameters

  • model - The model to train.
  • weak_x - A torch.tensor representing one batch of weakly augmented data (both labelled and unlabelled).
  • strong_x - A torch.tensor representing the same batch of 'weak_x' but with strong augmentation.
  • y - A torch.tensor representing the labels of the current batch (unlabelled are labelled as -1).
  • lbda (float) - Balancing weight for unlabelled loss.
  • threshold (float) - Probability threshold for pseudo-labelling.
  • dist_align - A semipy.methods.utils.DistAlign class to perform distribution alignment.
  • debiased (bool) - To activate or not safe-SSL (Schmutz et al.) via debiasing. Default: False