NEURON
neuron_solve_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
12 #include "parser/diffeq_driver.hpp"
13 #include "symtab/symbol.hpp"
14 #include "utils/logger.hpp"
16 
17 
18 namespace nmodl {
19 namespace visitor {
20 
22  auto name = node.get_block_name()->get_node_name();
23  const auto& method = node.get_method();
24  solve_method = method ? method->get_value()->eval() : "";
26 }
27 
28 
30  derivative_block_name = node.get_name()->get_node_name();
31  derivative_block = true;
32  node.visit_children(*this);
33  derivative_block = false;
35  const auto& statement_block = node.get_statement_block();
36  for (auto& e: euler_solution_expressions) {
37  statement_block->emplace_back_statement(e);
38  }
39  }
40 }
41 
42 
44  differential_equation = true;
45  node.visit_children(*this);
46  differential_equation = false;
47 }
48 
49 
51  const auto& lhs = node.get_lhs();
52 
53  /// we have to only solve odes under derivative block where lhs is variable
54  if (!derivative_block || !differential_equation || !lhs->is_var_name()) {
55  return;
56  }
57 
59  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
60 
61  if (name->is_prime_name()) {
62  auto equation = to_nmodl(node);
64  std::string solution;
65  /// check if ode can be solved with cnexp method
66  if (parser::DiffeqDriver::cnexp_possible(equation, solution)) {
67  auto statement = create_statement(solution);
68  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
69  statement);
70  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
71  expr_statement->get_expression());
72  node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
73  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
74  } else {
75  logger->warn("NeuronSolveVisitor :: cnexp solver not possible for {}",
76  to_nmodl(node));
77  }
79  // computation of the derivative in place
80  {
81  std::string solution = parser::DiffeqDriver::solve(equation, solve_method);
82  auto statement = create_statement(solution);
83  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
84  statement);
85  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
86  expr_statement->get_expression());
87  node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
88  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
89  }
90 
91  // create a new statement to compute the value based on the derivative
92  // this statement will be pushed at the end of the derivative block
93  {
94  std::string n = name->get_node_name();
95  auto statement = create_statement(fmt::format("{} = {} + dt * D{}", n, n, n));
96  euler_solution_expressions.emplace_back(statement);
97  }
99  auto varname = "D" + name->get_node_name();
100  node.set_lhs(std::make_shared<ast::Name>(new ast::String(varname)));
101  if (program_symtab->lookup(varname) == nullptr) {
102  auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
103  symbol->set_original_name(name->get_node_name());
104  symbol->created_from_state();
105  program_symtab->insert(symbol);
106  }
107  } else {
108  logger->error("NeuronSolveVisitor :: solver method '{}' not supported", solve_method);
109  }
110  }
111 }
112 
114  program_symtab = node.get_symbol_table();
115  node.visit_children(*this);
116 }
117 
118 } // namespace visitor
119 } // namespace nmodl
Auto generated AST classes declaration.
Represent token returned by scanner.
Definition: modtoken.hpp:50
Represents binary expression in the NMODL.
Represents DERIVATIVE block in the NMODL.
Represents differential equation in DERIVATIVE block.
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represents a string.
Definition: string.hpp:52
static std::string solve(const std::string &equation, std::string method, bool debug=false)
solve equation using provided method
static bool cnexp_possible(const std::string &equation, std::string &solution)
check if given equation can be solved using cnexp method
void insert(const std::shared_ptr< Symbol > &symbol)
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
void visit_solve_block(ast::SolveBlock &node) override
visit node of type ast::SolveBlock
std::vector< std::shared_ptr< ast::Statement > > euler_solution_expressions
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression
std::string derivative_block_name
the derivative name currently being visited
symtab::SymbolTable * program_symtab
global symbol table
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
bool differential_equation
true while visiting differential equation
std::map< std::string, std::string > solve_blocks
a map holding solve block names and methods
void visit_program(ast::Program &node) override
visit node of type ast::Program
void visit_binary_expression(ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
std::string solve_method
method specified in solve block
bool derivative_block
visiting derivative block
const char * name
Definition: init.cpp:16
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
static constexpr char EULER_METHOD[]
euler method in nmodl
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
logger_type logger
Definition: logger.cpp:34
Visitor that solves ODEs using old solvers of NEURON
static Node * node(Object *)
Definition: netcvode.cpp:291
int const size_t const size_t n
Definition: nrngsl.h:10
Implement class to represent a symbol in Symbol Table.
Utility functions for visitors implementation.