NEURON
constant_folder_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
11 #include "utils/logger.hpp"
13 
14 
15 namespace nmodl {
16 namespace visitor {
17 
18 /// check if given expression is a number
19 /// note that the DEFINE node is already expanded to integer
20 static inline bool is_number(const std::shared_ptr<ast::Expression>& node) {
21  return node->is_integer() || node->is_double() || node->is_float();
22 }
23 
24 /// get value of a number node
25 /// TODO : eval method can be added to virtual base class
26 static double get_value(const std::shared_ptr<ast::Expression>& node) {
27  if (node->is_integer()) {
28  return std::dynamic_pointer_cast<ast::Integer>(node)->eval();
29  } else if (node->is_float()) {
30  return std::dynamic_pointer_cast<ast::Float>(node)->to_double();
31  } else if (node->is_double()) {
32  return std::dynamic_pointer_cast<ast::Double>(node)->to_double();
33  }
34  throw std::runtime_error("Invalid type passed to is_number()");
35 }
36 
37 /// operators that currently implemented
38 static inline bool supported_operator(ast::BinaryOp op) {
39  return op == ast::BOP_ADDITION || op == ast::BOP_SUBTRACTION || op == ast::BOP_MULTIPLICATION ||
40  op == ast::BOP_DIVISION;
41 }
42 
43 /// Evaluate binary operation
44 /// TODO : add support for other binary operators like ^ (pow)
45 static double compute(double lhs, ast::BinaryOp op, double rhs) {
46  switch (op) {
47  case ast::BOP_ADDITION:
48  return lhs + rhs;
49 
51  return lhs - rhs;
52 
54  return lhs * rhs;
55 
56  case ast::BOP_DIVISION:
57  return lhs / rhs;
58 
59  default:
60  throw std::logic_error("Invalid binary operator in constant folding");
61  }
62 }
63 
64 /**
65  * Visit parenthesis expression and simplify it
66  * @param node AST node representing an expression with parenthesis
67  *
68  * AST could have expression like (1+2). In this case, it has following
69  * form in the AST :
70  *
71  * parenthesis_exp => wrapped_expr => binary_expression => ...
72  *
73  * To make constant folding simple, we can remove intermediate wrapped_expr
74  * and directly replace binary_expression inside parenthesis_exp :
75  *
76  * parenthesis_exp => binary_expression => ...
77  */
79  node.visit_children(*this);
80  auto expr = node.get_expression();
81  if (expr->is_wrapped_expression()) {
82  auto e = std::dynamic_pointer_cast<ast::WrappedExpression>(expr);
83  node.set_expression(e->get_expression());
84  }
85 }
86 
87 /**
88  * Visit wrapped node type and perform constant folding
89  * @param node AST node that wrap other node types
90  *
91  * MOD file has expressions like
92  *
93  * a = 1 + 2
94  * DEFINE NN 10
95  * FROM i=0 TO NN-2 {
96  *
97  * }
98  *
99  * which need to be turned into
100  *
101  * a = 1 + 2
102  * DEFINE NN 10
103  * FROM i=0 TO 8 {
104  *
105  * }
106  */
108  node.visit_children(*this);
109  node.visit_children(*this);
110 
111  /// first expression which is wrapped
112  auto expr = node.get_expression();
113 
114  /// if wrapped expression is parentheses
115  bool is_parentheses = false;
116 
117  /// opposite to visit_paren_expression, we might have
118  /// a = (2+1)
119  /// in this case we can pick inner expression.
120  if (expr->is_paren_expression()) {
121  auto e = std::dynamic_pointer_cast<ast::ParenExpression>(expr);
122  expr = e->get_expression();
123  is_parentheses = true;
124  }
125 
126  /// we want to simplify binary expressions only
127  if (!expr->is_binary_expression()) {
128  /// wrapped expression might be parenthesis expression like (2)
129  /// which we can simplify to 2 to help next evaluations
130  if (is_parentheses) {
131  node.set_expression(std::move(expr));
132  }
133  return;
134  }
135 
136  auto binary_expr = std::dynamic_pointer_cast<ast::BinaryExpression>(expr);
137  auto lhs = binary_expr->get_lhs();
138  auto rhs = binary_expr->get_rhs();
139  auto op = binary_expr->get_op().get_value();
140 
141  /// in case of expression like
142  /// a = 2 + ((1) + (3))
143  /// we are in the innermost expression i.e. ((1) + (3))
144  /// where (1) and (3) are wrapped expression themself. we can
145  /// remove these extra wrapped expressions
146 
147  if (lhs->is_wrapped_expression()) {
148  auto e = std::dynamic_pointer_cast<ast::WrappedExpression>(lhs);
149  lhs = e->get_expression();
150  }
151 
152  if (rhs->is_wrapped_expression()) {
153  auto e = std::dynamic_pointer_cast<ast::WrappedExpression>(rhs);
154  rhs = e->get_expression();
155  }
156 
157  /// once we simplify, lhs and rhs must be numbers for constant folding
158  if (!is_number(lhs) || !is_number(rhs) || !supported_operator(op)) {
159  return;
160  }
161 
162  const std::string& nmodl_before = to_nmodl(binary_expr);
163 
164  /// compute the value of expression
165  auto value = compute(get_value(lhs), op, get_value(rhs));
166 
167  /// if both operands are not integers or floats, result is double
168  if (lhs->is_integer() && rhs->is_integer()) {
169  node.set_expression(std::make_shared<ast::Integer>(static_cast<int>(value), nullptr));
170  } else if (lhs->is_double() || rhs->is_double()) {
171  node.set_expression(std::make_shared<ast::Double>(stringutils::to_string(value)));
172  } else {
173  node.set_expression(std::make_shared<ast::Float>(stringutils::to_string(value)));
174  }
175 
176  const std::string& nmodl_after = to_nmodl(node.get_expression());
177  logger->debug("ConstantFolderVisitor : expression {} folded to {}", nmodl_before, nmodl_after);
178 }
179 
180 } // namespace visitor
181 } // namespace nmodl
Auto generated AST classes declaration.
Wrap any other expression type.
void visit_wrapped_expression(ast::WrappedExpression &node) override
Visit wrapped node type and perform constant folding.
void visit_paren_expression(ast::ParenExpression &node) override
Visit parenthesis expression and simplify it.
Perform constant folding of integer/float/double expressions.
virtual bool is_integer() const noexcept
Check if the ast node is an instance of ast::Integer.
Definition: ast.cpp:76
virtual bool is_float() const noexcept
Check if the ast node is an instance of ast::Float.
Definition: ast.cpp:78
virtual bool is_double() const noexcept
Check if the ast node is an instance of ast::Double.
Definition: ast.cpp:80
BinaryOp
enum Type for binary operators in NMODL
Definition: ast_common.hpp:47
@ BOP_DIVISION
\/
Definition: ast_common.hpp:51
@ BOP_SUBTRACTION
Definition: ast_common.hpp:49
@ BOP_MULTIPLICATION
*
Definition: ast_common.hpp:50
std::string to_string(double value, const std::string &format_spec)
Convert double value to string without trailing zeros.
#define rhs
Definition: lineq.h:6
void move(Item *q1, Item *q2, Item *q3)
Definition: list.cpp:200
static bool is_number(const std::shared_ptr< ast::Expression > &node)
check if given expression is a number note that the DEFINE node is already expanded to integer
static double compute(double lhs, ast::BinaryOp op, double rhs)
Evaluate binary operation TODO : add support for other binary operators like ^ (pow)
static bool supported_operator(ast::BinaryOp op)
operators that currently implemented
static double get_value(const std::shared_ptr< ast::Expression > &node)
get value of a number node TODO : eval method can be added to virtual base class
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
logger_type logger
Definition: logger.cpp:34
static Node * node(Object *)
Definition: netcvode.cpp:291
static uint32_t value
Definition: scoprand.cpp:25
Utility functions for visitors implementation.