from __future__ import annotations
|
|
from datetime import (
|
date,
|
datetime,
|
time,
|
timedelta,
|
)
|
from decimal import Decimal
|
import re
|
|
import numpy as np
|
|
from pandas._libs.tslibs import (
|
Timedelta,
|
Timestamp,
|
)
|
from pandas._typing import (
|
TYPE_CHECKING,
|
DtypeObj,
|
type_t,
|
)
|
from pandas.compat import pa_version_under7p0
|
from pandas.util._decorators import cache_readonly
|
|
from pandas.core.dtypes.base import (
|
StorageExtensionDtype,
|
register_extension_dtype,
|
)
|
from pandas.core.dtypes.dtypes import CategoricalDtypeType
|
|
if not pa_version_under7p0:
|
import pyarrow as pa
|
|
if TYPE_CHECKING:
|
from pandas.core.arrays.arrow import ArrowExtensionArray
|
|
|
@register_extension_dtype
|
class ArrowDtype(StorageExtensionDtype):
|
"""
|
An ExtensionDtype for PyArrow data types.
|
|
.. warning::
|
|
ArrowDtype is considered experimental. The implementation and
|
parts of the API may change without warning.
|
|
While most ``dtype`` arguments can accept the "string"
|
constructor, e.g. ``"int64[pyarrow]"``, ArrowDtype is useful
|
if the data type contains parameters like ``pyarrow.timestamp``.
|
|
Parameters
|
----------
|
pyarrow_dtype : pa.DataType
|
An instance of a `pyarrow.DataType <https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions>`__.
|
|
Attributes
|
----------
|
pyarrow_dtype
|
|
Methods
|
-------
|
None
|
|
Returns
|
-------
|
ArrowDtype
|
|
Examples
|
--------
|
>>> import pyarrow as pa
|
>>> pd.ArrowDtype(pa.int64())
|
int64[pyarrow]
|
|
Types with parameters must be constructed with ArrowDtype.
|
|
>>> pd.ArrowDtype(pa.timestamp("s", tz="America/New_York"))
|
timestamp[s, tz=America/New_York][pyarrow]
|
>>> pd.ArrowDtype(pa.list_(pa.int64()))
|
list<item: int64>[pyarrow]
|
""" # noqa: E501
|
|
_metadata = ("storage", "pyarrow_dtype") # type: ignore[assignment]
|
|
def __init__(self, pyarrow_dtype: pa.DataType) -> None:
|
super().__init__("pyarrow")
|
if pa_version_under7p0:
|
raise ImportError("pyarrow>=7.0.0 is required for ArrowDtype")
|
if not isinstance(pyarrow_dtype, pa.DataType):
|
raise ValueError(
|
f"pyarrow_dtype ({pyarrow_dtype}) must be an instance "
|
f"of a pyarrow.DataType. Got {type(pyarrow_dtype)} instead."
|
)
|
self.pyarrow_dtype = pyarrow_dtype
|
|
def __repr__(self) -> str:
|
return self.name
|
|
@property
|
def type(self):
|
"""
|
Returns associated scalar type.
|
"""
|
pa_type = self.pyarrow_dtype
|
if pa.types.is_integer(pa_type):
|
return int
|
elif pa.types.is_floating(pa_type):
|
return float
|
elif pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
|
return str
|
elif (
|
pa.types.is_binary(pa_type)
|
or pa.types.is_fixed_size_binary(pa_type)
|
or pa.types.is_large_binary(pa_type)
|
):
|
return bytes
|
elif pa.types.is_boolean(pa_type):
|
return bool
|
elif pa.types.is_duration(pa_type):
|
if pa_type.unit == "ns":
|
return Timedelta
|
else:
|
return timedelta
|
elif pa.types.is_timestamp(pa_type):
|
if pa_type.unit == "ns":
|
return Timestamp
|
else:
|
return datetime
|
elif pa.types.is_date(pa_type):
|
return date
|
elif pa.types.is_time(pa_type):
|
return time
|
elif pa.types.is_decimal(pa_type):
|
return Decimal
|
elif pa.types.is_dictionary(pa_type):
|
# TODO: Potentially change this & CategoricalDtype.type to
|
# something more representative of the scalar
|
return CategoricalDtypeType
|
elif pa.types.is_list(pa_type) or pa.types.is_large_list(pa_type):
|
return list
|
elif pa.types.is_map(pa_type):
|
return dict
|
elif pa.types.is_null(pa_type):
|
# TODO: None? pd.NA? pa.null?
|
return type(pa_type)
|
else:
|
raise NotImplementedError(pa_type)
|
|
@property
|
def name(self) -> str: # type: ignore[override]
|
"""
|
A string identifying the data type.
|
"""
|
return f"{str(self.pyarrow_dtype)}[{self.storage}]"
|
|
@cache_readonly
|
def numpy_dtype(self) -> np.dtype:
|
"""Return an instance of the related numpy dtype"""
|
if pa.types.is_string(self.pyarrow_dtype):
|
# pa.string().to_pandas_dtype() = object which we don't want
|
return np.dtype(str)
|
try:
|
return np.dtype(self.pyarrow_dtype.to_pandas_dtype())
|
except (NotImplementedError, TypeError):
|
return np.dtype(object)
|
|
@cache_readonly
|
def kind(self) -> str:
|
if pa.types.is_timestamp(self.pyarrow_dtype):
|
# To mirror DatetimeTZDtype
|
return "M"
|
return self.numpy_dtype.kind
|
|
@cache_readonly
|
def itemsize(self) -> int:
|
"""Return the number of bytes in this dtype"""
|
return self.numpy_dtype.itemsize
|
|
@classmethod
|
def construct_array_type(cls) -> type_t[ArrowExtensionArray]:
|
"""
|
Return the array type associated with this dtype.
|
|
Returns
|
-------
|
type
|
"""
|
from pandas.core.arrays.arrow import ArrowExtensionArray
|
|
return ArrowExtensionArray
|
|
@classmethod
|
def construct_from_string(cls, string: str) -> ArrowDtype:
|
"""
|
Construct this type from a string.
|
|
Parameters
|
----------
|
string : str
|
string should follow the format f"{pyarrow_type}[pyarrow]"
|
e.g. int64[pyarrow]
|
"""
|
if not isinstance(string, str):
|
raise TypeError(
|
f"'construct_from_string' expects a string, got {type(string)}"
|
)
|
if not string.endswith("[pyarrow]"):
|
raise TypeError(f"'{string}' must end with '[pyarrow]'")
|
if string == "string[pyarrow]":
|
# Ensure Registry.find skips ArrowDtype to use StringDtype instead
|
raise TypeError("string[pyarrow] should be constructed by StringDtype")
|
|
base_type = string[:-9] # get rid of "[pyarrow]"
|
try:
|
pa_dtype = pa.type_for_alias(base_type)
|
except ValueError as err:
|
has_parameters = re.search(r"[\[\(].*[\]\)]", base_type)
|
if has_parameters:
|
# Fallback to try common temporal types
|
try:
|
return cls._parse_temporal_dtype_string(base_type)
|
except (NotImplementedError, ValueError):
|
# Fall through to raise with nice exception message below
|
pass
|
|
raise NotImplementedError(
|
"Passing pyarrow type specific parameters "
|
f"({has_parameters.group()}) in the string is not supported. "
|
"Please construct an ArrowDtype object with a pyarrow_dtype "
|
"instance with specific parameters."
|
) from err
|
raise TypeError(f"'{base_type}' is not a valid pyarrow data type.") from err
|
return cls(pa_dtype)
|
|
# TODO(arrow#33642): This can be removed once supported by pyarrow
|
@classmethod
|
def _parse_temporal_dtype_string(cls, string: str) -> ArrowDtype:
|
"""
|
Construct a temporal ArrowDtype from string.
|
"""
|
# we assume
|
# 1) "[pyarrow]" has already been stripped from the end of our string.
|
# 2) we know "[" is present
|
head, tail = string.split("[", 1)
|
|
if not tail.endswith("]"):
|
raise ValueError
|
tail = tail[:-1]
|
|
if head == "timestamp":
|
assert "," in tail # otherwise type_for_alias should work
|
unit, tz = tail.split(",", 1)
|
unit = unit.strip()
|
tz = tz.strip()
|
if tz.startswith("tz="):
|
tz = tz[3:]
|
|
pa_type = pa.timestamp(unit, tz=tz)
|
dtype = cls(pa_type)
|
return dtype
|
|
raise NotImplementedError(string)
|
|
@property
|
def _is_numeric(self) -> bool:
|
"""
|
Whether columns with this dtype should be considered numeric.
|
"""
|
# TODO: pa.types.is_boolean?
|
return (
|
pa.types.is_integer(self.pyarrow_dtype)
|
or pa.types.is_floating(self.pyarrow_dtype)
|
or pa.types.is_decimal(self.pyarrow_dtype)
|
)
|
|
@property
|
def _is_boolean(self) -> bool:
|
"""
|
Whether this dtype should be considered boolean.
|
"""
|
return pa.types.is_boolean(self.pyarrow_dtype)
|
|
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
|
# We unwrap any masked dtypes, find the common dtype we would use
|
# for that, then re-mask the result.
|
# Mirrors BaseMaskedDtype
|
from pandas.core.dtypes.cast import find_common_type
|
|
new_dtype = find_common_type(
|
[
|
dtype.numpy_dtype if isinstance(dtype, ArrowDtype) else dtype
|
for dtype in dtypes
|
]
|
)
|
if not isinstance(new_dtype, np.dtype):
|
return None
|
try:
|
pa_dtype = pa.from_numpy_dtype(new_dtype)
|
return type(self)(pa_dtype)
|
except NotImplementedError:
|
return None
|
|
def __from_arrow__(self, array: pa.Array | pa.ChunkedArray):
|
"""
|
Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray.
|
"""
|
array_class = self.construct_array_type()
|
arr = array.cast(self.pyarrow_dtype, safe=True)
|
return array_class(arr)
|