Source code for abracudabra.device.base
"""Define the base device class."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, NamedTuple, NoReturn, TypeGuard
if TYPE_CHECKING:
from torch import device as torch_device
[docs]
DeviceType = Literal["cpu", "cuda"]
"""The device type, e.g., ``"cpu"`` or ``"cuda"``."""
[docs]
DEVICE_TYPES: frozenset[DeviceType] = frozenset(["cpu", "cuda"])
"""The supported device types."""
def _is_valid_device_type(device_type: str, /) -> TypeGuard[DeviceType]:
"""Check if a device name is valid."""
return device_type in DEVICE_TYPES
def _raise_invalid_device_type(device_type: str, /) -> NoReturn:
"""Raise an error for an invalid device type."""
msg = (
f"Unsupported device type: {device_type!r}. Supported types are: "
+ ", ".join(map(repr, DEVICE_TYPES))
)
raise ValueError(msg)
[docs]
class Device(NamedTuple):
"""A device with a name and index."""
"""The device type, e.g., ``"cpu"`` or ``"cuda"``."""
"""The device index, e.g., ``0`` or ``None``."""
[docs]
def __str__(self) -> str:
"""Return the device name."""
type_ = self.type
return f"{type_}:{idx}" if (idx := self.idx) is not None else type_
@staticmethod
def _validate_type(device_type: object, /) -> DeviceType:
"""Validate a device type."""
device_type = str(device_type)
if not _is_valid_device_type(device_type):
_raise_invalid_device_type(device_type)
return device_type
@staticmethod
def _validate_idx(idx: object | None, /) -> int | None:
"""Validate a device index."""
if idx is None:
return None
try:
return int(idx) # type: ignore[call-overload]
except ValueError as e:
msg = (
"Expected an integer index or None, but got "
f"{idx!r} of type {type(idx).__name__}."
)
raise TypeError(msg) from e
@classmethod
[docs]
def validate(cls, device: object, idx: object | None = None) -> Device:
"""Return a device, validating the device type and index.
Args:
device: The device type.
idx: The optional device index.
Returns:
The device.
"""
device = cls._validate_type(device)
idx = cls._validate_idx(idx)
return cls(device, idx)
@classmethod
[docs]
def from_str(cls, device: str, /) -> Device:
"""Return a device from a string.
The string should be in the format ``"device[:idx]"``.
Examples:
>>> Device.from_str("cpu")
Device(type="cpu", idx=None)
>>> Device.from_str("cuda:1")
Device(type="cuda", idx=1)
"""
if ":" in device:
name, idx = device.split(":", 1)
return cls.validate(name, idx)
return cls.validate(device)
@classmethod
[docs]
def parse(cls, device: str | Device | torch_device, /) -> Device:
"""Return a device from a string or device.
If the input is already a device, it is returned as is.
Otherwise, the input is parsed as a string.
Args:
device: The device or device string (e.g., ``"cpu"`` or ``"cuda:1"``).
Returns:
The device.
"""
if isinstance(device, cls):
return device
# This works with strings and torch.device objects
return cls.from_str(str(device))
[docs]
def to_torch(self) -> torch_device:
"""Return a torch device.
Examples:
>>> Device("cpu", None).to_torch()
device(type='cpu')
>>> Device("cuda", 1).to_torch()
device(type='cuda', index=1)
"""
from torch import device as torch_device
return torch_device(*self) # type: ignore[arg-type]