8 #include <catch2/catch_test_macros.hpp>
9 #include <catch2/matchers/catch_matchers_string.hpp>
11 #include <pybind11/embed.h>
12 #include <pybind11/stl.h>
30 using namespace nmodl;
32 using namespace visitor;
34 using namespace test_utils;
36 using Catch::Matchers::ContainsSubstring;
49 const std::string&
text,
52 AstNodeType ret_nodetype = AstNodeType::DIFF_EQ_EXPRESSION,
53 bool kinetic =
false) {
54 std::vector<std::string> results;
90 std::stringstream ss(
result);
93 std::unordered_set<std::string> old_vars;
95 while (getline(ss, token,
' ')) {
96 if (!old_vars.insert(token).second) {
130 const std::string& expected,
131 const bool require_fail =
false) {
132 using namespace pybind11::literals;
135 pybind11::dict(
"result"_a =
result,
"expected"_a = expected,
"is_equal"_a =
false);
137 # Comments are in the doxygen for better highlighting
138 def compare_blocks(result, expected):
142 d = {'\[(\d+)\]':'_\\1', 'pow\((\w+), ?(\d+)\)':'\\1**\\2', 'beta': 'beta_var', 'gamma': 'gamma_var'}
144 for key, val in d.items():
145 out = re.sub(key, val, out)
148 def compare_systems_of_eq(result_dict, expected_dict):
149 from sympy.parsing.sympy_parser import parse_expr
151 for k, v in result_dict.items():
152 if parse_expr(f'simplify(({v})-({expected_dict[k]}))'):
158 expected_dict.clear()
166 # split of sout and a dict with the tmp variables
167 for line in s.split('\n'):
168 line_split = line.lstrip().split('=')
170 if len(line_split) == 2 and line_split[0].startswith('tmp_'):
171 # back-substitution of tmp variables in tmp variables
172 tmp_var = line_split[0].strip()
176 max_tmp = max(max_tmp, int(tmp_var[4:]))
177 for k, v in d.items():
178 line_split[1] = line_split[1].replace(k, f'({v})')
179 d[tmp_var] = line_split[1]
180 elif 'LOCAL' in line:
181 sout += line.split('tmp_0')[0] + '\n'
185 # Back-substitution of the tmps
186 # so that we do not replace tmp_11 with (tmp_1)1
187 for j in range(max_tmp, -1, -1):
189 sout = sout.replace(k, f'({d[k]})')
193 result = reduce(sanitize(result)).split('\n')
194 expected = reduce(sanitize(expected)).split('\n')
196 if len(result) != len(expected):
201 for token1, token2 in zip(result, expected):
203 if not compare_systems_of_eq(result_dict, expected_dict):
207 eq1 = token1.split('=')
208 eq2 = token2.split('=')
209 if len(eq1) == 2 and len(eq2) == 2:
210 result_dict[eq1[0]] = eq1[1]
211 expected_dict[eq2[0]] = eq2[1]
215 return compare_systems_of_eq(result_dict, expected_dict)
217 is_equal = compare_blocks(result, expected))",
222 if (require_fail == locals[
"is_equal"].cast<bool>()) {
224 REQUIRE(
result != expected);
226 REQUIRE(
result == expected);
253 std::stringstream stream;
258 SCENARIO(
"Check compare_blocks in sympy unit tests",
"[visitor][sympy]") {
259 GIVEN(
"Empty strings") {
260 THEN(
"Strings are equal") {
264 GIVEN(
"Equivalent equation") {
265 THEN(
"Strings are equal") {
269 GIVEN(
"Equivalent systems of equations") {
273 std::string expected = R"(
276 THEN("Systems of equations are equal") {
280 GIVEN(
"Equivalent systems of equations with brackets") {
286 std::string expected = R"(
291 y = pow(a, 2)*a + 2*b-b
293 THEN("Blocks are equal") {
297 GIVEN(
"Different systems of equations (additional space)") {
303 std::string expected = R"(
308 THEN("Blocks are different") {
312 GIVEN(
"Different systems of equations") {
320 std::string expected = R"(
325 THEN("Blocks are different") {
331 SCENARIO(
"Check local vars name-clash prevention",
"[visitor][sympy]") {
338 SOLVE states METHOD sparse
345 THEN("There are no duplicate vars in LOCAL") {
352 GIVEN(
"LOCAL tmp_0") {
358 SOLVE states METHOD sparse
365 THEN("There are no duplicate vars in LOCAL") {
374 SCENARIO(
"Solve ODEs with cnexp or euler method using SympySolverVisitor",
375 "[visitor][sympy][cnexp][euler]") {
376 GIVEN(
"Derivative block without ODE, solver method cnexp") {
379 SOLVE states METHOD cnexp
385 THEN("No ODEs found - do nothing") {
390 GIVEN(
"Derivative block with ODES, solver method is euler") {
393 SOLVE states METHOD euler
401 THEN("Construct forwards Euler solutions") {
403 REQUIRE(
result.size() == 2);
404 REQUIRE(
result[0] ==
"m = (-dt*(m-mInf)+m*mTau)/mTau");
405 REQUIRE(
result[1] ==
"h = (-dt*(h-hInf)+h*hTau)/hTau");
408 GIVEN(
"Derivative block with calling external functions passes sympy") {
411 SOLVE states METHOD euler
419 THEN("Construct forward Euler interpreting external functions as symbols") {
421 REQUIRE(
result.size() == 3);
422 REQUIRE(
result[0] ==
"m = dt*sawtooth(m)+m");
423 REQUIRE(
result[1] ==
"n = dt*sin(n)+n");
424 REQUIRE(
result[2] ==
"p = dt*my_user_func(p)+p");
427 GIVEN(
"Derivative block with ODE, 1 state var in array, solver method euler") {
433 SOLVE states METHOD euler
436 m'[0] = (mInf-m[0])/mTau
439 THEN("Construct forwards Euler solutions") {
441 REQUIRE(
result.size() == 1);
442 REQUIRE(
result[0] ==
"m[0] = (dt*(mInf-m[0])+mTau*m[0])/mTau");
445 GIVEN(
"Derivative block with ODE, 1 state var in array, solver method cnexp") {
451 SOLVE states METHOD cnexp
454 m'[0] = (mInf-m[0])/mTau
457 THEN("Construct forwards Euler solutions") {
459 REQUIRE(
result.size() == 1);
460 REQUIRE(
result[0] ==
"m[0] = mInf-(mInf-m[0])*exp(-dt/mTau)");
463 GIVEN(
"Derivative block with linear ODES, solver method cnexp") {
466 SOLVE states METHOD cnexp
471 h' = hInf/hTau - h/hTau
474 THEN("Integrate equations analytically") {
476 REQUIRE(
result.size() == 2);
477 REQUIRE(
result[0] ==
"m = mInf-(-m+mInf)*exp(-dt/mTau)");
478 REQUIRE(
result[1] ==
"h = hInf-(-h+hInf)*exp(-dt/hTau)");
481 GIVEN(
"Derivative block including non-linear but solvable ODES, solver method cnexp") {
484 SOLVE states METHOD cnexp
491 THEN("Integrate equations analytically") {
493 REQUIRE(
result.size() == 2);
494 REQUIRE(
result[0] ==
"m = mInf-(-m+mInf)*exp(-dt/mTau)");
495 REQUIRE(
result[1] ==
"h = -h/(c2*dt*h-1.0)");
498 GIVEN(
"Derivative block including array of 2 state vars, solver method cnexp") {
501 SOLVE states METHOD cnexp
507 X'[0] = (mInf-X[0])/mTau
508 X'[1] = c2 * X[1]*X[1]
511 THEN("Integrate equations analytically") {
513 REQUIRE(
result.size() == 2);
514 REQUIRE(
result[0] ==
"X[0] = mInf-(mInf-X[0])*exp(-dt/mTau)");
515 REQUIRE(
result[1] ==
"X[1] = -X[1]/(c2*dt*X[1]-1.0)");
518 GIVEN(
"Derivative block including loop over array vars, solver method cnexp") {
522 SOLVE states METHOD cnexp
532 X'[i] = (mInf-X[i])/mTau[i]
536 THEN("Integrate equations analytically") {
538 REQUIRE(
result.size() == 3);
539 REQUIRE(
result[0] ==
"X[0] = mInf-(mInf-X[0])*exp(-dt/mTau[0])");
540 REQUIRE(
result[1] ==
"X[1] = mInf-(mInf-X[1])*exp(-dt/mTau[1])");
541 REQUIRE(
result[2] ==
"X[2] = mInf-(mInf-X[2])*exp(-dt/mTau[2])");
544 GIVEN(
"Derivative block including loop over array vars, solver method euler") {
548 SOLVE states METHOD euler
558 X'[i] = (mInf-X[i])/mTau[i]
562 THEN("Integrate equations analytically") {
564 REQUIRE(
result.size() == 3);
565 REQUIRE(
result[0] ==
"X[0] = (dt*(mInf-X[0])+X[0]*mTau[0])/mTau[0]");
566 REQUIRE(
result[1] ==
"X[1] = (dt*(mInf-X[1])+X[1]*mTau[1])/mTau[1]");
567 REQUIRE(
result[2] ==
"X[2] = (dt*(mInf-X[2])+X[2]*mTau[2])/mTau[2]");
570 GIVEN(
"Derivative block including ODES that can't currently be solved, solver method cnexp") {
573 SOLVE states METHOD cnexp
582 THEN("Integrate equations analytically where possible, otherwise leave untouched") {
584 REQUIRE(
result.size() == 4);
586 REQUIRE((
result[0] ==
"z' = a/z+b/z/z" ||
588 "z = (0.5*pow(a, 2)*pow(z, 2)-a*b*z+pow(b, 2)*log(a*z+b))/pow(a, 3)"));
589 REQUIRE(
result[1] ==
"h = -h/(c2*dt*h-1.0)");
590 REQUIRE(
result[2] ==
"x = a*dt+x");
592 REQUIRE((
result[3] ==
"y' = c3*y*y*y" ||
593 result[3] ==
"y = sqrt(-pow(y, 2)/(2.0*c3*dt*pow(y, 2)-1.0))"));
596 GIVEN(
"Derivative block with cnexp solver method, AST after SympySolver pass") {
599 SOLVE states METHOD cnexp
617 THEN(
"More SympySolver passes do nothing to the AST and don't throw") {
624 SCENARIO(
"Solve ODEs with derivimplicit method using SympySolverVisitor",
625 "[visitor][sympy][derivimplicit]") {
626 GIVEN(
"Derivative block with derivimplicit solver method and conditional block") {
632 SOLVE states METHOD derivimplicit
641 std::string expected_result = R"(
643 EIGEN_NEWTON_SOLVE[1]{
653 nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+mInf)+mTau*(-nmodl_eigen_x[0]+old_m))/(dt*mTau)
654 nmodl_eigen_j[0] = (-dt-mTau)/(dt*mTau)
660 THEN("SympySolver correctly inserts ode to block") {
668 GIVEN(
"Derivative block, sparse, print in order") {
674 SOLVE states METHOD sparse
681 std::string expected_result = R"(
683 EIGEN_NEWTON_SOLVE[2]{
684 LOCAL a, b, old_y, old_x
692 nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_y)/dt
694 nmodl_eigen_j[2] = -1/dt
695 nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_x)/dt
696 nmodl_eigen_j[1] = -1/dt
705 THEN("Construct & solve linear system for backwards Euler") {
712 GIVEN(
"Derivative block, sparse, print in order, vectors") {
718 SOLVE states METHOD sparse
725 std::string expected_result = R"(
727 EIGEN_NEWTON_SOLVE[2]{
728 LOCAL a, b, old_M_1, old_M_0
733 nmodl_eigen_x[0] = M[0]
734 nmodl_eigen_x[1] = M[1]
736 nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_M_1)/dt
738 nmodl_eigen_j[2] = -1/dt
739 nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_M_0)/dt
740 nmodl_eigen_j[1] = -1/dt
743 M[0] = nmodl_eigen_x[0]
744 M[1] = nmodl_eigen_x[1]
749 THEN("Construct & solve linear system for backwards Euler") {
756 GIVEN(
"Derivative block, sparse, derivatives mixed with local variable reassignment") {
762 SOLVE states METHOD sparse
770 std::string expected_result = R"(
772 EIGEN_NEWTON_SOLVE[2]{
773 LOCAL a, b, old_x, old_y
781 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
782 nmodl_eigen_j[0] = -1/dt
785 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
787 nmodl_eigen_j[3] = -1/dt
795 THEN("Construct & solve linear system for backwards Euler") {
803 "Throw exception during derivative variable reassignment interleaved in the differential "
810 SOLVE states METHOD sparse
820 "Throw an error because state variable assignments are not allowed inside the system "
825 Catch::Matchers::ContainsSubstring(
826 "State variable assignment(s) interleaved in system of "
827 "equations/differential equations") &&
828 Catch::Matchers::StartsWith(
"SympyReplaceSolutionsVisitor"));
831 GIVEN(
"Derivative block in control flow block") {
837 SOLVE states METHOD sparse
846 std::string expected_result = R"(
850 EIGEN_NEWTON_SOLVE[2]{
859 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
860 nmodl_eigen_j[0] = -1/dt
862 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
864 nmodl_eigen_j[3] = -1/dt
873 THEN("Construct & solve linear system for backwards Euler") {
881 "Derivative block, sparse, coupled derivatives mixed with reassignment and control flow "
888 SOLVE states METHOD sparse
898 std::string expected_result = R"(
900 EIGEN_NEWTON_SOLVE[2]{
901 LOCAL a, b, old_x, old_y
909 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
910 nmodl_eigen_j[0] = -1/dt
915 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
916 nmodl_eigen_j[1] = 1.0
917 nmodl_eigen_j[3] = a-1/dt
924 std::string expected_result_cse = R"(
926 EIGEN_NEWTON_SOLVE[2]{
927 LOCAL a, b, old_x, old_y
935 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
936 nmodl_eigen_j[0] = -1/dt
941 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
942 nmodl_eigen_j[1] = 1.0
943 nmodl_eigen_j[3] = a-1/dt
951 THEN("Construct & solve linear system for backwards Euler") {
962 GIVEN(
"Derivative block of coupled & linear ODES, solver method sparse") {
968 SOLVE states METHOD sparse
977 std::string expected_result = R"(
979 EIGEN_NEWTON_SOLVE[3]{
980 LOCAL a, b, c, d, h, old_x, old_y, old_z
990 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
991 nmodl_eigen_j[0] = -1/dt
994 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
995 nmodl_eigen_j[1] = 2.0
996 nmodl_eigen_j[4] = -1/dt
998 nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1000 nmodl_eigen_j[5] = -1.0
1001 nmodl_eigen_j[8] = d-1/dt
1003 x = nmodl_eigen_x[0]
1004 y = nmodl_eigen_x[1]
1005 z = nmodl_eigen_x[2]
1009 std::string expected_cse_result = R"(
1011 EIGEN_NEWTON_SOLVE[3]{
1012 LOCAL a, b, c, d, h, old_x, old_y, old_z
1018 nmodl_eigen_x[0] = x
1019 nmodl_eigen_x[1] = y
1020 nmodl_eigen_x[2] = z
1022 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1023 nmodl_eigen_j[0] = -1/dt
1024 nmodl_eigen_j[3] = 0
1025 nmodl_eigen_j[6] = a
1026 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1027 nmodl_eigen_j[1] = 2.0
1028 nmodl_eigen_j[4] = -1/dt
1029 nmodl_eigen_j[7] = 0
1030 nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1031 nmodl_eigen_j[2] = 0
1032 nmodl_eigen_j[5] = -1.0
1033 nmodl_eigen_j[8] = d-1/dt
1035 x = nmodl_eigen_x[0]
1036 y = nmodl_eigen_x[1]
1037 z = nmodl_eigen_x[2]
1042 THEN("Construct & solve linear system for backwards Euler") {
1052 GIVEN(
"Derivative block including ODES with sparse method (from nmodl paper)") {
1058 SOLVE scheme1 METHOD sparse
1060 DERIVATIVE scheme1 {
1065 std::string expected_result = R"(
1066 DERIVATIVE scheme1 {
1067 EIGEN_NEWTON_SOLVE[2]{
1073 nmodl_eigen_x[0] = mc
1074 nmodl_eigen_x[1] = m
1076 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1077 nmodl_eigen_j[0] = -a-1/dt
1078 nmodl_eigen_j[2] = b
1079 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1080 nmodl_eigen_j[1] = a
1081 nmodl_eigen_j[3] = -b-1/dt
1083 mc = nmodl_eigen_x[0]
1084 m = nmodl_eigen_x[1]
1088 THEN("Construct & solve linear system") {
1095 GIVEN(
"Derivative block with ODES with sparse method, CONSERVE statement of form m = ...") {
1101 SOLVE scheme1 METHOD sparse
1103 DERIVATIVE scheme1 {
1109 std::string expected_result = R"(
1110 DERIVATIVE scheme1 {
1111 EIGEN_NEWTON_SOLVE[2]{
1116 nmodl_eigen_x[0] = mc
1117 nmodl_eigen_x[1] = m
1119 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1120 nmodl_eigen_j[0] = -a-1/dt
1121 nmodl_eigen_j[2] = b
1122 nmodl_eigen_f[1] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]+1.0
1123 nmodl_eigen_j[1] = -1.0
1124 nmodl_eigen_j[3] = -1.0
1126 mc = nmodl_eigen_x[0]
1127 m = nmodl_eigen_x[1]
1131 THEN("Construct & solve linear system, replace ODE for m with rhs of CONSERVE statement") {
1139 "Derivative block with ODES with sparse method, invalid CONSERVE statement of form m + mc "
1146 SOLVE scheme1 METHOD sparse
1148 DERIVATIVE scheme1 {
1154 std::string expected_result = R"(
1155 DERIVATIVE scheme1 {
1156 EIGEN_NEWTON_SOLVE[2]{
1162 nmodl_eigen_x[0] = mc
1163 nmodl_eigen_x[1] = m
1165 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1166 nmodl_eigen_j[0] = -a-1/dt
1167 nmodl_eigen_j[2] = b
1168 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1169 nmodl_eigen_j[1] = a
1170 nmodl_eigen_j[3] = -b-1/dt
1172 mc = nmodl_eigen_x[0]
1173 m = nmodl_eigen_x[1]
1177 THEN("Construct & solve linear system, ignore invalid CONSERVE statement") {
1184 GIVEN(
"Derivative block with ODES with sparse method, two CONSERVE statements") {
1190 SOLVE ihkin METHOD sparse
1193 LOCAL alpha, beta, k3p, k4, k1ca, k2
1194 evaluate_fct(v, cai)
1196 CONSERVE o2 = 1-c1-o1
1197 c1' = (-1*(alpha*c1-beta*o1))
1198 o1' = (1*(alpha*c1-beta*o1))+(-1*(k3p*o1-k4*o2))
1199 o2' = (1*(k3p*o1-k4*o2))
1200 p0' = (-1*(k1ca*p0-k2*p1))
1201 p1' = (1*(k1ca*p0-k2*p1))
1203 std::string expected_result = R"(
1205 EIGEN_NEWTON_SOLVE[5]{
1206 LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
1208 evaluate_fct(v, cai)
1213 nmodl_eigen_x[0] = c1
1214 nmodl_eigen_x[1] = o1
1215 nmodl_eigen_x[2] = o2
1216 nmodl_eigen_x[3] = p0
1217 nmodl_eigen_x[4] = p1
1219 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*alpha+nmodl_eigen_x[1]*beta)+old_c1)/dt
1220 nmodl_eigen_j[0] = -alpha-1/dt
1221 nmodl_eigen_j[5] = beta
1222 nmodl_eigen_j[10] = 0
1223 nmodl_eigen_j[15] = 0
1224 nmodl_eigen_j[20] = 0
1225 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*alpha-nmodl_eigen_x[1]*beta-nmodl_eigen_x[1]*k3p+nmodl_eigen_x[2]*k4)+old_o1)/dt
1226 nmodl_eigen_j[1] = alpha
1227 nmodl_eigen_j[6] = -beta-k3p-1/dt
1228 nmodl_eigen_j[11] = k4
1229 nmodl_eigen_j[16] = 0
1230 nmodl_eigen_j[21] = 0
1231 nmodl_eigen_f[2] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]+1.0
1232 nmodl_eigen_j[2] = -1.0
1233 nmodl_eigen_j[7] = -1.0
1234 nmodl_eigen_j[12] = -1.0
1235 nmodl_eigen_j[17] = 0
1236 nmodl_eigen_j[22] = 0
1237 nmodl_eigen_f[3] = (-nmodl_eigen_x[3]+dt*(-nmodl_eigen_x[3]*k1ca+nmodl_eigen_x[4]*k2)+old_p0)/dt
1238 nmodl_eigen_j[3] = 0
1239 nmodl_eigen_j[8] = 0
1240 nmodl_eigen_j[13] = 0
1241 nmodl_eigen_j[18] = -k1ca-1/dt
1242 nmodl_eigen_j[23] = k2
1243 nmodl_eigen_f[4] = -nmodl_eigen_x[3]-nmodl_eigen_x[4]+1.0
1244 nmodl_eigen_j[4] = 0
1245 nmodl_eigen_j[9] = 0
1246 nmodl_eigen_j[14] = 0
1247 nmodl_eigen_j[19] = -1.0
1248 nmodl_eigen_j[24] = -1.0
1250 c1 = nmodl_eigen_x[0]
1251 o1 = nmodl_eigen_x[1]
1252 o2 = nmodl_eigen_x[2]
1253 p0 = nmodl_eigen_x[3]
1254 p1 = nmodl_eigen_x[4]
1259 "Construct & solve linear system, replacing ODEs for p1 and o2 with CONSERVE statement "
1260 "algebraic relations") {
1267 GIVEN(
"Derivative block including ODES with sparse method - single var in array") {
1277 SOLVE scheme1 METHOD sparse
1279 DERIVATIVE scheme1 {
1280 W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1283 std::string expected_result = R"(
1284 DERIVATIVE scheme1 {
1285 EIGEN_NEWTON_SOLVE[1]{
1290 nmodl_eigen_x[0] = W[0]
1292 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1293 nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1295 W[0] = nmodl_eigen_x[0]
1299 THEN("Construct & solver linear system") {
1306 GIVEN(
"Derivative block including ODES with sparse method - array vars") {
1316 SOLVE scheme1 METHOD sparse
1318 DERIVATIVE scheme1 {
1319 M'[0] = -A[0]*M[0] + B[0]*M[1]
1320 M'[1] = A[1]*M[0] - B[1]*M[1]
1323 std::string expected_result = R"(
1324 DERIVATIVE scheme1 {
1325 EIGEN_NEWTON_SOLVE[2]{
1326 LOCAL old_M_0, old_M_1
1331 nmodl_eigen_x[0] = M[0]
1332 nmodl_eigen_x[1] = M[1]
1334 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[1]*B[0])+old_M_0)/dt
1335 nmodl_eigen_j[0] = -A[0]-1/dt
1336 nmodl_eigen_j[2] = B[0]
1337 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*A[1]-nmodl_eigen_x[1]*B[1])+old_M_1)/dt
1338 nmodl_eigen_j[1] = A[1]
1339 nmodl_eigen_j[3] = -B[1]-1/dt
1341 M[0] = nmodl_eigen_x[0]
1342 M[1] = nmodl_eigen_x[1]
1346 THEN("Construct & solver linear system") {
1353 GIVEN(
"Derivative block including ODES with derivimplicit method - single var in array") {
1363 SOLVE scheme1 METHOD derivimplicit
1365 DERIVATIVE scheme1 {
1366 W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1369 std::string expected_result = R"(
1370 DERIVATIVE scheme1 {
1371 EIGEN_NEWTON_SOLVE[1]{
1376 nmodl_eigen_x[0] = W[0]
1378 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1379 nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1381 W[0] = nmodl_eigen_x[0]
1385 THEN("Construct newton solve block") {
1392 GIVEN(
"Derivative block including ODES with derivimplicit method") {
1398 SOLVE states METHOD derivimplicit
1402 m' = (minf-m)/mtau - 3*h
1403 h' = (hinf-h)/htau + m*m
1408 std::string expected_result = R
"(
1410 EIGEN_NEWTON_SOLVE[3]{
1411 LOCAL old_m, old_h, old_n
1418 nmodl_eigen_x[0] = m
1419 nmodl_eigen_x[1] = h
1420 nmodl_eigen_x[2] = n
1422 nmodl_eigen_f[0] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt-3.0*nmodl_eigen_x[1]+minf/mtau+old_m/dt
1423 nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1424 nmodl_eigen_j[3] = -3.0
1425 nmodl_eigen_j[6] = 0
1426 nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1427 nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1428 nmodl_eigen_j[4] = (-dt-htau)/(dt*htau)
1429 nmodl_eigen_j[7] = 0
1430 nmodl_eigen_f[2] = (dt*(-nmodl_eigen_x[2]+ninf)+ntau*(-nmodl_eigen_x[2]+old_n))/(dt*ntau)
1431 nmodl_eigen_j[2] = 0
1432 nmodl_eigen_j[5] = 0
1433 nmodl_eigen_j[8] = (-dt-ntau)/(dt*ntau)
1435 m = nmodl_eigen_x[0]
1436 h = nmodl_eigen_x[1]
1437 n = nmodl_eigen_x[2]
1441 THEN("Construct newton solve block") {
1448 GIVEN(
"Multiple derivative blocks each with derivimplicit method") {
1454 SOLVE states1 METHOD derivimplicit
1455 SOLVE states2 METHOD derivimplicit
1458 DERIVATIVE states1 {
1460 h' = (hinf-h)/htau + m*m
1463 DERIVATIVE states2 {
1464 h' = (hinf-h)/htau + m*m
1465 m' = (minf-m)/mtau + h
1469 std::string expected_result_0 = R
"(
1470 DERIVATIVE states1 {
1471 EIGEN_NEWTON_SOLVE[2]{
1477 nmodl_eigen_x[0] = m
1478 nmodl_eigen_x[1] = h
1480 nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+minf)+mtau*(-nmodl_eigen_x[0]+old_m))/(dt*mtau)
1481 nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1482 nmodl_eigen_j[2] = 0
1483 nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau- nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1484 nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1485 nmodl_eigen_j[3] = (-dt-htau)/(dt*htau)
1487 m = nmodl_eigen_x[0]
1488 h = nmodl_eigen_x[1]
1492 std::string expected_result_1 = R"(
1493 DERIVATIVE states2 {
1494 EIGEN_NEWTON_SOLVE[2]{
1500 nmodl_eigen_x[0] = m
1501 nmodl_eigen_x[1] = h
1503 nmodl_eigen_f[0] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1504 nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0]
1505 nmodl_eigen_j[2] = (-dt-htau)/(dt*htau)
1506 nmodl_eigen_f[1] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt+nmodl_eigen_x[1]+minf/mtau+old_m/dt
1507 nmodl_eigen_j[1] = (-dt-mtau)/(dt*mtau)
1508 nmodl_eigen_j[3] = 1.0
1510 m = nmodl_eigen_x[0]
1511 h = nmodl_eigen_x[1]
1515 THEN("Construct newton solve block") {
1530 SCENARIO(
"LINEAR solve block (SympySolver Visitor)",
"[sympy][linear]") {
1531 GIVEN(
"1 state-var symbolic LINEAR solve block") {
1539 std::string expected_text = R"(
1543 THEN("solve analytically") {
1549 GIVEN(
"2 state-var LINEAR solve block") {
1558 std::string expected_text = R"(
1563 THEN("solve analytically") {
1569 GIVEN(
"Linear block, print in order, vectors") {
1578 std::string expected_result = R"(
1584 THEN("Construct & solve linear system") {
1591 GIVEN(
"Linear block, by value replacement, interleaved") {
1604 std::string expected_result = R"(
1614 THEN("Construct & solve linear system") {
1621 GIVEN(
"Linear block in control flow block") {
1633 std::string expected_result = R"(
1642 THEN("Construct & solve linear system") {
1649 GIVEN(
"Linear block, linear equations mixed with control flow blocks and reassignments") {
1663 std::string expected_result = R"(
1674 THEN("Construct & solve linear system") {
1681 GIVEN(
"4 state-var LINEAR solve block") {
1687 ~ w + z/3.2 = -2.0*y
1688 ~ x + 4*c*y = -5.343*a
1689 ~ a + x/b + z - y = 0.842*b*b
1690 ~ x + 1.3*y - 0.1*z/(a*a*b) = 1.43543/c
1692 std::string expected_text = R"(
1694 EIGEN_LINEAR_SOLVE[4]{
1697 nmodl_eigen_x[0] = w
1698 nmodl_eigen_x[1] = x
1699 nmodl_eigen_x[2] = y
1700 nmodl_eigen_x[3] = z
1701 nmodl_eigen_f[0] = 0
1702 nmodl_eigen_f[1] = 5.343*a
1703 nmodl_eigen_f[2] = a-0.84199999999999997*pow(b, 2)
1704 nmodl_eigen_f[3] = -1.43543/c
1705 nmodl_eigen_j[0] = -1.0
1706 nmodl_eigen_j[4] = 0
1707 nmodl_eigen_j[8] = -2.0
1708 nmodl_eigen_j[12] = -0.3125
1709 nmodl_eigen_j[1] = 0
1710 nmodl_eigen_j[5] = -1.0
1711 nmodl_eigen_j[9] = -4.0*c
1712 nmodl_eigen_j[13] = 0
1713 nmodl_eigen_j[2] = 0
1714 nmodl_eigen_j[6] = -1/b
1715 nmodl_eigen_j[10] = 1.0
1716 nmodl_eigen_j[14] = -1.0
1717 nmodl_eigen_j[3] = 0
1718 nmodl_eigen_j[7] = -1.0
1719 nmodl_eigen_j[11] = -1.3
1720 nmodl_eigen_j[15] = 0.10000000000000001/(pow(a, 2)*b)
1722 w = nmodl_eigen_x[0]
1723 x = nmodl_eigen_x[1]
1724 y = nmodl_eigen_x[2]
1725 z = nmodl_eigen_x[3]
1729 THEN("return matrix system to solve") {
1736 GIVEN(
"LINEAR solve block with an explicit SOLVEFOR statement") {
1743 LINEAR lin SOLVEFOR x, y {
1747 std::string expected_text = R"(
1748 LINEAR lin SOLVEFOR x,y{
1749 y = (v+15.0)/(3.0*z+1.0)
1750 x = (v*z-5.0)/(3.0*z+1.0)
1752 THEN("solve analytically") {
1764 SCENARIO(
"Solve NONLINEAR block using SympySolver Visitor",
"[visitor][solver][sympy][nonlinear]") {
1765 GIVEN(
"1 state-var numeric NONLINEAR solve block") {
1773 std::string expected_text = R"(
1775 EIGEN_NEWTON_SOLVE[1]{
1778 nmodl_eigen_x[0] = x
1780 nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
1781 nmodl_eigen_j[0] = -1.0
1783 x = nmodl_eigen_x[0]
1788 THEN("return F & J for newton solver") {
1794 GIVEN(
"array state-var numeric NONLINEAR solve block") {
1802 ~ s[2] + s[1] = s[0]
1804 std::string expected_text = R"(
1806 EIGEN_NEWTON_SOLVE[3]{
1809 nmodl_eigen_x[0] = s[0]
1810 nmodl_eigen_x[1] = s[1]
1811 nmodl_eigen_x[2] = s[2]
1813 nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
1814 nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
1815 nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
1816 nmodl_eigen_j[0] = -1.0
1817 nmodl_eigen_j[3] = 0
1818 nmodl_eigen_j[6] = 0
1819 nmodl_eigen_j[1] = 0
1820 nmodl_eigen_j[4] = -1.0
1821 nmodl_eigen_j[7] = 0
1822 nmodl_eigen_j[2] = 1.0
1823 nmodl_eigen_j[5] = -1.0
1824 nmodl_eigen_j[8] = -1.0
1826 s[0] = nmodl_eigen_x[0]
1827 s[1] = nmodl_eigen_x[1]
1828 s[2] = nmodl_eigen_x[2]
1832 THEN("return F & J for newton solver") {
1839 SCENARIO(
"Solve KINETIC block using SympySolver Visitor",
"[visitor][solver][sympy][kinetic]") {
1840 GIVEN(
"KINETIC block with not inlined function should work") {
1843 SOLVE kstates METHOD sparse
1849 FUNCTION alfa(v(mV)) {
1853 ~ C1 <-> C2 (alfa(v), alfa(v))
1855 std::string expected_text = R"(
1856 DERIVATIVE kstates {
1857 EIGEN_NEWTON_SOLVE[2]{
1858 LOCAL kf0_, kb0_, old_C1, old_C2
1865 nmodl_eigen_x[0] = C1
1866 nmodl_eigen_x[1] = C2
1868 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1869 nmodl_eigen_j[0] = -kf0_-1/dt
1870 nmodl_eigen_j[2] = kb0_
1871 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1872 nmodl_eigen_j[1] = kf0_
1873 nmodl_eigen_j[3] = -kb0_-1/dt
1875 C1 = nmodl_eigen_x[0]
1876 C2 = nmodl_eigen_x[1]
1880 THEN("Run Kinetic and Sympy Visitor") {
1881 std::vector<std::string>
result;
1883 nmodl_text,
false,
false, AstNodeType::DERIVATIVE_BLOCK,
true));
1887 GIVEN(
"Protected names in Sympy are respected") {
1890 SOLVE kstates METHOD sparse
1896 FUNCTION beta(v(mV)) {
1899 FUNCTION lowergamma(v(mV)) {
1903 ~ C1 <-> C2 (beta(v), lowergamma(v))
1905 std::string expected_text = R"(
1906 DERIVATIVE kstates {
1907 EIGEN_NEWTON_SOLVE[2]{
1908 LOCAL kf0_, kb0_, old_C1, old_C2
1911 kb0_ = lowergamma(v)
1915 nmodl_eigen_x[0] = C1
1916 nmodl_eigen_x[1] = C2
1918 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1919 nmodl_eigen_j[0] = -kf0_-1/dt
1920 nmodl_eigen_j[2] = kb0_
1921 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1922 nmodl_eigen_j[1] = kf0_
1923 nmodl_eigen_j[3] = -kb0_-1/dt
1925 C1 = nmodl_eigen_x[0]
1926 C2 = nmodl_eigen_x[1]
1930 THEN("Run Kinetic and Sympy Visitor") {
1931 std::vector<std::string>
result;
1933 nmodl_text,
false,
false, AstNodeType::DERIVATIVE_BLOCK,
true));
Visitor for checking parents of ast nodes
Represents top level AST node for whole NMODL input.
Class that binds all pieces together for parsing nmodl file.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Perform constant folding of integer/float/double expressions.
Visitor for kinetic block statements
void visit_program(ast::Program &node) override
visit node of type ast::Program
Unroll for loop in the AST.
Visitor for printing AST back to NMODL
void visit_program(const ast::Program &node) override
visit node of type ast::Program
Visitor for systems of algebraic and differential equations
void visit_program(ast::Program &node) override
visit node of type ast::Program
Concrete visitor for constructing symbol table from AST.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Visitor for checking parents of ast nodes
int check_ast(const ast::Ast &node)
A small wrapper to have a nicer call in parser.cpp.
Visitor for printing C++ code compatible with legacy api of CoreNEURON
Perform constant folding of integer/float/double expressions.
AstNodeType
Enum type for every AST node type.
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Visitor to inline local procedure and function calls
Visitor for kinetic block statements
Unroll for loop in the AST.
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
encapsulates code generation backend implementations
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Visitor that solves ODEs using old solvers of NEURON
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
static Node * node(Object *)
static double remove(void *v)
Auto generated AST classes declaration.
Replace solve block statements with actual solution node in the AST.
void compare_blocks(const std::string &result, const std::string &expected, const bool require_fail=false)
Compare nmodl blocks that contain systems of equations (i.e.
std::string ast_to_string(ast::Program &node)
std::vector< std::string > run_sympy_solver_visitor(const std::string &text, bool pade=false, bool cse=false, AstNodeType ret_nodetype=AstNodeType::DIFF_EQ_EXPRESSION, bool kinetic=false)
bool is_unique_vars(std::string result)
SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]")
void run_sympy_visitor_passes(ast::Program &node)
Visitor for systems of algebraic and differential equations
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::parser::UnitDriver driver