"""Move an array, series, or tensor to a device."""
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING, Literal, cast, overload
from .._import import get_library_name, raise_library_not_found
from .._validate import Library, validate_obj_type
from .base import Device, DeviceType, _raise_invalid_device_type
if TYPE_CHECKING:
    import cudf
    import pandas as pd
    from torch import Tensor
    from .._annotations import Array, DataFrame, Index, Series
[docs]
def to_cupy_array(sequence: object, /, device_idx: int | None = None) -> object:
    """Convert a sequence to a cupy array.
    Args:
        sequence: The sequence to convert.
        device_idx: The device index to move the array to.
    Returns:
        The sequence as a cupy array.
    """
    try:
        import cupy as cp
    except ImportError:  # pragma: no cover
        raise_library_not_found("cupy")
    with cp.cuda.Device(device_idx) if device_idx is not None else nullcontext():
        return cp.asarray(sequence) 
[docs]
def array_to_device(array: object, /, device: Device | str) -> Array:
    """Move a NumPy/CuPy array to a device.
    Args:
        array: The array to move.
        device: The device to move the array to.
    Returns:
        The array on the specified device.
    """
    device = Device.parse(device)
    library = get_library_name(array)
    if library == "numpy" and validate_obj_type(array, Library.numpy):
        if device.type == "cpu":
            return array
        elif device.type == "cuda":
            return to_cupy_array(array, device.idx)
        else:
            _raise_invalid_device_type(device.type)
    elif library == "cupy" and validate_obj_type(array, Library.cupy):
        match device.type:
            case "cpu":
                return array.get()
            case "cuda":
                return to_cupy_array(array, device.idx)
            case _:
                _raise_invalid_device_type(device.type)
    else:  # guard
        msg = f"Expected a numpy or cupy array, but got '{type(array).__name__}'."
        raise TypeError(msg) 
@overload
[docs]
def frame_to_device(frame: Index, /, device_type: Literal["cpu"]) -> pd.Index: ... 
@overload
def frame_to_device(frame: Index, /, device_type: Literal["cuda"]) -> cudf.Index: ...
@overload
def frame_to_device(frame: Series, /, device_type: Literal["cpu"]) -> pd.Series: ...
@overload
def frame_to_device(frame: Series, /, device_type: Literal["cuda"]) -> cudf.Series: ...
@overload
def frame_to_device(
    frame: DataFrame, /, device_type: Literal["cpu"]
) -> pd.DataFrame: ...
@overload
def frame_to_device(
    frame: DataFrame, /, device_type: Literal["cuda"]
) -> cudf.DataFrame: ...
@overload
def frame_to_device(
    frame: object, /, device_type: DeviceType
) -> Index | Series | DataFrame: ...
def frame_to_device(
    frame: object, /, device_type: DeviceType
) -> Index | Series | DataFrame:
    """Move a Pandas/cuDF index, series or dataframe to a device.
    Args:
        frame: The series or dataframe to move.
        device_type: The device type to move the frame to.
    Returns:
        The series or dataframe on the specified device.
    """
    library = get_library_name(frame)
    if library == "pandas" and validate_obj_type(frame, Library.pandas):
        match device_type:
            case "cpu":
                return frame
            case "cuda":
                try:
                    import cudf
                except ImportError:  # pragma: no cover
                    raise_library_not_found("cudf")
                return cast(
                    cudf.Index | cudf.Series | cudf.DataFrame, cudf.from_pandas(frame)
                )
            case _:
                _raise_invalid_device_type(device_type)
    if library == "cudf" and validate_obj_type(frame, Library.cudf):
        match device_type:
            case "cpu":
                return frame.to_pandas()
            case "cuda":
                return frame
            case _:
                _raise_invalid_device_type(device_type)
    msg = (
        "Expected a pandas or cudf series or dataframe, "
        f"but got '{type(frame).__name__}'."
    )
    raise TypeError(msg)
[docs]
def tensor_to_device(tensor: object, /, device: Device | str) -> Tensor:
    """Move a Torch tensor to a device."""
    if get_library_name(tensor) != "torch" or not validate_obj_type(
        tensor, Library.torch
    ):
        msg = f"Expected a Torch tensor, but got '{type(tensor)!r}'."
        raise TypeError(msg)
    return tensor.to(Device.parse(device).to_torch()) 
@overload
[docs]
def to_device(sequence: Series, /, device: Device | str) -> Series: ... 
@overload
def to_device(sequence: Tensor, /, device: Device | str) -> Tensor: ...
@overload
def to_device(sequence: DataFrame, /, device: Device | str) -> DataFrame: ...
@overload
def to_device(sequence: Index, /, device: Device | str) -> Index: ...
@overload
def to_device(sequence: Array, /, device: Device | str) -> Array: ...
@overload
def to_device(
    sequence: Array | Series | Tensor, /, device: Device | str
) -> Index | Series | DataFrame | Array | Tensor: ...
def to_device(
    sequence: Array | Series | Tensor, /, device: Device | str
) -> Array | Series | Tensor:
    """Move an array, series, or tensor to a device.
    Call the appropriate function to move the element to the device:
    * :py:func:`abracudabra.device.conversion.array_to_device` for NumPy/CuPy arrays.
    * :py:func:`abracudabra.device.conversion.frame_to_device`
      for Pandas/cuDF index/series/dataframes.
    * :py:func:`abracudabra.device.conversion.tensor_to_device` for Torch tensors.
    Args:
        sequence: The sequence to move to the device.
        device: The device to move the sequence to.
    Returns:
        The sequence on the specified device.
    Raises:
        TypeError: If the sequence is not a NumPy/CuPy array, Pandas/cuDF
            index/series/dataframe or Torch tensor.
    Examples:
        Move a Pandas dataframe to the GPU (cuDF dataframe):
        >>> import pandas as pd
        >>> from abracudabra import to_device
        >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
        >>> df_gpu = to_device(df, "cuda")
        >>> print(type(df_gpu))
        <class 'cudf.core.dataframe.DataFrame'>
        Move a cuDF dataframe to the CPU (Pandas dataframe):
        >>> df_cpu = to_device(df_gpu, "cpu")
        >>> print(type(df_cpu))
        <class 'pandas.core.frame.DataFrame
        Move a numpy array to the GPU (cupy):
        >>> import numpy as np
        >>> arr = np.array([1, 2, 3])
        >>> arr_gpu = to_device(arr, "cuda")
        >>> print(type(arr_gpu))
        <class 'cupy.ndarray'>
    """
    library = get_library_name(sequence)
    if library in {"numpy", "cupy"}:
        return array_to_device(sequence, device)
    elif library in {"pandas", "cudf"}:
        device = Device.parse(device)
        return frame_to_device(sequence, device.type)
    elif library == "torch":
        return tensor_to_device(sequence, device)
    else:
        msg = (
            "Expected a NumPy/CuPy array, Pandas/cuDF series/dataframe, "
            f"or Torch tensor, but got '{type(sequence).__name__}'."
        )
        raise TypeError(msg)