import numpy as np
|
import pytest
|
|
from pandas.core.dtypes.cast import find_common_type
|
from pandas.core.dtypes.dtypes import (
|
CategoricalDtype,
|
DatetimeTZDtype,
|
IntervalDtype,
|
PeriodDtype,
|
)
|
|
from pandas import (
|
Categorical,
|
Index,
|
)
|
|
|
@pytest.mark.parametrize(
|
"source_dtypes,expected_common_dtype",
|
[
|
((np.int64,), np.int64),
|
((np.uint64,), np.uint64),
|
((np.float32,), np.float32),
|
((object,), object),
|
# Into ints.
|
((np.int16, np.int64), np.int64),
|
((np.int32, np.uint32), np.int64),
|
((np.uint16, np.uint64), np.uint64),
|
# Into floats.
|
((np.float16, np.float32), np.float32),
|
((np.float16, np.int16), np.float32),
|
((np.float32, np.int16), np.float32),
|
((np.uint64, np.int64), np.float64),
|
((np.int16, np.float64), np.float64),
|
((np.float16, np.int64), np.float64),
|
# Into others.
|
((np.complex128, np.int32), np.complex128),
|
((object, np.float32), object),
|
((object, np.int16), object),
|
# Bool with int.
|
((np.dtype("bool"), np.int64), object),
|
((np.dtype("bool"), np.int32), object),
|
((np.dtype("bool"), np.int16), object),
|
((np.dtype("bool"), np.int8), object),
|
((np.dtype("bool"), np.uint64), object),
|
((np.dtype("bool"), np.uint32), object),
|
((np.dtype("bool"), np.uint16), object),
|
((np.dtype("bool"), np.uint8), object),
|
# Bool with float.
|
((np.dtype("bool"), np.float64), object),
|
((np.dtype("bool"), np.float32), object),
|
(
|
(np.dtype("datetime64[ns]"), np.dtype("datetime64[ns]")),
|
np.dtype("datetime64[ns]"),
|
),
|
(
|
(np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]")),
|
np.dtype("timedelta64[ns]"),
|
),
|
(
|
(np.dtype("datetime64[ns]"), np.dtype("datetime64[ms]")),
|
np.dtype("datetime64[ns]"),
|
),
|
(
|
(np.dtype("timedelta64[ms]"), np.dtype("timedelta64[ns]")),
|
np.dtype("timedelta64[ns]"),
|
),
|
((np.dtype("datetime64[ns]"), np.dtype("timedelta64[ns]")), object),
|
((np.dtype("datetime64[ns]"), np.int64), object),
|
],
|
)
|
def test_numpy_dtypes(source_dtypes, expected_common_dtype):
|
assert find_common_type(source_dtypes) == expected_common_dtype
|
|
|
def test_raises_empty_input():
|
with pytest.raises(ValueError, match="no types given"):
|
find_common_type([])
|
|
|
@pytest.mark.parametrize(
|
"dtypes,exp_type",
|
[
|
([CategoricalDtype()], "category"),
|
([object, CategoricalDtype()], object),
|
([CategoricalDtype(), CategoricalDtype()], "category"),
|
],
|
)
|
def test_categorical_dtype(dtypes, exp_type):
|
assert find_common_type(dtypes) == exp_type
|
|
|
def test_datetimetz_dtype_match():
|
dtype = DatetimeTZDtype(unit="ns", tz="US/Eastern")
|
assert find_common_type([dtype, dtype]) == "datetime64[ns, US/Eastern]"
|
|
|
@pytest.mark.parametrize(
|
"dtype2",
|
[
|
DatetimeTZDtype(unit="ns", tz="Asia/Tokyo"),
|
np.dtype("datetime64[ns]"),
|
object,
|
np.int64,
|
],
|
)
|
def test_datetimetz_dtype_mismatch(dtype2):
|
dtype = DatetimeTZDtype(unit="ns", tz="US/Eastern")
|
assert find_common_type([dtype, dtype2]) == object
|
assert find_common_type([dtype2, dtype]) == object
|
|
|
def test_period_dtype_match():
|
dtype = PeriodDtype(freq="D")
|
assert find_common_type([dtype, dtype]) == "period[D]"
|
|
|
@pytest.mark.parametrize(
|
"dtype2",
|
[
|
DatetimeTZDtype(unit="ns", tz="Asia/Tokyo"),
|
PeriodDtype(freq="2D"),
|
PeriodDtype(freq="H"),
|
np.dtype("datetime64[ns]"),
|
object,
|
np.int64,
|
],
|
)
|
def test_period_dtype_mismatch(dtype2):
|
dtype = PeriodDtype(freq="D")
|
assert find_common_type([dtype, dtype2]) == object
|
assert find_common_type([dtype2, dtype]) == object
|
|
|
interval_dtypes = [
|
IntervalDtype(np.int64, "right"),
|
IntervalDtype(np.float64, "right"),
|
IntervalDtype(np.uint64, "right"),
|
IntervalDtype(DatetimeTZDtype(unit="ns", tz="US/Eastern"), "right"),
|
IntervalDtype("M8[ns]", "right"),
|
IntervalDtype("m8[ns]", "right"),
|
]
|
|
|
@pytest.mark.parametrize("left", interval_dtypes)
|
@pytest.mark.parametrize("right", interval_dtypes)
|
def test_interval_dtype(left, right):
|
result = find_common_type([left, right])
|
|
if left is right:
|
assert result is left
|
|
elif left.subtype.kind in ["i", "u", "f"]:
|
# i.e. numeric
|
if right.subtype.kind in ["i", "u", "f"]:
|
# both numeric -> common numeric subtype
|
expected = IntervalDtype(np.float64, "right")
|
assert result == expected
|
else:
|
assert result == object
|
|
else:
|
assert result == object
|
|
|
@pytest.mark.parametrize("dtype", interval_dtypes)
|
def test_interval_dtype_with_categorical(dtype):
|
obj = Index([], dtype=dtype)
|
|
cat = Categorical([], categories=obj)
|
|
result = find_common_type([dtype, cat.dtype])
|
assert result == dtype
|