Source code for argus.callbacks.logging

"""Callbacks for logging argus model training process.
"""
import os
import csv
import logging
from datetime import datetime
from typing import Optional, Union, List, IO

from argus import types
from argus.engine import State
from argus.callbacks.callback import Callback, on_epoch_complete

__all__ = ["LoggingToFile", "LoggingToCSV"]


def _format_lr_to_str(lr: Union[float, List[float]],
                      precision: int = 5) -> str:
    if isinstance(lr, (list, tuple)):
        str_lrs = [f'{_lr:.{precision}g}' for _lr in lr]
        str_lr = "[" + ", ".join(str_lrs) + "]"
    else:
        str_lr = f'{lr:.{precision}g}'
    return str_lr


@on_epoch_complete
def default_logging(state: State):
    message = f"{state.phase} - epoch: {state.epoch}"

    if state.phase == 'train':
        lr = state.model.get_lr()
        message += f', lr: {_format_lr_to_str(lr)}'

    prefix = f"{state.phase}_" if state.phase else ''
    for metric_name, metric_value in state.metrics.items():
        if metric_name.startswith(prefix):
            message += f", {metric_name}: {metric_value:.7g}"
    state.logger.info(message)


[docs]class LoggingToFile(Callback): """Write the argus model training progress into a file. It adds a standard Python logger to log all losses and metrics values during training. The logger is used to output other messages, like info from callbacks and errors. Args: file_path (str): Path to the logging file. create_dir (bool, optional): Create the directory for the logging file if it does not exist. Defaults to True. formatter (str, optional): Standard Python logging formatter to format the log messages. Defaults to '%(asctime)s %(levelname)s %(message)s'. append (bool, optional): Append the log file if it already exists or rewrite it. Defaults to False. """ def __init__(self, file_path: types.Path, create_dir: bool = True, formatter: str = '[%(asctime)s][%(levelname)s]: %(message)s', append: bool = False): self.file_path = file_path self.create_dir = create_dir self.formatter = logging.Formatter(formatter) self.append = append self.file_handler: Optional[logging.FileHandler] = None def start(self, state: State): if self.create_dir: dir_path = os.path.dirname(self.file_path) if dir_path: if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) if not self.append and os.path.exists(self.file_path): os.remove(self.file_path) self.file_handler = logging.FileHandler(self.file_path) self.file_handler.setFormatter(self.formatter) state.logger.addHandler(self.file_handler) def complete(self, state: State): if self.file_handler is not None: state.logger.removeHandler(self.file_handler) def catch_exception(self, state: State): self.complete(state)
[docs]class LoggingToCSV(Callback): """Write the argus model training progress into a CSV file. It logs all losses and metrics values during training into a .csv file for for further analysis or visualization. Args: file_path (str): Path to the .csv logging file. create_dir (bool, optional): Create the directory for the logging file if it does not exist. Defaults to True. separator (str, optional): Values separator character to use. Defaults to ','. write_header (bool, optional): Write the column headers. Defaults to True. append (bool, optional):Append the log file if it already exists or rewrite it. Defaults to False. """ def __init__(self, file_path: types.Path, create_dir: bool = True, separator: str = ',', write_header: bool = True, append: bool = False): self.file_path = file_path self.separator = separator self.write_header = write_header self.append = append self.csv_file: Optional[IO] = None self.create_dir = create_dir def start(self, state: State): file_mode = 'a' if self.append else 'w' if self.create_dir: dir_path = os.path.dirname(self.file_path) if dir_path: if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) self.csv_file = open(self.file_path, file_mode, newline='') def epoch_complete(self, state: State): if self.csv_file is None: return lr = state.model.get_lr() fields = { 'time': str(datetime.now()), 'epoch': state.epoch, 'lr': _format_lr_to_str(lr), **state.metrics } writer = csv.DictWriter(self.csv_file, fieldnames=fields, delimiter=self.separator) if self.write_header: writer.writeheader() self.write_header = False writer.writerow(fields) self.csv_file.flush() def complete(self, state: State): if self.csv_file is not None: self.csv_file.close() def catch_exception(self, state: State): self.complete(state)