NEURON
sympy_replace_solutions_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
10 
11 #include "ast/all.hpp"
12 #include "utils/logger.hpp"
14 
15 namespace nmodl {
16 namespace visitor {
17 
18 /**
19  * \details SympyReplaceSolutionsVisitor tells us that a new equation appear and, depending where
20  * it is located, it can determine if it is part of the main system of equations or is something
21  * else. Every time we are out of the system and we print a new equation that is in the system
22  * we update the counter. \ref in_system follows, with lag, \param is_in_system and every time
23  * they are false and true respectively we detect a switch.
24  *
25  * \param is_in_system is a bool provided from outside that tells us if a new equation is indeed
26  * part of the main system of equations
27  */
29  n_interleaves += (!in_system && is_in_system); // count an interleave only if in_system ==
30  // false and is_in_system == true
31  in_system = is_in_system; // update in_system
32 }
33 
35  const std::vector<std::string>& pre_solve_statements,
36  const std::vector<std::string>& solutions,
37  const std::unordered_set<ast::Statement*>& to_be_removed,
38  const ReplacePolicy policy,
39  const size_t n_next_equations,
40  const std::string& tmp_unique_prefix)
44  , policy(policy)
46  , replaced_statements_range(-1, -1) {
47  // if tmp_unique_prefix we do not expect tmp_statements
48  const auto ss_tmp_delimeter =
49  tmp_unique_prefix.empty()
50  ? solutions.begin()
51  : std::find_if(solutions.begin(),
52  solutions.end(),
53  [&tmp_unique_prefix](const std::string& statement) {
54  return statement.substr(0, tmp_unique_prefix.size()) !=
55  tmp_unique_prefix;
56  });
57  tmp_statements = StatementDispenser(solutions.begin(), ss_tmp_delimeter, -1);
58  solution_statements = StatementDispenser(ss_tmp_delimeter, solutions.end(), -1);
59 
60  replacements.clear();
61 }
62 
63 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
65  const bool current_is_top_level_statement_block =
66  is_top_level_statement_block; // we mark it down since we are going to change it for
67  // visiting the children
69 
70  if (current_is_top_level_statement_block) {
71  logger->debug("SympyReplaceSolutionsVisitor :: visit statements. Matching policy: {}",
72  (policy == ReplacePolicy::VALUE ? "VALUE" : "GREEDY"));
74  node.visit_children(*this);
75 
77  logger->debug(
78  "SympyReplaceSolutionsVisitor :: not all solutions were replaced. Policy: GREEDY");
81  node.visit_children(*this);
82 
83  if (interleaves_counter.n() > 0) {
84  logger->warn(
85  "SympyReplaceSolutionsVisitor :: Found ambiguous system of equations "
86  "interleaved with {} assignment statements. I do not know what equations go "
87  "before and what "
88  "equations go after the assignment statements. Either put all the equations "
89  "that need to be solved "
90  "in the form: x = f(...) and with distinct variable assignments or do not "
91  "interleave the system with assignments.",
93  }
94  }
95  } else {
96  node.visit_children(*this);
97  }
98 
99  auto const& old_statements = node.get_statements();
100 
101  ast::StatementVector new_statements;
102  new_statements.reserve(2 * old_statements.size());
103  for (auto& old_statement: old_statements) {
104  const auto& replacement_ptr = replacements.find(old_statement);
105  if (replacement_ptr != replacements.end()) {
106  if (replaced_statements_range.first == -1) {
107  replaced_statements_range.first = static_cast<int>(new_statements.size());
108  }
109 
110  new_statements.insert(new_statements.end(),
111  replacement_ptr->second.begin(),
112  replacement_ptr->second.end());
113 
114  replaced_statements_range.second = static_cast<int>(new_statements.size());
115 
116  logger->debug("SympyReplaceSolutionsVisitor :: erasing {}", to_nmodl(old_statement));
117  for (const auto& replacement: replacement_ptr->second) {
118  logger->debug("SympyReplaceSolutionsVisitor :: adding {}", to_nmodl(replacement));
119  }
120  } else if (to_be_removed == nullptr ||
121  to_be_removed->find(&(*old_statement)) == to_be_removed->end()) {
122  logger->debug("SympyReplaceSolutionsVisitor :: found {}, nothing to do",
123  to_nmodl(old_statement));
124  new_statements.emplace_back(old_statement);
125  } else {
126  logger->debug("SympyReplaceSolutionsVisitor :: erasing {}", to_nmodl(old_statement));
127  }
128  }
129 
130  if (current_is_top_level_statement_block) {
131  if (!solution_statements.tags.empty()) {
132  std::ostringstream ss;
133  for (const auto ii: solution_statements.tags) {
134  ss << to_nmodl(solution_statements.statements[ii]) << '\n';
135  }
136  throw std::runtime_error(fmt::format(
137  "Not all solutions were replaced! Sympy returned {} equations but I could not find "
138  "a place "
139  "for all of them. In particular, the following equations remain to be replaced "
140  "somewhere:\n{}This is "
141  "probably a bug and I invite you to report it to a developer. Possible causes:\n"
142  " - I did not do a GREEDY pass and some solutions could not be replaced by VALUE\n "
143  "sympy "
144  "returned more equations than what we expected\n - There is a bug in the GREEDY "
145  "pass\n - some "
146  "solutions were replaced but not untagged",
148  ss.str()));
149  }
150 
151  if (replaced_statements_range.first == -1) {
152  replaced_statements_range.first = static_cast<int>(new_statements.size());
153  }
154  if (replaced_statements_range.second == -1) {
155  replaced_statements_range.second = static_cast<int>(new_statements.size());
156  }
157  }
158 
159  node.set_statements(std::move(new_statements));
160 }
161 
163  const ast::Node& node,
164  std::shared_ptr<ast::Expression> get_lhs(const ast::Node& node)) {
166 
167  const auto& statement = std::static_pointer_cast<ast::Statement>(
168  node.get_parent()->get_shared_ptr());
169 
170  // do not visit if already marked
171  if (replacements.find(statement) != replacements.end()) {
172  return;
173  }
174 
175 
176  switch (policy) {
177  case ReplacePolicy::VALUE: {
178  const auto key = statement_dependencies_key(get_lhs(node));
179 
181  logger->debug("SympyReplaceSolutionsVisitor :: marking for replacement {}",
182  to_nmodl(statement));
183 
184  ast::StatementVector new_statements;
185 
189 
190  replacements.emplace(statement, new_statements);
191  }
192  break;
193  }
194  case ReplacePolicy::GREEDY: {
195  logger->debug("SympyReplaceSolutionsVisitor :: marking for replacement {}",
196  to_nmodl(statement));
197 
198  ast::StatementVector new_statements;
199 
203 
204  replacements.emplace(statement, new_statements);
205  break;
206  }
207  }
208 }
209 
210 
212  logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node));
213  auto get_lhs = [](const ast::Node& node) -> std::shared_ptr<ast::Expression> {
214  return dynamic_cast<const ast::DiffEqExpression&>(node).get_expression()->get_lhs();
215  };
216 
218 }
219 
221  logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node));
222  auto get_lhs = [](const ast::Node& node) -> std::shared_ptr<ast::Expression> {
223  return dynamic_cast<const ast::LinEquation&>(node).get_lhs();
224  };
225 
227 }
228 
229 
231  logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node));
232  auto get_lhs = [](const ast::Node& node) -> std::shared_ptr<ast::Expression> {
233  return dynamic_cast<const ast::NonLinEquation&>(node).get_lhs();
234  };
235 
237 }
238 
239 
241  logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node));
242  if (node.get_op().get_value() == ast::BinaryOp::BOP_ASSIGN && node.get_lhs()->is_var_name()) {
244 
245  const auto& var =
246  std::static_pointer_cast<ast::VarName>(node.get_lhs())->get_name()->get_node_name();
249  }
250 }
251 
252 
254  const std::vector<std::string>::const_iterator& statements_str_beg,
255  const std::vector<std::string>::const_iterator& statements_str_end,
256  const int error_on_n_flushes)
257  : statements(create_statements(statements_str_beg, statements_str_end))
258  , error_on_n_flushes(error_on_n_flushes) {
260  build_maps();
261 }
262 
263 
264 /**
265  * \details CHere we construct a map variable -> affected equations. In other words this map tells
266  * me what equations need to be updated when I change a particular variable. To do that we build a a
267  * graph of dependencies var -> vars and in the mean time we reduce it to the root variables. This
268  * is ensured by the fact that the tmp variables are sorted so that the next tmp variable may depend
269  * on the previous one. Since it is a relation of equivalence (if an equation depends on a variable,
270  * it needs to be updated if the variable changes), we build the two maps at the same time.
271  *
272  * An example:
273  *
274  * \code{.mod}
275  * tmp0 = x + a
276  * tmp1 = tmp0 + b
277  * tmp2 = y
278  * \endcode
279  *
280  * dependency_map should be (the order of the equation is unimportant since we are building
281  * a map):
282  *
283  * - tmp0 : x, a
284  * - tmp1 : x, a, b
285  * - tmp2 : y
286  *
287  * and the var2statement map should be (the order of the following equations is unimportant
288  * since we are building a map. The number represents the index of the original equations):
289  *
290  * - x : 0, 1
291  * - y : 2
292  * - a : 0, 1
293  * - b : 1
294  *
295  */
296 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
298  for (size_t ii = 0; ii < statements.size(); ++ii) {
299  const auto& statement = statements[ii];
300 
301  if (statement->is_expression_statement()) {
302  const auto& e_statement =
303  std::static_pointer_cast<ast::ExpressionStatement>(statement)->get_expression();
304  if (e_statement->is_binary_expression()) {
305  const auto& bin_exp = std::static_pointer_cast<ast::BinaryExpression>(e_statement);
306  const auto& dependencies = statement_dependencies(bin_exp->get_lhs(),
307  bin_exp->get_rhs());
308 
309  const auto& key = dependencies.first;
310  const auto& vars = dependencies.second;
311  if (!key.empty()) {
312  var2statement.emplace(key, ii);
313  for (const auto& var: vars) {
314  const auto& var_already_inserted = dependency_map.find(var);
315  if (var_already_inserted != dependency_map.end()) {
316  dependency_map[key].insert(var_already_inserted->second.begin(),
317  var_already_inserted->second.end());
318  for (const auto& root_var: var_already_inserted->second) {
319  var2dependants[root_var].insert(ii);
320  }
321  } else {
322  dependency_map[key].insert(var);
323  var2dependants[var].insert(ii);
324  }
325  }
326  }
327  }
328  }
329  }
330 
331  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: var2dependants map");
332  for (const auto& entry: var2dependants) {
333  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: var `{}` used in:",
334  entry.first);
335  for (const auto ii: entry.second) {
336  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: -> {}",
337  to_nmodl(statements[ii]));
338  }
339  }
340  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: var2statement map");
341  for (const auto& entry: var2statement) {
342  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: var `{}` defined in:",
343  entry.first);
344  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: -> {}",
345  to_nmodl(statements[entry.second]));
346  }
347 }
348 
350  ast::StatementVector& new_statements,
351  const std::string& var) {
352  auto ptr = var2statement.find(var);
353  bool emplaced = false;
354  if (ptr != var2statement.end()) {
355  const auto ii = ptr->second;
356  const auto tag_ptr = tags.find(ii);
357  if (tag_ptr != tags.end()) {
358  new_statements.emplace_back(statements[ii]->clone());
359  tags.erase(tag_ptr);
360  emplaced = true;
361 
362  logger->debug(
363  "SympyReplaceSolutionsVisitor::StatementDispenser :: adding to replacement rule {}",
364  to_nmodl(statements[ii]));
365  } else {
366  logger->error(
367  "SympyReplaceSolutionsVisitor::StatementDispenser :: tried adding to replacement "
368  "rule {} but statement is not "
369  "tagged",
370  to_nmodl(statements[ii]));
371  }
372  }
373  return emplaced;
374 }
375 
377  ast::StatementVector& new_statements,
378  const size_t n_next_statements) {
379  size_t counter = 0;
380  for (size_t next_statement_ii = 0;
381  next_statement_ii < statements.size() && counter < n_next_statements;
382  ++next_statement_ii) {
383  const auto tag_ptr = tags.find(next_statement_ii);
384  if (tag_ptr != tags.end()) {
385  logger->debug(
386  "SympyReplaceSolutionsVisitor::StatementDispenser :: adding to replacement rule {}",
387  to_nmodl(statements[next_statement_ii]));
388  new_statements.emplace_back(statements[next_statement_ii]->clone());
389  tags.erase(tag_ptr);
390  ++counter;
391  }
392  }
393  return counter;
394 }
395 
397  ast::StatementVector& new_statements) {
398  for (const auto ii: tags) {
399  new_statements.emplace_back(statements[ii]->clone());
400  logger->debug(
401  "SympyReplaceSolutionsVisitor::StatementDispenser :: adding to replacement rule {}",
402  to_nmodl(statements[ii]));
403  }
404 
405  n_flushes += (!tags.empty());
406  if (error_on_n_flushes > 0 && n_flushes >= error_on_n_flushes) {
407  throw std::runtime_error(
408  "SympyReplaceSolutionsVisitor::StatementDispenser :: State variable assignment(s) "
409  "interleaved in system "
410  "of "
411  "equations/differential equations. It is not allowed due to possible numerical "
412  "instability and undefined "
413  "behavior. Erase the assignment statement(s) or move them before/after the"
414  " set of equations/differential equations.");
415  }
416 
417  const auto n_replacements = tags.size();
418 
419  tags.clear();
420 
421  return n_replacements;
422 }
423 
425  const std::string& var) {
426  auto ptr = var2dependants.find(var);
427  size_t n = 0;
428  if (ptr != var2dependants.end()) {
429  for (const auto ii: ptr->second) {
430  const auto pos = tags.insert(ii);
431  if (pos.second) {
432  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: tagging {}",
433  to_nmodl(statements[ii]));
434  }
435  ++n;
436  }
437  }
438  return n;
439 }
440 
442  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: tagging all statements");
443  for (size_t i = 0; i < statements.size(); ++i) {
444  tags.insert(i);
445  logger->debug("SympyReplaceSolutionsVisitor::StatementDispenser :: tagging {}",
446  to_nmodl(statements[i]));
447  }
448 }
449 
450 
451 } // namespace visitor
452 } // namespace nmodl
Auto generated AST classes declaration.
Represents binary expression in the NMODL.
Represents differential equation in DERIVATIVE block.
One equation in a system of equations tha collectively form a LINEAR block.
Base class for all AST node.
Definition: node.hpp:40
One equation in a system of equations that collectively make a NONLINEAR block.
Represents block encapsulating list of statements.
void try_replace_tagged_statement(const ast::Node &node, std::shared_ptr< ast::Expression > get_lhs(const ast::Node &node))
Try to replace a statement.
StatementDispenser solution_statements
solutions that we want to replace
@ VALUE
Replace statements matching by lhs varName.
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression
StatementDispenser tmp_statements
tmp statements that appear with –cse (i.e. )
std::pair< int, int > replaced_statements_range
{begin index, end index} of the added statements. -1 means that it is invalid
ReplacePolicy policy
Replacement policy used by the various visitors.
bool is_top_level_statement_block
Used to notify to visit_statement_block was called by the user (or another visitor) or re-called in a...
void visit_non_lin_equation(ast::NonLinEquation &node) override
visit node of type ast::NonLinEquation
void visit_binary_expression(ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
size_t n_next_equations
Number of solutions that match each old_statement with the greedy policy.
InterleavesCounter interleaves_counter
counts how many times the solution statements are interleaved with assignment expressions
const std::unordered_set< ast::Statement * > * to_be_removed
group of old statements that need replacing
StatementDispenser pre_solve_statements
Update state variable statements (i.e. )
void visit_lin_equation(ast::LinEquation &node) override
visit node of type ast::LinEquation
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
std::unordered_map< std::shared_ptr< ast::Statement >, ast::StatementVector > replacements
Replacements found by the visitor.
#define key
Definition: tqueue.hpp:45
#define i
Definition: md1redef.h:19
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:302
double var(InputIterator begin, InputIterator end)
Definition: ivocvect.h:108
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
void move(Item *q1, Item *q2, Item *q3)
Definition: list.cpp:200
int ii
Definition: cellorder.cpp:631
std::vector< std::shared_ptr< Statement > > create_statements(const std::vector< std::string >::const_iterator &code_statements_beg, const std::vector< std::string >::const_iterator &code_statements_end)
Same as for create_statement but for vectors of strings.
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::pair< std::string, std::unordered_set< std::string > > statement_dependencies(const std::shared_ptr< ast::Expression > &lhs, const std::shared_ptr< ast::Expression > &rhs)
If lhs and rhs combined represent an assignment (we assume to have an "=" in between them) we extract...
std::string statement_dependencies_key(const std::shared_ptr< ast::Expression > &lhs)
The result.first of statement_dependencies.
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
logger_type logger
Definition: logger.cpp:34
static Node * node(Object *)
Definition: netcvode.cpp:291
int const size_t const size_t n
Definition: nrngsl.h:10
Count interleaves of assignment statement inside the system of equations.
bool in_system
Bool that keeps track if just wrote an equation of the system of equations (true) or not (false)
void new_equation(const bool is_in_system)
Count interleaves defined as a switch false -> true for in_system.
size_t n_interleaves
Number of interleaves of assignment statements in between equations of the system of equations.
Sorts and maps statements to variables keeping track of what needs updating.
std::vector< std::shared_ptr< ast::Statement > > statements
Vector of statements.
bool try_emplace_back_tagged_statement(ast::StatementVector &new_statements, const std::string &var)
Look for var in var2statement and emplace back that statement in new_statements.
std::set< size_t > tags
Keeps track of what statements need updating.
size_t emplace_back_all_tagged_statements(ast::StatementVector &new_statements)
Emplace back all the statements that are marked for updating in tags.
void tag_all_statements()
Mark that all the statements need updating (probably unused)
void build_maps()
Construct the maps var2dependants, var2statement and dependency_map for easy access and classificatio...
size_t tag_dependant_statements(const std::string &var)
Tag all the statements that depend on var for updating.
size_t emplace_back_next_tagged_statements(ast::StatementVector &new_statements, const size_t n_next_statements)
Emplace back the next n_next_statements solutions in statements that is marked for updating in tags.
bool is_var_assigned_here(const std::string &var) const
Check if one of the statements assigns this variable (i.e.
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
Utility functions for visitors implementation.