zmc
2023-12-22 9fdbf60165db0400c2e8e6be2dc6e88138ac719a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
""" Test functions for linalg module using the matrix class."""
import numpy as np
 
from numpy.linalg.tests.test_linalg import (
    LinalgCase, apply_tag, TestQR as _TestQR, LinalgTestCase,
    _TestNorm2D, _TestNormDoubleBase, _TestNormSingleBase, _TestNormInt64Base,
    SolveCases, InvCases, EigvalsCases, EigCases, SVDCases, CondCases,
    PinvCases, DetCases, LstsqCases)
 
 
CASES = []
 
# square test cases
CASES += apply_tag('square', [
    LinalgCase("0x0_matrix",
               np.empty((0, 0), dtype=np.double).view(np.matrix),
               np.empty((0, 1), dtype=np.double).view(np.matrix),
               tags={'size-0'}),
    LinalgCase("matrix_b_only",
               np.array([[1., 2.], [3., 4.]]),
               np.matrix([2., 1.]).T),
    LinalgCase("matrix_a_and_b",
               np.matrix([[1., 2.], [3., 4.]]),
               np.matrix([2., 1.]).T),
])
 
# hermitian test-cases
CASES += apply_tag('hermitian', [
    LinalgCase("hmatrix_a_and_b",
               np.matrix([[1., 2.], [2., 1.]]),
               None),
])
# No need to make generalized or strided cases for matrices.
 
 
class MatrixTestCase(LinalgTestCase):
    TEST_CASES = CASES
 
 
class TestSolveMatrix(SolveCases, MatrixTestCase):
    pass
 
 
class TestInvMatrix(InvCases, MatrixTestCase):
    pass
 
 
class TestEigvalsMatrix(EigvalsCases, MatrixTestCase):
    pass
 
 
class TestEigMatrix(EigCases, MatrixTestCase):
    pass
 
 
class TestSVDMatrix(SVDCases, MatrixTestCase):
    pass
 
 
class TestCondMatrix(CondCases, MatrixTestCase):
    pass
 
 
class TestPinvMatrix(PinvCases, MatrixTestCase):
    pass
 
 
class TestDetMatrix(DetCases, MatrixTestCase):
    pass
 
 
class TestLstsqMatrix(LstsqCases, MatrixTestCase):
    pass
 
 
class _TestNorm2DMatrix(_TestNorm2D):
    array = np.matrix
 
 
class TestNormDoubleMatrix(_TestNorm2DMatrix, _TestNormDoubleBase):
    pass
 
 
class TestNormSingleMatrix(_TestNorm2DMatrix, _TestNormSingleBase):
    pass
 
 
class TestNormInt64Matrix(_TestNorm2DMatrix, _TestNormInt64Base):
    pass
 
 
class TestQRMatrix(_TestQR):
    array = np.matrix