Source code for abracudabra.conversion.ctensor
"""Convert to Torch tensor."""
from __future__ import annotations
from typing import TYPE_CHECKING
from .._import import get_library_name, raise_library_not_found
from .._validate import Library, validate_obj_type
from ..device.base import Device
from ..device.conversion import tensor_to_device
from ._cdtype import from_numpy_to_torch_dtype, get_frame_result_dtype
if TYPE_CHECKING:
    from torch import Tensor
    from .._annotations import Array, DataFrame, Series
def _to_tensor(
    sequence: Array | Series | DataFrame | Tensor, /, *, strict: bool = False
) -> Tensor:
    """Convert an array, series, or dataframe to a Torch tensor.
    The device of the tensor is determined by the device of the input.
    """
    try:
        import torch
    except ImportError:  # pragma: no cover
        raise_library_not_found("torch")
    from torch.utils.dlpack import from_dlpack
    library = get_library_name(sequence)
    # First convert the sequence to a Torch tensor
    if library == "torch" and validate_obj_type(sequence, Library.torch):
        return sequence
    elif library == "numpy" and validate_obj_type(sequence, Library.numpy):
        return torch.from_numpy(sequence)
    elif library == "cupy" and validate_obj_type(sequence, Library.cupy):
        tensor = torch.as_tensor(sequence)
        if tensor.numel() == 0:  # fix dtype bug when tensor is empty
            tensor = tensor.to(dtype=from_numpy_to_torch_dtype(sequence.dtype))
        return tensor
    elif library == "pandas" and validate_obj_type(sequence, Library.pandas):
        return torch.from_numpy(sequence.to_numpy())
    elif library == "cudf" and validate_obj_type(sequence, Library.cudf):
        dtype = get_frame_result_dtype(sequence)
        tensor = from_dlpack(sequence.to_dlpack())
        # Ensure correct dtype for whatever cudf series/dataframe
        dtype = get_frame_result_dtype(sequence)
        return tensor.to(dtype=from_numpy_to_torch_dtype(dtype))
    if strict:
        msg = (
            f"Expected a NumPy/CuPy array, Pandas/cuDF series or dataframe, "
            f"or Torch tensor, but got '{type(sequence)!r}'."
        )
        raise TypeError(msg)
    # hope for the best
    return torch.as_tensor(sequence)
[docs]
def to_tensor(
    sequence: Array | Series | Tensor,
    /,
    device: Device | str | None = None,
    *,
    strict: bool = False,
) -> Tensor:
    """Convert an array, series, or dataframe to a Torch tensor.
    Args:
        sequence: The sequence to convert.
        device: The device to convert the sequence to. If None, the sequence stays
            on the same device.
        strict: Whether to raise an error if the sequence is not a valid type.
            A NumPy/CuPy array, Pandas/cuDF series or dataframe, or Torch tensor
            are valid types.
            If False, the sequence is converted to a Torch tensor if possible,
            but it might raise an error if the conversion is not possible.
    Returns:
        A Torch tensor.
    Raises:
        TypeError: If the sequence is not a valid type and ``strict`` is True.
    Examples:
        Build a Torch tensor from a sequence
        >>> import torch
        >>> to_tensor([1, 2, 3])
        tensor([1, 2, 3])
        Build a Torch tensor from a CuPy array
        >>> import cupy as cp
        >>> cupy_array = cp.array([4, 5, 6])
        >>> torch_tensor = to_tensor(cupy_array)
        >>> print(torch_tensor.device)
        tensor([4, 5, 6], device='cuda:0')
    """
    tensor = _to_tensor(sequence, strict=strict)
    if device is not None:
        device = Device.parse(device)
        tensor = tensor_to_device(tensor, device)
    return tensor