NEURON
codegen_helper_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include <algorithm>
11 #include <cmath>
12 #include <memory>
13 
14 #include "ast/all.hpp"
15 #include "ast/constant_var.hpp"
17 #include "parser/c11_driver.hpp"
19 
20 #include "utils/logger.hpp"
21 
22 namespace nmodl {
23 namespace codegen {
24 
25 using namespace ast;
26 
29 
30 /**
31  * Check whether a given SOLVE block solves a PROCEDURE with any of the CVode methods
32  */
33 static bool check_procedure_has_cvode(const std::shared_ptr<const ast::Ast>& solve_node,
34  const std::shared_ptr<const ast::Ast>& procedure_node) {
35  const auto& solve_block = std::dynamic_pointer_cast<const ast::SolveBlock>(solve_node);
36  const auto& method = solve_block->get_method();
37  if (!method) {
38  return false;
39  }
40  const auto& method_name = method->get_node_name();
41 
42  return procedure_node->get_node_name() == solve_block->get_block_name()->get_node_name() &&
43  (method_name == codegen::naming::AFTER_CVODE_METHOD ||
44  method_name == codegen::naming::CVODE_T_METHOD ||
45  method_name == codegen::naming::CVODE_T_V_METHOD);
46 }
47 
48 /**
49  * How symbols are stored in NEURON? See notes written in markdown file.
50  *
51  * Some variables get printed by iterating over symbol table in mod2c.
52  * The example of this is thread variables (and also ions?). In this
53  * case we must have to arrange order if we are going keep compatibility
54  * with NEURON.
55  *
56  * Suppose there are three global variables: bcd, abc, abd, abe
57  * They will be in the 'a' bucket in order:
58  * abe, abd, abc
59  * and in 'b' bucket
60  * bcd
61  * So when we print thread variables, we first have to sort in the opposite
62  * order in which they come and then again order by first character in increasing
63  * order.
64  *
65  * Note that variables in double array do not need this transformation
66  * and it seems like they should just follow definition order.
67  */
68 void CodegenHelperVisitor::sort_with_mod2c_symbol_order(std::vector<SymbolType>& symbols) {
69  /// first sort by global id to get in reverse order
70  std::sort(symbols.begin(),
71  symbols.end(),
72  [](const SymbolType& first, const SymbolType& second) -> bool {
73  return first->get_id() > second->get_id();
74  });
75 
76  /// now order by name (to be same as neuron's bucket)
77  std::sort(symbols.begin(),
78  symbols.end(),
79  [](const SymbolType& first, const SymbolType& second) -> bool {
80  return first->get_name()[0] < second->get_name()[0];
81  });
82 }
83 
84 
85 /**
86  * Find all ions used in mod file
87  */
88 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
90  // collect all use ion statements
91  const auto& ion_nodes = collect_nodes(node, {AstNodeType::USEION});
92 
93  // ion names, read ion variables and write ion variables
94  std::vector<std::string> ion_vars;
95  std::vector<std::string> read_ion_vars;
96  std::vector<std::string> write_ion_vars;
97  std::map<std::string, double> valences;
98 
99  for (const auto& ion_node: ion_nodes) {
100  const auto& ion = std::dynamic_pointer_cast<const ast::Useion>(ion_node);
101  auto ion_name = ion->get_node_name();
102  ion_vars.push_back(ion_name);
103  for (const auto& var: ion->get_readlist()) {
104  read_ion_vars.push_back(var->get_node_name());
105  }
106  for (const auto& var: ion->get_writelist()) {
107  write_ion_vars.push_back(var->get_node_name());
108  }
109 
110  if (ion->get_valence() != nullptr) {
111  valences[ion_name] = ion->get_valence()->get_value()->to_double();
112  }
113  }
114 
115  /**
116  * Check if given variable belongs to given ion.
117  * For example, eca belongs to ca ion, nai belongs to na ion.
118  * We just check if we exclude first/last char, if that is ion name.
119  */
120  auto ion_variable = [](const std::string& var, const std::string& ion) -> bool {
121  auto len = var.size() - 1;
122  return (var.substr(1, len) == ion || var.substr(0, len) == ion);
123  };
124 
125  /// iterate over all ion types and construct the Ion objects
126  for (auto& ion_name: ion_vars) {
127  Ion ion(ion_name);
128  for (auto& read_var: read_ion_vars) {
129  if (ion_variable(read_var, ion_name)) {
130  ion.reads.push_back(read_var);
131  }
132  }
133  for (auto& write_var: write_ion_vars) {
134  if (ion_variable(write_var, ion_name)) {
135  ion.writes.push_back(write_var);
136  if (ion.is_intra_cell_conc(write_var) || ion.is_extra_cell_conc(write_var)) {
137  ion.need_style = true;
138  info.write_concentration = true;
139  }
140  }
141  }
142  if (auto it = valences.find(ion_name); it != valences.end()) {
143  ion.valence = it->second;
144  }
145 
146  info.ions.push_back(std::move(ion));
147  }
148 
149  /// once ions are populated, we can find all currents
150  auto vars = psymtab->get_variables_with_properties(NmodlType::nonspecific_cur_var);
151  for (auto& var: vars) {
152  info.currents.push_back(var->get_name());
153  }
154  vars = psymtab->get_variables_with_properties(NmodlType::electrode_cur_var);
155  for (auto& var: vars) {
156  info.currents.push_back(var->get_name());
157  }
158  for (auto& ion: info.ions) {
159  for (auto& var: ion.writes) {
160  if (ion.is_ionic_current(var)) {
161  info.currents.push_back(var);
162  }
163  }
164  }
165 
166  /// check if write_conc(...) will be needed
167  for (const auto& ion: info.ions) {
168  for (const auto& var: ion.writes) {
169  if (!ion.is_ionic_current(var) && !ion.is_rev_potential(var)) {
170  info.require_wrote_conc = true;
171  }
172  }
173  }
174 }
175 
176 /**
177  * Find whether or not we need to emit CVODE-related code for NEURON
178  * Notes: we generate CVODE-related code if and only if:
179  * - there is exactly ONE block being SOLVEd
180  * - the block is one of the following types:
181  * - DERIVATIVE
182  * - KINETIC
183  * - PROCEDURE being solved with the `after_cvode`, `cvode_t`, or `cvode_t_v` methods
184  */
186  // find the breakpoint block
187  const auto& breakpoint_nodes = collect_nodes(node, {AstNodeType::BREAKPOINT_BLOCK});
188 
189  // do nothing if there are no BREAKPOINT nodes
190  if (breakpoint_nodes.empty()) {
191  return;
192  }
193 
194  // there can only be one BREAKPOINT block in the entire program
195  assert(breakpoint_nodes.size() == 1);
196 
197  const auto& breakpoint_node = std::dynamic_pointer_cast<const ast::BreakpointBlock>(
198  breakpoint_nodes[0]);
199 
200  // all (global) kinetic/derivative nodes
201  const auto& kinetic_or_derivative_nodes =
202  collect_nodes(node, {AstNodeType::KINETIC_BLOCK, AstNodeType::DERIVATIVE_BLOCK});
203 
204  // all (global) procedure nodes
205  const auto& procedure_nodes = collect_nodes(node, {AstNodeType::PROCEDURE_BLOCK});
206 
207  // find all SOLVE blocks in that BREAKPOINT block
208  const auto& solve_nodes = collect_nodes(*breakpoint_node, {AstNodeType::SOLVE_BLOCK});
209 
210  // check whether any of the SOLVE blocks are solving any PROCEDURE with `after_cvode`,
211  // `cvode_t`, or `cvode_t_v` methods
212  const auto using_cvode = std::any_of(
213  solve_nodes.begin(), solve_nodes.end(), [&procedure_nodes](const auto& solve_node) {
214  return std::any_of(procedure_nodes.begin(),
215  procedure_nodes.end(),
216  [&solve_node](const auto& procedure_node) {
217  return check_procedure_has_cvode(solve_node, procedure_node);
218  });
219  });
220 
221  // only case when we emit CVODE code is if we have exactly one block, and
222  // that block is either a KINETIC/DERIVATIVE with any method, or a
223  // PROCEDURE with `after_cvode` method
224  if (solve_nodes.size() == 1 && (kinetic_or_derivative_nodes.size() || using_cvode)) {
225  logger->debug("Will emit code for CVODE");
226  info.emit_cvode = enable_cvode;
227  }
228 }
229 
230 /**
231  * Find non-range variables i.e. ones that are not belong to per instance allocation
232  *
233  * Certain variables like pointers, global, parameters are not necessary to be per
234  * instance variables. NEURON apply certain rules to determine which variables become
235  * thread, static or global variables. Here we construct those variables.
236  */
238  /**
239  * Top local variables are local variables appear in global scope. All local
240  * variables in program symbol table are in global scope.
241  */
242  info.constant_variables = psymtab->get_variables_with_properties(NmodlType::constant_var);
245 
246  /**
247  * All global variables remain global if mod file is not marked thread safe.
248  * Otherwise, global variables written at least once gets promoted to thread variables.
249  */
250 
251  std::string variables;
252 
253  auto vars = psymtab->get_variables_with_properties(NmodlType::global_var);
254  for (auto& var: vars) {
255  if (info.vectorize && info.declared_thread_safe && var->get_write_count() > 0) {
256  var->mark_thread_safe();
257  info.thread_variables.push_back(var);
258  info.thread_var_data_size += var->get_length();
259  variables += " " + var->get_name();
260  } else {
261  info.global_variables.push_back(var);
262  }
263  }
264 
265  /**
266  * If parameter is not a range and used only as read variable then it becomes global
267  * variable. To qualify it as thread variable it must be be written at least once and
268  * mod file must be marked as thread safe.
269  * To exclusively get parameters only, we exclude all other variables (in without)
270  * and then sort them with neuron/mod2c order.
271  */
272  // clang-format off
273  auto with = NmodlType::param_assign;
274  auto without = NmodlType::range_var
275  | NmodlType::assigned_definition
276  | NmodlType::global_var
277  | NmodlType::pointer_var
278  | NmodlType::bbcore_pointer_var
279  | NmodlType::read_ion_var
280  | NmodlType::write_ion_var;
281  // clang-format on
282  vars = psymtab->get_variables(with, without);
283  for (auto& var: vars) {
284  // some variables like area and diam are declared in parameter
285  // block but they are not global
286  if (var->get_name() == naming::DIAM_VARIABLE || var->get_name() == naming::AREA_VARIABLE ||
287  var->has_any_property(NmodlType::extern_neuron_variable)) {
288  continue;
289  }
290 
291  // if model is thread safe and if parameter is being written then
292  // those variables should be promoted to thread safe variable
293  if (info.vectorize && info.declared_thread_safe && var->get_write_count() > 0) {
294  var->mark_thread_safe();
295  info.thread_variables.push_back(var);
296  info.thread_var_data_size += var->get_length();
297  } else {
298  info.global_variables.push_back(var);
299  }
300  }
302 
303  /**
304  * \todo Below we calculate thread related id and sizes. This will
305  * need to do from global analysis pass as here we are handling
306  * top local variables, global variables, derivimplicit method.
307  * There might be more use cases with other solver methods.
308  */
309 
310  /**
311  * If derivimplicit is used, then first three thread ids get assigned to:
312  * 1st thread is used for: deriv_advance
313  * 2nd thread is used for: dith
314  * 3rd thread is used for: newtonspace
315  *
316  * slist and dlist represent the offsets for prime variables used. For
317  * euler or derivimplicit methods its always first number.
318  */
319  if (info.derivimplicit_used()) {
324  }
325 
326  /// next thread id is allocated for top local variables
327  if (info.vectorize && !info.top_local_variables.empty()) {
330  }
331 
332  /// next thread id is allocated for thread promoted variables
333  if (info.vectorize && !info.thread_variables.empty()) {
336  }
337 
338  /// find total size of local variables in global scope
339  for (auto& var: info.top_local_variables) {
340  info.top_local_thread_size += var->get_length();
341  }
342 
343  /// find number of prime variables and total size
344  auto primes = psymtab->get_variables_with_properties(NmodlType::prime_name);
345  info.num_primes = static_cast<int>(primes.size());
346  for (auto& variable: primes) {
347  info.primes_size += variable->get_length();
348  }
349 
350  /// find pointer or bbcore pointer variables
351  auto properties = NmodlType::pointer_var | NmodlType::bbcore_pointer_var;
353 
354  /// find RANDOM variables
355  properties = NmodlType::random_var;
357 
358  // find special variables like diam, area
359  properties = NmodlType::assigned_definition | NmodlType::param_assign;
360  vars = psymtab->get_variables_with_properties(properties);
361  for (auto& var: vars) {
362  if (var->get_name() == naming::AREA_VARIABLE) {
363  info.area_used = true;
364  }
365  if (var->get_name() == naming::DIAM_VARIABLE) {
366  info.diam_used = true;
367  }
368  }
369 }
370 
371 /**
372  * Find range variables i.e. ones that are belong to per instance allocation
373  *
374  * In order to be compatible with NEURON, we need to print all variables in
375  * exact order as NEURON/MOD2C implementation. This is important because memory
376  * for all variables is allocated in single 1-D array with certain offset
377  * for each variable. The order of variables determine the offset and hence
378  * they must be in same order as NEURON.
379  *
380  * Here is how order is determined into NEURON/MOD2C implementation:
381  *
382  * First, following three lists are created
383  * - variables with parameter and range property (List 1)
384  * - variables with state and range property (List 2)
385  * - variables with assigned and range property (List 3)
386  *
387  * Once created, we remove some variables due to the following criteria:
388  * - In NEURON/MOD2C implementation, we remove variables with NRNPRANGEIN
389  * or NRNPRANGEOUT type
390  * - So who has NRNPRANGEIN and NRNPRANGEOUT type? these are USEION read
391  * or write variables that are not ionic currents.
392  * - This is the reason for mod files CaDynamics_E2.mod or cal_mig.mod, ica variable
393  * is printed earlier in the list but other variables like cai, cao don't appear
394  * in same order.
395  *
396  * Finally we create 4th list:
397  * - variables with assigned property and not in the previous 3 lists
398  *
399  * We now print the variables in following order:
400  *
401  * - List 1 i.e. range + parameter variables are printed first
402  * - List 3 i.e. range + assigned variables are printed next
403  * - List 2 i.e. range + state variables are printed next
404  * - List 4 i.e. assigned and ion variables not present in the previous 3 lists
405  *
406  * NOTE:
407  * - State variables also have the property `assigned_definition` but these variables
408  * are not from ASSIGNED block.
409  * - Variable can not be range as well as state, it's redeclaration error
410  * - Variable can be parameter as well as range. Without range, parameter
411  * is considered as global variable i.e. one value for all instances.
412  * - If a variable is only defined as RANGE and not in assigned or parameter
413  * or state block then it's not printed.
414  * - Note that a variable property is different than the variable type. For example,
415  * if variable has range property, it doesn't mean the variable is declared as RANGE.
416  * Other variables like STATE and ASSIGNED block variables also get range property
417  * without being explicitly declared as RANGE in the mod file.
418  * - Also, there is difference between declaration order vs. definition order. For
419  * example, POINTER variable in NEURON block is just declaration and doesn't
420  * determine the order in which they will get printed. Below we query symbol table
421  * and order all instance variables into certain order.
422  */
424  /// comparator to decide the order based on definition
425  auto comparator = [](const SymbolType& first, const SymbolType& second) -> bool {
426  return first->get_definition_order() < second->get_definition_order();
427  };
428 
429  /// from symbols vector `vars`, remove all ion variables which are not ionic currents
430  auto remove_non_ioncur_vars = [](SymbolVectorType& vars, const CodegenInfo& info) -> void {
431  vars.erase(std::remove_if(vars.begin(),
432  vars.end(),
433  [&](SymbolType& s) {
434  return info.is_ion_variable(s->get_name()) &&
435  !info.is_ionic_current(s->get_name());
436  }),
437  vars.end());
438  };
439 
440  /// if `secondary` vector contains any symbol that exist in the `primary` then remove it
441  auto remove_var_exist = [](SymbolVectorType& primary, SymbolVectorType& secondary) -> void {
442  secondary.erase(std::remove_if(secondary.begin(),
443  secondary.end(),
444  [&primary](const SymbolType& tosearch) {
445  return std::find_if(primary.begin(),
446  primary.end(),
447  // compare by symbol name
448  [&tosearch](
449  const SymbolType& symbol) {
450  return tosearch->get_name() ==
451  symbol->get_name();
452  }) != primary.end();
453  }),
454  secondary.end());
455  };
456 
457  /**
458  * First come parameters which are range variables.
459  */
460  // clang-format off
461  auto with = NmodlType::range_var
462  | NmodlType::param_assign;
463  auto without = NmodlType::global_var
464  | NmodlType::pointer_var
465  | NmodlType::bbcore_pointer_var
466  | NmodlType::state_var;
467 
468  // clang-format on
470  remove_non_ioncur_vars(info.range_parameter_vars, info);
471  std::sort(info.range_parameter_vars.begin(), info.range_parameter_vars.end(), comparator);
472 
473  /**
474  * Second come assigned variables which are range variables.
475  */
476  // clang-format off
477  with = NmodlType::range_var
478  | NmodlType::assigned_definition;
479  without = NmodlType::global_var
480  | NmodlType::pointer_var
481  | NmodlType::bbcore_pointer_var
482  | NmodlType::state_var
483  | NmodlType::param_assign;
484 
485  // clang-format on
486  info.range_assigned_vars = psymtab->get_variables(with, without);
487  remove_non_ioncur_vars(info.range_assigned_vars, info);
488  std::sort(info.range_assigned_vars.begin(), info.range_assigned_vars.end(), comparator);
489 
490 
491  /**
492  * Third come state variables. All state variables are kind of range by default.
493  * Note that some mod files like CaDynamics_E2.mod use cai as state variable which
494  * appear in USEION read/write list. These variables are not considered in this
495  * variables because non ionic-current variables are removed and printed later.
496  */
497  // clang-format off
498  with = NmodlType::state_var;
499  without = NmodlType::global_var
500  | NmodlType::pointer_var
501  | NmodlType::bbcore_pointer_var;
502 
503  // clang-format on
504  info.state_vars = psymtab->get_variables(with, without);
505  std::sort(info.state_vars.begin(), info.state_vars.end(), comparator);
506 
507  /// range_state_vars is copy of state variables but without non ionic current variables
509  remove_non_ioncur_vars(info.range_state_vars, info);
510 
511  /// Remaining variables are assigned and ion variables which are not in the previous 3 lists
512 
513  // clang-format off
514  with = NmodlType::assigned_definition
515  | NmodlType::read_ion_var
516  | NmodlType::write_ion_var;
517  without = NmodlType::global_var
518  | NmodlType::pointer_var
519  | NmodlType::bbcore_pointer_var
520  | NmodlType::extern_neuron_variable;
521  // clang-format on
522  const auto& variables = psymtab->get_variables_with_properties(with, false);
523  for (const auto& variable: variables) {
524  if (!variable->has_any_property(without)) {
525  info.assigned_vars.push_back(variable);
526  }
527  }
528 
529  /// make sure that variables already present in previous lists
530  /// are removed to avoid any duplication
531  remove_var_exist(info.range_parameter_vars, info.assigned_vars);
532  remove_var_exist(info.range_assigned_vars, info.assigned_vars);
533  remove_var_exist(info.range_state_vars, info.assigned_vars);
534 
535  /// sort variables with their definition order
536  std::sort(info.assigned_vars.begin(), info.assigned_vars.end(), comparator);
537 }
538 
539 
541  auto property = NmodlType::table_statement_var;
543  property = NmodlType::table_assigned_var;
545 }
546 
548  // TODO: it would be nicer not to have this hardcoded list
549  using pair = std::pair<const char*, const char*>;
550  for (const auto& [var, type]: {pair{naming::CELSIUS_VARIABLE, "double"},
551  pair{"secondorder", "int"},
552  pair{"pi", "double"}}) {
553  auto sym = psymtab->lookup(var);
554  if (sym && (sym->get_read_count() || sym->get_write_count())) {
555  info.neuron_global_variables.emplace_back(std::move(sym), type);
556  }
557  }
558 }
559 
560 
562  const auto& type = node.get_type()->get_node_name();
563  if (type == naming::POINT_PROCESS) {
564  info.point_process = true;
565  }
566  if (type == naming::ARTIFICIAL_CELL) {
567  info.artificial_cell = true;
568  info.point_process = true;
569  }
570  info.mod_suffix = node.get_node_name();
571 }
572 
574  info.electrode_current = true;
575 }
576 
577 
581  } else {
583  }
584  node.visit_children(*this);
585 }
586 
587 
590  node.visit_children(*this);
591 }
592 
593 
596  node.visit_children(*this);
597 }
598 
599 
603  info.num_net_receive_parameters = static_cast<int>(node.get_parameters().size());
604  node.visit_children(*this);
605  under_net_receive_block = false;
606 }
607 
608 
610  under_derivative_block = true;
611  node.visit_children(*this);
612  under_derivative_block = false;
613 }
614 
616  info.derivimplicit_callbacks.push_back(&node);
617 }
618 
619 
621  under_breakpoint_block = true;
623  node.visit_children(*this);
624  under_breakpoint_block = false;
625 }
626 
627 
630  node.visit_children(*this);
631 }
632 
634  info.cvode_block = &node;
635  node.visit_children(*this);
636 }
637 
638 
640  info.procedures.push_back(&node);
641  node.visit_children(*this);
642  if (table_statement_used) {
643  table_statement_used = false;
644  info.functions_with_table.push_back(&node);
645  }
646 }
647 
648 
650  info.functions.push_back(&node);
651  node.visit_children(*this);
652  if (table_statement_used) {
653  table_statement_used = false;
654  info.functions_with_table.push_back(&node);
655  }
656 }
657 
658 
660  info.function_tables.push_back(&node);
661 }
662 
663 
667  // Avoid extra declaration for `functor` corresponding to the DERIVATIVE block which is not
668  // printed to the generated CPP file
669  if (!under_derivative_block) {
670  const auto new_unique_functor_name = "functor_" + info.mod_suffix + "_" +
672  info.functor_names[&node] = new_unique_functor_name;
673  }
674  node.visit_children(*this);
675 }
676 
680  node.visit_children(*this);
681 }
682 
684  auto name = node.get_node_name();
685  if (name == naming::NET_SEND_METHOD) {
686  info.net_send_used = true;
687  }
689  info.net_event_used = true;
690  }
691 }
692 
693 
695  const auto& ion = node.get_ion();
696  const auto& variable = node.get_conductance();
697  std::string ion_name;
698  if (ion) {
699  ion_name = ion->get_node_name();
700  }
701  info.conductances.push_back({ion_name, variable->get_node_name()});
702 }
703 
704 
705 /**
706  * Visit statement block and find prime symbols appear in derivative block
707  *
708  * Equation statements in derivative block has prime on the lhs. The order
709  * of primes could be different that declaration state block. Also, not all
710  * state variables need to appear in equation block. In this case, we want
711  * to find out the the primes in the order of equation definition. This is
712  * just to keep the same order as neuron implementation.
713  *
714  * The primes are already solved and replaced by Dstate or name. And hence
715  * we need to check if the lhs variable is derived from prime name. If it's
716  * Dstate then we have to lookup state to find out corresponding symbol. This
717  * is because prime_variables_by_order should contain state variable name and
718  * not the one replaced by solver pass.
719  *
720  * \todo AST can have duplicate DERIVATIVE blocks if a mod file uses SOLVE
721  * statements in its INITIAL block (e.g. in case of kinetic schemes using
722  * `STEADYSTATE sparse` solver). Such duplicated DERIVATIVE blocks could
723  * be removed by `SolveBlockVisitor`, or we have to avoid visiting them
724  * here. See e.g. SH_na8st.mod test and original reduced_dentate .mod.
725  */
727  const auto& statements = node.get_statements();
728  for (auto& statement: statements) {
729  statement->accept(*this);
731  (assign_lhs->is_name() || assign_lhs->is_var_name())) {
732  auto name = assign_lhs->get_node_name();
733  auto symbol = psymtab->lookup(name);
734  if (symbol != nullptr) {
735  auto is_prime = symbol->has_any_property(NmodlType::prime_name);
736  auto from_state = symbol->has_any_status(Status::from_state);
737  if (is_prime || from_state) {
738  if (from_state) {
739  symbol = psymtab->lookup(name.substr(1, name.size()));
740  }
741  // See the \todo note above.
742  if (std::find_if(info.prime_variables_by_order.begin(),
744  [&](auto const& sym) {
745  return sym->get_name() == symbol->get_name();
746  }) == info.prime_variables_by_order.end()) {
747  info.prime_variables_by_order.push_back(symbol);
749  }
750  }
751  }
752  }
753  assign_lhs = nullptr;
754  }
755 }
756 
758  info.factor_definitions.push_back(&node);
759 }
760 
761 
763  if (node.get_op().eval() == "=") {
764  assign_lhs = node.get_lhs();
765  }
766  node.get_lhs()->accept(*this);
767  node.get_rhs()->accept(*this);
768 }
769 
770 
772  info.bbcore_pointer_used = true;
773 }
774 
776  info.declared_thread_safe = true;
777 }
778 
779 
781  info.watch_count++;
782 }
783 
784 
786  info.watch_statements.push_back(&node);
787  node.visit_children(*this);
788 }
789 
790 
792  info.for_netcon_used = true;
793 }
794 
795 
797  info.table_count++;
798  table_statement_used = true;
799 }
800 
801 
803  psymtab = node.get_symbol_table();
804  auto blocks = node.get_blocks();
805  for (auto& block: blocks) {
806  info.top_blocks.push_back(block.get());
807  if (block->is_verbatim()) {
808  info.top_verbatim_blocks.push_back(block.get());
809  }
810  }
811  node.visit_children(*this);
812  find_ion_variables(node); // Keep this before find_*_range_variables()
818 }
819 
820 
822  node.accept(*this);
823  return info;
824 }
825 
827  info.vectorize = true;
828 }
829 
831  info.vectorize = true;
832 }
833 
835  info.vectorize = false;
836 }
837 
839  info.changed_dt = node.get_value()->eval();
840 }
841 
842 /// visit verbatim block and find all symbols used
844  const auto& text = node.get_statement()->eval();
845  // use C parser to get all tokens
847  driver.scan_string(text);
848  const auto& tokens = driver.all_tokens();
849 
850  // check if the token exist in the symbol table
851  for (auto& token: tokens) {
852  auto symbol = psymtab->lookup(token);
853  if (symbol != nullptr) {
854  info.variables_in_verbatim.insert(token);
855  }
856  }
857 }
858 
860  info.before_after_blocks.push_back(&node);
861 }
862 
864  info.before_after_blocks.push_back(&node);
865 }
866 
867 static std::shared_ptr<ast::Compartment> find_compartment(
869  const std::string& var_name) {
870  const auto& compartment_block = node.get_compartment_statements();
871  for (const auto& stmt: compartment_block->get_statements()) {
872  auto comp = std::dynamic_pointer_cast<ast::Compartment>(stmt);
873 
874  auto species = comp->get_species();
875  auto it = std::find_if(species.begin(), species.end(), [&var_name](auto var) {
876  return var->get_node_name() == var_name;
877  });
878 
879  if (it != species.end()) {
880  return comp;
881  }
882  }
883 
884  return nullptr;
885 }
886 
889  auto longitudinal_diffusion_block = node.get_longitudinal_diffusion_statements();
890  for (auto stmt: longitudinal_diffusion_block->get_statements()) {
891  auto diffusion = std::dynamic_pointer_cast<ast::LonDiffuse>(stmt);
892  auto rate_index_name = diffusion->get_index_name();
893  auto rate_expr = diffusion->get_rate();
894  auto species = diffusion->get_species();
895 
896  auto process_compartment = [](const std::shared_ptr<ast::Compartment>& compartment)
897  -> std::pair<std::shared_ptr<ast::Name>, std::shared_ptr<ast::Expression>> {
898  std::shared_ptr<ast::Expression> volume_expr;
899  std::shared_ptr<ast::Name> volume_index_name;
900  if (!compartment) {
901  volume_index_name = nullptr;
902  volume_expr = std::make_shared<ast::Double>("1.0");
903  } else {
904  volume_index_name = compartment->get_index_name();
905  volume_expr = std::shared_ptr<ast::Expression>(compartment->get_volume()->clone());
906  }
907  return {std::move(volume_index_name), std::move(volume_expr)};
908  };
909 
910  for (auto var: species) {
911  std::string state_name = var->get_value()->get_value();
912  auto compartment = find_compartment(node, state_name);
913  auto [volume_index_name, volume_expr] = process_compartment(compartment);
914 
916  {state_name,
917  LongitudinalDiffusionInfo(volume_index_name,
918  std::shared_ptr<ast::Expression>(volume_expr),
919  rate_index_name,
920  std::shared_ptr<ast::Expression>(rate_expr->clone()))});
921  }
922  }
923 }
924 
925 } // namespace codegen
926 } // namespace nmodl
Auto generated AST classes declaration.
Represents a AFTER block in NMODL.
Definition: after_block.hpp:51
Represents BBCOREPOINTER statement in NMODL.
Represents a BEFORE block in NMODL.
Represents binary expression in the NMODL.
Represents a BREAKPOINT block in NMODL.
Represents CONDUCTANCE statement in NMODL.
Represents a CONSTRUCTOR block in the NMODL.
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
Definition: cvode_block.hpp:38
Represents DERIVATIVE block in the NMODL.
Represent a callback to NEURON's derivimplicit solver.
Represents a DESTRUCTOR block in the NMODL.
Represent linear solver solution block based on Eigen.
Represent newton solver solution block based on Eigen.
Represents ELECTRODE_CURRENT variables statement in NMODL.
Represents a INITIAL block in the NMODL.
Represents LINEAR block in the NMODL.
Extracts information required for LONGITUDINAL_DIFFUSION for each KINETIC block.
Represents NONLINEAR block in the NMODL.
Represents the coreneuron nrn_state callback function.
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represents block encapsulating list of statements.
Represents SUFFIX statement in NMODL.
Definition: suffix.hpp:38
Represents TABLE statement in NMODL.
Represents THREADSAFE statement in NMODL.
Definition: thread_safe.hpp:38
Statement to indicate a change in timestep in a given block.
Definition: update_dt.hpp:38
Represents a C code block.
Definition: verbatim.hpp:38
Represent WATCH statement in NMODL.
static void sort_with_mod2c_symbol_order(std::vector< SymbolType > &symbols)
How symbols are stored in NEURON? See notes written in markdown file.
void visit_derivimplicit_callback(const ast::DerivimplicitCallback &node) override
visit node of type ast::DerivimplicitCallback
void visit_derivative_block(const ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
void check_cvode_codegen(const ast::Program &node)
Find whether or not we need to emit CVODE-related code for NEURON Notes: we generate CVODE-related co...
void visit_non_linear_block(const ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
void visit_breakpoint_block(const ast::BreakpointBlock &node) override
visit node of type ast::BreakpointBlock
void visit_thread_safe(const ast::ThreadSafe &) override
visit node of type ast::ThreadSafe
bool table_statement_used
table statement found
void visit_before_block(const ast::BeforeBlock &node) override
visit node of type ast::BeforeBlock
void visit_update_dt(const ast::UpdateDt &node) override
visit node of type ast::UpdateDt
void visit_discrete_block(const ast::DiscreteBlock &node) override
visit node of type ast::DiscreteBlock
void visit_nrn_state_block(const ast::NrnStateBlock &node) override
visit node of type ast::NrnStateBlock
void visit_function_table_block(const ast::FunctionTableBlock &node) override
visit node of type ast::FunctionTableBlock
void visit_function_call(const ast::FunctionCall &node) override
visit node of type ast::FunctionCall
bool under_derivative_block
if visiting derivative block
void visit_linear_block(const ast::LinearBlock &node) override
visit node of type ast::LinearBlock
void visit_longitudinal_diffusion_block(const ast::LongitudinalDiffusionBlock &node) override
visit node of type ast::LongitudinalDiffusionBlock
void visit_eigen_linear_solver_block(const ast::EigenLinearSolverBlock &node) override
visit node of type ast::EigenLinearSolverBlock
void visit_conductance_hint(const ast::ConductanceHint &node) override
visit node of type ast::ConductanceHint
bool under_breakpoint_block
if visiting breakpoint block
void visit_after_block(const ast::AfterBlock &node) override
visit node of type ast::AfterBlock
std::shared_ptr< ast::Expression > assign_lhs
lhs of assignment in derivative block
void find_ion_variables(const ast::Program &node)
Find all ions used in mod file.
void visit_table_statement(const ast::TableStatement &node) override
visit node of type ast::TableStatement
void visit_suffix(const ast::Suffix &node) override
visit node of type ast::Suffix
void visit_verbatim(const ast::Verbatim &node) override
visit verbatim block and find all symbols used
void visit_constructor_block(const ast::ConstructorBlock &node) override
visit node of type ast::ConstructorBlock
void visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock &node) override
visit node of type ast::EigenNewtonSolverBlock
codegen::CodegenInfo analyze(const ast::Program &node)
run visitor and return information for code generation
void visit_electrode_current(const ast::ElectrodeCurrent &node) override
visit node of type ast::ElectrodeCurrent
void visit_factor_def(const ast::FactorDef &node) override
visit node of type ast::FactorDef
void visit_statement_block(const ast::StatementBlock &node) override
Visit statement block and find prime symbols appear in derivative block.
void visit_net_receive_block(const ast::NetReceiveBlock &node) override
visit node of type ast::NetReceiveBlock
std::shared_ptr< symtab::Symbol > SymbolType
void visit_watch_statement(const ast::WatchStatement &node) override
visit node of type ast::WatchStatement
std::vector< std::shared_ptr< symtab::Symbol > > SymbolVectorType
void find_range_variables()
Find range variables i.e.
void visit_watch(const ast::Watch &node) override
visit node of type ast::Watch
symtab::SymbolTable * psymtab
symbol table for the program
void visit_program(const ast::Program &node) override
visit node of type ast::Program
codegen::CodegenInfo info
holds all codegen related information
void visit_function_block(const ast::FunctionBlock &node) override
visit node of type ast::FunctionBlock
void find_non_range_variables()
Find non-range variables i.e.
void visit_binary_expression(const ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
void visit_for_netcon(const ast::ForNetcon &node) override
visit node of type ast::ForNetcon
void visit_procedure_block(const ast::ProcedureBlock &node) override
visit node of type ast::ProcedureBlock
bool under_net_receive_block
if visiting net receive block
void visit_bbcore_pointer(const ast::BbcorePointer &node) override
visit node of type ast::BbcorePointer
void visit_initial_block(const ast::InitialBlock &node) override
visit node of type ast::InitialBlock
void visit_destructor_block(const ast::DestructorBlock &node) override
visit node of type ast::DestructorBlock
void visit_cvode_block(const ast::CvodeBlock &node) override
visit node of type ast::CvodeBlock
Information required to print LONGITUDINAL_DIFFUSION callbacks.
Class that binds all pieces together for parsing C verbatim blocks.
Definition: c11_driver.hpp:37
std::vector< std::shared_ptr< Symbol > > get_variables(syminfo::NmodlType with=syminfo::NmodlType::empty, syminfo::NmodlType without=syminfo::NmodlType::empty) const
get variables
std::vector< std::shared_ptr< Symbol > > get_variables_with_properties(syminfo::NmodlType properties, bool all=false) const
get variables with properties
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
Helper visitor to gather AST information to help code generation.
Auto generated AST classes declaration.
#define assert(ex)
Definition: hocassrt.h:24
double var(InputIterator begin, InputIterator end)
Definition: ivocvect.h:108
const char * name
Definition: init.cpp:16
void move(Item *q1, Item *q2, Item *q3)
Definition: list.cpp:200
static constexpr char AREA_VARIABLE[]
similar to node_area but user can explicitly declare it as area
static constexpr char POINT_PROCESS[]
point process keyword in nmodl
static constexpr char NET_EVENT_METHOD[]
net_event function call in nmodl
static constexpr char DIAM_VARIABLE[]
inbuilt neuron variable for diameter of the compartment
static constexpr char ARTIFICIAL_CELL[]
artificial cell keyword in nmodl
static constexpr char CVODE_T_METHOD[]
cvode_t method in nmodl
static constexpr char NET_SEND_METHOD[]
net_send function call in nmodl
static constexpr char CELSIUS_VARIABLE[]
global temperature variable
static constexpr char CVODE_T_V_METHOD[]
cvode_t_v method in nmodl
static constexpr char AFTER_CVODE_METHOD[]
cvode method in nmodl
static bool check_procedure_has_cvode(const std::shared_ptr< const ast::Ast > &solve_node, const std::shared_ptr< const ast::Ast > &procedure_node)
Check whether a given SOLVE block solves a PROCEDURE with any of the CVode methods.
static std::shared_ptr< ast::Compartment > find_compartment(const ast::LongitudinalDiffusionBlock &node, const std::string &var_name)
Status
state during various compiler passes
std::string to_string(const T &obj)
NmodlType
NMODL variable properties.
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
logger_type logger
Definition: logger.cpp:34
static List * info
static int using_cvode
Definition: nocpout.cpp:2682
static Node * node(Object *)
Definition: netcvode.cpp:291
s
Definition: multisend.cpp:521
short type
Definition: cabvars.h:10
#define text
Definition: plot.cpp:60
unsigned char diffusion
Definition: rxd.cpp:52
Represent information collected from AST for code generation.
int thread_var_data_size
sum of length of thread promoted variables
std::vector< SymbolType > range_assigned_vars
range variables which are assigned variables as well
std::vector< std::pair< SymbolType, std::string > > neuron_global_variables
[Core]NEURON global variables used (e.g. celsius) and their types
bool bbcore_pointer_used
if bbcore pointer is used
std::vector< const ast::FactorDef * > factor_definitions
all factors defined in the mod file
std::vector< SymbolType > assigned_vars
remaining assigned variables
const ast::CvodeBlock * cvode_block
the CVODE block
std::vector< SymbolType > range_state_vars
state variables excluding such useion read/write variables that are not ionic currents.
bool is_ion_variable(const std::string &name) const noexcept
if either read or write variable
int num_equations
number of equations (i.e.
bool artificial_cell
if mod file is artificial cell
std::vector< const ast::FunctionTableBlock * > function_tables
all functions tables defined in the mod file
std::vector< SymbolType > pointer_variables
pointer or bbcore pointer variables
bool point_process
if mod file is point process
bool electrode_current
if electrode current specified
std::vector< ast::Node * > top_blocks
all top level global blocks
bool diam_used
if diam is used
std::unordered_map< const ast::EigenNewtonSolverBlock *, std::string > functor_names
unique functor names for all the EigenNewtonSolverBlock s
std::vector< SymbolType > global_variables
global variables
int thread_data_index
thread_data_index indicates number of threads being allocated.
bool vectorize
true if mod file is vectorizable (which should be always true for coreneuron) But there are some bloc...
bool net_event_used
if net_event function is used
const ast::BreakpointBlock * breakpoint_node
derivative block
bool net_send_used
if net_send function is used
std::vector< const ast::WatchStatement * > watch_statements
all watch statements
bool declared_thread_safe
A mod file can be declared to be thread safe using the keyword THREADSAFE.
bool thread_callback_register
if thread thread call back routines need to register
int num_primes
number of primes (all state variables not necessary to be prime)
const ast::DestructorBlock * destructor_node
destructor block only for point process
std::vector< SymbolType > external_variables
external variables
std::vector< const ast::ProcedureBlock * > procedures
all procedures defined in the mod file
const ast::NrnStateBlock * nrn_state_block
nrn_state block
const ast::InitialBlock * net_receive_initial_node
initial block within net receive block
std::vector< SymbolType > thread_variables
thread variables (e.g. global variables promoted to thread)
bool eigen_newton_solver_exist
true if eigen newton solver is used
std::vector< SymbolType > constant_variables
constant variables
std::vector< ast::Node * > top_verbatim_blocks
all top level verbatim blocks
int num_net_receive_parameters
number of arguments to net_receive block
std::vector< const ast::DerivimplicitCallback * > derivimplicit_callbacks
derivimplicit callbacks need to be emited
bool is_ionic_current(const std::string &name) const noexcept
if given variable is a ionic current
std::vector< const ast::Block * > before_after_blocks
all before after blocks
bool derivimplicit_used() const
if legacy derivimplicit solver from coreneuron to be used
int table_count
number of table statements
std::vector< SymbolType > table_statement_variables
table variables
std::vector< const ast::Block * > functions_with_table
function or procedures with table statement
std::vector< SymbolType > random_variables
RANDOM variables.
std::vector< SymbolType > top_local_variables
local variables in the global scope
const ast::NetReceiveBlock * net_receive_node
net receive block for point process
std::vector< SymbolType > range_parameter_vars
range variables which are parameter as well
std::vector< const ast::FunctionBlock * > functions
all functions defined in the mod file
int top_local_thread_id
Top local variables are those local variables that appear in global scope.
int derivimplicit_var_thread_id
thread id for derivimplicit variables
bool eigen_linear_solver_exist
true if eigen linear solver is used
int thread_var_thread_id
thread id for thread promoted variables
int primes_size
sum of length of all prime variables
std::map< std::string, LongitudinalDiffusionInfo > longitudinal_diffusion_info
for each state, the information needed to print the callbacks.
int watch_count
number of watch expressions
std::vector< Conductance > conductances
represent conductance statements used in mod file
std::vector< SymbolType > prime_variables_by_order
this is the order in which they appear in derivative block this is required while printing them in in...
bool area_used
if area is used
std::string mod_suffix
name of the suffix
std::string changed_dt
updated dt to use with steadystate solver (in initial block) empty string means no change in dt
int derivimplicit_list_num
slist/dlist id for derivimplicit block
bool for_netcon_used
if for_netcon is used
const ast::ConstructorBlock * constructor_node
constructor block
std::vector< SymbolType > state_vars
all state variables
std::unordered_set< std::string > variables_in_verbatim
all variables/symbols used in the verbatim block
std::vector< SymbolType > table_assigned_variables
const ast::InitialBlock * initial_node
initial block
int top_local_thread_size
total length of all top local variables
Represent ions used in mod file.
bool need_style
if style semantic needed
bool is_intra_cell_conc(const std::string &text) const
Check if variable name is internal cell concentration.
std::optional< double > valence
ion valence
bool is_extra_cell_conc(const std::string &text) const
Check if variable name is external cell concentration.
std::vector< std::string > reads
ion variables that are being read
std::vector< std::string > writes
ion variables that are being written
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
Utility functions for visitors implementation.