zmc
2023-10-12 ed135d79df12a2466b52dae1a82326941211dcc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import pytest
 
from pandas import (
    DataFrame,
    Series,
)
import pandas._testing as tm
 
 
class SharedSetAxisTests:
    @pytest.fixture
    def obj(self):
        raise NotImplementedError("Implemented by subclasses")
 
    def test_set_axis(self, obj):
        # GH14636; this tests setting index for both Series and DataFrame
        new_index = list("abcd")[: len(obj)]
        expected = obj.copy()
        expected.index = new_index
        result = obj.set_axis(new_index, axis=0)
        tm.assert_equal(expected, result)
 
    def test_set_axis_copy(self, obj, using_copy_on_write):
        # Test copy keyword GH#47932
        new_index = list("abcd")[: len(obj)]
 
        orig = obj.iloc[:]
        expected = obj.copy()
        expected.index = new_index
 
        result = obj.set_axis(new_index, axis=0, copy=True)
        tm.assert_equal(expected, result)
        assert result is not obj
        # check we DID make a copy
        if not using_copy_on_write:
            if obj.ndim == 1:
                assert not tm.shares_memory(result, obj)
            else:
                assert not any(
                    tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
                    for i in range(obj.shape[1])
                )
 
        result = obj.set_axis(new_index, axis=0, copy=False)
        tm.assert_equal(expected, result)
        assert result is not obj
        # check we did NOT make a copy
        if obj.ndim == 1:
            assert tm.shares_memory(result, obj)
        else:
            assert all(
                tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
                for i in range(obj.shape[1])
            )
 
        # copy defaults to True
        result = obj.set_axis(new_index, axis=0)
        tm.assert_equal(expected, result)
        assert result is not obj
        if using_copy_on_write:
            # check we DID NOT make a copy
            if obj.ndim == 1:
                assert tm.shares_memory(result, obj)
            else:
                assert any(
                    tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
                    for i in range(obj.shape[1])
                )
        else:
            # check we DID make a copy
            if obj.ndim == 1:
                assert not tm.shares_memory(result, obj)
            else:
                assert not any(
                    tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
                    for i in range(obj.shape[1])
                )
 
        res = obj.set_axis(new_index, copy=False)
        tm.assert_equal(expected, res)
        # check we did NOT make a copy
        if res.ndim == 1:
            assert tm.shares_memory(res, orig)
        else:
            assert all(
                tm.shares_memory(res.iloc[:, i], orig.iloc[:, i])
                for i in range(res.shape[1])
            )
 
    def test_set_axis_unnamed_kwarg_warns(self, obj):
        # omitting the "axis" parameter
        new_index = list("abcd")[: len(obj)]
 
        expected = obj.copy()
        expected.index = new_index
 
        result = obj.set_axis(new_index)
        tm.assert_equal(result, expected)
 
    @pytest.mark.parametrize("axis", [3, "foo"])
    def test_set_axis_invalid_axis_name(self, axis, obj):
        # wrong values for the "axis" parameter
        with pytest.raises(ValueError, match="No axis named"):
            obj.set_axis(list("abc"), axis=axis)
 
    def test_set_axis_setattr_index_not_collection(self, obj):
        # wrong type
        msg = (
            r"Index\(\.\.\.\) must be called with a collection of some "
            r"kind, None was passed"
        )
        with pytest.raises(TypeError, match=msg):
            obj.index = None
 
    def test_set_axis_setattr_index_wrong_length(self, obj):
        # wrong length
        msg = (
            f"Length mismatch: Expected axis has {len(obj)} elements, "
            f"new values have {len(obj)-1} elements"
        )
        with pytest.raises(ValueError, match=msg):
            obj.index = np.arange(len(obj) - 1)
 
        if obj.ndim == 2:
            with pytest.raises(ValueError, match="Length mismatch"):
                obj.columns = obj.columns[::2]
 
 
class TestDataFrameSetAxis(SharedSetAxisTests):
    @pytest.fixture
    def obj(self):
        df = DataFrame(
            {"A": [1.1, 2.2, 3.3], "B": [5.0, 6.1, 7.2], "C": [4.4, 5.5, 6.6]},
            index=[2010, 2011, 2012],
        )
        return df
 
 
class TestSeriesSetAxis(SharedSetAxisTests):
    @pytest.fixture
    def obj(self):
        ser = Series(np.arange(4), index=[1, 3, 5, 7], dtype="int64")
        return ser