Source code for argus.engine.engine

"""Events, State, and Engine in the current file are highly inspired by
pytorch-ignite (https://github.com/pytorch/ignite).
"""
import logging
from enum import Enum
from types import MethodType
from collections import defaultdict
from typing import Callable, Optional, Iterable, Tuple, List, Dict, Any

import argus

__all__ = [
    "EventEnum",
    "Events",
    "State",
    "Engine",
]


[docs]class EventEnum(Enum): """Base class for engine events. User defined custom events should also inherit this class. Example of creating custom events you can find `here <https://github.com/lRomul/argus/blob/master/examples/custom_events.py>`_. """
[docs]class Events(EventEnum): """Events that are fired by the :class:`argus.engine.Engine` during running. Built-in events: - ``START``: triggered when the engine's run is started. - ``COMPLETE``: triggered when the engine's run is completed. - ``EPOCH_START``: triggered when the epoch is started. - ``EPOCH_COMPLETE``: triggered when the epoch is ended. - ``ITERATION_START``: triggered when an iteration is started. - ``ITERATION_COMPLETE``: triggered when the iteration is ended. - ``CATCH_EXCEPTION``: triggered on catching of an exception. """ START = "start" COMPLETE = "complete" EPOCH_START = "epoch_start" EPOCH_COMPLETE = "epoch_complete" ITERATION_START = "iteration_start" ITERATION_COMPLETE = "iteration_complete" CATCH_EXCEPTION = "catch_exception"
def init_step_method( step_method: Callable ) -> Tuple[Callable, 'argus.model.Model', str]: if isinstance(step_method, MethodType): model = step_method.__self__ if isinstance(model, argus.model.Model): phase: str = step_method.__name__ if phase.endswith('_step'): phase = phase[:-len('_step')] return step_method, model, phase raise TypeError("step_method must be a method of 'argus.model.Model'.")
[docs]class State: """A state used to store internal and user-defined variables during a run of :class:`argus.engine.Engine`. The class is highly inspired by the State from `pytorch-ignite <https://github.com/pytorch/ignite>`_. Args: step_method (Callable): Method of :class:`argus.model.Model` that takes ``batch, state`` and returns step output. engine (Engine, optional): :class:`argus.engine.Engine` that uses this object as a state. phase_states (dict, optional): Dictionary with states for each training phase. **kwargs: Initial attributes of the state. By default, the state contains the following attributes. Attributes: iteration (int): Iteration, the first iteration is 0. epoch (int): Epoch, the first iteration is 0. model (:class:`argus.Model`): :class:`argus.Model` that uses :attr:`argus.engine.State.engine` and this object as a state. data_loader (Iterable, optional): A data passed to the :class:`argus.engine.Engine`. logger (logging.Logger): Logger. exception (BaseException, optional): Catched exception. engine (Engine, optional): :class:`argus.engine.Engine` that uses this object as a state. phase (str): A phase of training this state was created for. The value takes from the name of the method `step_method`. If the `step_method` name ends with `_step`, the postfix will be removed. For default steps of argus model values are 'train' and 'val'. phase_states (dict, optional): Dictionary with states for each training phase. batch (Any): Batch sample from a data loader on the current iteration. step_output (Any): Current output from `step_method` on current iteration. metrics (dict): Dictionary with metrics values. stopped (bool): Boolean indicates :class:`argus.engine.Engine` is stopped or not. """ def __init__(self, step_method: Callable[[Any, 'argus.engine.State'], Any], engine: Optional['argus.engine.Engine'] = None, phase_states: Optional[Dict[str, 'argus.engine.State']] = None, **kwargs): self.iteration: int = 0 self.epoch: int = 0 self.step_method, self.model, self.phase = init_step_method(step_method) if phase_states is not None: phase_states[self.phase] = self self.phase_states = phase_states self.logger: logging.Logger = self.model.logger self.data_loader: Optional[Iterable] = None self.exception: Optional[BaseException] = None self.engine: Optional[Engine] = engine self.batch: Any = None self.step_output: Any = None self.metrics: Dict[str, Any] = dict() self.stopped: bool = True self.update(**kwargs)
[docs] def update(self, **kwargs): """ Update state attributes. Args: **kwargs: Update attributes using kwargs """ for key, value in kwargs.items(): setattr(self, key, value)
[docs]class Engine: """Runs ``step_method`` over each batch of a data loader with triggering event handlers. The class is highly inspired by the Engine from `pytorch-ignite <https://github.com/pytorch/ignite>`_. Args: step_method (Callable): Method of :class:`argus.model.Model` that takes ``batch, state`` and returns step output. phase_states (dict, optional): Dictionary with states for each training phase. **kwargs: Initial attributes of the state. Attributes: state (State): Stores internal and user-defined variables during a run of the engine. step_method (Callable): Method of :class:`argus.model.Model` that takes ``batch, state`` and returns step output. event_handlers (dict of EventEnum: list): Dictionary that stores event handlers. """ def __init__(self, step_method: Callable[[Any, State], Any], phase_states: Optional[Dict[str, State]] = None, **kwargs): self.event_handlers: Dict[ EventEnum, List[Tuple[Callable, Tuple, Dict]] ] = defaultdict(list) self.step_method = step_method self.state = State( step_method=step_method, engine=self, phase_states=phase_states, **kwargs )
[docs] def add_event_handler(self, event: EventEnum, handler: Callable, *args, **kwargs): """Add an event handler to be executed when the event is triggered. Args: event (EventEnum): An event that will be associated with the handler. handler (Callable): A callable handler that will be executed on the event. The handler should take :class:`argus.engine.State` as the first argument. *args: optional args arguments to be passed to the handler. **kwargs: optional kwargs arguments to be passed to the handler. """ if not isinstance(event, EventEnum): raise TypeError("Event should be 'argus.engine.EventEnum' enum") self.event_handlers[event].append((handler, args, kwargs))
[docs] def raise_event(self, event: EventEnum): """Execute all the handlers associated with the given event. Args: event (EventEnum): An event that will be triggered. """ if not isinstance(event, EventEnum): raise TypeError("Event should be 'argus.engine.EventEnum' enum") if event in self.event_handlers: for handler, args, kwargs in self.event_handlers[event]: handler(self.state, *args, **kwargs)
[docs] def run(self, data_loader: Iterable, start_epoch: int = 0, end_epoch: int = 1) -> State: """Run ``step_method`` on each batch from data loader ``end_epoch - start_epoch`` times. Args: data_loader (Iterable): An iterable collection that returns batches. start_epoch (int): The first epoch number. end_epoch (int): One above the largest epoch number. Returns: State: An engine state. """ self.state.update(data_loader=data_loader, epoch=start_epoch, iteration=0, stopped=False) try: self.raise_event(Events.START) while self.state.epoch < end_epoch and not self.state.stopped: self.state.iteration = 0 self.state.metrics = dict() self.raise_event(Events.EPOCH_START) for batch in data_loader: self.state.batch = batch self.raise_event(Events.ITERATION_START) self.state.step_output = self.step_method(batch, self.state) self.raise_event(Events.ITERATION_COMPLETE) self.state.step_output = None if self.state.stopped: break self.state.iteration += 1 self.raise_event(Events.EPOCH_COMPLETE) self.state.epoch += 1 self.raise_event(Events.COMPLETE) except BaseException as exception: if self.state.logger is not None: self.state.logger.exception(exception) self.state.exception = exception self.raise_event(Events.CATCH_EXCEPTION) raise exception finally: self.state.stopped = True return self.state