torch_utils

swem.utils.torch_utils.to_device

Transfer a whole container of tensors to a device.

Some addtional utilities.

swem.utils.torch_utils.to_device(tensors: Any, device: str | torch.device, non_tensors: str = 'error') Any

Transfer a whole container of tensors to a device.

The function takes an arbitrarily nested structure of uples, lists, and dicts whose entries at some point are tensors (or arbitrary other data types, see argument ‘non_tensors’) and returns the same nested structure but with all tensors transfered to the given device.

Parameters
  • tensors (Any) – The container of (eventually) tensors to be transfered.

  • device (str | torch.device) – The target device.

  • non_tensors (str, optional) – A string describing the behaviour of the function when a non-tensor is encountered. If ‘error’ raises a ValueError, if ‘ignore’ the value is returned as is, if ‘drop’ the value is not included in the output (Note that this may subtly change the nested structure of the output since lists and tuples may be shorter than in the input and dicts may be missing keys). Defaults to “error”.

Raises
  • ValueError – If ‘non_tensors’ is ‘error’ and a value not of type list, tuple, dict, or tensor is encountered.

  • ValueError – If an unsupported option for ‘non_tensor’ is given.

Returns

Same type and nested structure as the input but with all tensors on the given device.

Return type

Any

Examples

>>> x = torch.tensor([1.0])
>>> y = torch.tensor([2.0])
>>> to_device([x, (x, y), {"x": x, "y": y}], device="cuda:0")
[tensor([1.], device='cuda:0'),
(tensor([1.], device='cuda:0'), tensor([2.], device='cuda:0')),
{'x': tensor([1.], device='cuda:0'), 'y': tensor([2.], device='cuda:0')}]