import os
from typing import Union, Optional, Callable
import torch
from argus import types
from argus.model.build import MODEL_REGISTRY, cast_device
from argus.utils import deep_to, device_to_str, Default, default, identity
__all__ = ["load_model"]
def default_change_state_dict_func(nn_state_dict: dict,
optimizer_state_dict: Optional[dict] = None):
return nn_state_dict, optimizer_state_dict
def default_state_load_func(file_path: types.Path):
if os.path.isfile(file_path):
return torch.load(file_path)
else:
raise FileNotFoundError(f"No state found at {file_path}")
[docs]
def load_model(file_path: types.Path,
nn_module: Union[Default, types.Param] = default,
optimizer: Union[Default, None, types.Param] = default,
loss: Union[Default, None, types.Param] = default,
prediction_transform: Union[Default, None, types.Param] = default,
device: Union[Default, types.InputDevices] = default,
state_load_func: Callable = default_state_load_func,
change_params_func: Callable = identity,
change_state_dict_func: Callable = default_change_state_dict_func,
model_name: Union[Default, str] = default,
**kwargs):
"""Load an argus model from a file.
The function allows loading an argus model, saved with
:meth:`argus.model.Model.save`. The model is always loaded in *eval* mode.
Args:
file_path (str or :class:`pathlib.Path`): Path to the file to load.
nn_module (dict, tuple or str, optional): Params of the nn_module to
replace params in the state.
optimizer (None, dict, tuple or str, optional): Params of the optimizer to
replace params in the state. Optimizer is not created in the loaded
model if it is set to `None`.
loss (None, dict, tuple or str, optional): Params of the loss to replace params
in the state. Loss is not created in the loaded model if it is set
to `None`.
prediction_transform (None, dict, tuple or str, optional): Params of the
prediction_transform to replace params in the state.
prediction_transform is not created in the loaded model if it is
set to `None`.
device (str, torch.device or list of devices, optional): The model device.
state_load_func (function, optional): Function for loading state from file path.
change_params_func (function, optional): Function for modification of
the loaded params. It takes params from the loaded state as an
input and outputs params to use during the model creation.
change_state_dict_func (function, optional): Function for modification of
nn_module and optimizer state dict. Takes `nn_state_dict` and
`optimizer_state_dict` as inputs and outputs state dicts for the
model creation.
model_name (str, optional): Class name of :class:`argus.model.Model`.
By default uses the name from the loaded state.
Returns:
:class:`argus.model.Model`: Loaded argus model.
Example:
.. code-block:: python
model = ArgusModel(params)
model.save(model_path, optimizer_state=True)
# restarting python...
# ArgusModel class must be already in the scope
model = argus.load_model(model_path, device="cuda:0")
You can find more options how to use load_model
`here <https://github.com/lRomul/argus/blob/master/examples/load_model.py>`_.
Raises:
ImportError: If the model is not available in the scope. Often it means
that it is not imported or defined.
FileNotFoundError: If the file is not found by the *file_path*.
"""
state = state_load_func(file_path)
if isinstance(model_name, Default):
str_model_name = state['model_name']
else:
str_model_name = model_name
if str_model_name in MODEL_REGISTRY:
params = state['params']
if not isinstance(device, Default):
params['device'] = device_to_str(cast_device(device))
if nn_module is not default:
if nn_module is None:
raise ValueError(
"nn_module is a required attribute for argus.Model")
params['nn_module'] = nn_module
if optimizer is not default:
params['optimizer'] = optimizer
if loss is not default:
params['loss'] = loss
if prediction_transform is not default:
params['prediction_transform'] = prediction_transform
for attribute, attribute_params in kwargs.items():
params[attribute] = attribute_params
model_class = MODEL_REGISTRY[str_model_name]
params = change_params_func(params)
model = model_class(params)
nn_state_dict = deep_to(state['nn_state_dict'], model.device)
optimizer_state_dict = None
if 'optimizer_state_dict' in state:
optimizer_state_dict = deep_to(
state['optimizer_state_dict'], model.device)
nn_state_dict, optimizer_state_dict = change_state_dict_func(
nn_state_dict, optimizer_state_dict
)
model.get_nn_module().load_state_dict(nn_state_dict)
if model.optimizer is not None and optimizer_state_dict is not None:
model.optimizer.load_state_dict(optimizer_state_dict)
model.eval()
return model
else:
raise ImportError(f"Model '{model_name}' not found in scope")