Source code for abracudabra.device.base
"""Define the base device class."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, NamedTuple, NoReturn, TypeGuard
if TYPE_CHECKING:
from torch import device as torch_device
[docs]
DeviceType = Literal["cpu", "cuda"]
"""The device type, e.g., ``"cpu"`` or ``"cuda"``."""
[docs]
DEVICE_TYPES: frozenset[DeviceType] = frozenset(["cpu", "cuda"])
"""The device types, e.g., ``{"cpu", "cuda"}``."""
[docs]
def is_valid_device_type(device_type: str, /) -> TypeGuard[DeviceType]:
"""Check if a device name is valid."""
return device_type in DEVICE_TYPES
[docs]
def _raise_invalid_device_type(device_type: str, /) -> NoReturn:
"""Raise an error for an invalid device type."""
msg = (
f"Unsupported device type: {device_type!r}. Supported types are: "
+ ", ".join(map(repr, DEVICE_TYPES))
)
raise ValueError(msg)
[docs]
class Device(NamedTuple):
"""A device with a name and index."""
"""The device type, e.g., ``"cpu"`` or ``"cuda"``."""
"""The device index, e.g., ``0`` or ``None``."""
[docs]
def __str__(self) -> str:
"""Return the device name."""
type_ = self.type
return f"{type_}:{idx}" if (idx := self.idx) is not None else type_
@staticmethod
[docs]
def _validate_device(device: object, /) -> DeviceType:
"""Validate a device name."""
device = str(device)
if not is_valid_device_type(device):
_raise_invalid_device_type(device)
return device
@staticmethod
[docs]
def _validate_idx(idx: object | None, /) -> int | None:
"""Validate a device index."""
if idx is None:
return None
try:
return int(idx) # type: ignore[call-overload]
except ValueError as e:
msg = (
"Expected an integer index or None, but got "
f"{idx!r} of type {type(idx).__name__}."
)
raise TypeError(msg) from e
@classmethod
[docs]
def validate(cls, device: object, idx: object | None = None) -> Device:
"""Return a device, validating the device name."""
device = cls._validate_device(device)
idx = cls._validate_idx(idx)
return cls(device, idx)
@classmethod
[docs]
def from_str(cls, device: str, /) -> Device:
"""Return a device from a string."""
if ":" in device:
name, idx = device.split(":", 1)
return cls.validate(name, idx)
return cls.validate(device)
@classmethod
[docs]
def parse(cls, device: str | Device, /) -> Device:
"""Return a device from a string or device."""
if isinstance(device, cls):
return device
return cls.from_str(str(device))
[docs]
def to_torch(self) -> torch_device:
"""Return a torch device."""
from torch import device as torch_device
return torch_device(*self) # type: ignore[arg-type]