NEURON
solve_block_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "utils/fmt.h"
11 #include <cassert>
12 #include <memory>
13 
14 #include "ast/all.hpp"
16 #include "visitor_utils.hpp"
17 
18 namespace nmodl {
19 namespace visitor {
20 
22  in_breakpoint_block = true;
23  node.visit_children(*this);
24  in_breakpoint_block = false;
25 }
26 
27 /// check if given node contains sympy solution
28 static bool has_sympy_solution(const ast::Ast& node) {
30 }
31 
32 /**
33  * Create solution expression node that will be used for solve block
34  * \param solve_block solve block used to describe node to solve and method
35  * \return solution expression that will be used to replace the solve block
36  *
37  * Depending on the solver used, solve block is converted to solve expression statement
38  * that will be used to replace solve block. Note that the blocks are clones instead of
39  * shared_ptr because DerivimplicitCallback is currently contain whole node
40  * instead of just pointer.
41  */
43  ast::SolveBlock& solve_block) {
44  /// find out the block that is going to solved
45  const auto& block_name = solve_block.get_block_name()->get_node_name();
46  const auto& solve_node_symbol = symtab->lookup(block_name);
47  if (solve_node_symbol == nullptr) {
48  throw std::runtime_error(
49  fmt::format("SolveBlockVisitor :: cannot find the block '{}' to solve it", block_name));
50  }
51  auto node_to_solve = solve_node_symbol->get_nodes().front();
52 
53  /// in case of derivimplicit method if neuron solver is used (i.e. not sympy) then
54  /// the solution is not in place but we have to create a callback to newton solver
55  const auto& method = solve_block.get_method();
56  std::string solve_method = method ? method->get_node_name() : "";
57  if (solve_method == codegen::naming::DERIVIMPLICIT_METHOD &&
58  !has_sympy_solution(*node_to_solve)) {
59  /// typically derivimplicit is used for derivative block only
60  if (node_to_solve->get_node_type() != ast::AstNodeType::DERIVATIVE_BLOCK) {
61  const std::string node_name = node_to_solve->get_node_name();
62  const std::string node_type = node_to_solve->get_node_type_name();
63  throw std::runtime_error(fmt::format(
64  "Method {} cannot be used for {} {}", solve_method, node_type, node_name));
65  }
66  auto derivative_block = dynamic_cast<ast::DerivativeBlock*>(node_to_solve);
67  auto callback_expr = new ast::DerivimplicitCallback(derivative_block->clone());
68  return new ast::SolutionExpression(solve_block.clone(), callback_expr);
69  }
70 
71  if (node_to_solve->get_node_type() == ast::AstNodeType::PROCEDURE_BLOCK) {
72  auto procedure_call = new ast::FunctionCall(solve_block.get_block_name()->clone(), {});
73  auto statement = std::make_shared<ast::ExpressionStatement>(procedure_call);
74  auto statement_block = new ast::StatementBlock({statement});
75 
76  return new ast::SolutionExpression(solve_block.clone(), statement_block);
77  }
78 
79  auto block_to_solve = node_to_solve->get_statement_block();
80  return new ast::SolutionExpression(solve_block.clone(), block_to_solve->clone());
81 }
82 
83 /**
84  * Replace solve blocks with solution expression
85  * @param node Ast node for SOLVE statement in the mod file
86  */
88  node.visit_children(*this);
89  if (node.get_expression()->is_solve_block()) {
90  auto solve_block = dynamic_cast<ast::SolveBlock*>(node.get_expression().get());
91  auto sol_expr = create_solution_expression(*solve_block);
92  if (in_breakpoint_block) {
93  nrn_state_solve_statements.emplace_back(new ast::ExpressionStatement(sol_expr));
94  } else {
95  node.set_expression(std::shared_ptr<ast::SolutionExpression>(sol_expr));
96  }
97  }
98 }
99 
101  symtab = node.get_symbol_table();
102  node.visit_children(*this);
103  /// add new node NrnState with solve blocks from breakpoint block
104  if (!nrn_state_solve_statements.empty()) {
106  node.emplace_back_node(nrn_state);
107  }
108 }
109 
110 } // namespace visitor
111 } // namespace nmodl
Auto generated AST classes declaration.
Represents a BREAKPOINT block in NMODL.
Represents DERIVATIVE block in the NMODL.
Represent a callback to NEURON's derivimplicit solver.
Represents the coreneuron nrn_state callback function.
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represent solution of a block in the AST.
std::shared_ptr< Name > get_block_name() const noexcept
Getter for member variable SolveBlock::block_name.
std::shared_ptr< Name > get_method() const noexcept
Getter for member variable SolveBlock::method.
SolveBlock * clone() const override
Return a copy of the current node.
Definition: solve_block.hpp:79
Represents block encapsulating list of statements.
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
void visit_breakpoint_block(ast::BreakpointBlock &node) override
visit node of type ast::BreakpointBlock
ast::SolutionExpression * create_solution_expression(ast::SolveBlock &solve_block)
Create solution expression node that will be used for solve block.
void visit_program(ast::Program &node) override
visit node of type ast::Program
ast::StatementVector nrn_state_solve_statements
solve expression statements for NrnState block
void visit_expression_statement(ast::ExpressionStatement &node) override
Replace solve blocks with solution expression.
@ DERIVATIVE_BLOCK
type of ast::DerivativeBlock
@ EIGEN_NEWTON_SOLVER_BLOCK
type of ast::EigenNewtonSolverBlock
@ PROCEDURE_BLOCK
type of ast::ProcedureBlock
static void nrn_state(neuron::model_sorted_token const &, NrnThread *nt, Memb_list *ml, int type)
Definition: kschan.cpp:87
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
static bool has_sympy_solution(const ast::Ast &node)
check if given node contains sympy solution
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
static Node * node(Object *)
Definition: netcvode.cpp:291
Replace solve block statements with actual solution node in the AST.
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:52
Utility functions for visitors implementation.