zmc
2023-08-08 e792e9a60d958b93aef96050644f369feb25d61b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import pytest
 
from pandas.core.dtypes import dtypes
from pandas.core.dtypes.common import is_extension_array_dtype
 
import pandas as pd
import pandas._testing as tm
from pandas.core.arrays import ExtensionArray
 
 
class DummyDtype(dtypes.ExtensionDtype):
    pass
 
 
class DummyArray(ExtensionArray):
    def __init__(self, data) -> None:
        self.data = data
 
    def __array__(self, dtype):
        return self.data
 
    @property
    def dtype(self):
        return DummyDtype()
 
    def astype(self, dtype, copy=True):
        # we don't support anything but a single dtype
        if isinstance(dtype, DummyDtype):
            if copy:
                return type(self)(self.data)
            return self
 
        return np.array(self, dtype=dtype, copy=copy)
 
 
class TestExtensionArrayDtype:
    @pytest.mark.parametrize(
        "values",
        [
            pd.Categorical([]),
            pd.Categorical([]).dtype,
            pd.Series(pd.Categorical([])),
            DummyDtype(),
            DummyArray(np.array([1, 2])),
        ],
    )
    def test_is_extension_array_dtype(self, values):
        assert is_extension_array_dtype(values)
 
    @pytest.mark.parametrize("values", [np.array([]), pd.Series(np.array([]))])
    def test_is_not_extension_array_dtype(self, values):
        assert not is_extension_array_dtype(values)
 
 
def test_astype():
    arr = DummyArray(np.array([1, 2, 3]))
    expected = np.array([1, 2, 3], dtype=object)
 
    result = arr.astype(object)
    tm.assert_numpy_array_equal(result, expected)
 
    result = arr.astype("object")
    tm.assert_numpy_array_equal(result, expected)
 
 
def test_astype_no_copy():
    arr = DummyArray(np.array([1, 2, 3], dtype=np.int64))
    result = arr.astype(arr.dtype, copy=False)
 
    assert arr is result
 
    result = arr.astype(arr.dtype)
    assert arr is not result
 
 
@pytest.mark.parametrize("dtype", [dtypes.CategoricalDtype(), dtypes.IntervalDtype()])
def test_is_extension_array_dtype(dtype):
    assert isinstance(dtype, dtypes.ExtensionDtype)
    assert is_extension_array_dtype(dtype)