1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import pytest
 
from pandas import array
import pandas._testing as tm
from pandas.core.arrays.sparse import SparseArray
 
 
@pytest.mark.parametrize(
    "kwargs",
    [
        {},  # Default is check_exact=False
        {"check_exact": False},
        {"check_exact": True},
    ],
)
def test_assert_extension_array_equal_not_exact(kwargs):
    # see gh-23709
    arr1 = SparseArray([-0.17387645482451206, 0.3414148016424936])
    arr2 = SparseArray([-0.17387645482451206, 0.3414148016424937])
 
    if kwargs.get("check_exact", False):
        msg = """\
ExtensionArray are different
 
ExtensionArray values are different \\(50\\.0 %\\)
\\[left\\]:  \\[-0\\.17387645482.*, 0\\.341414801642.*\\]
\\[right\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]"""
 
        with pytest.raises(AssertionError, match=msg):
            tm.assert_extension_array_equal(arr1, arr2, **kwargs)
    else:
        tm.assert_extension_array_equal(arr1, arr2, **kwargs)
 
 
@pytest.mark.parametrize("decimals", range(10))
def test_assert_extension_array_equal_less_precise(decimals):
    rtol = 0.5 * 10**-decimals
    arr1 = SparseArray([0.5, 0.123456])
    arr2 = SparseArray([0.5, 0.123457])
 
    if decimals >= 5:
        msg = """\
ExtensionArray are different
 
ExtensionArray values are different \\(50\\.0 %\\)
\\[left\\]:  \\[0\\.5, 0\\.123456\\]
\\[right\\]: \\[0\\.5, 0\\.123457\\]"""
 
        with pytest.raises(AssertionError, match=msg):
            tm.assert_extension_array_equal(arr1, arr2, rtol=rtol)
    else:
        tm.assert_extension_array_equal(arr1, arr2, rtol=rtol)
 
 
def test_assert_extension_array_equal_dtype_mismatch(check_dtype):
    end = 5
    kwargs = {"check_dtype": check_dtype}
 
    arr1 = SparseArray(np.arange(end, dtype="int64"))
    arr2 = SparseArray(np.arange(end, dtype="int32"))
 
    if check_dtype:
        msg = """\
ExtensionArray are different
 
Attribute "dtype" are different
\\[left\\]:  Sparse\\[int64, 0\\]
\\[right\\]: Sparse\\[int32, 0\\]"""
 
        with pytest.raises(AssertionError, match=msg):
            tm.assert_extension_array_equal(arr1, arr2, **kwargs)
    else:
        tm.assert_extension_array_equal(arr1, arr2, **kwargs)
 
 
def test_assert_extension_array_equal_missing_values():
    arr1 = SparseArray([np.nan, 1, 2, np.nan])
    arr2 = SparseArray([np.nan, 1, 2, 3])
 
    msg = """\
ExtensionArray NA mask are different
 
ExtensionArray NA mask values are different \\(25\\.0 %\\)
\\[left\\]:  \\[True, False, False, True\\]
\\[right\\]: \\[True, False, False, False\\]"""
 
    with pytest.raises(AssertionError, match=msg):
        tm.assert_extension_array_equal(arr1, arr2)
 
 
@pytest.mark.parametrize("side", ["left", "right"])
def test_assert_extension_array_equal_non_extension_array(side):
    numpy_array = np.arange(5)
    extension_array = SparseArray(numpy_array)
 
    msg = f"{side} is not an ExtensionArray"
    args = (
        (numpy_array, extension_array)
        if side == "left"
        else (extension_array, numpy_array)
    )
 
    with pytest.raises(AssertionError, match=msg):
        tm.assert_extension_array_equal(*args)
 
 
@pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
def test_assert_extension_array_equal_ignore_dtype_mismatch(right_dtype):
    # https://github.com/pandas-dev/pandas/issues/35715
    left = array([1, 2, 3], dtype="Int64")
    right = array([1, 2, 3], dtype=right_dtype)
    tm.assert_extension_array_equal(left, right, check_dtype=False)