Skip to content

Commit

Permalink
Add support for shift operator chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer committed Jan 19, 2025
1 parent e3c4651 commit 0d130fc
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/Spice.g4
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ bitwiseXorExpr: bitwiseAndExpr (BITWISE_XOR bitwiseAndExpr)*;
bitwiseAndExpr: equalityExpr (BITWISE_AND equalityExpr)*;
equalityExpr: relationalExpr ((EQUAL | NOT_EQUAL) relationalExpr)?;
relationalExpr: shiftExpr ((LESS | GREATER | LESS_EQUAL | GREATER_EQUAL) shiftExpr)?;
shiftExpr: additiveExpr ((LESS LESS | GREATER GREATER) additiveExpr)?;
shiftExpr: additiveExpr ((LESS LESS | GREATER GREATER) additiveExpr)*;
additiveExpr: multiplicativeExpr ((PLUS | MINUS) multiplicativeExpr)*;
multiplicativeExpr: castExpr ((MUL | DIV | REM) castExpr)*;
castExpr: (LPAREN dataType RPAREN)? prefixUnaryExpr;
Expand Down
29 changes: 24 additions & 5 deletions src/ast/ASTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1076,11 +1076,30 @@ std::any ASTBuilder::visitShiftExpr(SpiceParser::ShiftExprContext *ctx) {
// Visit children
fetchChildrenIntoVector(shiftExprNode->operands, ctx->additiveExpr());

// Extract operator
if (!ctx->LESS().empty())
shiftExprNode->op = ShiftExprNode::OP_SHIFT_LEFT;
else if (!ctx->GREATER().empty())
shiftExprNode->op = ShiftExprNode::OP_SHIFT_RIGHT;
bool seenFirstLess = false;
bool seenFirstGreater = false;
for (ParserRuleContext::ParseTree *subTree : ctx->children) {
const auto terminal = dynamic_cast<TerminalNode *>(subTree);
if (!terminal)
continue;

if (terminal->getSymbol()->getType() == SpiceParser::LESS) {
if (seenFirstLess)
shiftExprNode->opQueue.emplace(ShiftExprNode::ShiftOp::OP_SHIFT_LEFT, TY_INVALID);
seenFirstLess = !seenFirstLess;
continue;
}

if (terminal->getSymbol()->getType() == SpiceParser::GREATER) {
if (seenFirstGreater)
shiftExprNode->opQueue.emplace(ShiftExprNode::ShiftOp::OP_SHIFT_RIGHT, TY_INVALID);
seenFirstGreater = !seenFirstGreater;
continue;
}

assert_fail("Invalid terminal symbol for additive expression"); // GCOV_EXCL_LINE
}
assert(!seenFirstLess && !seenFirstGreater);

return concludeNode(shiftExprNode);
}
Expand Down
23 changes: 15 additions & 8 deletions src/ast/ASTNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,21 @@ CompileTimeValue ShiftExprNode::getCompileTimeValue() const {
if (operands.size() == 1)
return operands.front()->getCompileTimeValue();

const CompileTimeValue op0Value = operands.at(0)->getCompileTimeValue();
const CompileTimeValue op1Value = operands.at(1)->getCompileTimeValue();
if (op == OP_SHIFT_LEFT)
return CompileTimeValue{.longValue = op0Value.longValue << op1Value.longValue};
if (op == OP_SHIFT_RIGHT)
return CompileTimeValue{.longValue = op0Value.longValue >> op1Value.longValue};

throw CompilerError(UNHANDLED_BRANCH, "ShiftExprNode::getCompileTimeValue()");
CompileTimeValue result = operands.front()->getCompileTimeValue();
OpQueue opQueueCopy = opQueue;
for (size_t i = 1; i < operands.size(); i++) {
assert(operands.at(i)->hasCompileTimeValue());
const CompileTimeValue opCompileTimeValue = operands.at(i)->getCompileTimeValue();
const ShiftOp op = opQueueCopy.front().first;
opQueueCopy.pop();
if (op == ShiftOp::OP_SHIFT_LEFT)
result.longValue <<= opCompileTimeValue.longValue;
else if (op == ShiftOp::OP_SHIFT_RIGHT)
result.longValue >>= opCompileTimeValue.longValue;
else
throw CompilerError(UNHANDLED_BRANCH, "ShiftExprNode::getCompileTimeValue()");
}
return result;
}

bool AdditiveExprNode::hasCompileTimeValue() const {
Expand Down
7 changes: 5 additions & 2 deletions src/ast/ASTNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1738,12 +1738,15 @@ class RelationalExprNode final : public ExprNode {
class ShiftExprNode final : public ExprNode {
public:
// Enums
enum ShiftOp : uint8_t {
enum class ShiftOp : uint8_t {
OP_NONE,
OP_SHIFT_LEFT,
OP_SHIFT_RIGHT,
};

// Typedefs
using OpQueue = std::queue<std::pair<ShiftOp, QualType>>;

// Constructors
using ExprNode::ExprNode;

Expand All @@ -1761,7 +1764,7 @@ class ShiftExprNode final : public ExprNode {

// Public members
std::vector<AdditiveExprNode *> operands;
ShiftOp op = OP_NONE;
OpQueue opQueue;
std::vector<std::vector<const Function *>> opFct; // Operator overloading functions
};

Expand Down
49 changes: 30 additions & 19 deletions src/irgenerator/GenExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,30 +453,41 @@ std::any IRGenerator::visitShiftExpr(const ShiftExprNode *node) {
return visit(node->operands.front());

// It is a shift expression
// Evaluate lhs
const AdditiveExprNode *lhsNode = node->operands[0];
const QualType lhsSTy = lhsNode->getEvaluatedSymbolType(manIdx);
auto result = std::any_cast<LLVMExprResult>(visit(lhsNode));
// Evaluate first operand
const AdditiveExprNode *lhsNode = node->operands.front();
QualType lhsSTy = lhsNode->getEvaluatedSymbolType(manIdx);
auto lhs = std::any_cast<LLVMExprResult>(visit(lhsNode));

// Evaluate rhs
const AdditiveExprNode *rhsNode = node->operands[1];
const QualType rhsSTy = rhsNode->getEvaluatedSymbolType(manIdx);
auto rhs = std::any_cast<LLVMExprResult>(visit(rhsNode));
auto opQueue = node->opQueue;
size_t operandIndex = 1;
while (!opQueue.empty()) {
const size_t operatorIndex = operandIndex - 1;
// Evaluate next operand
const AdditiveExprNode *rhsNode = node->operands[operandIndex++];
assert(rhsNode != nullptr);
const QualType rhsSTy = rhsNode->getEvaluatedSymbolType(manIdx);
auto rhs = std::any_cast<LLVMExprResult>(visit(rhsNode));

// Retrieve the result value, based on the exact operator
switch (node->op) {
case ShiftExprNode::OP_SHIFT_LEFT:
result = conversionManager.getShiftLeftInst(node, result, lhsSTy, rhs, rhsSTy, 0);
break;
case ShiftExprNode::OP_SHIFT_RIGHT:
result = conversionManager.getShiftRightInst(node, result, lhsSTy, rhs, rhsSTy, 0);
break;
default: // GCOV_EXCL_LINE
throw CompilerError(UNHANDLED_BRANCH, "ShiftExpr fall-through"); // GCOV_EXCL_LINE
// Retrieve the result, based on the exact operator
switch (opQueue.front().first) {
case ShiftExprNode::ShiftOp::OP_SHIFT_LEFT:
lhs = conversionManager.getShiftLeftInst(node, lhs, lhsSTy, rhs, rhsSTy, operatorIndex);
break;
case ShiftExprNode::ShiftOp::OP_SHIFT_RIGHT:
lhs = conversionManager.getShiftRightInst(node, lhs, lhsSTy, rhs, rhsSTy, operatorIndex);
break;
default: // GCOV_EXCL_LINE
throw CompilerError(UNHANDLED_BRANCH, "AdditiveExpr fall-through"); // GCOV_EXCL_LINE
}

// Retrieve the new lhs symbol type
lhsSTy = opQueue.front().second;

opQueue.pop();
}

// Return the result
return result;
return lhs;
}

std::any IRGenerator::visitAdditiveExpr(const AdditiveExprNode *node) {
Expand Down
55 changes: 32 additions & 23 deletions src/typechecker/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ std::any TypeChecker::visitLogicalOrExpr(LogicalOrExprNode *node) {
// Visit leftmost operand
auto currentOperand = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type)

// Loop through all remaining operands
for (size_t i = 1; i < node->operands.size(); i++) {
auto rhsOperand = std::any_cast<ExprResult>(visit(node->operands[i]));
Expand All @@ -1045,6 +1046,7 @@ std::any TypeChecker::visitLogicalAndExpr(LogicalAndExprNode *node) {
// Visit leftmost operand
auto currentOperand = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type)

// Loop through all remaining operands
for (size_t i = 1; i < node->operands.size(); i++) {
auto rhsOperand = std::any_cast<ExprResult>(visit(node->operands[i]));
Expand All @@ -1064,6 +1066,7 @@ std::any TypeChecker::visitBitwiseOrExpr(BitwiseOrExprNode *node) {
// Visit leftmost operand
auto currentOperand = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type)

// Loop through all remaining operands
for (size_t i = 1; i < node->operands.size(); i++) {
auto rhsOperand = std::any_cast<ExprResult>(visit(node->operands[i]));
Expand All @@ -1083,6 +1086,7 @@ std::any TypeChecker::visitBitwiseXorExpr(BitwiseXorExprNode *node) {
// Visit leftmost operand
auto currentOperand = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type)

// Loop through all remaining operands
for (size_t i = 1; i < node->operands.size(); i++) {
auto rhsOperand = std::any_cast<ExprResult>(visit(node->operands[i]));
Expand All @@ -1102,6 +1106,7 @@ std::any TypeChecker::visitBitwiseAndExpr(BitwiseAndExprNode *node) {
// Visit leftmost operand
auto currentOperand = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type)

// Loop through all remaining operands
for (size_t i = 1; i < node->operands.size(); i++) {
auto rhsOperand = std::any_cast<ExprResult>(visit(node->operands[i]));
Expand All @@ -1119,9 +1124,9 @@ std::any TypeChecker::visitEqualityExpr(EqualityExprNode *node) {
return visit(node->operands.front());

// Visit right side first, then left side
auto rhs = std::any_cast<ExprResult>(visit(node->operands[1]));
const auto rhs = std::any_cast<ExprResult>(visit(node->operands[1]));
HANDLE_UNRESOLVED_TYPE_ER(rhs.type)
auto lhs = std::any_cast<ExprResult>(visit(node->operands[0]));
const auto lhs = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(lhs.type)

// Check if we need the string runtime to perform a string comparison
Expand All @@ -1147,9 +1152,9 @@ std::any TypeChecker::visitRelationalExpr(RelationalExprNode *node) {
return visit(node->operands.front());

// Visit right side first, then left side
auto rhs = std::any_cast<ExprResult>(visit(node->operands[1]));
const auto rhs = std::any_cast<ExprResult>(visit(node->operands[1]));
HANDLE_UNRESOLVED_TYPE_ER(rhs.type)
auto lhs = std::any_cast<ExprResult>(visit(node->operands[0]));
const auto lhs = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(lhs.type)

// Check operator
Expand All @@ -1173,20 +1178,28 @@ std::any TypeChecker::visitShiftExpr(ShiftExprNode *node) {
if (node->operands.size() == 1)
return visit(node->operands.front());

// Visit right side first, then left
auto rhs = std::any_cast<ExprResult>(visit(node->operands[1]));
HANDLE_UNRESOLVED_TYPE_ER(rhs.type)
auto lhs = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(lhs.type)
// Visit leftmost operand
auto currentResult = std::any_cast<ExprResult>(visit(node->operands[0]));
HANDLE_UNRESOLVED_TYPE_ER(currentResult.type)

// Check operator
ExprResult currentResult;
if (node->op == ShiftExprNode::OP_SHIFT_LEFT) // Operator was shl
currentResult = opRuleManager.getShiftLeftResultType(node, lhs, rhs, 0);
else if (node->op == ShiftExprNode::OP_SHIFT_RIGHT) // Operator was shr
currentResult = opRuleManager.getShiftRightResultType(node, lhs, rhs, 0);
else
throw CompilerError(UNHANDLED_BRANCH, "ShiftExpr fall-through"); // GCOV_EXCL_LINE
// Loop through remaining operands
for (size_t i = 0; i < node->opQueue.size(); i++) {
auto operandResult = std::any_cast<ExprResult>(visit(node->operands[i + 1]));
HANDLE_UNRESOLVED_TYPE_ER(operandResult.type)

// Check operator
const ShiftExprNode::ShiftOp &op = node->opQueue.front().first;
if (op == ShiftExprNode::ShiftOp::OP_SHIFT_LEFT)
currentResult = opRuleManager.getShiftLeftResultType(node, currentResult, operandResult, i);
else if (op == ShiftExprNode::ShiftOp::OP_SHIFT_RIGHT)
currentResult = opRuleManager.getShiftRightResultType(node, currentResult, operandResult, i);
else
throw CompilerError(UNHANDLED_BRANCH, "ShiftExpr fall-through"); // GCOV_EXCL_LINE

// Push the new item and pop the old one on the other side of the queue
node->opQueue.emplace(op, currentResult.type);
node->opQueue.pop();
}

node->setEvaluatedSymbolType(currentResult.type, manIdx);
return currentResult;
Expand All @@ -1203,9 +1216,7 @@ std::any TypeChecker::visitAdditiveExpr(AdditiveExprNode *node) {

// Loop through remaining operands
for (size_t i = 0; i < node->opQueue.size(); i++) {
// Visit next operand
MultiplicativeExprNode *operand = node->operands[i + 1];
auto operandResult = std::any_cast<ExprResult>(visit(operand));
auto operandResult = std::any_cast<ExprResult>(visit(node->operands[i + 1]));
HANDLE_UNRESOLVED_TYPE_ER(operandResult.type)

// Check operator
Expand Down Expand Up @@ -1236,9 +1247,7 @@ std::any TypeChecker::visitMultiplicativeExpr(MultiplicativeExprNode *node) {
HANDLE_UNRESOLVED_TYPE_ER(currentResult.type)
// Loop through remaining operands
for (size_t i = 0; i < node->opQueue.size(); i++) {
// Visit next operand
CastExprNode *operand = node->operands[i + 1];
auto operandResult = std::any_cast<ExprResult>(visit(operand));
auto operandResult = std::any_cast<ExprResult>(visit(node->operands[i + 1]));
HANDLE_UNRESOLVED_TYPE_ER(operandResult.type)

// Check operator
Expand Down
2 changes: 1 addition & 1 deletion test/test-files/std/test/lifetime-object/source.spice
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ f<LifetimeObject> decideOnLORef(bool cond, LifetimeObject& lo1, LifetimeObject&
f<LifetimeObject> decideOnLOVal(bool cond, LifetimeObject lo1, LifetimeObject lo2) {
cond ? lo1 : lo2; // Shoud do a copy, although the result is ignored
LifetimeObject loCopy = cond ? lo1 : lo2; // Shoud do a copy
const LifetimeObject& loRef = cond ? lo1 : lo2; // Should not do a copy
const LifetimeObject& loRef = cond ? lo1 : lo2; // Should do a copy
return cond ? lo1 : lo2; // Return statement should do a copy
}

Expand Down

0 comments on commit 0d130fc

Please sign in to comment.