Source code for argus.metrics.metric

import math
import torch
import warnings
from typing import Optional, Callable, Union, Tuple, List, Dict, Type

import argus
from argus.callbacks import Callback
from argus.engine import State, Engine


METRIC_REGISTRY: Dict[str, Type['argus.metrics.Metric']] = dict()

__all__ = ["Metric", "attach_metrics"]


def init_better(better: str, monitor: str) -> Tuple[str, Callable, float]:
    if better not in ['min', 'max', 'auto']:
        raise ValueError(f"Unknown better option '{better}'")

    if better == 'auto':
        if monitor.startswith('val_'):
            metric_name = monitor[len('val_'):]
        else:
            metric_name = monitor[len('train_'):]
        if metric_name not in METRIC_REGISTRY:
            raise ImportError(f"Metric '{metric_name}' not found in scope")
        better = METRIC_REGISTRY[metric_name].better

    if better == 'min':
        def _less(a, b):
            return a < b

        better_comp = _less
        best_value = math.inf
    else:  # better == 'max':
        def _greater(a, b):
            return a > b

        better_comp = _greater
        best_value = -math.inf

    return better, better_comp, best_value


class MetricMeta(type):
    def __new__(mcs, name, bases, attrs, *args, **kwargs):
        new_class = super().__new__(mcs, name, bases, attrs)
        metric_name = attrs['name']
        if metric_name:
            if metric_name in METRIC_REGISTRY:
                current_class = (f"<class '{attrs['__module__']}."
                                 f"{attrs['__qualname__']}'>")
                warnings.warn(f"{current_class} redefined '{metric_name}' "
                              "that was already registered by "
                              f"{METRIC_REGISTRY[metric_name]}")
            METRIC_REGISTRY[metric_name] = new_class
        return new_class


[docs]class Metric(Callback, metaclass=MetricMeta): """Base metric class. One needs to create a class inherited from the Metric class, to define a custom metric. In the basic use case scenarios, the following should be done: * Override three methods: reset, update, and compute. * Set class attribute: name, better. Attributes: name (str): Unique metric name. The name is used to reference the metric by other components, like Callbacks. Defaults to ''. better (str): Minimization or maximization is better. Should be ‘min’ or ‘max’. It will be used, for example, by :class:`argus.callbacks.MonitorCheckpoint`. Defaults to 'min'. Example: MAP@k implementation: .. code-block:: python import torch import numpy as np from argus.metrics import Metric def apk(actual, predicted, k=3): if len(predicted) > k: predicted = predicted[:k] score = 0.0 num_hits = 0.0 for i, p in enumerate(predicted): if p in actual and p not in predicted[:i]: num_hits += 1.0 score += num_hits / (i+1.0) if not actual: return 0.0 return score / min(len(actual), k) def mapk(actual, predicted, k=3): return np.mean([apk(a, p, k) for a, p in zip(actual, predicted)]) class MAPatK(Metric): name = 'map_at_k' better = 'max' def __init__(self, k=3): super().__init__() self.k = k self.scores = [] def reset(self): self.scores = [] def update(self, step_output: dict): preds = step_output['prediction'].cpu().numpy() trgs = step_output['target'].cpu().numpy() preds_idx = preds.argsort(axis=1) preds_idx = np.fliplr(preds_idx)[:, :self.k] self.scores += [apk([a], p, self.k) for a, p in zip(trgs, preds_idx)] def compute(self): return np.mean(self.scores) Then you can use the metric like this: .. code-block:: python callbacks = [ MonitorCheckpoint(dir_path='mnist', monitor='val_map_at_k') ] model.fit(train_loader, val_loader=val_loader, metrics=['map_at_k'], # or the same: metrics=[MAPatK(k=3)] callbacks=callbacks) In the case of name-based custom metric reference, it is enough to define or import the metric class in the module to use it. Note that the metric values saved into :class:`argus.engine.State` are prepended with *val_* or *train_*, so, the full metric name, like *val_map_at_k* in the example, should be used to retrieve the metric value, for instance, as a value to monitor by :class:`argus.callbacks.MonitorCheckpoint` """ name: str = '' better: str = 'min'
[docs] def reset(self): """Init or reset internal variables and accumulators."""
[docs] def update(self, step_output: dict): """Update internal variables with a provided *step_output*. *step_output* from default :meth:`argus.model.Model.train_step` and :meth:`argus.model.Model.val_step` looks like:: { 'prediction': The batch predictions, 'target': The batch targets, 'loss': Loss function value } """
[docs] def compute(self): """Compute the custom metric and return the result."""
def epoch_start(self, state: State): self.reset() def iteration_complete(self, state: State): self.update(state.step_output)
[docs] def epoch_complete(self, state: State): """Store metric value to :class:`argus.engine.State`. You can override this method if you want, for example, to save several metrics values in the state. """ with torch.no_grad(): score = self.compute() name_prefix = f"{state.phase}_" if state.phase else '' state.metrics[name_prefix + self.name] = score
def attach_metrics(engine: Engine, metrics: Optional[List[Union[Metric, str]]]): """Attaches metrics to the :class:`argus.engine.Engine`. Finds the metric in the registry if it's a string. Args: engine (Engine): The engine to which metrics will be attached. metrics (list of :class:`argus.metrics.Metric` or str, optional): List of metrics. """ if metrics is None: return for metric in metrics: if isinstance(metric, str): if metric in METRIC_REGISTRY: metric = METRIC_REGISTRY[metric]() else: raise ValueError(f"Metric '{metric}' not found in scope") if isinstance(metric, Metric): metric.attach(engine) else: raise TypeError(f"Expected metric type {Metric} or str, " f"got {type(metric)}")