Source code for argus.metrics.loss

from argus.metrics.metric import Metric
from argus.utils import AverageMeter

__all__ = ["Loss"]


[docs]class Loss(Metric): """Accumulates the average of a loss function. By default used by :meth:`argus.model.Model.fit`. """ name = 'loss' def __init__(self): self.avg_meter = AverageMeter() super().__init__() def reset(self): self.avg_meter.reset() def update(self, step_output: dict): self.avg_meter.update(step_output['loss']) def compute(self): if self.avg_meter.count == 0: raise RuntimeError('Must be at least one example for computation') return self.avg_meter.average