12 #include "utils/logger.hpp"
19 const std::shared_ptr<ast::SolveBlock>& solve_block,
20 const std::vector<std::shared_ptr<ast::Ast>>& deriv_blocks) {
22 std::shared_ptr<ast::DerivativeBlock> ss_block;
25 const auto solve_block_name = solve_block->get_block_name()->get_value()->eval();
26 const auto steadystate_method = solve_block->get_steadystate()->get_value()->eval();
28 logger->debug(
"SteadystateVisitor :: Found STEADYSTATE SOLVE statement: using {} for {}",
33 for (
const auto& block_ptr: deriv_blocks) {
34 auto deriv_block = std::dynamic_pointer_cast<ast::DerivativeBlock>(block_ptr);
35 if (deriv_block->get_node_name() == solve_block_name) {
36 logger->debug(
"SteadystateVisitor :: -> found corresponding DERIVATIVE block: {}",
38 deriv_block_ptr = deriv_block.get();
43 if (deriv_block_ptr !=
nullptr) {
45 ss_block = std::shared_ptr<ast::DerivativeBlock>(deriv_block_ptr->
clone());
46 auto ss_name = ss_block->get_name();
47 ss_name->set_name(ss_name->get_value()->get_value() +
"_steadystate");
48 auto ss_name_clone = std::shared_ptr<ast::Name>(ss_name->clone());
50 logger->debug(
"SteadystateVisitor :: -> adding new DERIVATIVE block: {}",
51 ss_block->get_node_name());
59 logger->warn(
"SteadystateVisitor :: solve method {} not supported for STEADYSTATE",
64 auto statement_block = ss_block->get_statement_block();
65 auto statements = statement_block->get_statements();
68 auto update_dt_statement = std::make_shared<ast::UpdateDt>(
new ast::Double(new_dt));
69 statements.insert(statements.begin(), update_dt_statement);
72 statement_block->set_statements(
std::move(statements));
76 solve_block->set_block_name(
std::move(ss_name_clone));
78 solve_block->set_method(solve_block->get_steadystate());
79 solve_block->set_steadystate(
nullptr);
81 logger->warn(
"SteadystateVisitor :: Could not find derivative block {} for STEADYSTATE",
96 for (
const auto& solve_block_ptr: solve_block_nodes) {
97 if (
auto solve_block = std::dynamic_pointer_cast<ast::SolveBlock>(solve_block_ptr)) {
98 if (solve_block->get_steadystate()) {
100 if (ss_block !=
nullptr) {
101 node.emplace_back_node(ss_block);
Auto generated AST classes declaration.
Represents DERIVATIVE block in the NMODL.
DerivativeBlock * clone() const override
Return a copy of the current node.
Represents a double variable.
Represents top level AST node for whole NMODL input.
const double STEADYSTATE_SPARSE_DT
const double STEADYSTATE_DERIVIMPLICIT_DT
void visit_program(ast::Program &node) override
visit node of type ast::Program
std::shared_ptr< ast::DerivativeBlock > create_steadystate_block(const std::shared_ptr< ast::SolveBlock > &solve_block, const std::vector< std::shared_ptr< ast::Ast >> &deriv_blocks)
create new steady state derivative block for given solve block
@ DERIVATIVE_BLOCK
type of ast::DerivativeBlock
@ SOLVE_BLOCK
type of ast::SolveBlock
void move(Item *q1, Item *q2, Item *q3)
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
static constexpr char SPARSE_METHOD[]
sparse method in nmodl
encapsulates code generation backend implementations
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
static Node * node(Object *)
Visitor for STEADYSTATE solve statements
Utility functions for visitors implementation.