NEURON
ocmatrix.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <memory>
4 #include <utility>
5 #include <vector>
6 
7 #include <Eigen/Eigen>
8 #include <Eigen/Sparse>
9 #include <Eigen/LU>
10 
11 struct Object;
12 class IvocVect;
13 class OcMatrix;
14 using Matrix = OcMatrix;
15 class OcFullMatrix;
16 using Vect = IvocVect;
17 
18 class OcMatrix {
19  public:
20  enum { MFULL = 1, MSPARSE, MBAND };
21  static OcMatrix* instance(int nrow, int ncol, int type = MFULL);
22  virtual ~OcMatrix() = default;
23 
24  // This function is deprecated and should not be used!
25  // mep stands for 'matrix element pointer'
26  inline double* mep(int i, int j) {
27  return &coeff(i, j);
28  }
29 
30  inline double operator()(int i, int j) const {
31  return getval(i, j);
32  };
33 
34  virtual double& coeff(int i, int j) {
35  static double zero = 0.0;
36  unimp();
37  return zero;
38  }
39 
40  inline double& operator()(int i, int j) {
41  return coeff(i, j);
42  };
43 
44  virtual double getval(int i, int j) const {
45  unimp();
46  return 0.;
47  }
48  virtual int nrow() const {
49  unimp();
50  return 0;
51  }
52  virtual int ncol() const {
53  unimp();
54  return 0;
55  }
56  virtual void resize(int, int) {
57  unimp();
58  }
59 
60  virtual std::vector<std::pair<int, int>> nonzeros() const;
61 
62  OcFullMatrix* full();
63 
64  inline void mulv(Vect& in, Vect& out) const {
65  mulv(&in, &out);
66  };
67  virtual void mulv(Vect* in, Vect* out) const {
68  unimp();
69  }
70  virtual void mulm(Matrix* in, Matrix* out) const {
71  unimp();
72  }
73  virtual void muls(double, Matrix* out) const {
74  unimp();
75  }
76  virtual void add(Matrix*, Matrix* out) const {
77  unimp();
78  }
79  virtual void getrow(int, Vect* out) const {
80  unimp();
81  }
82  virtual void getcol(int, Vect* out) const {
83  unimp();
84  }
85  virtual void getdiag(int, Vect* out) const {
86  unimp();
87  }
88  virtual void setrow(int, Vect* in) {
89  unimp();
90  }
91  virtual void setcol(int, Vect* in) {
92  unimp();
93  }
94  virtual void setdiag(int, Vect* in) {
95  unimp();
96  }
97  virtual void setrow(int, double in) {
98  unimp();
99  }
100  virtual void setcol(int, double in) {
101  unimp();
102  }
103  virtual void setdiag(int, double in) {
104  unimp();
105  }
106  virtual void zero() {
107  unimp();
108  }
109  virtual void ident() {
110  unimp();
111  }
112  virtual void exp(Matrix* out) const {
113  unimp();
114  }
115  virtual void pow(int, Matrix* out) const {
116  unimp();
117  }
118  virtual void inverse(Matrix* out) const {
119  unimp();
120  }
121  virtual void solv(Vect* vin, Vect* vout, bool use_lu) {
122  unimp();
123  }
124  virtual void copy(Matrix* out) const {
125  unimp();
126  }
127  virtual void bcopy(Matrix* mout, int i0, int j0, int n0, int m0, int i1, int j1) const {
128  unimp();
129  }
130  virtual void transpose(Matrix* out) {
131  unimp();
132  }
133  virtual void symmeigen(Matrix* mout, Vect* vout) const {
134  unimp();
135  }
136  virtual void svd1(Matrix* u, Matrix* v, Vect* d) const {
137  unimp();
138  }
139  virtual double det(int* e) const {
140  unimp();
141  return 0.0;
142  }
143  virtual int sprowlen(int) const {
144  unimp();
145  return 0;
146  }
147  virtual double spgetrowval(int i, int jindx, int* j) const {
148  unimp();
149  return 0.;
150  }
151 
152  void unimp() const;
153 
154  protected:
155  OcMatrix(int type);
156 
157  public:
159 
160  private:
161  int type_{};
162 };
163 
164 extern Matrix* matrix_arg(int);
165 
166 class OcFullMatrix final: public OcMatrix { // type 1
167  public:
168  OcFullMatrix(int, int);
169  ~OcFullMatrix() override = default;
170 
171  double& coeff(int, int) override;
172  double getval(int i, int j) const override;
173  int nrow() const override;
174  int ncol() const override;
175  void resize(int, int) override;
176 
177  void mulv(Vect* in, Vect* out) const override;
178  void mulm(Matrix* in, Matrix* out) const override;
179  void muls(double, Matrix* out) const override;
180  void add(Matrix*, Matrix* out) const override;
181  void getrow(int, Vect* out) const override;
182  void getcol(int, Vect* out) const override;
183  void getdiag(int, Vect* out) const override;
184  void setrow(int, Vect* in) override;
185  void setcol(int, Vect* in) override;
186  void setdiag(int, Vect* in) override;
187  void setrow(int, double in) override;
188  void setcol(int, double in) override;
189  void setdiag(int, double in) override;
190  void zero() override;
191  void ident() override;
192  void exp(Matrix* out) const override;
193  void pow(int, Matrix* out) const override;
194  void inverse(Matrix* out) const override;
195  void solv(Vect* vin, Vect* vout, bool use_lu) override;
196  void copy(Matrix* out) const override;
197  void bcopy(Matrix* mout, int i0, int j0, int n0, int m0, int i1, int j1) const override;
198  void transpose(Matrix* out) override;
199  void symmeigen(Matrix* mout, Vect* vout) const override;
200  void svd1(Matrix* u, Matrix* v, Vect* d) const override;
201  double det(int* exponent) const override;
202 
203  private:
204  Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> m_{};
205  std::unique_ptr<Eigen::FullPivLU<decltype(m_)>> lu_{};
206 };
207 
208 class OcSparseMatrix final: public OcMatrix { // type 2
209  public:
210  OcSparseMatrix(int, int);
211  ~OcSparseMatrix() override = default;
212 
213  double& coeff(int, int) override;
214  int nrow() const override;
215  int ncol() const override;
216  double getval(int, int) const override;
217  void ident() override;
218  void mulv(Vect* in, Vect* out) const override;
219  void solv(Vect* vin, Vect* vout, bool use_lu) override;
220 
221  void setrow(int, Vect* in) override;
222  void setcol(int, Vect* in) override;
223  void setdiag(int, Vect* in) override;
224  void setrow(int, double in) override;
225  void setcol(int, double in) override;
226  void setdiag(int, double in) override;
227 
228  std::vector<std::pair<int, int>> nonzeros() const override;
229 
230  int sprowlen(int) const override; // how many elements in row
231  double spgetrowval(int i, int jindx, int* j) const override;
232 
233  void zero() override;
234 
235  private:
236  Eigen::SparseMatrix<double, Eigen::RowMajor> m_{};
237  std::unique_ptr<Eigen::SparseLU<decltype(m_)>> lu_{};
238 };
double getval(int i, int j) const override
Definition: ocmatrix.cpp:70
void exp(Matrix *out) const override
Definition: ocmatrix.cpp:207
void mulm(Matrix *in, Matrix *out) const override
Definition: ocmatrix.cpp:92
void ident() override
Definition: ocmatrix.cpp:203
void solv(Vect *vin, Vect *vout, bool use_lu) override
Definition: ocmatrix.cpp:219
int nrow() const override
Definition: ocmatrix.cpp:73
void getcol(int, Vect *out) const override
Definition: ocmatrix.cpp:145
double & coeff(int, int) override
Definition: ocmatrix.cpp:67
~OcFullMatrix() override=default
void inverse(Matrix *out) const override
Definition: ocmatrix.cpp:215
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > m_
Definition: ocmatrix.h:204
void setrow(int, Vect *in) override
Definition: ocmatrix.cpp:163
void pow(int, Matrix *out) const override
Definition: ocmatrix.cpp:211
void bcopy(Matrix *mout, int i0, int j0, int n0, int m0, int i1, int j1) const override
Definition: ocmatrix.cpp:108
int ncol() const override
Definition: ocmatrix.cpp:76
void mulv(Vect *in, Vect *out) const override
Definition: ocmatrix.cpp:86
std::unique_ptr< Eigen::FullPivLU< decltype(m_)> > lu_
Definition: ocmatrix.h:205
void resize(int, int) override
Definition: ocmatrix.cpp:80
void transpose(Matrix *out) override
Definition: ocmatrix.cpp:112
void symmeigen(Matrix *mout, Vect *vout) const override
Definition: ocmatrix.cpp:121
void add(Matrix *, Matrix *out) const override
Definition: ocmatrix.cpp:100
double det(int *exponent) const override
Definition: ocmatrix.cpp:228
void getrow(int, Vect *out) const override
Definition: ocmatrix.cpp:140
void setcol(int, Vect *in) override
Definition: ocmatrix.cpp:168
void muls(double, Matrix *out) const override
Definition: ocmatrix.cpp:96
void zero() override
Definition: ocmatrix.cpp:199
void getdiag(int, Vect *out) const override
Definition: ocmatrix.cpp:150
void setdiag(int, Vect *in) override
Definition: ocmatrix.cpp:173
void copy(Matrix *out) const override
Definition: ocmatrix.cpp:104
void svd1(Matrix *u, Matrix *v, Vect *d) const override
Definition: ocmatrix.cpp:128
OcFullMatrix(int, int)
Definition: ocmatrix.cpp:60
Object * obj_
Definition: ocmatrix.h:158
virtual void mulv(Vect *in, Vect *out) const
Definition: ocmatrix.h:67
virtual double det(int *e) const
Definition: ocmatrix.h:139
@ MFULL
Definition: ocmatrix.h:20
@ MSPARSE
Definition: ocmatrix.h:20
@ MBAND
Definition: ocmatrix.h:20
virtual void mulm(Matrix *in, Matrix *out) const
Definition: ocmatrix.h:70
virtual void transpose(Matrix *out)
Definition: ocmatrix.h:130
virtual int nrow() const
Definition: ocmatrix.h:48
virtual ~OcMatrix()=default
virtual void ident()
Definition: ocmatrix.h:109
virtual void pow(int, Matrix *out) const
Definition: ocmatrix.h:115
virtual double spgetrowval(int i, int jindx, int *j) const
Definition: ocmatrix.h:147
virtual void setcol(int, Vect *in)
Definition: ocmatrix.h:91
virtual void getcol(int, Vect *out) const
Definition: ocmatrix.h:82
virtual void resize(int, int)
Definition: ocmatrix.h:56
int type_
Definition: ocmatrix.h:161
virtual void setcol(int, double in)
Definition: ocmatrix.h:100
virtual double & coeff(int i, int j)
Definition: ocmatrix.h:34
virtual void svd1(Matrix *u, Matrix *v, Vect *d) const
Definition: ocmatrix.h:136
static OcMatrix * instance(int nrow, int ncol, int type=MFULL)
Definition: ocmatrix.cpp:27
virtual void exp(Matrix *out) const
Definition: ocmatrix.h:112
OcFullMatrix * full()
Definition: ocmatrix.cpp:53
void mulv(Vect &in, Vect &out) const
Definition: ocmatrix.h:64
OcMatrix(int type)
Definition: ocmatrix.cpp:24
virtual int sprowlen(int) const
Definition: ocmatrix.h:143
virtual void setrow(int, double in)
Definition: ocmatrix.h:97
virtual void copy(Matrix *out) const
Definition: ocmatrix.h:124
virtual void setdiag(int, double in)
Definition: ocmatrix.h:103
virtual void setrow(int, Vect *in)
Definition: ocmatrix.h:88
double & operator()(int i, int j)
Definition: ocmatrix.h:40
virtual void add(Matrix *, Matrix *out) const
Definition: ocmatrix.h:76
virtual int ncol() const
Definition: ocmatrix.h:52
double * mep(int i, int j)
Definition: ocmatrix.h:26
virtual std::vector< std::pair< int, int > > nonzeros() const
Definition: ocmatrix.cpp:41
virtual void inverse(Matrix *out) const
Definition: ocmatrix.h:118
virtual void getdiag(int, Vect *out) const
Definition: ocmatrix.h:85
void unimp() const
Definition: ocmatrix.cpp:37
virtual void muls(double, Matrix *out) const
Definition: ocmatrix.h:73
double operator()(int i, int j) const
Definition: ocmatrix.h:30
virtual void solv(Vect *vin, Vect *vout, bool use_lu)
Definition: ocmatrix.h:121
virtual void zero()
Definition: ocmatrix.h:106
virtual void getrow(int, Vect *out) const
Definition: ocmatrix.h:79
virtual double getval(int i, int j) const
Definition: ocmatrix.h:44
virtual void setdiag(int, Vect *in)
Definition: ocmatrix.h:94
virtual void bcopy(Matrix *mout, int i0, int j0, int n0, int m0, int i1, int j1) const
Definition: ocmatrix.h:127
virtual void symmeigen(Matrix *mout, Vect *vout) const
Definition: ocmatrix.h:133
int sprowlen(int) const override
Definition: ocmatrix.cpp:350
std::vector< std::pair< int, int > > nonzeros() const override
Definition: ocmatrix.cpp:370
int nrow() const override
Definition: ocmatrix.cpp:266
~OcSparseMatrix() override=default
Eigen::SparseMatrix< double, Eigen::RowMajor > m_
Definition: ocmatrix.h:236
void mulv(Vect *in, Vect *out) const override
Definition: ocmatrix.cpp:274
double spgetrowval(int i, int jindx, int *j) const override
Definition: ocmatrix.cpp:358
double getval(int, int) const override
Definition: ocmatrix.cpp:262
void setdiag(int, Vect *in) override
Definition: ocmatrix.cpp:304
int ncol() const override
Definition: ocmatrix.cpp:270
OcSparseMatrix(int, int)
Definition: ocmatrix.cpp:246
void setcol(int, Vect *in) override
Definition: ocmatrix.cpp:297
double & coeff(int, int) override
Definition: ocmatrix.cpp:250
std::unique_ptr< Eigen::SparseLU< decltype(m_)> > lu_
Definition: ocmatrix.h:237
void setrow(int, Vect *in) override
Definition: ocmatrix.cpp:290
void solv(Vect *vin, Vect *vout, bool use_lu) override
Definition: ocmatrix.cpp:280
void ident() override
Definition: ocmatrix.cpp:332
void zero() override
Definition: ocmatrix.cpp:254
#define v
Definition: md1redef.h:11
#define i
Definition: md1redef.h:19
fixed_vector< double > IvocVect
Definition: ivocvect.hpp:72
size_t j
short type
Definition: cabvars.h:10
Matrix * matrix_arg(int)
Definition: matrix.cpp:33
Definition: hocdec.h:173