import torch
import collections
from functools import partial
from tempfile import TemporaryFile
from typing import List, Union, Type, Set, Any
from argus import types
__all__ = ["deep_to", "deep_detach", "deep_chunk", "AverageMeter"]
class Default:
def __repr__(self) -> str:
return "default"
class Identity:
def __call__(self, x: types.TVar) -> types.TVar:
return x
def __repr__(self) -> str:
return "Identity()"
default = Default()
identity = Identity()
[docs]def deep_to(input: Any, *args, **kwarg) -> Any:
"""Recursively performs dtype and/or device conversion for
tensors and nn modules.
Args:
input: Any input with tensors, tuples, lists, dicts, and other objects.
*args: args arguments to :meth:`torch.Tensor.to`.
**kwargs: kwargs arguments to :meth:`torch.Tensor.to`.
Returns:
Any: Output with converted tensors.
Example:
::
>>> x = [torch.ones(4, 2, device='cuda:1'),
... {'target': torch.zeros(4, dtype=torch.uint8)}]
>>> x
[tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]], device='cuda:1'),
{'target': tensor([0, 0, 0, 0], dtype=torch.uint8)}]
>>> deep_to(x, 'cuda:0', dtype=torch.float16)
[tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]], device='cuda:0', dtype=torch.float16),
{'target': tensor([0., 0., 0., 0.], device='cuda:0', dtype=torch.float16)}]
"""
if torch.is_tensor(input):
return input.to(*args, **kwarg)
elif isinstance(input, str):
return input
elif isinstance(input, collections.abc.Sequence):
return [deep_to(sample, *args, **kwarg) for sample in input]
elif isinstance(input, collections.abc.Mapping):
return {k: deep_to(sample, *args, **kwarg) for k, sample in input.items()}
elif isinstance(input, torch.nn.Module):
return input.to(*args, **kwarg)
else:
return input
[docs]def deep_detach(input: Any) -> Any:
"""Returns new tensors, detached from the current graph without gradient
requirement. Recursively performs :meth:`torch.Tensor.detach`.
Args:
input: Any input with tensors, tuples, lists, dicts, and other objects.
Returns:
Any: Output with detached tensors.
Example:
::
>>> x = [torch.ones(4, 2),
... {'target': torch.zeros(4, requires_grad=True)}]
>>> x
[tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]),
{'target': tensor([0., 0., 0., 0.], requires_grad=True)}]
>>> deep_detach(x)
[tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]),
{'target': tensor([0., 0., 0., 0.])}]
"""
if torch.is_tensor(input):
return input.detach()
elif isinstance(input, str):
return input
elif isinstance(input, collections.abc.Sequence):
return [deep_detach(sample) for sample in input]
elif isinstance(input, collections.abc.Mapping):
return {k: deep_detach(sample) for k, sample in input.items()}
else:
return input
[docs]def deep_chunk(input: Any, chunks: int, dim: int = 0) -> List[Any]:
"""Slice tensors into approximately equal chunks. Duplicates references to
objects that are not tensors. Recursively performs :func:`torch.chunk`.
Args:
input: Any input with tensors, tuples, lists, dicts, and other objects.
chunks (int): Number of chunks to return.
dim (int): Dimension along which to split the tensors. Defaults to 0.
Returns:
list of Any: List length `chunks` with sliced tensors.
Example:
::
>>> x = [torch.ones(4, 2),
... {'target': torch.zeros(4), 'weights': torch.ones(4)}]
>>> x
[tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]),
{'target': tensor([0., 0., 0., 0.]), 'weights': tensor([1., 1., 1., 1.])}]
>>> deep_chunk(x, 2, 0)
[[tensor([[1., 1.],
[1., 1.]]),
{'target': tensor([0., 0.]), 'weights': tensor([1., 1.])}],
[tensor([[1., 1.],
[1., 1.]]),
{'target': tensor([0., 0.]), 'weights': tensor([1., 1.])}]]
"""
partial_deep_chunk = partial(deep_chunk, chunks=chunks, dim=dim)
if torch.is_tensor(input):
return torch.chunk(input, chunks, dim=dim)
if isinstance(input, str):
return [input for _ in range(chunks)]
if isinstance(input, collections.abc.Sequence) and len(input) > 0:
return list(map(list, zip(*map(partial_deep_chunk, input))))
if isinstance(input, collections.abc.Mapping) and len(input) > 0:
return list(map(type(input), zip(*map(partial_deep_chunk, input.items()))))
else:
return [input for _ in range(chunks)]
def device_to_str(device: types.Devices) -> Union[str, List[str]]:
if isinstance(device, (list, tuple)):
return [str(d) for d in device]
else:
return str(device)
def inheritors(cls: Type[types.TVar]) -> Set[Type[types.TVar]]:
subclasses = set()
cls_list = [cls]
while cls_list:
parent = cls_list.pop()
for child in parent.__subclasses__():
if child not in subclasses:
subclasses.add(child)
cls_list.append(child)
return subclasses
def check_pickleble(obj):
with TemporaryFile() as file:
torch.save(obj, file)
def get_device_indices(devices: List[torch.device]) -> List[int]:
device_ids = []
for dev in devices:
if dev.type != 'cuda':
raise ValueError("Non CUDA device in list of devices")
if dev.index is None:
raise ValueError("CUDA device without index in list of devices")
device_ids.append(dev.index)
if len(device_ids) != len(set(device_ids)):
raise ValueError("CUDA device indices must be unique")
return device_ids
class AverageMeter:
"""Computes and stores the average by Welford's algorithm"""
def __init__(self):
self.average = 0
self.count: int = 0
def reset(self):
self.average = 0
self.count = 0
def update(self, value, n: int = 1):
self.count += n
self.average += (value - self.average) / self.count