Source code for argus.metrics.categorical_accuracy

import torch

from argus.metrics.metric import Metric

__all__ = ["CategoricalAccuracy"]


[docs]class CategoricalAccuracy(Metric): """Calculates the accuracy for multiclass classification.""" name = 'accuracy' better = 'max' def __init__(self): self.correct = 0 self.count = 0 def reset(self): self.correct = 0 self.count = 0 def update(self, step_output: dict): indices = torch.max(step_output['prediction'], dim=1)[1] correct = torch.eq(indices, step_output['target']).view(-1) self.correct += torch.sum(correct).item() self.count += correct.shape[0] def compute(self) -> float: if self.count == 0: raise RuntimeError('Must be at least one example for computation') return self.correct / self.count