Source code for argus.callbacks.checkpoints

"""Callbacks for argus model saving.
import os
import math
import warnings

from argus.engine import State
from argus.callbacks.callback import Callback
from argus.metrics.metric import init_better

[docs]class Checkpoint(Callback): """Save the model with a given period. In the simplest case, the callback can be used to save the model after each epoch. Args: dir_path (str, optional): Directory to save checkpoints. The desired directory will be created if it does not exist. Defaults to ''. file_format (str, optional): Model saving filename format. Any valid value names from the model State may be used. Defaults to 'model-{epoch:03d}-{train_loss:.6f}.pth'. max_saves (int, optional): Number of last saved models to keep. Should be positive. If None - save all models. Defaults to None. period (int, optional): Interval (number of epochs) between checkpoint saves. Defaults to 1. save_after_exception (bool, optional): Save the model checkpoint after an exception occurs. Defaults to False. """ def __init__(self, dir_path='', file_format='model-{epoch:03d}-{train_loss:.6f}.pth', max_saves=None, period=1, save_after_exception=False): if not (max_saves is None or max_saves > 0): raise ValueError("max_saves should be positive or 'None'") self.dir_path = dir_path self.file_format = file_format self.max_saves = max_saves self.saved_files_paths = [] if self.dir_path: if not os.path.exists(dir_path): os.makedirs(dir_path) else: warnings.warn(f"Directory '{dir_path}' already exists") self.period = period self.save_after_exception = save_after_exception self.epochs_since_last_save = 0
[docs] def save_model(self, state: State, file_path): """Save model to file. Override the method if you need custom checkpoint saving. Args: state (:class:`argus.engine.State`): State. file_path (str): Checkpoint file path. """
def _format_file_path(self, state: State): format_state = {'epoch': state.epoch, **state.metrics} file_name = self.file_format.format(**format_state) file_path = os.path.join(self.dir_path, file_name) return file_path def start(self, state: State): self.epochs_since_last_save = 0 self.saved_files_paths = [] def save_checkpoint(self, state: State): self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 file_path = self._format_file_path(state) self.save_model(state, file_path) self.saved_files_paths.append(file_path) if self.max_saves is not None: if len(self.saved_files_paths) > self.max_saves: old_file_path = self.saved_files_paths.pop(0) if os.path.exists(old_file_path): os.remove(old_file_path)"Model removed '{old_file_path}'") def epoch_complete(self, state: State): self.save_checkpoint(state) def catch_exception(self, state: State): if self.save_after_exception: exception_model_path = os.path.join(self.dir_path, 'model-after-exception.pth') self.save_model(state, exception_model_path)
[docs]class MonitorCheckpoint(Checkpoint): """Save the model checkpoints after a metric is improved. The MonitorCheckpoint augments the simple Checkpoint with a metric monitoring. It saves the model after the defined metric is improved. It is possible to monitor loss values during training as well as any metric available in the model State. Args: dir_path (str, optional): Directory to save checkpoints. The desired directory will be created if it does not exist. Defaults to ''. file_format (str, optional): Model saving filename format. Any valid value names from the model State may be used. Defaults to 'model-{epoch:03d}-{monitor:.6f}.pth'. max_saves ([type], optional): Number of last saved models to keep. Should be positive. If None - save all models. Defaults to None. save_after_exception (bool, optional): Save the model checkpoint after an exception occurs. Defaults to False. monitor (str, optional): Metric name to monitor. It should be prepended with *val_* for the metric value on validation data and *train_* for the metric value on the date from the train loader. A val_loader should be provided during the model fit to make it possible to monitor metrics start with *val_*. Defaults to *val_loss*. better (str, optional): The metric improvement criterion. Should be 'min', 'max' or 'auto'. 'auto' means the criterion should be taken from the metric itself, which is appropriate behavior in most cases. Defaults to 'auto'. """ def __init__(self, dir_path='', file_format='model-{epoch:03d}-{monitor:.6f}.pth', max_saves=None, save_after_exception=False, monitor='val_loss', better='auto'): if not monitor.startswith('val_') and not monitor.startswith('train_'): raise ValueError("monitor should be prepended with 'val_' or 'train_'") super().__init__(dir_path=dir_path, file_format=file_format, max_saves=max_saves, period=1, save_after_exception=save_after_exception) self.monitor = monitor self.better, self.better_comp, self.best_value = init_better( better, monitor) def _format_file_path(self, state: State): format_state = {'epoch': state.epoch, 'monitor': state.metrics[self.monitor], **state.metrics} file_name = self.file_format.format(**format_state) file_path = os.path.join(self.dir_path, file_name) return file_path def start(self, state: State): self.best_value = math.inf if self.better == 'min' else -math.inf def epoch_complete(self, state: State): if self.monitor not in state.metrics: raise ValueError(f"Monitor '{self.monitor}' metric not found in state") current_value = state.metrics[self.monitor] if self.better_comp(current_value, self.best_value): self.best_value = current_value self.save_checkpoint(state)