NEURON
sympy_solver_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
10 
11 #include "ast/all.hpp"
13 #include "pybind/pyembed.hpp"
14 #include "symtab/symbol.hpp"
15 #include "utils/logger.hpp"
16 #include "utils/string_utils.hpp"
18 
19 
20 namespace pywrap = nmodl::pybind_wrappers;
21 
22 namespace nmodl {
23 namespace visitor {
24 
26 
28  // clear any previous data
29  expression_statements.clear();
30  eq_system.clear();
31  state_vars_in_block.clear();
32  last_expression_statement = nullptr;
34  eq_system_is_valid = true;
35  conserve_equation.clear();
36  // get set of local block vars & global vars
37  vars = global_vars;
38  if (auto symtab = node->get_statement_block()->get_symbol_table()) {
39  auto localvars = symtab->get_variables_with_properties(NmodlType::local_var);
40  for (const auto& localvar: localvars) {
41  std::string var_name = localvar->get_name();
42  if (localvar->is_array()) {
43  var_name += "[" + std::to_string(localvar->get_length()) + "]";
44  }
45  vars.insert(var_name);
46  }
47  }
48  const auto& fcall_nodes = collect_nodes(*node->get_statement_block(),
49  {ast::AstNodeType::FUNCTION_CALL});
50  for (const auto& call: fcall_nodes) {
51  function_calls.insert(call->get_node_name());
52  }
53 }
54 
56  state_vars.clear();
57  for (const auto& state_var: all_state_vars) {
58  if (state_vars_in_block.find(state_var) != state_vars_in_block.cend()) {
59  state_vars.push_back(state_var);
60  }
61  }
62 
63  // in case we have a SOLVEFOR in the block, we need to set `state_vars` to those instead
64  if (node->is_linear_block()) {
65  const auto& solvefor_vars = dynamic_cast<const ast::LinearBlock*>(node)->get_solvefor();
66  if (!solvefor_vars.empty()) {
67  state_vars.clear();
68  for (const auto& solvefor_var: solvefor_vars) {
69  state_vars.push_back(solvefor_var->get_node_name());
70  }
71  }
72  } else if (node->is_non_linear_block()) {
73  const auto& solvefor_vars = dynamic_cast<const ast::NonLinearBlock*>(node)->get_solvefor();
74  if (!solvefor_vars.empty()) {
75  state_vars.clear();
76  for (const auto& solvefor_var: solvefor_vars) {
77  state_vars.push_back(solvefor_var->get_node_name());
78  }
79  }
80  }
81 }
82 
84  const std::string& new_expr) {
85  auto new_statement = create_statement(new_expr);
86  auto new_expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(new_statement);
87  auto new_bin_expr = std::dynamic_pointer_cast<ast::BinaryExpression>(
88  new_expr_statement->get_expression());
89  expr.set_expression(std::move(new_bin_expr));
90 }
91 
93  /// all ode/kinetic/(non)linear statements (typically) appear in the same statement block
94  /// if this is not the case, for now return an error (and should instead use fallback solver)
95  if (block_with_expression_statements != nullptr &&
97  logger->warn(
98  "SympySolverVisitor :: Coupled equations are appearing in different blocks - not "
99  "supported");
100  eq_system_is_valid = false;
101  }
103 }
104 
105 ast::StatementVector::const_iterator SympySolverVisitor::get_solution_location_iterator(
106  const ast::StatementVector& statements) {
107  // find out where to insert solutions in statement block
108  // returns iterator pointing to the first element after the last (non)linear eq
109  // so if there are no such elements, it returns statements.end()
110  auto it = statements.begin();
111  if (last_expression_statement != nullptr) {
112  while ((it != statements.end()) &&
113  (std::dynamic_pointer_cast<ast::ExpressionStatement>(*it).get() !=
115  logger->debug("SympySolverVisitor :: {} != {}",
116  to_nmodl(*it),
118  ++it;
119  }
120  if (it != statements.end()) {
121  logger->debug("SympySolverVisitor :: {} == {}",
122  to_nmodl(std::dynamic_pointer_cast<ast::ExpressionStatement>(*it)),
124  ++it;
125  }
126  }
127  return it;
128 }
129 
130 /**
131  * Check if provided statement is local variable declaration statement
132  * @param statement AST node representing statement in the MOD file
133  * @return True if statement is local variable declaration else False
134  *
135  * Statement declaration could be wrapped into another statement type like
136  * expression statement and hence we try to look inside if it's really a
137  * variable declaration.
138  */
139 static bool is_local_statement(const std::shared_ptr<ast::Statement>& statement) {
140  if (statement->is_local_list_statement()) {
141  return true;
142  }
143  if (statement->is_expression_statement()) {
144  auto e_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(statement);
145  auto expression = e_statement->get_expression();
146  if (expression->is_local_list_statement()) {
147  return true;
148  }
149  }
150  return false;
151 }
152 
153 std::string& SympySolverVisitor::replaceAll(std::string& context,
154  const std::string& from,
155  const std::string& to) {
156  std::size_t lookHere = 0;
157  std::size_t foundHere{};
158  while ((foundHere = context.find(from, lookHere)) != std::string::npos) {
159  context.replace(foundHere, from.size(), to);
160  lookHere = foundHere + to.size();
161  }
162  return context;
163 }
164 
166  const std::vector<std::string>& original_vector,
167  const std::string& original_string,
168  const std::string& substitution_string) {
169  std::vector<std::string> filtered_vector;
170  for (auto element: original_vector) {
171  std::string filtered_element = replaceAll(element, original_string, substitution_string);
172  filtered_vector.push_back(filtered_element);
173  }
174  return filtered_vector;
175 }
176 
178  const std::vector<std::string>& pre_solve_statements,
179  const std::vector<std::string>& solutions,
180  bool linear) {
181  auto solutions_filtered = filter_string_vector(solutions, "X[", "nmodl_eigen_x[");
182  solutions_filtered = filter_string_vector(solutions_filtered, "dX_[", "nmodl_eigen_dx[");
183  solutions_filtered = filter_string_vector(solutions_filtered, "J[", "nmodl_eigen_j[");
184  solutions_filtered = filter_string_vector(solutions_filtered, "Jm[", "nmodl_eigen_jm[");
185  solutions_filtered = filter_string_vector(solutions_filtered, "F[", "nmodl_eigen_f[");
186 
187  for (const auto& sol: solutions_filtered) {
188  logger->debug("SympySolverVisitor :: -> adding statement: {}", sol);
189  }
190 
191  std::vector<std::string> pre_solve_statements_and_setup_x_eqs = pre_solve_statements;
192  std::vector<std::string> update_statements;
193 
194  for (int i = 0; i < state_vars.size(); i++) {
195  auto eigen_name = fmt::format("nmodl_eigen_x[{}]", i);
196 
197  auto update_state = fmt::format("{} = {}", state_vars[i], eigen_name);
198  update_statements.push_back(update_state);
199  logger->debug("SympySolverVisitor :: update_state: {}", update_state);
200 
201  auto setup_x = fmt::format("{} = {}", eigen_name, state_vars[i]);
202  pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
203  logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x);
204  }
205 
206  visitor::SympyReplaceSolutionsVisitor solution_replacer(
207  pre_solve_statements_and_setup_x_eqs,
208  solutions_filtered,
211  state_vars.size() + 1,
212  "");
214 
215  // split in the various blocks for eigen
216  auto n_state_vars = std::make_shared<ast::Integer>(state_vars.size(), nullptr);
217 
218  const auto& statements = block_with_expression_statements->get_statements();
219 
220  ast::StatementVector variable_statements; // LOCAL //
221  ast::StatementVector initialize_statements; // pre_solve_statements //
222  ast::StatementVector setup_x_statements; // old_x = x, X[0] = x //
223  ast::StatementVector functor_statements; // J[0]_row * X = F[0], additional assignments during
224  // computation //
225  ast::StatementVector finalize_statements; // assignments at the end //
226  std::ptrdiff_t const sr_begin{solution_replacer.replaced_statements_begin()};
227  std::ptrdiff_t const sr_end{solution_replacer.replaced_statements_end()};
228 
229  // initialize and edge case where the system of equations is empty
230  for (size_t idx = 0; idx < statements.size(); ++idx) {
231  auto& s = statements[idx];
232  if (is_local_statement(s)) {
233  variable_statements.push_back(s);
234  } else if (sr_begin == statements.size() || idx < sr_begin) {
235  initialize_statements.push_back(s);
236  }
237  }
238 
239  if (sr_begin != statements.size()) {
240  auto init_begin = statements.begin() + sr_begin;
241  auto init_end = init_begin + static_cast<std::ptrdiff_t>(pre_solve_statements.size());
242  initialize_statements.insert(initialize_statements.end(), init_begin, init_end);
243 
244  auto setup_x_begin = init_end;
245  auto setup_x_end = setup_x_begin + static_cast<std::ptrdiff_t>(state_vars.size());
246  setup_x_statements = ast::StatementVector(setup_x_begin, setup_x_end);
247 
248  auto functor_begin = setup_x_end;
249  auto functor_end = statements.begin() + sr_end;
250  functor_statements = ast::StatementVector(functor_begin, functor_end);
251 
252  auto finalize_begin = functor_end;
253  auto finalize_end = statements.end();
254  finalize_statements = ast::StatementVector(finalize_begin, finalize_end);
255  }
256 
257  const size_t total_statements_size = variable_statements.size() + initialize_statements.size() +
258  setup_x_statements.size() + functor_statements.size() +
259  finalize_statements.size();
260  if (statements.size() != total_statements_size) {
261  logger->error(
262  "SympySolverVisitor :: statement number missmatch ({} =/= {}) during splitting before "
263  "creation of "
264  "eigen "
265  "solver block.",
266  statements.size(),
267  total_statements_size);
268  return;
269  }
270 
271  auto variable_block = std::make_shared<ast::StatementBlock>(std::move(variable_statements));
272  auto initialize_block = std::make_shared<ast::StatementBlock>(std::move(initialize_statements));
273  auto update_state_block = create_statement_block(update_statements);
274  auto finalize_block = std::make_shared<ast::StatementBlock>(std::move(finalize_statements));
275  if (linear) {
276  /// functor and initialize block converge in the same block
277  setup_x_statements.insert(setup_x_statements.end(),
278  functor_statements.begin(),
279  functor_statements.end());
280  auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
281  auto solver_block = std::make_shared<ast::EigenLinearSolverBlock>(n_state_vars,
282  variable_block,
283  initialize_block,
284  setup_x_block,
285  update_state_block,
286  finalize_block);
287  /// replace statement block with solver block as it contains all statements
288  ast::StatementVector solver_block_statements{
289  std::make_shared<ast::ExpressionStatement>(solver_block)};
290  block_with_expression_statements->set_statements(std::move(solver_block_statements));
291  } else {
292  /// create eigen newton solver block
293  auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
294  auto functor_block = std::make_shared<ast::StatementBlock>(std::move(functor_statements));
295  auto solver_block = std::make_shared<ast::EigenNewtonSolverBlock>(n_state_vars,
296  variable_block,
297  initialize_block,
298  setup_x_block,
299  functor_block,
300  update_state_block,
301  finalize_block);
302  /// replace statement block with solver block as it contains all statements
303  ast::StatementVector solver_block_statements{
304  std::make_shared<ast::ExpressionStatement>(solver_block)};
305  block_with_expression_statements->set_statements(std::move(solver_block_statements));
306  }
307 }
308 
309 
311  const std::vector<std::string>& pre_solve_statements) {
312  // construct ordered vector of state vars used in linear system
314  // call sympy linear solver
315  bool small_system = (eq_system.size() <= SMALL_LINEAR_SYSTEM_MAX_STATES);
317  // this is necessary after we destroy the solver
318  const auto tmp_unique_prefix = suffix_random_string(vars, "tmp");
319 
320  // returns a vector of solutions, i.e. new statements to add to block;
321  // and a vector of new local variables that need to be declared in the block;
322  // may also return a python exception message:
323  auto [solutions, new_local_vars, exception_message] = solver(
324  eq_system, state_vars, vars, small_system, elimination, tmp_unique_prefix, function_calls);
325 
326  if (!exception_message.empty()) {
327  logger->warn(
328  "SympySolverVisitor :: solve_lin_system python exception occured. (--verbose=info)");
329  logger->info(exception_message +
330  "\n (Note: line numbers are of by a few compared to `ode.py`.)");
331  return;
332  }
333  // find out where to insert solutions in statement block
334  if (small_system) {
335  // for small number of state vars, linear solver
336  // directly returns solution by solving symbolically at compile time
337  logger->debug("SympySolverVisitor :: Solving *small* linear system of eqs");
338  // declare new local vars
339  if (!new_local_vars.empty()) {
340  for (const auto& new_local_var: new_local_vars) {
341  logger->debug("SympySolverVisitor :: -> declaring new local variable: {}",
342  new_local_var);
344  }
345  }
346  visitor::SympyReplaceSolutionsVisitor solution_replacer(
347  pre_solve_statements,
348  solutions,
351  1,
352  tmp_unique_prefix);
354  } else {
355  // otherwise it returns a linear matrix system to solve
356  logger->debug("SympySolverVisitor :: Constructing linear newton solve block");
357  construct_eigen_solver_block(pre_solve_statements, solutions, true);
358  }
359 }
360 
362  const ast::Node& node,
363  const std::vector<std::string>& pre_solve_statements) {
364  // construct ordered vector of state vars used in non-linear system
366 
368  auto [solutions, exception_message] = solver(eq_system, state_vars, vars, function_calls);
369 
370  if (!exception_message.empty()) {
371  logger->warn(
372  "SympySolverVisitor :: solve_non_lin_system python exception. (--verbose=info)");
373  logger->info(exception_message +
374  "\n (Note: line numbers are of by a few compared to `ode.py`.)");
375  return;
376  }
377  logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
378 
379  construct_eigen_solver_block(pre_solve_statements, solutions, false);
380 }
381 
383  if (collect_state_vars) {
384  std::string var_name = node.get_node_name();
385  if (node.get_name()->is_indexed_name()) {
386  auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(node.get_name());
387  var_name +=
388  "[" +
390  std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
391  "]";
392  }
393  // if var_name is a state var, add it to set
394  if (std::find(all_state_vars.cbegin(), all_state_vars.cend(), var_name) !=
395  all_state_vars.cend()) {
396  logger->debug("SympySolverVisitor :: adding state var: {}", var_name);
397  state_vars_in_block.insert(var_name);
398  }
399  }
400 }
401 
402 // Skip visiting CVODE block
404 
406  const auto& lhs = node.get_expression()->get_lhs();
407 
408  if (!lhs->is_var_name()) {
409  logger->warn("SympySolverVisitor :: LHS of differential equation is not a VariableName");
410  return;
411  }
412  auto lhs_name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
413  if ((lhs_name->is_indexed_name() &&
414  !std::dynamic_pointer_cast<ast::IndexedName>(lhs_name)->get_name()->is_prime_name()) ||
415  (!lhs_name->is_indexed_name() && !lhs_name->is_prime_name())) {
416  logger->warn("SympySolverVisitor :: LHS of differential equation is not a PrimeName");
417  return;
418  }
419 
421 
422  const auto node_as_nmodl = to_nmodl_for_sympy(node);
424 
426  auto [solution, exception_message] = (*diffeq_solver)(
427  node_as_nmodl, dt_var, vars, use_pade_approx, function_calls, solve_method);
429  // replace x' = f(x) differential equation
430  // with forwards Euler timestep:
431  // x = x + f(x) * dt
432  logger->debug("SympySolverVisitor :: EULER - solving: {}", node_as_nmodl);
434  // replace x' = f(x) differential equation
435  // with analytic solution for x(t+dt) in terms of x(t)
436  // x = ...
437  logger->debug("SympySolverVisitor :: CNEXP - solving: {}", node_as_nmodl);
438  } else {
439  // for other solver methods: just collect the ODEs & return
440  std::string eq_str = to_nmodl_for_sympy(node);
441  std::string var_name = lhs_name->get_node_name();
442  if (lhs_name->is_indexed_name()) {
443  auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(lhs_name);
444  var_name +=
445  "[" +
447  std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
448  "]";
449  }
450  logger->debug("SympySolverVisitor :: adding ODE system: {}", eq_str);
451  eq_system.push_back(eq_str);
452  logger->debug("SympySolverVisitor :: adding state var: {}", var_name);
453  state_vars_in_block.insert(var_name);
456  return;
457  }
458 
459  // replace ODE with solution in AST
460  logger->debug("SympySolverVisitor :: -> solution: {}", solution);
461 
462  if (!exception_message.empty()) {
463  logger->warn("SympySolverVisitor :: python exception. (--verbose=info)");
464  logger->info(exception_message +
465  "\n (Note: line numbers are of by a few compared to `ode.py`.)");
466  return;
467  }
468 
469  if (!solution.empty()) {
470  replace_diffeq_expression(node, solution);
471  } else {
472  logger->warn("SympySolverVisitor :: solution to differential equation not possible");
473  }
474 }
475 
477  // Replace ODE for state variable on LHS of CONSERVE statement with
478  // algebraic expression on RHS (see p244 of NEURON book)
479  logger->debug("SympySolverVisitor :: CONSERVE statement: {}", to_nmodl(node));
480  expression_statements.insert(&node);
481  std::string conserve_equation_statevar;
482  if (node.get_react()->is_react_var_name()) {
483  conserve_equation_statevar = node.get_react()->get_node_name();
484  }
485  if (std::find(all_state_vars.cbegin(), all_state_vars.cend(), conserve_equation_statevar) ==
486  all_state_vars.cend()) {
487  logger->error(
488  "SympySolverVisitor :: Invalid CONSERVE statement for DERIVATIVE block, LHS should be "
489  "a state variable, instead found: {}. Ignoring CONSERVE statement",
490  to_nmodl(node.get_react()));
491  return;
492  }
493  auto conserve_equation_str = to_nmodl_for_sympy(*node.get_expr());
494  logger->debug("SympySolverVisitor :: --> replace ODE for state var {} with equation {}",
495  conserve_equation_statevar,
496  conserve_equation_str);
497  conserve_equation[conserve_equation_statevar] = conserve_equation_str;
498 }
499 
501  /// clear information from previous block, get global vars + block local vars
503 
504  // get user specified solve method for this block
506 
507  // visit each differential equation:
508  // - for CNEXP or EULER, each equation is independent & is replaced with its solution
509  // - otherwise, each equation is added to eq_system
510  node.visit_children(*this);
511 
512  if (eq_system_is_valid && !eq_system.empty()) {
513  // solve system of ODEs in eq_system
514  logger->debug("SympySolverVisitor :: Solving {} system of ODEs", solve_method);
515 
516  // construct implicit Euler equations from ODEs
517  std::vector<std::string> pre_solve_statements;
518  for (auto& eq: eq_system) {
519  auto split_eq = stringutils::split_string(eq, '=');
520  auto x_prime_split = stringutils::split_string(split_eq[0], '\'');
521  auto x = stringutils::trim(x_prime_split[0]);
522  std::string x_array_index;
523  std::string x_array_index_i;
524  if (x_prime_split.size() > 1 && stringutils::trim(x_prime_split[1]).size() > 2) {
525  x_array_index = stringutils::trim(x_prime_split[1]);
526  x_array_index_i = "_" + x_array_index.substr(1, x_array_index.size() - 2);
527  }
528  std::string state_var_name = x + x_array_index;
529  auto var_eq_pair = conserve_equation.find(state_var_name);
530  if (var_eq_pair != conserve_equation.cend()) {
531  // replace the ODE for this state var with corresponding CONSERVE equation
532  eq = state_var_name + " = " + var_eq_pair->second;
533  logger->debug(
534  "SympySolverVisitor :: -> instead of Euler eq using CONSERVE equation: {} = {}",
535  state_var_name,
536  var_eq_pair->second);
537  } else {
538  // no CONSERVE equation, construct Euler equation
539  auto dxdt = stringutils::trim(split_eq[1]);
540 
541  auto const old_x = [&]() {
542  std::string old_x_name{"old_"};
543  old_x_name.append(x);
544  old_x_name.append(x_array_index_i);
545  return suffix_random_string(vars, old_x_name);
546  }();
547  // declare old_x
548  logger->debug("SympySolverVisitor :: -> declaring new local variable: {}", old_x);
550  // assign old_x = x
551  {
552  std::string expression{old_x};
553  expression.append(" = ");
554  expression.append(x);
555  expression.append(x_array_index);
556  pre_solve_statements.push_back(std::move(expression));
557  }
558  // replace ODE with Euler equation
559  eq = "(";
560  eq.append(x);
561  eq.append(x_array_index);
562  eq.append(" - ");
563  eq.append(old_x);
564  eq.append(") / ");
566  eq.append(" = ");
567  eq.append(dxdt);
568  logger->debug("SympySolverVisitor :: -> constructed Euler eq: {}", eq);
569  }
570  }
571 
574  solve_non_linear_system(node, pre_solve_statements);
575  } else {
576  logger->error("SympySolverVisitor :: Solve method {} not supported", solve_method);
577  }
578  }
579 }
580 
583  std::string lin_eq = to_nmodl_for_sympy(*node.get_lhs());
584  lin_eq += " = ";
585  lin_eq += to_nmodl_for_sympy(*node.get_rhs());
586  eq_system.push_back(lin_eq);
589  logger->debug("SympySolverVisitor :: adding linear eq: {}", lin_eq);
590  collect_state_vars = true;
591  node.visit_children(*this);
592  collect_state_vars = false;
593 }
594 
596  logger->debug("SympySolverVisitor :: found LINEAR block: {}", node.get_node_name());
597 
598  /// clear information from previous block, get global vars + block local vars
600 
601  // collect linear equations
602  node.visit_children(*this);
603 
604  if (eq_system_is_valid && !eq_system.empty()) {
606  }
607 }
608 
611  std::string non_lin_eq = to_nmodl_for_sympy(*node.get_lhs());
612  non_lin_eq += " = ";
613  non_lin_eq += to_nmodl_for_sympy(*node.get_rhs());
614  eq_system.push_back(non_lin_eq);
617  logger->debug("SympySolverVisitor :: adding non-linear eq: {}", non_lin_eq);
618  collect_state_vars = true;
619  node.visit_children(*this);
620  collect_state_vars = false;
621 }
622 
624  logger->debug("SympySolverVisitor :: found NONLINEAR block: {}", node.get_node_name());
625 
626  /// clear information from previous block, get global vars + block local vars
628 
629  // collect non-linear equations
630  node.visit_children(*this);
631 
632  if (eq_system_is_valid && !eq_system.empty()) {
634  }
635 }
636 
638  auto prev_expression_statement = current_expression_statement;
640  node.visit_children(*this);
641  current_expression_statement = prev_expression_statement;
642 }
643 
645  auto prev_statement_block = current_statement_block;
647  node.visit_children(*this);
648  current_statement_block = prev_statement_block;
649 }
650 
653 
655 
656  // get list of solve statements with names & methods
657  const auto& solve_block_nodes = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK});
658  for (const auto& block: solve_block_nodes) {
659  if (auto block_ptr = std::dynamic_pointer_cast<const ast::SolveBlock>(block)) {
660  const auto& block_name = block_ptr->get_block_name()->get_value()->eval();
661  if (block_ptr->get_method()) {
662  // Note: solve method name is an optional parameter
663  // LINEAR and NONLINEAR blocks do not have solve method specified
664  const auto& solve_method = block_ptr->get_method()->get_value()->eval();
665  logger->debug("SympySolverVisitor :: Found SOLVE statement: using {} for {}",
666  solve_method,
667  block_name);
669  }
670  }
671  }
672 
673  // get set of all state vars
674  all_state_vars.clear();
675  if (auto symtab = node.get_symbol_table()) {
676  auto statevars = symtab->get_variables_with_properties(NmodlType::state_var);
677  for (const auto& v: statevars) {
678  std::string var_name = v->get_name();
679  if (v->is_array()) {
680  for (int i = 0; i < v->get_length(); ++i) {
681  std::string var_name_i = var_name + "[" + std::to_string(i) + "]";
682  all_state_vars.push_back(var_name_i);
683  }
684  } else {
685  all_state_vars.push_back(var_name);
686  }
687  }
688  }
689 
690  node.visit_children(*this);
691 }
692 
693 } // namespace visitor
694 } // namespace nmodl
Auto generated AST classes declaration.
Represent CONSERVE statement in NMODL.
Definition: conserve.hpp:38
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
Definition: cvode_block.hpp:38
Represents DERIVATIVE block in the NMODL.
Represents differential equation in DERIVATIVE block.
void set_expression(std::shared_ptr< BinaryExpression > &&expression)
Setter for member variable DiffEqExpression::expression (rvalue reference)
Definition: ast.cpp:6700
One equation in a system of equations tha collectively form a LINEAR block.
Represents LINEAR block in the NMODL.
Base class for all AST node.
Definition: node.hpp:40
One equation in a system of equations that collectively make a NONLINEAR block.
Represents NONLINEAR block in the NMODL.
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represents block encapsulating list of statements.
void set_statements(StatementVector &&statements)
Setter for member variable StatementBlock::statements (rvalue reference)
Definition: ast.cpp:3222
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Represents a variable.
Definition: var_name.hpp:43
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Definition: pyembed.hpp:29
const pybind_wrap_api & api()
Get a pointer to the pybind_wrap_api struct.
Definition: pyembed.cpp:136
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
@ VALUE
Replace statements matching by lhs varName.
int replaced_statements_begin() const
idx (in the new statementVector) of the first statement that was added.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
int replaced_statements_end() const
idx (in the new statementVector) of the last statement that was added.
std::vector< std::string > all_state_vars
vector of all state variables (in order specified in STATE block in mod file)
void visit_non_lin_equation(ast::NonLinEquation &node) override
visit node of type ast::NonLinEquation
ast::StatementBlock * current_statement_block
current statement block being visited
void solve_non_linear_system(const ast::Node &node, const std::vector< std::string > &pre_solve_statements={})
solve non-linear system (for "derivimplicit", "sparse" and "NONLINEAR")
bool elimination
optionally do CSE (common subexpression elimination) for sparse solver
std::set< std::string > function_calls
custom function calls used in ODE block
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression
void visit_linear_block(ast::LinearBlock &node) override
visit node of type ast::LinearBlock
void visit_lin_equation(ast::LinEquation &node) override
visit node of type ast::LinEquation
void init_block_data(ast::Node *node)
clear any data from previous block & get set of block local vars + global vars
void solve_linear_system(const ast::Node &node, const std::vector< std::string > &pre_solve_statements={})
solve linear system (for "LINEAR")
std::set< std::string > state_vars_in_block
set of state variables used in block
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
std::unordered_map< std::string, std::string > derivative_block_solve_method
map between derivative block names and associated solver method
ast::ExpressionStatement * current_expression_statement
current expression statement being visited (to track ODEs / (non)lineqs)
std::vector< std::string > state_vars
vector of state vars used in block (in same order as all_state_vars)
void visit_program(ast::Program &node) override
visit node of type ast::Program
static std::string & replaceAll(std::string &context, const std::string &from, const std::string &to)
Function used by SympySolverVisitor::filter_X to replace the name X in a std::string to X_operator.
bool collect_state_vars
true for (non)linear eqs, to identify all state vars used in equations
int SMALL_LINEAR_SYSTEM_MAX_STATES
max number of state vars allowed for small system linear solver
void visit_var_name(ast::VarName &node) override
visit node of type ast::VarName
void visit_expression_statement(ast::ExpressionStatement &node) override
visit node of type ast::ExpressionStatement
static std::string to_nmodl_for_sympy(ast::Ast &node)
return NMODL string version of node, excluding any units
void check_expr_statements_in_same_block()
raise error if kinetic/ode/(non)linear statements are spread over multiple blocks
ast::StatementVector::const_iterator get_solution_location_iterator(const ast::StatementVector &statements)
return iterator pointing to where solution should be inserted in statement block
void construct_eigen_solver_block(const std::vector< std::string > &pre_solve_statements, const std::vector< std::string > &solutions, bool linear)
construct solver block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
std::set< std::string > global_vars
global variables
std::string solve_method
method specified in solve block
void visit_cvode_block(ast::CvodeBlock &node) override
visit node of type ast::CvodeBlock
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
void visit_non_linear_block(ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
bool eq_system_is_valid
only solve eq_system system of equations if this is true:
std::unordered_set< ast::Statement * > expression_statements
expression statements appearing in the block (these can be of type DiffEqExpression,...
std::unordered_map< std::string, std::string > conserve_equation
map from state vars to the algebraic equation from CONSERVE statement that should replace their ODE,...
void init_state_vars_vector(const ast::Node *node)
construct vector from set of state vars in correct order
static void replace_diffeq_expression(ast::DiffEqExpression &expr, const std::string &new_expr)
replace binary expression with new expression provided as string
std::set< std::string > vars
local variables in current block + globals
std::vector< std::string > eq_system
vector of {ODE, linear eq, non-linear eq} system to solve
bool use_pade_approx
optionally replace cnexp solution with (1,1) pade approx
ast::ExpressionStatement * last_expression_statement
last expression statement visited (to know where to insert solutions in statement block)
static std::vector< std::string > filter_string_vector(const std::vector< std::string > &original_vector, const std::string &original_string, const std::string &substitution_string)
Check original_vector for elements that contain a variable named original_string and rename it to sub...
ast::StatementBlock * block_with_expression_statements
block where expression statements appear (to check there is only one)
#define v
Definition: md1redef.h:11
#define i
Definition: md1redef.h:19
@ SOLVE_BLOCK
type of ast::SolveBlock
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:302
static std::string trim(std::string text)
static std::vector< std::string > split_string(const std::string &text, char delimiter)
Split a text in a list of words, using a given delimiter character.
static void call(Symbol *s, Node *nd, Prop *p)
Definition: hocmech.cpp:170
void move(Item *q1, Item *q2, Item *q3)
Definition: list.cpp:200
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
static constexpr char EULER_METHOD[]
euler method in nmodl
static constexpr char NTHREAD_DT_VARIABLE[]
dt variable in neuron thread structure
static constexpr char SPARSE_METHOD[]
sparse method in nmodl
std::string to_string(const T &obj)
NmodlType
NMODL variable properties.
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
std::string suffix_random_string(const std::set< std::string > &vars, const std::string &original_string, const UseNumbersInString use_num)
Return the "original_string" with a random suffix if "original_string" exists in "vars".
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
static bool is_local_statement(const std::shared_ptr< ast::Statement > &statement)
Check if provided statement is local variable declaration statement.
std::shared_ptr< StatementBlock > create_statement_block(const std::vector< std::string > &code_statements)
Convert given code statement (in string format) to corresponding ast node.
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
logger_type logger
Definition: logger.cpp:34
static Node * node(Object *)
Definition: netcvode.cpp:291
s
Definition: multisend.cpp:521
static double context(void *v)
Definition: ocbbs.cpp:171
int find(const int, const int, const int, const int, const int)
Implement string manipulation functions.
decltype(&call_diffeq_solver) diffeq_solver
Definition: wrapper.hpp:65
decltype(&call_solve_linear_system) solve_linear_system
Definition: wrapper.hpp:64
decltype(&call_solve_nonlinear_system) solve_nonlinear_system
Definition: wrapper.hpp:63
Implement class to represent a symbol in Symbol Table.
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
Visitor for systems of algebraic and differential equations
Utility functions for visitors implementation.