NEURON
sympy_solver.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <catch2/catch_test_macros.hpp>
9 #include <catch2/matchers/catch_matchers_string.hpp>
10 
11 #include <pybind11/embed.h>
12 #include <pybind11/stl.h>
13 
14 #include "ast/program.hpp"
16 #include "parser/nmodl_driver.hpp"
17 #include "utils/test_utils.hpp"
28 
29 
30 using namespace nmodl;
31 using namespace codegen;
32 using namespace visitor;
33 using namespace test;
34 using namespace test_utils;
35 
36 using Catch::Matchers::ContainsSubstring; // ContainsSubstring in newer Catch2
37 
39 
40 using ast::AstNodeType;
42 
43 
44 //=============================================================================
45 // SympySolver visitor tests
46 //=============================================================================
47 
48 std::vector<std::string> run_sympy_solver_visitor(
49  const std::string& text,
50  bool pade = false,
51  bool cse = false,
52  AstNodeType ret_nodetype = AstNodeType::DIFF_EQ_EXPRESSION,
53  bool kinetic = false) {
54  std::vector<std::string> results;
55 
56  // construct AST from text
58  const auto& ast = driver.parse_string(text);
59 
60  // construct symbol table from AST
62 
63  // unroll loops and fold constants
68 
69  if (kinetic) {
71  }
72 
73  // run SympySolver on AST
74  SympySolverVisitor(pade, cse).visit_program(*ast);
75 
76  // check that, after visitor rearrangement, parents are still up-to-date
78 
79  // run lookup visitor to extract results from AST
80  for (const auto& eq: collect_nodes(*ast, {ret_nodetype})) {
81  results.push_back(to_nmodl(eq));
82  }
83 
84  return results;
85 }
86 
87 // check if in a list of vars (like LOCAL) there are duplicates
88 bool is_unique_vars(std::string result) {
89  result.erase(std::remove(result.begin(), result.end(), ','), result.end());
90  std::stringstream ss(result);
91  std::string token;
92 
93  std::unordered_set<std::string> old_vars;
94 
95  while (getline(ss, token, ' ')) {
96  if (!old_vars.insert(token).second) {
97  return false;
98  }
99  }
100  return true;
101 }
102 
103 
104 /**
105  * \brief Compare nmodl blocks that contain systems of equations (i.e. derivative, linear, etc.)
106  *
107  * This is basically and advanced string == string comparison where we detect the (various) systems
108  * of equations and check if they are equivalent. Implemented mostly in python since we need a call
109  * to sympy to simplify the equations.
110  *
111  * - compare_systems_of_eq The core of the code. \p result_dict and \p expected_dict are
112  * dictionaries that represent the systems of equations in this way:
113  *
114  * a = b*x + c -> result_dict['a'] = 'b*x + c'
115  *
116  * where the variable \p a become a key \p k of the dictionary.
117  *
118  * In there we go over all the equations in \p result_dict and \p expected_dict and check that
119  * result_dict[k] - expected_dict[k] simplifies to 0.
120  *
121  * - sanitize is to transform the equations in something treatable by sympy (i.e. pow(dt, 3) ->
122  * dt**3
123  * - reduce back-substitution of the temporary variables
124  *
125  * \p require_fail requires that the equations are different. Used only for unit-test this function
126  *
127  * \warning do not use this method when there are tmp variables not in the form: tmp_<number>
128  */
129 void compare_blocks(const std::string& result,
130  const std::string& expected,
131  const bool require_fail = false) {
132  using namespace pybind11::literals;
133 
134  auto locals =
135  pybind11::dict("result"_a = result, "expected"_a = expected, "is_equal"_a = false);
136  pybind11::exec(R"(
137  # Comments are in the doxygen for better highlighting
138  def compare_blocks(result, expected):
139 
140  def sanitize(s):
141  import re
142  d = {'\[(\d+)\]':'_\\1', 'pow\‍((\w+), ?(\d+)\)':'\\1**\\2', 'beta': 'beta_var', 'gamma': 'gamma_var'}
143  out = s
144  for key, val in d.items():
145  out = re.sub(key, val, out)
146  return out
147 
148  def compare_systems_of_eq(result_dict, expected_dict):
149  from sympy.parsing.sympy_parser import parse_expr
150  try:
151  for k, v in result_dict.items():
152  if parse_expr(f'simplify(({v})-({expected_dict[k]}))'):
153  return False
154  except KeyError:
155  return False
156 
157  result_dict.clear()
158  expected_dict.clear()
159  return True
160 
161  def reduce(s):
162  max_tmp = -1
163  d = {}
164 
165  sout = ""
166  # split of sout and a dict with the tmp variables
167  for line in s.split('\n'):
168  line_split = line.lstrip().split('=')
169 
170  if len(line_split) == 2 and line_split[0].startswith('tmp_'):
171  # back-substitution of tmp variables in tmp variables
172  tmp_var = line_split[0].strip()
173  if tmp_var in d:
174  continue
175 
176  max_tmp = max(max_tmp, int(tmp_var[4:]))
177  for k, v in d.items():
178  line_split[1] = line_split[1].replace(k, f'({v})')
179  d[tmp_var] = line_split[1]
180  elif 'LOCAL' in line:
181  sout += line.split('tmp_0')[0] + '\n'
182  else:
183  sout += line + '\n'
184 
185  # Back-substitution of the tmps
186  # so that we do not replace tmp_11 with (tmp_1)1
187  for j in range(max_tmp, -1, -1):
188  k = f'tmp_{j}'
189  sout = sout.replace(k, f'({d[k]})')
190 
191  return sout
192 
193  result = reduce(sanitize(result)).split('\n')
194  expected = reduce(sanitize(expected)).split('\n')
195 
196  if len(result) != len(expected):
197  return False
198 
199  result_dict = {}
200  expected_dict = {}
201  for token1, token2 in zip(result, expected):
202  if token1 == token2:
203  if not compare_systems_of_eq(result_dict, expected_dict):
204  return False
205  continue
206 
207  eq1 = token1.split('=')
208  eq2 = token2.split('=')
209  if len(eq1) == 2 and len(eq2) == 2:
210  result_dict[eq1[0]] = eq1[1]
211  expected_dict[eq2[0]] = eq2[1]
212  continue
213 
214  return False
215  return compare_systems_of_eq(result_dict, expected_dict)
216 
217  is_equal = compare_blocks(result, expected))",
218  pybind11::globals(),
219  locals);
220 
221  // Error log
222  if (require_fail == locals["is_equal"].cast<bool>()) {
223  if (require_fail) {
224  REQUIRE(result != expected);
225  } else {
226  REQUIRE(result == expected);
227  }
228  } else { // so that we signal to ctest that an assert was performed
229  REQUIRE(true);
230  }
231 }
232 
233 
235  // construct symbol table from AST
236  SymtabVisitor v_symtab;
237  v_symtab.visit_program(node);
238 
239  // run SympySolver on AST several times
240  SympySolverVisitor v_sympy1;
241  v_sympy1.visit_program(node);
242  v_sympy1.visit_program(node);
243 
244  // also use a second instance of SympySolver
245  SympySolverVisitor v_sympy2;
246  v_sympy2.visit_program(node);
247  v_sympy1.visit_program(node);
248  v_sympy2.visit_program(node);
249 }
250 
251 
253  std::stringstream stream;
255  return stream.str();
256 }
257 
258 SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]") {
259  GIVEN("Empty strings") {
260  THEN("Strings are equal") {
261  compare_blocks("", "");
262  }
263  }
264  GIVEN("Equivalent equation") {
265  THEN("Strings are equal") {
266  compare_blocks("a = 3*b + c", "a = 2*b + b + c");
267  }
268  }
269  GIVEN("Equivalent systems of equations") {
270  std::string result = R"(
271  x = 3*b + c
272  y = 2*a + b)";
273  std::string expected = R"(
274  x = b+2*b + c
275  y = 2*a + 2*b-b)";
276  THEN("Systems of equations are equal") {
277  compare_blocks(result, expected);
278  }
279  }
280  GIVEN("Equivalent systems of equations with brackets") {
281  std::string result = R"(
282  DERIVATIVE {
283  A[0] = 3*b + c
284  y = pow(a, 3) + b
285  })";
286  std::string expected = R"(
287  DERIVATIVE {
288  tmp_0 = a + c
289  tmp_1 = tmp_0 - a
290  A[0] = b+2*b + tmp_1
291  y = pow(a, 2)*a + 2*b-b
292  })";
293  THEN("Blocks are equal") {
294  compare_blocks(result, expected);
295  }
296  }
297  GIVEN("Different systems of equations (additional space)") {
298  std::string result = R"(
299  DERIVATIVE {
300  x = 3*b + c
301  y = 2*a + b
302  })";
303  std::string expected = R"(
304  DERIVATIVE {
305  x = b+2*b + c
306  y = 2*a + 2*b-b
307  })";
308  THEN("Blocks are different") {
309  compare_blocks(result, expected, true);
310  }
311  }
312  GIVEN("Different systems of equations") {
313  std::string result = R"(
314  DERIVATIVE {
315  tmp_0 = a - c
316  tmp_1 = tmp_0 - a
317  x = 3*b + tmp_1
318  y = 2*a + b
319  })";
320  std::string expected = R"(
321  DERIVATIVE {
322  x = b+2*b + c
323  y = 2*a + 2*b-b
324  })";
325  THEN("Blocks are different") {
326  compare_blocks(result, expected, true);
327  }
328  }
329 }
330 
331 SCENARIO("Check local vars name-clash prevention", "[visitor][sympy]") {
332  GIVEN("LOCAL tmp") {
333  std::string nmodl_text = R"(
334  STATE {
335  x y
336  }
337  BREAKPOINT {
338  SOLVE states METHOD sparse
339  }
340  DERIVATIVE states {
341  LOCAL tmp, b
342  x' = tmp + b
343  y' = tmp + b
344  })";
345  THEN("There are no duplicate vars in LOCAL") {
346  auto result =
347  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::LOCAL_LIST_STATEMENT);
348  REQUIRE(!result.empty());
349  REQUIRE(is_unique_vars(result[0]));
350  }
351  }
352  GIVEN("LOCAL tmp_0") {
353  std::string nmodl_text = R"(
354  STATE {
355  x y
356  }
357  BREAKPOINT {
358  SOLVE states METHOD sparse
359  }
360  DERIVATIVE states {
361  LOCAL tmp_0, b
362  x' = tmp_0 + b
363  y' = tmp_0 + b
364  })";
365  THEN("There are no duplicate vars in LOCAL") {
366  auto result =
367  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::LOCAL_LIST_STATEMENT);
368  REQUIRE(!result.empty());
369  REQUIRE(is_unique_vars(result[0]));
370  }
371  }
372 }
373 
374 SCENARIO("Solve ODEs with cnexp or euler method using SympySolverVisitor",
375  "[visitor][sympy][cnexp][euler]") {
376  GIVEN("Derivative block without ODE, solver method cnexp") {
377  std::string nmodl_text = R"(
378  BREAKPOINT {
379  SOLVE states METHOD cnexp
380  }
381  DERIVATIVE states {
382  m = m + h
383  }
384  )";
385  THEN("No ODEs found - do nothing") {
387  REQUIRE(result.empty());
388  }
389  }
390  GIVEN("Derivative block with ODES, solver method is euler") {
391  std::string nmodl_text = R"(
392  BREAKPOINT {
393  SOLVE states METHOD euler
394  }
395  DERIVATIVE states {
396  m' = (mInf-m)/mTau
397  h' = (hInf-h)/hTau
398  z = a*b + c
399  }
400  )";
401  THEN("Construct forwards Euler solutions") {
403  REQUIRE(result.size() == 2);
404  REQUIRE(result[0] == "m = (-dt*(m-mInf)+m*mTau)/mTau");
405  REQUIRE(result[1] == "h = (-dt*(h-hInf)+h*hTau)/hTau");
406  }
407  }
408  GIVEN("Derivative block with calling external functions passes sympy") {
409  std::string nmodl_text = R"(
410  BREAKPOINT {
411  SOLVE states METHOD euler
412  }
413  DERIVATIVE states {
414  m' = sawtooth(m)
415  n' = sin(n)
416  p' = my_user_func(p)
417  }
418  )";
419  THEN("Construct forward Euler interpreting external functions as symbols") {
421  REQUIRE(result.size() == 3);
422  REQUIRE(result[0] == "m = dt*sawtooth(m)+m");
423  REQUIRE(result[1] == "n = dt*sin(n)+n");
424  REQUIRE(result[2] == "p = dt*my_user_func(p)+p");
425  }
426  }
427  GIVEN("Derivative block with ODE, 1 state var in array, solver method euler") {
428  std::string nmodl_text = R"(
429  STATE {
430  m[1]
431  }
432  BREAKPOINT {
433  SOLVE states METHOD euler
434  }
435  DERIVATIVE states {
436  m'[0] = (mInf-m[0])/mTau
437  }
438  )";
439  THEN("Construct forwards Euler solutions") {
441  REQUIRE(result.size() == 1);
442  REQUIRE(result[0] == "m[0] = (dt*(mInf-m[0])+mTau*m[0])/mTau");
443  }
444  }
445  GIVEN("Derivative block with ODE, 1 state var in array, solver method cnexp") {
446  std::string nmodl_text = R"(
447  STATE {
448  m[1]
449  }
450  BREAKPOINT {
451  SOLVE states METHOD cnexp
452  }
453  DERIVATIVE states {
454  m'[0] = (mInf-m[0])/mTau
455  }
456  )";
457  THEN("Construct forwards Euler solutions") {
459  REQUIRE(result.size() == 1);
460  REQUIRE(result[0] == "m[0] = mInf-(mInf-m[0])*exp(-dt/mTau)");
461  }
462  }
463  GIVEN("Derivative block with linear ODES, solver method cnexp") {
464  std::string nmodl_text = R"(
465  BREAKPOINT {
466  SOLVE states METHOD cnexp
467  }
468  DERIVATIVE states {
469  m' = (mInf-m)/mTau
470  z = a*b + c
471  h' = hInf/hTau - h/hTau
472  }
473  )";
474  THEN("Integrate equations analytically") {
476  REQUIRE(result.size() == 2);
477  REQUIRE(result[0] == "m = mInf-(-m+mInf)*exp(-dt/mTau)");
478  REQUIRE(result[1] == "h = hInf-(-h+hInf)*exp(-dt/hTau)");
479  }
480  }
481  GIVEN("Derivative block including non-linear but solvable ODES, solver method cnexp") {
482  std::string nmodl_text = R"(
483  BREAKPOINT {
484  SOLVE states METHOD cnexp
485  }
486  DERIVATIVE states {
487  m' = (mInf-m)/mTau
488  h' = c2 * h*h
489  }
490  )";
491  THEN("Integrate equations analytically") {
493  REQUIRE(result.size() == 2);
494  REQUIRE(result[0] == "m = mInf-(-m+mInf)*exp(-dt/mTau)");
495  REQUIRE(result[1] == "h = -h/(c2*dt*h-1.0)");
496  }
497  }
498  GIVEN("Derivative block including array of 2 state vars, solver method cnexp") {
499  std::string nmodl_text = R"(
500  BREAKPOINT {
501  SOLVE states METHOD cnexp
502  }
503  STATE {
504  X[2]
505  }
506  DERIVATIVE states {
507  X'[0] = (mInf-X[0])/mTau
508  X'[1] = c2 * X[1]*X[1]
509  }
510  )";
511  THEN("Integrate equations analytically") {
513  REQUIRE(result.size() == 2);
514  REQUIRE(result[0] == "X[0] = mInf-(mInf-X[0])*exp(-dt/mTau)");
515  REQUIRE(result[1] == "X[1] = -X[1]/(c2*dt*X[1]-1.0)");
516  }
517  }
518  GIVEN("Derivative block including loop over array vars, solver method cnexp") {
519  std::string nmodl_text = R"(
520  DEFINE N 3
521  BREAKPOINT {
522  SOLVE states METHOD cnexp
523  }
524  ASSIGNED {
525  mTau[N]
526  }
527  STATE {
528  X[N]
529  }
530  DERIVATIVE states {
531  FROM i=0 TO N-1 {
532  X'[i] = (mInf-X[i])/mTau[i]
533  }
534  }
535  )";
536  THEN("Integrate equations analytically") {
538  REQUIRE(result.size() == 3);
539  REQUIRE(result[0] == "X[0] = mInf-(mInf-X[0])*exp(-dt/mTau[0])");
540  REQUIRE(result[1] == "X[1] = mInf-(mInf-X[1])*exp(-dt/mTau[1])");
541  REQUIRE(result[2] == "X[2] = mInf-(mInf-X[2])*exp(-dt/mTau[2])");
542  }
543  }
544  GIVEN("Derivative block including loop over array vars, solver method euler") {
545  std::string nmodl_text = R"(
546  DEFINE N 3
547  BREAKPOINT {
548  SOLVE states METHOD euler
549  }
550  ASSIGNED {
551  mTau[N]
552  }
553  STATE {
554  X[N]
555  }
556  DERIVATIVE states {
557  FROM i=0 TO N-1 {
558  X'[i] = (mInf-X[i])/mTau[i]
559  }
560  }
561  )";
562  THEN("Integrate equations analytically") {
564  REQUIRE(result.size() == 3);
565  REQUIRE(result[0] == "X[0] = (dt*(mInf-X[0])+X[0]*mTau[0])/mTau[0]");
566  REQUIRE(result[1] == "X[1] = (dt*(mInf-X[1])+X[1]*mTau[1])/mTau[1]");
567  REQUIRE(result[2] == "X[2] = (dt*(mInf-X[2])+X[2]*mTau[2])/mTau[2]");
568  }
569  }
570  GIVEN("Derivative block including ODES that can't currently be solved, solver method cnexp") {
571  std::string nmodl_text = R"(
572  BREAKPOINT {
573  SOLVE states METHOD cnexp
574  }
575  DERIVATIVE states {
576  z' = a/z + b/z/z
577  h' = c2 * h*h
578  x' = a
579  y' = c3 * y*y*y
580  }
581  )";
582  THEN("Integrate equations analytically where possible, otherwise leave untouched") {
584  REQUIRE(result.size() == 4);
585  /// sympy 1.9 able to solve ode but not older versions
586  REQUIRE((result[0] == "z' = a/z+b/z/z" ||
587  result[0] ==
588  "z = (0.5*pow(a, 2)*pow(z, 2)-a*b*z+pow(b, 2)*log(a*z+b))/pow(a, 3)"));
589  REQUIRE(result[1] == "h = -h/(c2*dt*h-1.0)");
590  REQUIRE(result[2] == "x = a*dt+x");
591  /// sympy 1.4 able to solve ode but not older versions
592  REQUIRE((result[3] == "y' = c3*y*y*y" ||
593  result[3] == "y = sqrt(-pow(y, 2)/(2.0*c3*dt*pow(y, 2)-1.0))"));
594  }
595  }
596  GIVEN("Derivative block with cnexp solver method, AST after SympySolver pass") {
597  std::string nmodl_text = R"(
598  BREAKPOINT {
599  SOLVE states METHOD cnexp
600  }
601  DERIVATIVE states {
602  m' = (mInf-m)/mTau
603  }
604  )";
605  // construct AST from text
607  auto ast = driver.parse_string(nmodl_text);
608 
609  // construct symbol table from AST
611 
612  // run SympySolver on AST
614 
615  std::string AST_string = ast_to_string(*ast);
616 
617  THEN("More SympySolver passes do nothing to the AST and don't throw") {
618  REQUIRE_NOTHROW(run_sympy_visitor_passes(*ast));
619  REQUIRE(AST_string == ast_to_string(*ast));
620  }
621  }
622 }
623 
624 SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
625  "[visitor][sympy][derivimplicit]") {
626  GIVEN("Derivative block with derivimplicit solver method and conditional block") {
627  std::string nmodl_text = R"(
628  STATE {
629  m
630  }
631  BREAKPOINT {
632  SOLVE states METHOD derivimplicit
633  }
634  DERIVATIVE states {
635  IF (mInf == 1) {
636  mInf = mInf+1
637  }
638  m' = (mInf-m)/mTau
639  }
640  )";
641  std::string expected_result = R"(
642  DERIVATIVE states {
643  EIGEN_NEWTON_SOLVE[1]{
644  LOCAL old_m
645  }{
646  IF (mInf == 1) {
647  mInf = mInf+1
648  }
649  old_m = m
650  }{
651  nmodl_eigen_x[0] = m
652  }{
653  nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+mInf)+mTau*(-nmodl_eigen_x[0]+old_m))/(dt*mTau)
654  nmodl_eigen_j[0] = (-dt-mTau)/(dt*mTau)
655  }{
656  m = nmodl_eigen_x[0]
657  }{
658  }
659  })";
660  THEN("SympySolver correctly inserts ode to block") {
661  CAPTURE(nmodl_text);
662  auto result =
663  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
664  compare_blocks(result[0], reindent_text(expected_result));
665  }
666  }
667 
668  GIVEN("Derivative block, sparse, print in order") {
669  std::string nmodl_text = R"(
670  STATE {
671  x y
672  }
673  BREAKPOINT {
674  SOLVE states METHOD sparse
675  }
676  DERIVATIVE states {
677  LOCAL a, b
678  y' = a
679  x' = b
680  })";
681  std::string expected_result = R"(
682  DERIVATIVE states {
683  EIGEN_NEWTON_SOLVE[2]{
684  LOCAL a, b, old_y, old_x
685  }{
686  old_y = y
687  old_x = x
688  }{
689  nmodl_eigen_x[0] = x
690  nmodl_eigen_x[1] = y
691  }{
692  nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_y)/dt
693  nmodl_eigen_j[0] = 0
694  nmodl_eigen_j[2] = -1/dt
695  nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_x)/dt
696  nmodl_eigen_j[1] = -1/dt
697  nmodl_eigen_j[3] = 0
698  }{
699  x = nmodl_eigen_x[0]
700  y = nmodl_eigen_x[1]
701  }{
702  }
703  })";
704 
705  THEN("Construct & solve linear system for backwards Euler") {
706  auto result =
707  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
708 
709  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
710  }
711  }
712  GIVEN("Derivative block, sparse, print in order, vectors") {
713  std::string nmodl_text = R"(
714  STATE {
715  M[2]
716  }
717  BREAKPOINT {
718  SOLVE states METHOD sparse
719  }
720  DERIVATIVE states {
721  LOCAL a, b
722  M'[1] = a
723  M'[0] = b
724  })";
725  std::string expected_result = R"(
726  DERIVATIVE states {
727  EIGEN_NEWTON_SOLVE[2]{
728  LOCAL a, b, old_M_1, old_M_0
729  }{
730  old_M_1 = M[1]
731  old_M_0 = M[0]
732  }{
733  nmodl_eigen_x[0] = M[0]
734  nmodl_eigen_x[1] = M[1]
735  }{
736  nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_M_1)/dt
737  nmodl_eigen_j[0] = 0
738  nmodl_eigen_j[2] = -1/dt
739  nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_M_0)/dt
740  nmodl_eigen_j[1] = -1/dt
741  nmodl_eigen_j[3] = 0
742  }{
743  M[0] = nmodl_eigen_x[0]
744  M[1] = nmodl_eigen_x[1]
745  }{
746  }
747  })";
748 
749  THEN("Construct & solve linear system for backwards Euler") {
750  auto result =
751  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
752 
753  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
754  }
755  }
756  GIVEN("Derivative block, sparse, derivatives mixed with local variable reassignment") {
757  std::string nmodl_text = R"(
758  STATE {
759  x y
760  }
761  BREAKPOINT {
762  SOLVE states METHOD sparse
763  }
764  DERIVATIVE states {
765  LOCAL a, b
766  x' = a
767  b = b + 1
768  y' = b
769  })";
770  std::string expected_result = R"(
771  DERIVATIVE states {
772  EIGEN_NEWTON_SOLVE[2]{
773  LOCAL a, b, old_x, old_y
774  }{
775  old_x = x
776  old_y = y
777  }{
778  nmodl_eigen_x[0] = x
779  nmodl_eigen_x[1] = y
780  }{
781  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
782  nmodl_eigen_j[0] = -1/dt
783  nmodl_eigen_j[2] = 0
784  b = b+1
785  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
786  nmodl_eigen_j[1] = 0
787  nmodl_eigen_j[3] = -1/dt
788  }{
789  x = nmodl_eigen_x[0]
790  y = nmodl_eigen_x[1]
791  }{
792  }
793  })";
794 
795  THEN("Construct & solve linear system for backwards Euler") {
796  auto result =
797  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
798 
799  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
800  }
801  }
802  GIVEN(
803  "Throw exception during derivative variable reassignment interleaved in the differential "
804  "equation set") {
805  std::string nmodl_text = R"(
806  STATE {
807  x y
808  }
809  BREAKPOINT {
810  SOLVE states METHOD sparse
811  }
812  DERIVATIVE states {
813  LOCAL a, b
814  x' = a
815  x = x + 1
816  y' = b + x
817  })";
818 
819  THEN(
820  "Throw an error because state variable assignments are not allowed inside the system "
821  "of differential "
822  "equations") {
823  REQUIRE_THROWS_WITH(
824  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK),
825  Catch::Matchers::ContainsSubstring(
826  "State variable assignment(s) interleaved in system of "
827  "equations/differential equations") &&
828  Catch::Matchers::StartsWith("SympyReplaceSolutionsVisitor"));
829  }
830  }
831  GIVEN("Derivative block in control flow block") {
832  std::string nmodl_text = R"(
833  STATE {
834  x y
835  }
836  BREAKPOINT {
837  SOLVE states METHOD sparse
838  }
839  DERIVATIVE states {
840  LOCAL a, b
841  if (a == 1) {
842  x' = a
843  y' = b
844  }
845  })";
846  std::string expected_result = R"(
847  DERIVATIVE states {
848  LOCAL a, b
849  IF (a == 1) {
850  EIGEN_NEWTON_SOLVE[2]{
851  LOCAL old_x, old_y
852  }{
853  old_x = x
854  old_y = y
855  }{
856  nmodl_eigen_x[0] = x
857  nmodl_eigen_x[1] = y
858  }{
859  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
860  nmodl_eigen_j[0] = -1/dt
861  nmodl_eigen_j[2] = 0
862  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
863  nmodl_eigen_j[1] = 0
864  nmodl_eigen_j[3] = -1/dt
865  }{
866  x = nmodl_eigen_x[0]
867  y = nmodl_eigen_x[1]
868  }{
869  }
870  }
871  })";
872 
873  THEN("Construct & solve linear system for backwards Euler") {
874  auto result =
875  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
876 
877  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
878  }
879  }
880  GIVEN(
881  "Derivative block, sparse, coupled derivatives mixed with reassignment and control flow "
882  "block") {
883  std::string nmodl_text = R"(
884  STATE {
885  x y
886  }
887  BREAKPOINT {
888  SOLVE states METHOD sparse
889  }
890  DERIVATIVE states {
891  LOCAL a, b
892  x' = a * y+b
893  if (b == 1) {
894  a = a + 1
895  }
896  y' = x + a*y
897  })";
898  std::string expected_result = R"(
899  DERIVATIVE states {
900  EIGEN_NEWTON_SOLVE[2]{
901  LOCAL a, b, old_x, old_y
902  }{
903  old_x = x
904  old_y = y
905  }{
906  nmodl_eigen_x[0] = x
907  nmodl_eigen_x[1] = y
908  }{
909  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
910  nmodl_eigen_j[0] = -1/dt
911  nmodl_eigen_j[2] = a
912  IF (b == 1) {
913  a = a+1
914  }
915  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
916  nmodl_eigen_j[1] = 1.0
917  nmodl_eigen_j[3] = a-1/dt
918  }{
919  x = nmodl_eigen_x[0]
920  y = nmodl_eigen_x[1]
921  }{
922  }
923  })";
924  std::string expected_result_cse = R"(
925  DERIVATIVE states {
926  EIGEN_NEWTON_SOLVE[2]{
927  LOCAL a, b, old_x, old_y
928  }{
929  old_x = x
930  old_y = y
931  }{
932  nmodl_eigen_x[0] = x
933  nmodl_eigen_x[1] = y
934  }{
935  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
936  nmodl_eigen_j[0] = -1/dt
937  nmodl_eigen_j[2] = a
938  IF (b == 1) {
939  a = a+1
940  }
941  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
942  nmodl_eigen_j[1] = 1.0
943  nmodl_eigen_j[3] = a-1/dt
944  }{
945  x = nmodl_eigen_x[0]
946  y = nmodl_eigen_x[1]
947  }{
948  }
949  })";
950 
951  THEN("Construct & solve linear system for backwards Euler") {
952  auto result =
953  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
954  auto result_cse =
955  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::DERIVATIVE_BLOCK);
956 
957  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
958  compare_blocks(reindent_text(result_cse[0]), reindent_text(expected_result_cse));
959  }
960  }
961 
962  GIVEN("Derivative block of coupled & linear ODES, solver method sparse") {
963  std::string nmodl_text = R"(
964  STATE {
965  x y z
966  }
967  BREAKPOINT {
968  SOLVE states METHOD sparse
969  }
970  DERIVATIVE states {
971  LOCAL a, b, c, d, h
972  x' = a*z + b*h
973  y' = c + 2*x
974  z' = d*z - y
975  }
976  )";
977  std::string expected_result = R"(
978  DERIVATIVE states {
979  EIGEN_NEWTON_SOLVE[3]{
980  LOCAL a, b, c, d, h, old_x, old_y, old_z
981  }{
982  old_x = x
983  old_y = y
984  old_z = z
985  }{
986  nmodl_eigen_x[0] = x
987  nmodl_eigen_x[1] = y
988  nmodl_eigen_x[2] = z
989  }{
990  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
991  nmodl_eigen_j[0] = -1/dt
992  nmodl_eigen_j[3] = 0
993  nmodl_eigen_j[6] = a
994  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
995  nmodl_eigen_j[1] = 2.0
996  nmodl_eigen_j[4] = -1/dt
997  nmodl_eigen_j[7] = 0
998  nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
999  nmodl_eigen_j[2] = 0
1000  nmodl_eigen_j[5] = -1.0
1001  nmodl_eigen_j[8] = d-1/dt
1002  }{
1003  x = nmodl_eigen_x[0]
1004  y = nmodl_eigen_x[1]
1005  z = nmodl_eigen_x[2]
1006  }{
1007  }
1008  })";
1009  std::string expected_cse_result = R"(
1010  DERIVATIVE states {
1011  EIGEN_NEWTON_SOLVE[3]{
1012  LOCAL a, b, c, d, h, old_x, old_y, old_z
1013  }{
1014  old_x = x
1015  old_y = y
1016  old_z = z
1017  }{
1018  nmodl_eigen_x[0] = x
1019  nmodl_eigen_x[1] = y
1020  nmodl_eigen_x[2] = z
1021  }{
1022  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1023  nmodl_eigen_j[0] = -1/dt
1024  nmodl_eigen_j[3] = 0
1025  nmodl_eigen_j[6] = a
1026  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1027  nmodl_eigen_j[1] = 2.0
1028  nmodl_eigen_j[4] = -1/dt
1029  nmodl_eigen_j[7] = 0
1030  nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1031  nmodl_eigen_j[2] = 0
1032  nmodl_eigen_j[5] = -1.0
1033  nmodl_eigen_j[8] = d-1/dt
1034  }{
1035  x = nmodl_eigen_x[0]
1036  y = nmodl_eigen_x[1]
1037  z = nmodl_eigen_x[2]
1038  }{
1039  }
1040  })";
1041 
1042  THEN("Construct & solve linear system for backwards Euler") {
1043  auto result =
1044  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1045  auto result_cse =
1046  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::DERIVATIVE_BLOCK);
1047 
1048  compare_blocks(result[0], reindent_text(expected_result));
1049  compare_blocks(result_cse[0], reindent_text(expected_cse_result));
1050  }
1051  }
1052  GIVEN("Derivative block including ODES with sparse method (from nmodl paper)") {
1053  std::string nmodl_text = R"(
1054  STATE {
1055  mc m
1056  }
1057  BREAKPOINT {
1058  SOLVE scheme1 METHOD sparse
1059  }
1060  DERIVATIVE scheme1 {
1061  mc' = -a*mc + b*m
1062  m' = a*mc - b*m
1063  }
1064  )";
1065  std::string expected_result = R"(
1066  DERIVATIVE scheme1 {
1067  EIGEN_NEWTON_SOLVE[2]{
1068  LOCAL old_mc, old_m
1069  }{
1070  old_mc = mc
1071  old_m = m
1072  }{
1073  nmodl_eigen_x[0] = mc
1074  nmodl_eigen_x[1] = m
1075  }{
1076  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1077  nmodl_eigen_j[0] = -a-1/dt
1078  nmodl_eigen_j[2] = b
1079  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1080  nmodl_eigen_j[1] = a
1081  nmodl_eigen_j[3] = -b-1/dt
1082  }{
1083  mc = nmodl_eigen_x[0]
1084  m = nmodl_eigen_x[1]
1085  }{
1086  }
1087  })";
1088  THEN("Construct & solve linear system") {
1089  CAPTURE(nmodl_text);
1090  auto result =
1091  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1092  compare_blocks(result[0], reindent_text(expected_result));
1093  }
1094  }
1095  GIVEN("Derivative block with ODES with sparse method, CONSERVE statement of form m = ...") {
1096  std::string nmodl_text = R"(
1097  STATE {
1098  mc m
1099  }
1100  BREAKPOINT {
1101  SOLVE scheme1 METHOD sparse
1102  }
1103  DERIVATIVE scheme1 {
1104  mc' = -a*mc + b*m
1105  m' = a*mc - b*m
1106  CONSERVE m = 1 - mc
1107  }
1108  )";
1109  std::string expected_result = R"(
1110  DERIVATIVE scheme1 {
1111  EIGEN_NEWTON_SOLVE[2]{
1112  LOCAL old_mc
1113  }{
1114  old_mc = mc
1115  }{
1116  nmodl_eigen_x[0] = mc
1117  nmodl_eigen_x[1] = m
1118  }{
1119  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1120  nmodl_eigen_j[0] = -a-1/dt
1121  nmodl_eigen_j[2] = b
1122  nmodl_eigen_f[1] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]+1.0
1123  nmodl_eigen_j[1] = -1.0
1124  nmodl_eigen_j[3] = -1.0
1125  }{
1126  mc = nmodl_eigen_x[0]
1127  m = nmodl_eigen_x[1]
1128  }{
1129  }
1130  })";
1131  THEN("Construct & solve linear system, replace ODE for m with rhs of CONSERVE statement") {
1132  CAPTURE(nmodl_text);
1133  auto result =
1134  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1135  compare_blocks(result[0], reindent_text(expected_result));
1136  }
1137  }
1138  GIVEN(
1139  "Derivative block with ODES with sparse method, invalid CONSERVE statement of form m + mc "
1140  "= ...") {
1141  std::string nmodl_text = R"(
1142  STATE {
1143  mc m
1144  }
1145  BREAKPOINT {
1146  SOLVE scheme1 METHOD sparse
1147  }
1148  DERIVATIVE scheme1 {
1149  mc' = -a*mc + b*m
1150  m' = a*mc - b*m
1151  CONSERVE m + mc = 1
1152  }
1153  )";
1154  std::string expected_result = R"(
1155  DERIVATIVE scheme1 {
1156  EIGEN_NEWTON_SOLVE[2]{
1157  LOCAL old_mc, old_m
1158  }{
1159  old_mc = mc
1160  old_m = m
1161  }{
1162  nmodl_eigen_x[0] = mc
1163  nmodl_eigen_x[1] = m
1164  }{
1165  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1166  nmodl_eigen_j[0] = -a-1/dt
1167  nmodl_eigen_j[2] = b
1168  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1169  nmodl_eigen_j[1] = a
1170  nmodl_eigen_j[3] = -b-1/dt
1171  }{
1172  mc = nmodl_eigen_x[0]
1173  m = nmodl_eigen_x[1]
1174  }{
1175  }
1176  })";
1177  THEN("Construct & solve linear system, ignore invalid CONSERVE statement") {
1178  CAPTURE(nmodl_text);
1179  auto result =
1180  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1181  compare_blocks(result[0], reindent_text(expected_result));
1182  }
1183  }
1184  GIVEN("Derivative block with ODES with sparse method, two CONSERVE statements") {
1185  std::string nmodl_text = R"(
1186  STATE {
1187  c1 o1 o2 p0 p1
1188  }
1189  BREAKPOINT {
1190  SOLVE ihkin METHOD sparse
1191  }
1192  DERIVATIVE ihkin {
1193  LOCAL alpha, beta, k3p, k4, k1ca, k2
1194  evaluate_fct(v, cai)
1195  CONSERVE p1 = 1-p0
1196  CONSERVE o2 = 1-c1-o1
1197  c1' = (-1*(alpha*c1-beta*o1))
1198  o1' = (1*(alpha*c1-beta*o1))+(-1*(k3p*o1-k4*o2))
1199  o2' = (1*(k3p*o1-k4*o2))
1200  p0' = (-1*(k1ca*p0-k2*p1))
1201  p1' = (1*(k1ca*p0-k2*p1))
1202  })";
1203  std::string expected_result = R"(
1204  DERIVATIVE ihkin {
1205  EIGEN_NEWTON_SOLVE[5]{
1206  LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
1207  }{
1208  evaluate_fct(v, cai)
1209  old_c1 = c1
1210  old_o1 = o1
1211  old_p0 = p0
1212  }{
1213  nmodl_eigen_x[0] = c1
1214  nmodl_eigen_x[1] = o1
1215  nmodl_eigen_x[2] = o2
1216  nmodl_eigen_x[3] = p0
1217  nmodl_eigen_x[4] = p1
1218  }{
1219  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*alpha+nmodl_eigen_x[1]*beta)+old_c1)/dt
1220  nmodl_eigen_j[0] = -alpha-1/dt
1221  nmodl_eigen_j[5] = beta
1222  nmodl_eigen_j[10] = 0
1223  nmodl_eigen_j[15] = 0
1224  nmodl_eigen_j[20] = 0
1225  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*alpha-nmodl_eigen_x[1]*beta-nmodl_eigen_x[1]*k3p+nmodl_eigen_x[2]*k4)+old_o1)/dt
1226  nmodl_eigen_j[1] = alpha
1227  nmodl_eigen_j[6] = -beta-k3p-1/dt
1228  nmodl_eigen_j[11] = k4
1229  nmodl_eigen_j[16] = 0
1230  nmodl_eigen_j[21] = 0
1231  nmodl_eigen_f[2] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]+1.0
1232  nmodl_eigen_j[2] = -1.0
1233  nmodl_eigen_j[7] = -1.0
1234  nmodl_eigen_j[12] = -1.0
1235  nmodl_eigen_j[17] = 0
1236  nmodl_eigen_j[22] = 0
1237  nmodl_eigen_f[3] = (-nmodl_eigen_x[3]+dt*(-nmodl_eigen_x[3]*k1ca+nmodl_eigen_x[4]*k2)+old_p0)/dt
1238  nmodl_eigen_j[3] = 0
1239  nmodl_eigen_j[8] = 0
1240  nmodl_eigen_j[13] = 0
1241  nmodl_eigen_j[18] = -k1ca-1/dt
1242  nmodl_eigen_j[23] = k2
1243  nmodl_eigen_f[4] = -nmodl_eigen_x[3]-nmodl_eigen_x[4]+1.0
1244  nmodl_eigen_j[4] = 0
1245  nmodl_eigen_j[9] = 0
1246  nmodl_eigen_j[14] = 0
1247  nmodl_eigen_j[19] = -1.0
1248  nmodl_eigen_j[24] = -1.0
1249  }{
1250  c1 = nmodl_eigen_x[0]
1251  o1 = nmodl_eigen_x[1]
1252  o2 = nmodl_eigen_x[2]
1253  p0 = nmodl_eigen_x[3]
1254  p1 = nmodl_eigen_x[4]
1255  }{
1256  }
1257  })";
1258  THEN(
1259  "Construct & solve linear system, replacing ODEs for p1 and o2 with CONSERVE statement "
1260  "algebraic relations") {
1261  CAPTURE(nmodl_text);
1262  auto result =
1263  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1264  compare_blocks(result[0], reindent_text(expected_result));
1265  }
1266  }
1267  GIVEN("Derivative block including ODES with sparse method - single var in array") {
1268  std::string nmodl_text = R"(
1269  STATE {
1270  W[1]
1271  }
1272  ASSIGNED {
1273  A[2]
1274  B[1]
1275  }
1276  BREAKPOINT {
1277  SOLVE scheme1 METHOD sparse
1278  }
1279  DERIVATIVE scheme1 {
1280  W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1281  }
1282  )";
1283  std::string expected_result = R"(
1284  DERIVATIVE scheme1 {
1285  EIGEN_NEWTON_SOLVE[1]{
1286  LOCAL old_W_0
1287  }{
1288  old_W_0 = W[0]
1289  }{
1290  nmodl_eigen_x[0] = W[0]
1291  }{
1292  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1293  nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1294  }{
1295  W[0] = nmodl_eigen_x[0]
1296  }{
1297  }
1298  })";
1299  THEN("Construct & solver linear system") {
1300  CAPTURE(nmodl_text);
1301  auto result =
1302  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1303  compare_blocks(result[0], reindent_text(expected_result));
1304  }
1305  }
1306  GIVEN("Derivative block including ODES with sparse method - array vars") {
1307  std::string nmodl_text = R"(
1308  STATE {
1309  M[2]
1310  }
1311  ASSIGNED {
1312  A[2]
1313  B[2]
1314  }
1315  BREAKPOINT {
1316  SOLVE scheme1 METHOD sparse
1317  }
1318  DERIVATIVE scheme1 {
1319  M'[0] = -A[0]*M[0] + B[0]*M[1]
1320  M'[1] = A[1]*M[0] - B[1]*M[1]
1321  }
1322  )";
1323  std::string expected_result = R"(
1324  DERIVATIVE scheme1 {
1325  EIGEN_NEWTON_SOLVE[2]{
1326  LOCAL old_M_0, old_M_1
1327  }{
1328  old_M_0 = M[0]
1329  old_M_1 = M[1]
1330  }{
1331  nmodl_eigen_x[0] = M[0]
1332  nmodl_eigen_x[1] = M[1]
1333  }{
1334  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[1]*B[0])+old_M_0)/dt
1335  nmodl_eigen_j[0] = -A[0]-1/dt
1336  nmodl_eigen_j[2] = B[0]
1337  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*A[1]-nmodl_eigen_x[1]*B[1])+old_M_1)/dt
1338  nmodl_eigen_j[1] = A[1]
1339  nmodl_eigen_j[3] = -B[1]-1/dt
1340  }{
1341  M[0] = nmodl_eigen_x[0]
1342  M[1] = nmodl_eigen_x[1]
1343  }{
1344  }
1345  })";
1346  THEN("Construct & solver linear system") {
1347  CAPTURE(nmodl_text);
1348  auto result =
1349  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1350  compare_blocks(result[0], reindent_text(expected_result));
1351  }
1352  }
1353  GIVEN("Derivative block including ODES with derivimplicit method - single var in array") {
1354  std::string nmodl_text = R"(
1355  STATE {
1356  W[1]
1357  }
1358  ASSIGNED {
1359  A[2]
1360  B[1]
1361  }
1362  BREAKPOINT {
1363  SOLVE scheme1 METHOD derivimplicit
1364  }
1365  DERIVATIVE scheme1 {
1366  W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1367  }
1368  )";
1369  std::string expected_result = R"(
1370  DERIVATIVE scheme1 {
1371  EIGEN_NEWTON_SOLVE[1]{
1372  LOCAL old_W_0
1373  }{
1374  old_W_0 = W[0]
1375  }{
1376  nmodl_eigen_x[0] = W[0]
1377  }{
1378  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1379  nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1380  }{
1381  W[0] = nmodl_eigen_x[0]
1382  }{
1383  }
1384  })";
1385  THEN("Construct newton solve block") {
1386  CAPTURE(nmodl_text);
1387  auto result =
1388  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1389  compare_blocks(result[0], reindent_text(expected_result));
1390  }
1391  }
1392  GIVEN("Derivative block including ODES with derivimplicit method") {
1393  std::string nmodl_text = R"(
1394  STATE {
1395  m h n
1396  }
1397  BREAKPOINT {
1398  SOLVE states METHOD derivimplicit
1399  }
1400  DERIVATIVE states {
1401  rates(v)
1402  m' = (minf-m)/mtau - 3*h
1403  h' = (hinf-h)/htau + m*m
1404  n' = (ninf-n)/ntau
1405  }
1406  )";
1407  /// new derivative block with EigenNewtonSolverBlock node
1408  std::string expected_result = R"(
1409  DERIVATIVE states {
1410  EIGEN_NEWTON_SOLVE[3]{
1411  LOCAL old_m, old_h, old_n
1412  }{
1413  rates(v)
1414  old_m = m
1415  old_h = h
1416  old_n = n
1417  }{
1418  nmodl_eigen_x[0] = m
1419  nmodl_eigen_x[1] = h
1420  nmodl_eigen_x[2] = n
1421  }{
1422  nmodl_eigen_f[0] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt-3.0*nmodl_eigen_x[1]+minf/mtau+old_m/dt
1423  nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1424  nmodl_eigen_j[3] = -3.0
1425  nmodl_eigen_j[6] = 0
1426  nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1427  nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1428  nmodl_eigen_j[4] = (-dt-htau)/(dt*htau)
1429  nmodl_eigen_j[7] = 0
1430  nmodl_eigen_f[2] = (dt*(-nmodl_eigen_x[2]+ninf)+ntau*(-nmodl_eigen_x[2]+old_n))/(dt*ntau)
1431  nmodl_eigen_j[2] = 0
1432  nmodl_eigen_j[5] = 0
1433  nmodl_eigen_j[8] = (-dt-ntau)/(dt*ntau)
1434  }{
1435  m = nmodl_eigen_x[0]
1436  h = nmodl_eigen_x[1]
1437  n = nmodl_eigen_x[2]
1438  }{
1439  }
1440  })";
1441  THEN("Construct newton solve block") {
1442  CAPTURE(nmodl_text);
1443  auto result =
1444  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1445  compare_blocks(result[0], reindent_text(expected_result));
1446  }
1447  }
1448  GIVEN("Multiple derivative blocks each with derivimplicit method") {
1449  std::string nmodl_text = R"(
1450  STATE {
1451  m h
1452  }
1453  BREAKPOINT {
1454  SOLVE states1 METHOD derivimplicit
1455  SOLVE states2 METHOD derivimplicit
1456  }
1457 
1458  DERIVATIVE states1 {
1459  m' = (minf-m)/mtau
1460  h' = (hinf-h)/htau + m*m
1461  }
1462 
1463  DERIVATIVE states2 {
1464  h' = (hinf-h)/htau + m*m
1465  m' = (minf-m)/mtau + h
1466  }
1467  )";
1468  /// EigenNewtonSolverBlock in each derivative block
1469  std::string expected_result_0 = R"(
1470  DERIVATIVE states1 {
1471  EIGEN_NEWTON_SOLVE[2]{
1472  LOCAL old_m, old_h
1473  }{
1474  old_m = m
1475  old_h = h
1476  }{
1477  nmodl_eigen_x[0] = m
1478  nmodl_eigen_x[1] = h
1479  }{
1480  nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+minf)+mtau*(-nmodl_eigen_x[0]+old_m))/(dt*mtau)
1481  nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1482  nmodl_eigen_j[2] = 0
1483  nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau- nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1484  nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1485  nmodl_eigen_j[3] = (-dt-htau)/(dt*htau)
1486  }{
1487  m = nmodl_eigen_x[0]
1488  h = nmodl_eigen_x[1]
1489  }{
1490  }
1491  })";
1492  std::string expected_result_1 = R"(
1493  DERIVATIVE states2 {
1494  EIGEN_NEWTON_SOLVE[2]{
1495  LOCAL old_h, old_m
1496  }{
1497  old_h = h
1498  old_m = m
1499  }{
1500  nmodl_eigen_x[0] = m
1501  nmodl_eigen_x[1] = h
1502  }{
1503  nmodl_eigen_f[0] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1504  nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0]
1505  nmodl_eigen_j[2] = (-dt-htau)/(dt*htau)
1506  nmodl_eigen_f[1] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt+nmodl_eigen_x[1]+minf/mtau+old_m/dt
1507  nmodl_eigen_j[1] = (-dt-mtau)/(dt*mtau)
1508  nmodl_eigen_j[3] = 1.0
1509  }{
1510  m = nmodl_eigen_x[0]
1511  h = nmodl_eigen_x[1]
1512  }{
1513  }
1514  })";
1515  THEN("Construct newton solve block") {
1516  auto result =
1517  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1518  CAPTURE(nmodl_text);
1519  compare_blocks(result[0], reindent_text(expected_result_0));
1520  compare_blocks(result[1], reindent_text(expected_result_1));
1521  }
1522  }
1523 }
1524 
1525 
1526 //=============================================================================
1527 // LINEAR solve block tests
1528 //=============================================================================
1529 
1530 SCENARIO("LINEAR solve block (SympySolver Visitor)", "[sympy][linear]") {
1531  GIVEN("1 state-var symbolic LINEAR solve block") {
1532  std::string nmodl_text = R"(
1533  STATE {
1534  x
1535  }
1536  LINEAR lin {
1537  ~ 2*a*x = 1
1538  })";
1539  std::string expected_text = R"(
1540  LINEAR lin {
1541  x = 0.5/a
1542  })";
1543  THEN("solve analytically") {
1544  auto result =
1545  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1546  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1547  }
1548  }
1549  GIVEN("2 state-var LINEAR solve block") {
1550  std::string nmodl_text = R"(
1551  STATE {
1552  x y
1553  }
1554  LINEAR lin {
1555  ~ x + 4*y = 5*a
1556  ~ x - y = 0
1557  })";
1558  std::string expected_text = R"(
1559  LINEAR lin {
1560  x = a
1561  y = a
1562  })";
1563  THEN("solve analytically") {
1564  auto result =
1565  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1566  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1567  }
1568  }
1569  GIVEN("Linear block, print in order, vectors") {
1570  std::string nmodl_text = R"(
1571  STATE {
1572  M[2]
1573  }
1574  LINEAR lin {
1575  ~ M[1] = M[0] + 1
1576  ~ M[0] = 2
1577  })";
1578  std::string expected_result = R"(
1579  LINEAR lin {
1580  M[1] = 3.0
1581  M[0] = 2.0
1582  })";
1583 
1584  THEN("Construct & solve linear system") {
1585  auto result =
1586  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1587 
1588  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1589  }
1590  }
1591  GIVEN("Linear block, by value replacement, interleaved") {
1592  std::string nmodl_text = R"(
1593  STATE {
1594  x y
1595  }
1596  LINEAR lin {
1597  LOCAL a
1598  a = 0
1599  ~ x = y + a
1600  a = 1
1601  ~ y = a
1602  a = 2
1603  })";
1604  std::string expected_result = R"(
1605  LINEAR lin {
1606  LOCAL a
1607  a = 0
1608  x = 2.0*a
1609  a = 1
1610  y = a
1611  a = 2
1612  })";
1613 
1614  THEN("Construct & solve linear system") {
1615  auto result =
1616  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1617 
1618  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1619  }
1620  }
1621  GIVEN("Linear block in control flow block") {
1622  std::string nmodl_text = R"(
1623  STATE {
1624  x y
1625  }
1626  LINEAR lin {
1627  LOCAL a
1628  if (a == 1) {
1629  ~ x = y + a
1630  ~ y = a
1631  }
1632  })";
1633  std::string expected_result = R"(
1634  LINEAR lin {
1635  LOCAL a
1636  IF (a == 1) {
1637  x = 2.0*a
1638  y = a
1639  }
1640  })";
1641 
1642  THEN("Construct & solve linear system") {
1643  auto result =
1644  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1645 
1646  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1647  }
1648  }
1649  GIVEN("Linear block, linear equations mixed with control flow blocks and reassignments") {
1650  std::string nmodl_text = R"(
1651  STATE {
1652  x y
1653  }
1654  LINEAR lin {
1655  LOCAL a
1656  ~ x = y + a
1657  if (a == 1) {
1658  a = a + 1
1659  x = a + 1
1660  }
1661  ~ y = a
1662  })";
1663  std::string expected_result = R"(
1664  LINEAR lin {
1665  LOCAL a
1666  x = 2.0*a
1667  IF (a == 1) {
1668  a = a+1
1669  x = a+1
1670  }
1671  y = a
1672  })";
1673 
1674  THEN("Construct & solve linear system") {
1675  auto result =
1676  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1677 
1678  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1679  }
1680  }
1681  GIVEN("4 state-var LINEAR solve block") {
1682  std::string nmodl_text = R"(
1683  STATE {
1684  w x y z
1685  }
1686  LINEAR lin {
1687  ~ w + z/3.2 = -2.0*y
1688  ~ x + 4*c*y = -5.343*a
1689  ~ a + x/b + z - y = 0.842*b*b
1690  ~ x + 1.3*y - 0.1*z/(a*a*b) = 1.43543/c
1691  })";
1692  std::string expected_text = R"(
1693  LINEAR lin {
1694  EIGEN_LINEAR_SOLVE[4]{
1695  }{
1696  }{
1697  nmodl_eigen_x[0] = w
1698  nmodl_eigen_x[1] = x
1699  nmodl_eigen_x[2] = y
1700  nmodl_eigen_x[3] = z
1701  nmodl_eigen_f[0] = 0
1702  nmodl_eigen_f[1] = 5.343*a
1703  nmodl_eigen_f[2] = a-0.84199999999999997*pow(b, 2)
1704  nmodl_eigen_f[3] = -1.43543/c
1705  nmodl_eigen_j[0] = -1.0
1706  nmodl_eigen_j[4] = 0
1707  nmodl_eigen_j[8] = -2.0
1708  nmodl_eigen_j[12] = -0.3125
1709  nmodl_eigen_j[1] = 0
1710  nmodl_eigen_j[5] = -1.0
1711  nmodl_eigen_j[9] = -4.0*c
1712  nmodl_eigen_j[13] = 0
1713  nmodl_eigen_j[2] = 0
1714  nmodl_eigen_j[6] = -1/b
1715  nmodl_eigen_j[10] = 1.0
1716  nmodl_eigen_j[14] = -1.0
1717  nmodl_eigen_j[3] = 0
1718  nmodl_eigen_j[7] = -1.0
1719  nmodl_eigen_j[11] = -1.3
1720  nmodl_eigen_j[15] = 0.10000000000000001/(pow(a, 2)*b)
1721  }{
1722  w = nmodl_eigen_x[0]
1723  x = nmodl_eigen_x[1]
1724  y = nmodl_eigen_x[2]
1725  z = nmodl_eigen_x[3]
1726  }{
1727  }
1728  })";
1729  THEN("return matrix system to solve") {
1730  auto result =
1731  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1732  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1733  }
1734  }
1735 
1736  GIVEN("LINEAR solve block with an explicit SOLVEFOR statement") {
1737  std::string nmodl_text = R"(
1738  STATE {
1739  x
1740  y
1741  z
1742  }
1743  LINEAR lin SOLVEFOR x, y {
1744  ~ 3 * x = v - y
1745  ~ x = z * y - 5
1746  })";
1747  std::string expected_text = R"(
1748  LINEAR lin SOLVEFOR x,y{
1749  y = (v+15.0)/(3.0*z+1.0)
1750  x = (v*z-5.0)/(3.0*z+1.0)
1751  })";
1752  THEN("solve analytically") {
1753  auto result =
1754  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1755  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1756  }
1757  }
1758 }
1759 
1760 //=============================================================================
1761 // NONLINEAR solve block tests
1762 //=============================================================================
1763 
1764 SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][sympy][nonlinear]") {
1765  GIVEN("1 state-var numeric NONLINEAR solve block") {
1766  std::string nmodl_text = R"(
1767  STATE {
1768  x
1769  }
1770  NONLINEAR nonlin {
1771  ~ x = 5
1772  })";
1773  std::string expected_text = R"(
1774  NONLINEAR nonlin {
1775  EIGEN_NEWTON_SOLVE[1]{
1776  }{
1777  }{
1778  nmodl_eigen_x[0] = x
1779  }{
1780  nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
1781  nmodl_eigen_j[0] = -1.0
1782  }{
1783  x = nmodl_eigen_x[0]
1784  }{
1785  }
1786  })";
1787 
1788  THEN("return F & J for newton solver") {
1789  auto result =
1790  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::NON_LINEAR_BLOCK);
1791  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1792  }
1793  }
1794  GIVEN("array state-var numeric NONLINEAR solve block") {
1795  std::string nmodl_text = R"(
1796  STATE {
1797  s[3]
1798  }
1799  NONLINEAR nonlin {
1800  ~ s[0] = 1
1801  ~ s[1] = 3
1802  ~ s[2] + s[1] = s[0]
1803  })";
1804  std::string expected_text = R"(
1805  NONLINEAR nonlin {
1806  EIGEN_NEWTON_SOLVE[3]{
1807  }{
1808  }{
1809  nmodl_eigen_x[0] = s[0]
1810  nmodl_eigen_x[1] = s[1]
1811  nmodl_eigen_x[2] = s[2]
1812  }{
1813  nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
1814  nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
1815  nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
1816  nmodl_eigen_j[0] = -1.0
1817  nmodl_eigen_j[3] = 0
1818  nmodl_eigen_j[6] = 0
1819  nmodl_eigen_j[1] = 0
1820  nmodl_eigen_j[4] = -1.0
1821  nmodl_eigen_j[7] = 0
1822  nmodl_eigen_j[2] = 1.0
1823  nmodl_eigen_j[5] = -1.0
1824  nmodl_eigen_j[8] = -1.0
1825  }{
1826  s[0] = nmodl_eigen_x[0]
1827  s[1] = nmodl_eigen_x[1]
1828  s[2] = nmodl_eigen_x[2]
1829  }{
1830  }
1831  })";
1832  THEN("return F & J for newton solver") {
1833  auto result =
1834  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::NON_LINEAR_BLOCK);
1835  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1836  }
1837  }
1838 }
1839 SCENARIO("Solve KINETIC block using SympySolver Visitor", "[visitor][solver][sympy][kinetic]") {
1840  GIVEN("KINETIC block with not inlined function should work") {
1841  std::string nmodl_text = R"(
1842  BREAKPOINT {
1843  SOLVE kstates METHOD sparse
1844  }
1845  STATE {
1846  C1
1847  C2
1848  }
1849  FUNCTION alfa(v(mV)) {
1850  alfa = v
1851  }
1852  KINETIC kstates {
1853  ~ C1 <-> C2 (alfa(v), alfa(v))
1854  })";
1855  std::string expected_text = R"(
1856  DERIVATIVE kstates {
1857  EIGEN_NEWTON_SOLVE[2]{
1858  LOCAL kf0_, kb0_, old_C1, old_C2
1859  }{
1860  kb0_ = alfa(v)
1861  kf0_ = alfa(v)
1862  old_C1 = C1
1863  old_C2 = C2
1864  }{
1865  nmodl_eigen_x[0] = C1
1866  nmodl_eigen_x[1] = C2
1867  }{
1868  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1869  nmodl_eigen_j[0] = -kf0_-1/dt
1870  nmodl_eigen_j[2] = kb0_
1871  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1872  nmodl_eigen_j[1] = kf0_
1873  nmodl_eigen_j[3] = -kb0_-1/dt
1874  }{
1875  C1 = nmodl_eigen_x[0]
1876  C2 = nmodl_eigen_x[1]
1877  }{
1878  }
1879  })";
1880  THEN("Run Kinetic and Sympy Visitor") {
1881  std::vector<std::string> result;
1882  REQUIRE_NOTHROW(result = run_sympy_solver_visitor(
1883  nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK, true));
1884  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1885  }
1886  }
1887  GIVEN("Protected names in Sympy are respected") {
1888  std::string nmodl_text = R"(
1889  BREAKPOINT {
1890  SOLVE kstates METHOD sparse
1891  }
1892  STATE {
1893  C1
1894  C2
1895  }
1896  FUNCTION beta(v(mV)) {
1897  beta = v
1898  }
1899  FUNCTION lowergamma(v(mV)) {
1900  lowergamma = v
1901  }
1902  KINETIC kstates {
1903  ~ C1 <-> C2 (beta(v), lowergamma(v))
1904  })";
1905  std::string expected_text = R"(
1906  DERIVATIVE kstates {
1907  EIGEN_NEWTON_SOLVE[2]{
1908  LOCAL kf0_, kb0_, old_C1, old_C2
1909  }{
1910  kf0_ = beta(v)
1911  kb0_ = lowergamma(v)
1912  old_C1 = C1
1913  old_C2 = C2
1914  }{
1915  nmodl_eigen_x[0] = C1
1916  nmodl_eigen_x[1] = C2
1917  }{
1918  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1919  nmodl_eigen_j[0] = -kf0_-1/dt
1920  nmodl_eigen_j[2] = kb0_
1921  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1922  nmodl_eigen_j[1] = kf0_
1923  nmodl_eigen_j[3] = -kb0_-1/dt
1924  }{
1925  C1 = nmodl_eigen_x[0]
1926  C2 = nmodl_eigen_x[1]
1927  }{
1928  }
1929  })";
1930  THEN("Run Kinetic and Sympy Visitor") {
1931  std::vector<std::string> result;
1932  REQUIRE_NOTHROW(result = run_sympy_solver_visitor(
1933  nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK, true));
1934  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1935  }
1936  }
1937 }
Visitor for checking parents of ast nodes
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Class that binds all pieces together for parsing nmodl file.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Perform constant folding of integer/float/double expressions.
Visitor for kinetic block statements
void visit_program(ast::Program &node) override
visit node of type ast::Program
Visitor for printing AST back to NMODL
void visit_program(const ast::Program &node) override
visit node of type ast::Program
Visitor for systems of algebraic and differential equations
void visit_program(ast::Program &node) override
visit node of type ast::Program
Concrete visitor for constructing symbol table from AST.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Visitor for checking parents of ast nodes
int check_ast(const ast::Ast &node)
A small wrapper to have a nicer call in parser.cpp.
Visitor for printing C++ code compatible with legacy api of CoreNEURON
Perform constant folding of integer/float/double expressions.
int nmodl_text
Definition: modl.cpp:58
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
Visitor to inline local procedure and function calls
Visitor for kinetic block statements
Unroll for loop in the AST.
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
Definition: test_utils.cpp:55
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Visitor that solves ODEs using old solvers of NEURON
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
static Node * node(Object *)
Definition: netcvode.cpp:291
static double remove(void *v)
Definition: ocdeck.cpp:205
#define text
Definition: plot.cpp:60
Auto generated AST classes declaration.
Replace solve block statements with actual solution node in the AST.
void compare_blocks(const std::string &result, const std::string &expected, const bool require_fail=false)
Compare nmodl blocks that contain systems of equations (i.e.
std::string ast_to_string(ast::Program &node)
std::vector< std::string > run_sympy_solver_visitor(const std::string &text, bool pade=false, bool cse=false, AstNodeType ret_nodetype=AstNodeType::DIFF_EQ_EXPRESSION, bool kinetic=false)
bool is_unique_vars(std::string result)
SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]")
void run_sympy_visitor_passes(ast::Program &node)
Visitor for systems of algebraic and differential equations
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28