from __future__ import annotations import json import pyarrow from pandas._typing import IntervalClosedType from pandas.core.arrays.interval import VALID_CLOSED class ArrowPeriodType(pyarrow.ExtensionType): def __init__(self, freq) -> None: # attributes need to be set first before calling # super init (as that calls serialize) self._freq = freq pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period") @property def freq(self): return self._freq def __arrow_ext_serialize__(self) -> bytes: metadata = {"freq": self.freq} return json.dumps(metadata).encode() @classmethod def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType: metadata = json.loads(serialized.decode()) return ArrowPeriodType(metadata["freq"]) def __eq__(self, other): if isinstance(other, pyarrow.BaseExtensionType): return type(self) == type(other) and self.freq == other.freq else: return NotImplemented def __ne__(self, other) -> bool: return not self == other def __hash__(self) -> int: return hash((str(self), self.freq)) def to_pandas_dtype(self): import pandas as pd return pd.PeriodDtype(freq=self.freq) # register the type with a dummy instance _period_type = ArrowPeriodType("D") pyarrow.register_extension_type(_period_type) class ArrowIntervalType(pyarrow.ExtensionType): def __init__(self, subtype, closed: IntervalClosedType) -> None: # attributes need to be set first before calling # super init (as that calls serialize) assert closed in VALID_CLOSED self._closed: IntervalClosedType = closed if not isinstance(subtype, pyarrow.DataType): subtype = pyarrow.type_for_alias(str(subtype)) self._subtype = subtype storage_type = pyarrow.struct([("left", subtype), ("right", subtype)]) pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval") @property def subtype(self): return self._subtype @property def closed(self) -> IntervalClosedType: return self._closed def __arrow_ext_serialize__(self) -> bytes: metadata = {"subtype": str(self.subtype), "closed": self.closed} return json.dumps(metadata).encode() @classmethod def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType: metadata = json.loads(serialized.decode()) subtype = pyarrow.type_for_alias(metadata["subtype"]) closed = metadata["closed"] return ArrowIntervalType(subtype, closed) def __eq__(self, other): if isinstance(other, pyarrow.BaseExtensionType): return ( type(self) == type(other) and self.subtype == other.subtype and self.closed == other.closed ) else: return NotImplemented def __ne__(self, other) -> bool: return not self == other def __hash__(self) -> int: return hash((str(self), str(self.subtype), self.closed)) def to_pandas_dtype(self): import pandas as pd return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.closed) # register the type with a dummy instance _interval_type = ArrowIntervalType(pyarrow.int64(), "left") pyarrow.register_extension_type(_interval_type)