Source code for argus.callbacks.logging

"""Callbacks for logging argus model training process.
"""
import os
import csv
import logging
from datetime import datetime

from argus.engine import State
from argus.callbacks.callback import Callback, on_epoch_complete


def _format_lr_to_str(lr, precision=5):
    if isinstance(lr, (list, tuple)):
        lr = [f'{l:.{precision}g}' for l in lr]
        lr = "[" + ", ".join(lr) + "]"
    else:
        lr = f'{lr:.{precision}g}'
    return lr


@on_epoch_complete
def metrics_logging(state: State, train=False, print_epoch=True):
    if train:
        epoch_name = 'Train'
        prefix = 'train_'
    else:
        epoch_name = 'Validation'
        prefix = 'val_'

    if print_epoch:
        train_epoch = state.epoch
        message = [f"{epoch_name} - Epoch: {train_epoch}"]
    else:
        message = [epoch_name]

    if train:
        lr = state.model.get_lr()
        lr = _format_lr_to_str(lr)
        message.append(f'LR: {lr}')

    for metric_name, metric_value in state.metrics.items():
        if not metric_name.startswith(prefix):
            continue
        message.append(f"{metric_name}: {metric_value:.7g}")
    state.logger.info(", ".join(message))


[docs]class LoggingToFile(Callback): """Write the argus model training progress into a file. It adds a standard Python logger to log all losses and metrics values during training. The logger is used to output other messages, like info from callbacks and errors. Args: file_path (str): Path to the logging file. create_dir (bool, optional): Create the directory for the logging file if it does not exist. Defaults to True. formatter (str, optional): Standard Python logging formatter to format the log messages. Defaults to '%(asctime)s %(levelname)s %(message)s'. append (bool, optional): Append the log file if it already exists or rewrite it. Defaults to False. """ def __init__(self, file_path, create_dir=True, formatter='[%(asctime)s][%(levelname)s]: %(message)s', append=False): self.file_path = file_path self.create_dir = create_dir self.formatter = logging.Formatter(formatter) self.append = append self.file_handler = None def start(self, state: State): if self.create_dir: dir_path = os.path.dirname(self.file_path) if dir_path: if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) if not self.append and os.path.exists(self.file_path): os.remove(self.file_path) self.file_handler = logging.FileHandler(self.file_path) self.file_handler.setFormatter(self.formatter) state.logger.addHandler(self.file_handler) def complete(self, state: State): state.logger.removeHandler(self.file_handler) def catch_exception(self, state: State): self.complete(state)
[docs]class LoggingToCSV(Callback): """Write the argus model training progress into a CSV file. It logs all losses and metrics values during training into a .csv file for for further analysis or visualization. Args: file_path (str): Path to the .csv logging file. create_dir (bool, optional): Create the directory for the logging file if it does not exist. Defaults to True. separator (str, optional): Values separator character to use. Defaults to ','. write_header (bool, optional): Write the column headers. Defaults to True. append (bool, optional):Append the log file if it already exists or rewrite it. Defaults to False. """ def __init__(self, file_path, create_dir=True, separator=',', write_header=True, append=False): self.file_path = file_path self.separator = separator self.write_header = write_header self.append = append self.csv_file = None self.create_dir = create_dir def start(self, state: State): if self.append: file_mode = 'a' else: file_mode = 'w' if self.create_dir: dir_path = os.path.dirname(self.file_path) if dir_path: if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) self.csv_file = open(self.file_path, file_mode, newline='') def epoch_complete(self, state: State): lr = state.model.get_lr() fields = { 'time': str(datetime.now()), 'epoch': state.epoch, 'lr': _format_lr_to_str(lr), **state.metrics } writer = csv.DictWriter(self.csv_file, fieldnames=fields, delimiter=self.separator) if self.write_header: writer.writeheader() self.write_header = False writer.writerow(fields) self.csv_file.flush() def complete(self, state: State): if self.csv_file is not None: self.csv_file.close() def catch_exception(self, state: State): self.complete(state)