metrics

Implementation of certain useful metrics.

swem.metrics.ClassificationReport

A class for tracking various metrics in a classification task.

ClassificationReport

class swem.metrics.ClassificationReport(target_names: list[str | int] | None = None, binary: bool = False, from_probas: bool = False)

A class for tracking various metrics in a classification task.

The class is particularly useful when doing handling a dataset batch-wise where metrics should be aggregated over the whole dataset.

Parameters
  • target_names (list[str | int] | None) – Labels in the classification task in the same order as the output of the model.

  • binary (bool) – Whether or not we are doing binary classification (i.e. the model output is the pre-sigmoid logit for the positive class). Defaults to False.

  • from_probas (bool) – If True we interpret the input as probabilities instead of logits. This is only relevant when dealing with binary classification (since in the multiclass setting the predicted label is an argmax which is the same for logits and probabilities). Defaults to False.

Examples

>>> report = ClassificationReport(target_names=["A", "B"])
>>> logits = torch.tensor([[1,0], [0,1], [1,0]])
>>> labels = torch.tensor([0, 0, 1])
>>> report.update(logits, labels)
>>> report.get()
{
    "num_samples": 3,
    "accuracy": 0.3333333333333333,
    "recall_macro_avg": 0.25,
    "recall_weighted_avg": 0.3333333333333333,
    "precision_macro_avg": 0.25,
    "precision_weighted_avg": 0.3333333333333333,
    "f1_score_macro_avg": 0.25,
    "f1_score_weighted_avg": 0.3333333333333333,
    "class_metrics": {
        "A": {
        "support": 2,
        "recall": 0.5,
        "precision": 0.5,
        "f1_score": 0.5
        },
        "B": {
        "support": 1,
        "recall": 0.0,
        "precision": 0.0,
        "f1_score": 0
        }
    }
}
>>> mask = torch.tensor([1,1,0])
>>> report.reset()
>>> report.update(logits, labels, mask)
>>> report.get()
{
    "num_samples": 2,
    "accuracy": 0.5,
    "recall_macro_avg": 0.25,
    "recall_weighted_avg": 0.5,
    "precision_macro_avg": 0.5,
    "precision_weighted_avg": 1.0,
    "f1_score_macro_avg": 0.3333333333333333,
    "f1_score_weighted_avg": 0.6666666666666666,
    "class_metrics": {
        "A": {
        "support": 2,
        "recall": 0.5,
        "precision": 1.0,
        "f1_score": 0.6666666666666666
        },
        "B": {
        "support": 0,
        "recall": 0,
        "precision": 0.0,
        "f1_score": 0
        }
    }
}
get() dict[str, int | float | dict] | None

Get the current state of all tracked metrics.

Return type

dict[str, int | float | dict] | None

reset()

Reset all tracked values.

update(logits: array_like, labels: array_like, mask: 'array_like' | None = None)

Update the tracked metrics with the results from a new batch.

The inputs can be any type that can be turned into a torch.Tensor (“array_like”).

Parameters
  • logits (array_like) – Output of the model in the classification task (pre-softmax/sigmoid if self.from_probas is False or probabilites if self.from_probas is True).

  • labels (array_like) – The correct labels.

  • mask (array_like | None) – A 0/1-mask telling us which samples to take into account (where mask is 1). Defaults to None.

Shapes:
  • logits: \((*, \text{num_classes})\) if self.binary is False otherwise \((*,)\) where * is any number of dimensions.

  • labels: \((*,)\) where * is the same as for logits.

  • mask: \((*,)\) where * is the same as for logits.