1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
| import pytest
|
| from numpy import array_api as xp
|
|
| @pytest.mark.parametrize(
| "obj, axis, expected",
| [
| ([0, 0], -1, [0, 1]),
| ([0, 1, 0], -1, [1, 0, 2]),
| ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
| ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
| ],
| )
| def test_stable_desc_argsort(obj, axis, expected):
| """
| Indices respect relative order of a descending stable-sort
|
| See https://github.com/numpy/numpy/issues/20778
| """
| x = xp.asarray(obj)
| out = xp.argsort(x, axis=axis, stable=True, descending=True)
| assert xp.all(out == xp.asarray(expected))
|
|