NEURON
matrix.cpp
Go to the documentation of this file.
1 #include <ocmatrix.h>
2 #include <vector>
3 #include <iostream>
4 #include "ivocvect.h"
5 
6 #include <catch2/catch_test_macros.hpp>
7 using namespace Catch::literals;
8 
9 template <typename T = double>
10 class ApproxOrOpposite: public Catch::MatcherBase<std::vector<T>> {
11  std::vector<T> vec;
12 
13  public:
14  ApproxOrOpposite(std::vector<T> vec)
15  : vec(vec) {}
16 
17  bool match(std::vector<T> const& in) const override {
18  if (in.size() != vec.size()) {
19  return false;
20  }
21  bool matched = true;
22  for (int i = 0; i < in.size(); ++i) {
23  if (in[i] != approx(vec[i])) {
24  matched = false;
25  break;
26  }
27  }
28  if (matched) {
29  return true;
30  }
31  matched = true;
32  for (int i = 0; i < in.size(); ++i) {
33  if (in[i] != approx(-1 * vec[i])) {
34  matched = false;
35  break;
36  }
37  }
38 
39  return matched;
40  }
41 
42  std::string describe() const override {
43  std::ostringstream ss;
44  ss << "is not approx or opposite approx of " << Catch::Detail::stringify(vec);
45  return ss.str();
46  }
48  ApproxOrOpposite& epsilon(T const& newEpsilon) {
49  approx.epsilon(newEpsilon);
50  return *this;
51  }
53  ApproxOrOpposite& margin(T const& newMargin) {
54  approx.margin(newMargin);
55  return *this;
56  }
58  ApproxOrOpposite& scale(T const& newScale) {
59  approx.scale(newScale);
60  return *this;
61  }
62 
63  mutable Catch::Detail::Approx approx = Catch::Detail::Approx::custom();
64 };
65 
66 bool compareMatrix(OcMatrix& m, const std::vector<std::vector<double>>& ref) {
67  REQUIRE(m.nrow() == ref.size());
68  for (int i = 0; i < m.nrow(); ++i) {
69  REQUIRE(m.ncol() == ref[i].size());
70  for (int j = 0; j < m.ncol(); ++j) {
71  REQUIRE(m.getval(i, j) == Catch::Detail::Approx(ref[i][j]).margin(1e-10));
72  }
73  }
74  return true;
75 }
76 
77 SCENARIO("A Matrix", "[neuron_ivoc][OcMatrix]") {
78  GIVEN("A 3x3 Full matrix") {
79  OcFullMatrix m{3, 3};
80  REQUIRE(m.nrow() == 3);
81  REQUIRE(m.ncol() == 3);
82  {
83  m.ident();
84  REQUIRE(compareMatrix(m, {{1., 0., 0.}, {0., 1., 0}, {0., 0., 1.}}));
85  }
86  {
87  double* value = m.mep(0, 0);
88  REQUIRE(*value == 1);
89  *value = 3;
90  REQUIRE(m.getval(0, 0) == 3);
91  }
92  m.resize(4, 3);
93  {
94  m.setrow(3, 2.0);
95  REQUIRE(compareMatrix(m, {{3., 0., 0.}, {0., 1., 0.}, {0., 0., 1.}, {2., 2., 2.}}));
96  }
97  {
98  std::vector<std::pair<int, int>> nzs = m.nonzeros();
99  std::vector<std::pair<int, int>> res = {{0, 0}, {1, 1}, {2, 2}, {3, 0}, {3, 1}, {3, 2}};
100  REQUIRE(nzs == res);
101  }
102  {
103  std::vector<int> x, y;
104  m.nonzeros(x, y);
105  std::vector<int> res_x = {0, 1, 2, 3, 3, 3};
106  std::vector<int> res_y = {0, 1, 2, 0, 1, 2};
107  REQUIRE(x == res_x);
108  REQUIRE(y == res_y);
109  }
110  {
111  m.setcol(1, 4.0);
112  REQUIRE(compareMatrix(m, {{3., 4., 0.}, {0., 4., 0.}, {0., 4., 1.}, {2., 4., 2.}}));
113  }
114  {
115  m.setdiag(0, 5.0);
116  REQUIRE(compareMatrix(m, {{5., 4., 0.}, {0., 5., 0.}, {0., 4., 5.}, {2., 4., 2.}}));
117  }
118  {
119  m.setdiag(1, 6.0);
120  REQUIRE(compareMatrix(m, {{5., 6., 0.}, {0., 5., 6.}, {0., 4., 5.}, {2., 4., 2.}}));
121  }
122  {
123  m.setdiag(-1, 7.0);
124  REQUIRE(compareMatrix(m, {{5., 6., 0.}, {7., 5., 6.}, {0., 7., 5.}, {2., 4., 7.}}));
125  }
126 
127  {
128  OcFullMatrix n(4, 3);
129  n.ident();
130  m.add(&n, &m);
131  REQUIRE(compareMatrix(m, {{6., 6., 0.}, {7., 6., 6.}, {0., 7., 6.}, {2., 4., 7.}}));
132  }
133  {
134  OcFullMatrix n(4, 3);
135  m.bcopy(&n, 1, 1, 3, 2, 0, 0);
136  REQUIRE(compareMatrix(n, {{6., 6., 0.}, {7., 6., 0.}, {4., 7., 0.}, {0., 0., 0.}}));
137  }
138  {
139  OcFullMatrix n(4, 3);
140  m.transpose(&n);
141  REQUIRE(compareMatrix(n, {{6., 7., 0., 2.}, {6., 6., 7., 4.}, {0., 6., 6., 7.}}));
142  }
143  {
144  IvocVect v(3);
145  m.getrow(2, &v);
146  REQUIRE_THAT(v.vec(), Catch::Matchers::Approx(std::vector<double>({0., 7., 6.})));
147  m.setrow(0, &v);
148  REQUIRE(compareMatrix(m, {{0., 7., 6.}, {7., 6., 6.}, {0., 7., 6.}, {2., 4., 7.}}));
149  }
150  {
151  IvocVect v(4);
152  m.getcol(2, &v);
153  REQUIRE_THAT(v.vec(), Catch::Matchers::Approx(std::vector<double>({6., 6., 6., 7.})));
154  m.setcol(1, &v);
155  REQUIRE(compareMatrix(m, {{0., 6., 6.}, {7., 6., 6.}, {0., 6., 6.}, {2., 7., 7.}}));
156  }
157  {
158  m.resize(3, 3);
159  REQUIRE(compareMatrix(m, {{0., 6., 6.}, {7., 6., 6.}, {0., 6., 6.}}));
160  }
161  {
162  OcFullMatrix n(4, 3);
163  m.exp(&n);
164  REQUIRE(n(0, 0) == Catch::Detail::Approx(442925.));
165  REQUIRE(compareMatrix(n,
166  {{442925., 938481., 938481.},
167  {651970., 1381407., 1381407.},
168  {442926., 938481., 938482.}}));
169  }
170  {
171  m.pow(2, &m);
172  REQUIRE(compareMatrix(m, {{42., 72., 72.}, {42., 114., 114.}, {42., 72., 72.}}));
173  }
174  {
175  int e{};
176  double det = m.det(&e);
177  REQUIRE(det == 0.);
178  REQUIRE(e == 0);
179  }
180  *m.mep(2, 0) = 1;
181  *m.mep(2, 2) = 2;
182  {
183  int e{};
184  double det = m.det(&e);
185  REQUIRE(det == -1.2348_a);
186  REQUIRE(e == 5);
187  }
188  {
189  OcFullMatrix n(4, 3);
190  m.inverse(&n);
191  n.resize(3, 3); // ???
192  REQUIRE(compareMatrix(n,
193  {{0.064625850, -0.040816326, 0.},
194  {-0.00024295432, -0.0000971817, 0.01428571},
195  {-0.023566569, 0.0239067055, -0.014285714}}));
196  n.zero();
197  REQUIRE(compareMatrix(n, {{0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}));
198  }
199  {
200  IvocVect v(3);
201  m.getdiag(1, &v);
202  REQUIRE_THAT(v.vec(), Catch::Matchers::Approx(std::vector<double>({72., 114., 0.})));
203  v.vec() = {0., 72., 114.};
204  m.setdiag(-1, &v);
205  REQUIRE(compareMatrix(m, {{42., 72., 72.}, {72., 114., 114.}, {1., 114., 2.}}));
206  }
207  {
208  IvocVect v(3);
209  m.getdiag(-2, &v);
210  REQUIRE(v.vec()[2] == Catch::Detail::Approx(1.0));
211  v.vec() = {1., 0., 0.};
212  m.setdiag(2, &v);
213  REQUIRE(compareMatrix(m, {{42., 72., 1.}, {72., 114., 114.}, {1., 114., 2.}}));
214  }
215  {
216  IvocVect v(3);
217  v.vec() = {1, 1, 1};
218  IvocVect vout(3);
219  m.mulv(&v, &vout);
220  REQUIRE_THAT(vout.vec(),
221  Catch::Matchers::Approx(std::vector<double>({115., 300., 117.})));
222  }
223  {
224  OcFullMatrix n(3, 3);
225  m.copy(&n);
226  REQUIRE(compareMatrix(n, {{42., 72., 1.}, {72., 114., 114.}, {1., 114., 2.}}));
227  OcFullMatrix o(3, 3);
228  m.mulm(&n, &o);
229  REQUIRE(compareMatrix(
230  o, {{6949., 11346., 8252.}, {11346., 31176., 13296.}, {8252., 13296., 13001.}}));
231  }
232  {
233  OcFullMatrix n(3, 3);
234  m.muls(2, &n);
235  REQUIRE(compareMatrix(n, {{84., 144., 2.}, {144., 228., 228.}, {2., 228., 4.}}));
236  }
237  {
238  IvocVect v(3);
239  v.vec() = {1, 1, 1};
240  IvocVect vout(3);
241  m.solv(&v, &vout, false);
242  REQUIRE_THAT(vout.vec(),
243  Catch::Matchers::Approx(
244  std::vector<double>({0.0088700, 0.0087927, -0.00562299})));
245  m.solv(&v, &vout, true);
246  REQUIRE_THAT(vout.vec(),
247  Catch::Matchers::Approx(
248  std::vector<double>({0.0088700, 0.0087927, -0.00562299})));
249  }
250  {
251  IvocVect v(3);
252  OcFullMatrix n(3, 3);
253  v.vec() = {1, 2, 3};
254  m.setrow(0, &v);
255  v.vec() = {2, 1, 4};
256  m.setrow(1, &v);
257  v.vec() = {3, 4, 1};
258  m.setrow(2, &v);
259  m.symmeigen(&n, &v);
260  REQUIRE_THAT(v.vec(),
261  Catch::Matchers::Approx(
262  std::vector<double>({7.074673, -0.88679, -3.18788})));
263  n.getcol(0, &v);
264  REQUIRE_THAT(v.vec(), ApproxOrOpposite({0.50578, 0.5843738, 0.634577}));
265  n.getcol(1, &v);
266  REQUIRE_THAT(v.vec(), ApproxOrOpposite({-0.8240377, 0.544925, 0.154978}));
267  n.getcol(2, &v);
268  REQUIRE_THAT(v.vec(), ApproxOrOpposite({-0.255231, -0.601301, 0.7571611}));
269  }
270  {
271  m.resize(2, 2);
272  OcFullMatrix u(2, 2);
273  OcFullMatrix v(2, 2);
274  IvocVect d(2);
275  m.svd1(&u, &v, &d);
276  REQUIRE_THAT(d.vec(), Catch::Matchers::Approx(std::vector<double>({3., 1.})));
277  // For comparison of u and v and problems with signs, see:
278  // https://www.educative.io/blog/sign-ambiguity-in-singular-value-decomposition
279  IvocVect c(4);
280  c.vec() = {u(0, 0), u(0, 1), v(0, 0), v(0, 1)};
281  CHECK_THAT(c.vec(), ApproxOrOpposite({0.70710, 0.70710, 0.70710, 0.70710}));
282  c.vec() = {u(1, 0), u(1, 1), v(1, 0), v(1, 1)};
283  CHECK_THAT(c.vec(), ApproxOrOpposite({0.70710, -0.70710, -0.70710, 0.70710}));
284  }
285  {
286  m.resize(2, 3);
287  {
288  IvocVect s(3);
289  s.vec() = {3., 2., 2.};
290  m.setrow(0, &s);
291  s.vec() = {2., 3., -2.};
292  m.setrow(1, &s);
293  }
294  OcFullMatrix u(2, 2);
295  OcFullMatrix v(3, 3);
296  IvocVect d(2);
297  m.svd1(&u, &v, &d);
298  REQUIRE_THAT(d.vec(), Catch::Matchers::Approx(std::vector<double>({5., 3.})));
299  // For comparison of u and v and problems with signs, see:
300  // https://www.educative.io/blog/sign-ambiguity-in-singular-value-decomposition
301  IvocVect c(5);
302  c.vec() = {u(0, 0), u(0, 1), v(0, 0), v(0, 1), v(0, 2)};
303  CHECK_THAT(c.vec(),
304  ApproxOrOpposite({0.70710, 0.70710, 0.70710, 0.70710, 0.}).margin(1e-10));
305  c.vec() = {u(1, 0), u(1, 1), v(1, 0), v(1, 1), v(1, 2)};
306  CHECK_THAT(c.vec(),
307  ApproxOrOpposite({0.70710, -0.70710, 0.235702, -0.235702, 0.942809}));
308  c.vec() = {0., 0., v(2, 0), v(2, 1), v(2, 2)};
309  CHECK_THAT(c.vec(), ApproxOrOpposite({0., 0., 0.66666, -0.66666, -0.3333333}));
310  }
311  { // Try with vectors too short
312  IvocVect s(2);
313  s.vec() = {1., 2.};
314  m.setrow(0, &s);
315  REQUIRE(compareMatrix(m, {{1., 2., 2.}, {2., 3., -2.}}));
316  m.setcol(0, &s);
317  REQUIRE(compareMatrix(m, {{1., 2., 2.}, {2., 3., -2.}}));
318  IvocVect d(1);
319  d.vec() = {1.};
320  m.setdiag(0, &d);
321  REQUIRE(compareMatrix(m, {{1., 2., 2.}, {2., 3., -2.}}));
322  }
323  }
324  GIVEN("A 3x3 Sparse matrix") {
325  OcSparseMatrix m{3, 3};
326  REQUIRE(m.nrow() == 3);
327  REQUIRE(m.ncol() == 3);
328  {
329  m.ident();
330  REQUIRE(compareMatrix(m, {{1., 0., 0.}, {0., 1., 0}, {0., 0., 1.}}));
331  REQUIRE(m.sprowlen(1) == 1);
332  }
333  {
334  std::vector<int> x, y, result = {0, 1, 2};
335  m.nonzeros(x, y);
336  REQUIRE(x == result);
337  REQUIRE(y == result);
338  }
339  {
340  double* pmep = m.mep(1, 1);
341  REQUIRE(*pmep == 1);
342  pmep = m.mep(1, 0);
343  REQUIRE(*pmep == 0);
344  }
345  {
346  int col{};
347  double value = m.spgetrowval(2, 0, &col);
348  REQUIRE(col == 2);
349  REQUIRE(value == Catch::Detail::Approx(1.0));
350  }
351  { // m.zero() don't erase the matrix but only replace existing values by 0.
352  m.zero();
353  REQUIRE(m.sprowlen(2) == 1);
354  REQUIRE(compareMatrix(m, {{0., 0., 0.}, {0., 0., 0}, {0., 0., 0.}}));
355  }
356  {
357  m.setrow(1, 2);
358  REQUIRE(compareMatrix(m, {{0., 0., 0.}, {2., 2., 2.}, {0., 0., 0.}}));
359  }
360  {
361  m.setcol(0, 3);
362  REQUIRE(compareMatrix(m, {{3., 0., 0.}, {3., 2., 2.}, {3., 0., 0.}}));
363  }
364  {
365  m.setdiag(0, 1);
366  REQUIRE(compareMatrix(m, {{1., 0., 0.}, {3., 1., 2.}, {3., 0., 1.}}));
367  }
368  {
369  m.setdiag(-1, 4);
370  REQUIRE(compareMatrix(m, {{1., 0., 0.}, {4., 1., 2.}, {3., 4., 1.}}));
371  }
372  {
373  m.setdiag(2, 5);
374  REQUIRE(compareMatrix(m, {{1., 0., 5.}, {4., 1., 2.}, {3., 4., 1.}}));
375  }
376  REQUIRE(m.sprowlen(1) == 3);
377 
378  {
379  IvocVect v(3);
380  v.vec() = {1, 2, 3};
381  m.setrow(0, &v);
382  REQUIRE(compareMatrix(m, {{1., 2., 3.}, {4., 1., 2.}, {3., 4., 1.}}));
383  }
384  {
385  IvocVect v(3);
386  v.vec() = {1, 2, 3};
387  m.setcol(0, &v);
388  REQUIRE(compareMatrix(m, {{1., 2., 3.}, {2., 1., 2.}, {3., 4., 1.}}));
389  }
390  {
391  IvocVect v(3);
392  v.vec() = {1, 2, 3};
393  m.setdiag(0, &v);
394  REQUIRE(compareMatrix(m, {{1., 2., 3.}, {2., 2., 2.}, {3., 4., 3.}}));
395  }
396  {
397  IvocVect v(3);
398  v.vec() = {0., 1., 2.};
399  m.setdiag(-1, &v);
400  REQUIRE(compareMatrix(m, {{1., 2., 3.}, {1., 2., 2.}, {3., 2., 3.}}));
401  }
402  {
403  IvocVect v(3);
404  v.vec() = {1, 2, 3};
405  IvocVect out(3);
406  m.mulv(&v, &out);
407  REQUIRE_THAT(out.vec(), Catch::Matchers::Approx(std::vector<double>({14., 11., 16.})));
408  }
409  {
410  IvocVect v(3);
411  v.vec() = {1, 1, 1};
412  IvocVect vout(3);
413  m.solv(&v, &vout, false);
414  REQUIRE_THAT(vout.vec(), Catch::Matchers::Approx(std::vector<double>({0., 0.5, 0.})));
415  m.solv(&v, &vout, true);
416  REQUIRE_THAT(vout.vec(), Catch::Matchers::Approx(std::vector<double>({0., 0.5, 0.})));
417  }
418  }
419 }
std::string describe() const override
Definition: matrix.cpp:42
ApproxOrOpposite & epsilon(T const &newEpsilon)
Definition: matrix.cpp:48
ApproxOrOpposite(std::vector< T > vec)
Definition: matrix.cpp:14
ApproxOrOpposite & scale(T const &newScale)
Definition: matrix.cpp:58
ApproxOrOpposite & margin(T const &newMargin)
Definition: matrix.cpp:53
bool match(std::vector< T > const &in) const override
Definition: matrix.cpp:17
std::vector< T > vec
Definition: matrix.cpp:11
std::vector< double > & vec()
Definition: ivocvect.h:30
virtual int nrow() const
Definition: ocmatrix.h:48
virtual int ncol() const
Definition: ocmatrix.h:52
virtual double getval(int i, int j) const
Definition: ocmatrix.h:44
#define v
Definition: md1redef.h:11
#define i
Definition: md1redef.h:19
static int c
Definition: hoc.cpp:169
int const size_t const size_t n
Definition: nrngsl.h:10
size_t j
s
Definition: multisend.cpp:521
short type
Definition: cabvars.h:10
static double ref(void *v)
Definition: ocbox.cpp:381
static uint32_t value
Definition: scoprand.cpp:25
SCENARIO("A Matrix", "[neuron_ivoc][OcMatrix]")
Definition: matrix.cpp:77
bool compareMatrix(OcMatrix &m, const std::vector< std::vector< double >> &ref)
Definition: matrix.cpp:66