zmc
2023-12-22 9fdbf60165db0400c2e8e6be2dc6e88138ac719a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import contextlib
from typing import Generator
 
import pytest
 
import pandas as pd
import pandas._testing as tm
from pandas.core import accessor
 
 
def test_dirname_mixin():
    # GH37173
 
    class X(accessor.DirNamesMixin):
        x = 1
        y: int
 
        def __init__(self) -> None:
            self.z = 3
 
    result = [attr_name for attr_name in dir(X()) if not attr_name.startswith("_")]
 
    assert result == ["x", "z"]
 
 
@contextlib.contextmanager
def ensure_removed(obj, attr) -> Generator[None, None, None]:
    """Ensure that an attribute added to 'obj' during the test is
    removed when we're done
    """
    try:
        yield
    finally:
        try:
            delattr(obj, attr)
        except AttributeError:
            pass
        obj._accessors.discard(attr)
 
 
class MyAccessor:
    def __init__(self, obj) -> None:
        self.obj = obj
        self.item = "item"
 
    @property
    def prop(self):
        return self.item
 
    def method(self):
        return self.item
 
 
@pytest.mark.parametrize(
    "obj, registrar",
    [
        (pd.Series, pd.api.extensions.register_series_accessor),
        (pd.DataFrame, pd.api.extensions.register_dataframe_accessor),
        (pd.Index, pd.api.extensions.register_index_accessor),
    ],
)
def test_register(obj, registrar):
    with ensure_removed(obj, "mine"):
        before = set(dir(obj))
        registrar("mine")(MyAccessor)
        o = obj([]) if obj is not pd.Series else obj([], dtype=object)
        assert o.mine.prop == "item"
        after = set(dir(obj))
        assert (before ^ after) == {"mine"}
        assert "mine" in obj._accessors
 
 
def test_accessor_works():
    with ensure_removed(pd.Series, "mine"):
        pd.api.extensions.register_series_accessor("mine")(MyAccessor)
 
        s = pd.Series([1, 2])
        assert s.mine.obj is s
 
        assert s.mine.prop == "item"
        assert s.mine.method() == "item"
 
 
def test_overwrite_warns():
    # Need to restore mean
    mean = pd.Series.mean
    try:
        with tm.assert_produces_warning(UserWarning) as w:
            pd.api.extensions.register_series_accessor("mean")(MyAccessor)
            s = pd.Series([1, 2])
            assert s.mean.prop == "item"
        msg = str(w[0].message)
        assert "mean" in msg
        assert "MyAccessor" in msg
        assert "Series" in msg
    finally:
        pd.Series.mean = mean
 
 
def test_raises_attribute_error():
    with ensure_removed(pd.Series, "bad"):
 
        @pd.api.extensions.register_series_accessor("bad")
        class Bad:
            def __init__(self, data) -> None:
                raise AttributeError("whoops")
 
        with pytest.raises(AttributeError, match="whoops"):
            pd.Series([], dtype=object).bad