NEURON
cvode_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
11 #include "lexer/token_mapping.hpp"
12 #include "pybind/pyembed.hpp"
13 #include "utils/logger.hpp"
15 #include <optional>
16 #include <regex>
17 #include <utility>
18 
19 namespace pywrap = nmodl::pybind_wrappers;
20 
21 namespace nmodl {
22 namespace visitor {
23 
24 static int get_index(const ast::IndexedName& node) {
25  return std::stoi(to_nmodl(node.get_length()));
26 }
27 
29  auto conserve_equations = collect_nodes(node, {ast::AstNodeType::CONSERVE});
30  if (!conserve_equations.empty()) {
31  std::unordered_set<ast::Statement*> eqs;
32  for (const auto& item: conserve_equations) {
33  eqs.insert(std::dynamic_pointer_cast<ast::Statement>(item).get());
34  }
35  node.erase_statement(eqs);
36  }
37 }
38 
39 // remove units from CVODE block so sympy can parse it properly
41  // matches either an int or a float, followed by any (including zero)
42  // number of spaces, followed by an expression in parentheses, that only
43  // has letters of the alphabet
44  std::regex unit_pattern(R"((\d+\.?\d*|\.\d+)\s*\‍([a-zA-Z]+\))");
45  auto rhs_string = to_nmodl(node.get_rhs());
46  auto rhs_string_no_units = fmt::format("{} = {}",
47  to_nmodl(node.get_lhs()),
48  std::regex_replace(rhs_string, unit_pattern, "$1"));
49  logger->debug("CvodeVisitor :: removing units from statement {}", to_nmodl(node));
50  logger->debug("CvodeVisitor :: result: {}", rhs_string_no_units);
51  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
52  create_statement(rhs_string_no_units));
53  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
54  expr_statement->get_expression());
55  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
56 }
57 
58 static std::pair<std::string, std::optional<int>> parse_independent_var(
59  std::shared_ptr<ast::Identifier> node) {
60  auto variable = std::make_pair(node->get_node_name(), std::optional<int>());
61  if (node->is_indexed_name()) {
62  variable.second = std::optional<int>(
63  get_index(*std::dynamic_pointer_cast<const ast::IndexedName>(node)));
64  }
65  return variable;
66 }
67 
68 /// set of all indexed variables not equal to ``ignored_name``
69 static std::unordered_set<std::string> get_indexed_variables(const ast::Expression& node,
70  const std::string& ignored_name) {
71  std::unordered_set<std::string> indexed_variables;
72  // all of the "reserved" vars
73  auto reserved_symbols = get_external_functions();
74  // all indexed vars
75  auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME});
76  for (const auto& var: indexed_vars) {
77  const auto& varname = var->get_node_name();
78  // skip if it's a reserved var
79  auto varname_not_reserved =
80  std::none_of(reserved_symbols.begin(),
81  reserved_symbols.end(),
82  [&varname](const auto item) { return varname == item; });
83  if (indexed_variables.count(varname) == 0 && varname != ignored_name &&
84  varname_not_reserved) {
85  indexed_variables.insert(varname);
86  }
87  }
88  return indexed_variables;
89 }
90 
92  const auto& lhs = node.get_lhs();
93 
94  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
95 
96  std::string varname;
97  if (name->is_prime_name()) {
98  varname = "D" + name->get_node_name();
99  node.set_lhs(std::make_shared<ast::Name>(new ast::String(varname)));
100  } else if (name->is_indexed_name()) {
102  // make sure the LHS isn't just a plain indexed var
103  if (!nodes.empty()) {
104  varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\'');
105  auto statement = fmt::format("{} = {}", varname, varname);
106  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
107  create_statement(statement));
108  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
109  expr_statement->get_expression());
110  node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
111  }
112  }
113  return varname;
114 }
115 
116 
118  protected:
121 
122  public:
125  node.visit_children(*this);
126  in_differential_equation = false;
127  }
128 };
129 
131  public:
133  program_symtab = symtab;
134  }
135 
137  const auto& lhs = node.get_lhs();
138 
139  if (!in_differential_equation || !lhs->is_var_name()) {
140  return;
141  }
142 
143  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
144  auto varname = cvode_set_lhs(node);
145 
146  if (program_symtab->lookup(varname) == nullptr) {
147  auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
148  symbol->set_original_name(name->get_node_name());
149  program_symtab->insert(symbol);
150  }
151  }
152 };
153 
155  public:
156  explicit StiffVisitor(symtab::SymbolTable* symtab) {
157  program_symtab = symtab;
158  }
159 
161  const auto& lhs = node.get_lhs();
162 
163  if (!in_differential_equation || !lhs->is_var_name()) {
164  return;
165  }
166 
167  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
168  auto varname = cvode_set_lhs(node);
169 
170  if (program_symtab->lookup(varname) == nullptr) {
171  auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
172  symbol->set_original_name(name->get_node_name());
173  program_symtab->insert(symbol);
174  }
175 
177 
178  auto rhs = node.get_rhs();
179 
180  // all indexed variables (need special treatment in SymPy)
181  auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name());
183  auto [jacobian, exception_message] =
184  diff2c(to_nmodl(*rhs), parse_independent_var(name), indexed_variables);
185  if (!exception_message.empty()) {
186  logger->warn("CvodeVisitor :: python exception: {}", exception_message);
187  }
188  // NOTE: LHS can be anything here, the equality is to keep `create_statement` from
189  // complaining, we discard the LHS later
190  auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian);
191  logger->debug("CvodeVisitor :: replacing statement {} with {}", to_nmodl(node), statement);
192  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
193  create_statement(statement));
194  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
195  expr_statement->get_expression());
196  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
197  }
198 };
199 
200 static std::shared_ptr<ast::DerivativeBlock> get_derivative_block(ast::Program& node) {
201  auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK});
202  if (derivative_blocks.empty()) {
203  return nullptr;
204  }
205 
206  // steady state adds a DERIVATIVE block with a `_steadystate` suffix
207  auto not_steadystate = [](const auto& item) {
208  auto name = std::dynamic_pointer_cast<const ast::DerivativeBlock>(item)->get_node_name();
209  return !stringutils::ends_with(name, "_steadystate");
210  };
211  decltype(derivative_blocks) derivative_blocks_copy;
212  std::copy_if(derivative_blocks.begin(),
213  derivative_blocks.end(),
214  std::back_inserter(derivative_blocks_copy),
215  not_steadystate);
216  if (derivative_blocks_copy.size() > 1) {
217  auto message = "CvodeVisitor :: cannot have multiple DERIVATIVE blocks";
218  logger->error(message);
219  throw std::runtime_error(message);
220  }
221 
222  return std::dynamic_pointer_cast<ast::DerivativeBlock>(derivative_blocks_copy[0]);
223 }
224 
225 
227  auto derivative_block = get_derivative_block(node);
228  if (derivative_block == nullptr) {
229  return;
230  }
231 
232  auto non_stiff_block = derivative_block->get_statement_block()->clone();
233  remove_conserve_statements(*non_stiff_block);
234 
235  auto stiff_block = derivative_block->get_statement_block()->clone();
236  remove_conserve_statements(*stiff_block);
237 
238  NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block);
239  StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block);
240  auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME});
241  node.emplace_back_node(new ast::CvodeBlock(
242  derivative_block->get_name(),
243  std::shared_ptr<ast::Integer>(new ast::Integer(prime_vars.size(), nullptr)),
244  std::shared_ptr<ast::StatementBlock>(non_stiff_block),
245  std::shared_ptr<ast::StatementBlock>(stiff_block)));
246 }
247 
248 } // namespace visitor
249 } // namespace nmodl
Auto generated AST classes declaration.
Represent token returned by scanner.
Definition: modtoken.hpp:50
Represents binary expression in the NMODL.
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
Definition: cvode_block.hpp:38
Represents differential equation in DERIVATIVE block.
Base class for all expressions in the NMODL.
Definition: expression.hpp:43
Represents specific element of an array variable.
Represents an integer variable.
Definition: integer.hpp:49
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represents block encapsulating list of statements.
Represents a string.
Definition: string.hpp:52
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Definition: pyembed.hpp:29
const pybind_wrap_api & api()
Get a pointer to the pybind_wrap_api struct.
Definition: pyembed.cpp:136
Represent symbol table for a NMODL block.
void insert(const std::shared_ptr< Symbol > &symbol)
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
Concrete visitor for all AST classes.
Definition: ast_visitor.hpp:37
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
void visit_diff_eq_expression(ast::DiffEqExpression &node)
visit node of type ast::DiffEqExpression
symtab::SymbolTable * program_symtab
void visit_program(ast::Program &node) override
visit node of type ast::Program
void visit_binary_expression(ast::BinaryExpression &node)
visit node of type ast::BinaryExpression
NonStiffVisitor(symtab::SymbolTable *symtab)
void visit_binary_expression(ast::BinaryExpression &node)
visit node of type ast::BinaryExpression
StiffVisitor(symtab::SymbolTable *symtab)
Visitor used for generating the necessary AST nodes for CVODE.
static double jacobian(void *v)
Definition: cvodeobj.cpp:245
virtual bool is_indexed_name() const noexcept
Check if the ast node is an instance of ast::IndexedName.
Definition: ast.cpp:88
virtual std::string get_node_name() const
Return name of of the node.
Definition: ast.cpp:28
@ DERIVATIVE_BLOCK
type of ast::DerivativeBlock
@ INDEXED_NAME
type of ast::IndexedName
@ CONSERVE
type of ast::Conserve
@ PRIME_NAME
type of ast::PrimeName
static bool ends_with(const std::string &haystack, const std::string &needle)
Check if haystack ends with needle.
static std::string remove_character(std::string text, const char c)
Remove all occurrences of a given character in a text.
double var(InputIterator begin, InputIterator end)
Definition: ivocvect.h:108
#define rhs
Definition: lineq.h:6
const char * name
Definition: init.cpp:16
static std::shared_ptr< ast::DerivativeBlock > get_derivative_block(ast::Program &node)
static std::pair< std::string, std::optional< int > > parse_independent_var(std::shared_ptr< ast::Identifier > node)
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
static void remove_units(ast::BinaryExpression &node)
static void remove_conserve_statements(ast::StatementBlock &node)
static std::string cvode_set_lhs(ast::BinaryExpression &node)
static int get_index(const ast::IndexedName &node)
static std::unordered_set< std::string > get_indexed_variables(const ast::Expression &node, const std::string &ignored_name)
set of all indexed variables not equal to ignored_name
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
logger_type logger
Definition: logger.cpp:34
std::vector< std::string > get_external_functions()
Return functions that can be used in the NMODL.
static Node * node(Object *)
Definition: netcvode.cpp:291
int get()
Definition: units.cpp:918
decltype(&call_diff2c) diff2c
Definition: wrapper.hpp:67
Map different tokens from lexer to token types.
Utility functions for visitors implementation.