13 #include "utils/logger.hpp"
30 if (!conserve_equations.empty()) {
31 std::unordered_set<ast::Statement*> eqs;
32 for (
const auto& item: conserve_equations) {
33 eqs.insert(std::dynamic_pointer_cast<ast::Statement>(item).
get());
35 node.erase_statement(eqs);
44 std::regex unit_pattern(R
"((\d+\.?\d*|\.\d+)\s*\([a-zA-Z]+\))");
46 auto rhs_string_no_units = fmt::format(
"{} = {}",
48 std::regex_replace(rhs_string, unit_pattern,
"$1"));
50 logger->debug(
"CvodeVisitor :: result: {}", rhs_string_no_units);
51 auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
53 const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
54 expr_statement->get_expression());
55 node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
59 std::shared_ptr<ast::Identifier>
node) {
62 variable.second = std::optional<int>(
63 get_index(*std::dynamic_pointer_cast<const ast::IndexedName>(
node)));
70 const std::string& ignored_name) {
71 std::unordered_set<std::string> indexed_variables;
76 for (
const auto&
var: indexed_vars) {
77 const auto& varname =
var->get_node_name();
79 auto varname_not_reserved =
80 std::none_of(reserved_symbols.begin(),
81 reserved_symbols.end(),
82 [&varname](
const auto item) { return varname == item; });
83 if (indexed_variables.count(varname) == 0 && varname != ignored_name &&
84 varname_not_reserved) {
85 indexed_variables.insert(varname);
88 return indexed_variables;
92 const auto& lhs =
node.get_lhs();
94 auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
97 if (
name->is_prime_name()) {
98 varname =
"D" +
name->get_node_name();
100 }
else if (
name->is_indexed_name()) {
103 if (!nodes.empty()) {
105 auto statement = fmt::format(
"{} = {}", varname, varname);
106 auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
108 const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
109 expr_statement->get_expression());
110 node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
125 node.visit_children(*
this);
137 const auto& lhs =
node.get_lhs();
143 auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
147 auto symbol = std::make_shared<symtab::Symbol>(varname,
ModToken());
148 symbol->set_original_name(
name->get_node_name());
161 const auto& lhs =
node.get_lhs();
167 auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
171 auto symbol = std::make_shared<symtab::Symbol>(varname,
ModToken());
172 symbol->set_original_name(
name->get_node_name());
183 auto [
jacobian, exception_message] =
185 if (!exception_message.empty()) {
186 logger->warn(
"CvodeVisitor :: python exception: {}", exception_message);
190 auto statement = fmt::format(
"{} = {} / (1 - dt * ({}))", varname, varname,
jacobian);
191 logger->debug(
"CvodeVisitor :: replacing statement {} with {}",
to_nmodl(
node), statement);
192 auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
194 const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
195 expr_statement->get_expression());
196 node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
202 if (derivative_blocks.empty()) {
207 auto not_steadystate = [](
const auto& item) {
208 auto name = std::dynamic_pointer_cast<const ast::DerivativeBlock>(item)->get_node_name();
211 decltype(derivative_blocks) derivative_blocks_copy;
212 std::copy_if(derivative_blocks.begin(),
213 derivative_blocks.end(),
214 std::back_inserter(derivative_blocks_copy),
216 if (derivative_blocks_copy.size() > 1) {
217 auto message =
"CvodeVisitor :: cannot have multiple DERIVATIVE blocks";
219 throw std::runtime_error(message);
222 return std::dynamic_pointer_cast<ast::DerivativeBlock>(derivative_blocks_copy[0]);
228 if (derivative_block ==
nullptr) {
232 auto non_stiff_block = derivative_block->get_statement_block()->clone();
235 auto stiff_block = derivative_block->get_statement_block()->clone();
242 derivative_block->get_name(),
243 std::shared_ptr<ast::Integer>(
new ast::Integer(prime_vars.size(),
nullptr)),
244 std::shared_ptr<ast::StatementBlock>(non_stiff_block),
245 std::shared_ptr<ast::StatementBlock>(stiff_block)));
Auto generated AST classes declaration.
Represent token returned by scanner.
Represents binary expression in the NMODL.
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
Represents differential equation in DERIVATIVE block.
Base class for all expressions in the NMODL.
Represents specific element of an array variable.
Represents an integer variable.
Represents top level AST node for whole NMODL input.
Represents block encapsulating list of statements.
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
const pybind_wrap_api & api()
Get a pointer to the pybind_wrap_api struct.
Represent symbol table for a NMODL block.
void insert(const std::shared_ptr< Symbol > &symbol)
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
Concrete visitor for all AST classes.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
bool in_differential_equation
void visit_diff_eq_expression(ast::DiffEqExpression &node)
visit node of type ast::DiffEqExpression
symtab::SymbolTable * program_symtab
void visit_program(ast::Program &node) override
visit node of type ast::Program
void visit_binary_expression(ast::BinaryExpression &node)
visit node of type ast::BinaryExpression
NonStiffVisitor(symtab::SymbolTable *symtab)
void visit_binary_expression(ast::BinaryExpression &node)
visit node of type ast::BinaryExpression
StiffVisitor(symtab::SymbolTable *symtab)
Visitor used for generating the necessary AST nodes for CVODE.
static double jacobian(void *v)
virtual bool is_indexed_name() const noexcept
Check if the ast node is an instance of ast::IndexedName.
virtual std::string get_node_name() const
Return name of of the node.
@ DERIVATIVE_BLOCK
type of ast::DerivativeBlock
@ INDEXED_NAME
type of ast::IndexedName
@ CONSERVE
type of ast::Conserve
@ PRIME_NAME
type of ast::PrimeName
static bool ends_with(const std::string &haystack, const std::string &needle)
Check if haystack ends with needle.
static std::string remove_character(std::string text, const char c)
Remove all occurrences of a given character in a text.
double var(InputIterator begin, InputIterator end)
static std::shared_ptr< ast::DerivativeBlock > get_derivative_block(ast::Program &node)
static std::pair< std::string, std::optional< int > > parse_independent_var(std::shared_ptr< ast::Identifier > node)
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
static void remove_units(ast::BinaryExpression &node)
static void remove_conserve_statements(ast::StatementBlock &node)
static std::string cvode_set_lhs(ast::BinaryExpression &node)
static int get_index(const ast::IndexedName &node)
static std::unordered_set< std::string > get_indexed_variables(const ast::Expression &node, const std::string &ignored_name)
set of all indexed variables not equal to ignored_name
encapsulates code generation backend implementations
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
std::vector< std::string > get_external_functions()
Return functions that can be used in the NMODL.
static Node * node(Object *)
decltype(&call_diff2c) diff2c
Map different tokens from lexer to token types.
Utility functions for visitors implementation.