Source code for argus.callbacks.callback

"""Base class for Callbacks.
"""

from typing import Callable

from argus.utils import inheritors
from argus.engine import Events, EventEnum


[docs]class Callback: """Base callback class. Raises: TypeError: Attribute is not callable. """ def attach(self, engine, handler_kwargs_dict=None): if handler_kwargs_dict is None: handler_kwargs_dict = dict() for event_enum in inheritors(EventEnum): for key, event in event_enum.__members__.items(): if hasattr(self, event.value): handler = getattr(self, event.value) if isinstance(handler, Callable): handler_kwargs = handler_kwargs_dict.get(event, dict()) engine.add_event_handler(event, handler, **handler_kwargs) else: raise TypeError(f"Attribute {event.value} is not callable.")
class FunctionCallback(Callback): def __init__(self, event: EventEnum, handler): self.event = event self.handler = handler def attach(self, engine, *args, **kwargs): engine.add_event_handler(self.event, self.handler, *args, **kwargs) def on_event(event): def wrap(func): return FunctionCallback(event, func) return wrap def on_start(func): return FunctionCallback(Events.START, func) def on_complete(func): return FunctionCallback(Events.COMPLETE, func) def on_epoch_start(func): return FunctionCallback(Events.EPOCH_START, func) def on_epoch_complete(func): return FunctionCallback(Events.EPOCH_COMPLETE, func) def on_iteration_start(func): return FunctionCallback(Events.ITERATION_START, func) def on_iteration_complete(func): return FunctionCallback(Events.ITERATION_COMPLETE, func) def on_catch_exception(func): return FunctionCallback(Events.CATCH_EXCEPTION, func)