diff --git a/src/Spice.g4 b/src/Spice.g4 index 14dc2b8ef..3616aa86c 100644 --- a/src/Spice.g4 +++ b/src/Spice.g4 @@ -78,7 +78,7 @@ bitwiseXorExpr: bitwiseAndExpr (BITWISE_XOR bitwiseAndExpr)*; bitwiseAndExpr: equalityExpr (BITWISE_AND equalityExpr)*; equalityExpr: relationalExpr ((EQUAL | NOT_EQUAL) relationalExpr)?; relationalExpr: shiftExpr ((LESS | GREATER | LESS_EQUAL | GREATER_EQUAL) shiftExpr)?; -shiftExpr: additiveExpr ((LESS LESS | GREATER GREATER) additiveExpr)?; +shiftExpr: additiveExpr ((LESS LESS | GREATER GREATER) additiveExpr)*; additiveExpr: multiplicativeExpr ((PLUS | MINUS) multiplicativeExpr)*; multiplicativeExpr: castExpr ((MUL | DIV | REM) castExpr)*; castExpr: (LPAREN dataType RPAREN)? prefixUnaryExpr; diff --git a/src/ast/ASTBuilder.cpp b/src/ast/ASTBuilder.cpp index 04ecd998a..60f9aa5fd 100644 --- a/src/ast/ASTBuilder.cpp +++ b/src/ast/ASTBuilder.cpp @@ -1076,11 +1076,30 @@ std::any ASTBuilder::visitShiftExpr(SpiceParser::ShiftExprContext *ctx) { // Visit children fetchChildrenIntoVector(shiftExprNode->operands, ctx->additiveExpr()); - // Extract operator - if (!ctx->LESS().empty()) - shiftExprNode->op = ShiftExprNode::OP_SHIFT_LEFT; - else if (!ctx->GREATER().empty()) - shiftExprNode->op = ShiftExprNode::OP_SHIFT_RIGHT; + bool seenFirstLess = false; + bool seenFirstGreater = false; + for (ParserRuleContext::ParseTree *subTree : ctx->children) { + const auto terminal = dynamic_cast(subTree); + if (!terminal) + continue; + + if (terminal->getSymbol()->getType() == SpiceParser::LESS) { + if (seenFirstLess) + shiftExprNode->opQueue.emplace(ShiftExprNode::ShiftOp::OP_SHIFT_LEFT, TY_INVALID); + seenFirstLess = !seenFirstLess; + continue; + } + + if (terminal->getSymbol()->getType() == SpiceParser::GREATER) { + if (seenFirstGreater) + shiftExprNode->opQueue.emplace(ShiftExprNode::ShiftOp::OP_SHIFT_RIGHT, TY_INVALID); + seenFirstGreater = !seenFirstGreater; + continue; + } + + assert_fail("Invalid terminal symbol for additive expression"); // GCOV_EXCL_LINE + } + assert(!seenFirstLess && !seenFirstGreater); return concludeNode(shiftExprNode); } diff --git a/src/ast/ASTNodes.cpp b/src/ast/ASTNodes.cpp index 95c8a3e71..783db2fe3 100644 --- a/src/ast/ASTNodes.cpp +++ b/src/ast/ASTNodes.cpp @@ -368,14 +368,21 @@ CompileTimeValue ShiftExprNode::getCompileTimeValue() const { if (operands.size() == 1) return operands.front()->getCompileTimeValue(); - const CompileTimeValue op0Value = operands.at(0)->getCompileTimeValue(); - const CompileTimeValue op1Value = operands.at(1)->getCompileTimeValue(); - if (op == OP_SHIFT_LEFT) - return CompileTimeValue{.longValue = op0Value.longValue << op1Value.longValue}; - if (op == OP_SHIFT_RIGHT) - return CompileTimeValue{.longValue = op0Value.longValue >> op1Value.longValue}; - - throw CompilerError(UNHANDLED_BRANCH, "ShiftExprNode::getCompileTimeValue()"); + CompileTimeValue result = operands.front()->getCompileTimeValue(); + OpQueue opQueueCopy = opQueue; + for (size_t i = 1; i < operands.size(); i++) { + assert(operands.at(i)->hasCompileTimeValue()); + const CompileTimeValue opCompileTimeValue = operands.at(i)->getCompileTimeValue(); + const ShiftOp op = opQueueCopy.front().first; + opQueueCopy.pop(); + if (op == ShiftOp::OP_SHIFT_LEFT) + result.longValue <<= opCompileTimeValue.longValue; + else if (op == ShiftOp::OP_SHIFT_RIGHT) + result.longValue >>= opCompileTimeValue.longValue; + else + throw CompilerError(UNHANDLED_BRANCH, "ShiftExprNode::getCompileTimeValue()"); + } + return result; } bool AdditiveExprNode::hasCompileTimeValue() const { diff --git a/src/ast/ASTNodes.h b/src/ast/ASTNodes.h index 381c6bf4e..22cbc2b4f 100644 --- a/src/ast/ASTNodes.h +++ b/src/ast/ASTNodes.h @@ -1738,12 +1738,15 @@ class RelationalExprNode final : public ExprNode { class ShiftExprNode final : public ExprNode { public: // Enums - enum ShiftOp : uint8_t { + enum class ShiftOp : uint8_t { OP_NONE, OP_SHIFT_LEFT, OP_SHIFT_RIGHT, }; + // Typedefs + using OpQueue = std::queue>; + // Constructors using ExprNode::ExprNode; @@ -1761,7 +1764,7 @@ class ShiftExprNode final : public ExprNode { // Public members std::vector operands; - ShiftOp op = OP_NONE; + OpQueue opQueue; std::vector> opFct; // Operator overloading functions }; diff --git a/src/irgenerator/GenExpressions.cpp b/src/irgenerator/GenExpressions.cpp index b4a751269..cb2ff6025 100644 --- a/src/irgenerator/GenExpressions.cpp +++ b/src/irgenerator/GenExpressions.cpp @@ -453,30 +453,41 @@ std::any IRGenerator::visitShiftExpr(const ShiftExprNode *node) { return visit(node->operands.front()); // It is a shift expression - // Evaluate lhs - const AdditiveExprNode *lhsNode = node->operands[0]; - const QualType lhsSTy = lhsNode->getEvaluatedSymbolType(manIdx); - auto result = std::any_cast(visit(lhsNode)); + // Evaluate first operand + const AdditiveExprNode *lhsNode = node->operands.front(); + QualType lhsSTy = lhsNode->getEvaluatedSymbolType(manIdx); + auto lhs = std::any_cast(visit(lhsNode)); - // Evaluate rhs - const AdditiveExprNode *rhsNode = node->operands[1]; - const QualType rhsSTy = rhsNode->getEvaluatedSymbolType(manIdx); - auto rhs = std::any_cast(visit(rhsNode)); + auto opQueue = node->opQueue; + size_t operandIndex = 1; + while (!opQueue.empty()) { + const size_t operatorIndex = operandIndex - 1; + // Evaluate next operand + const AdditiveExprNode *rhsNode = node->operands[operandIndex++]; + assert(rhsNode != nullptr); + const QualType rhsSTy = rhsNode->getEvaluatedSymbolType(manIdx); + auto rhs = std::any_cast(visit(rhsNode)); - // Retrieve the result value, based on the exact operator - switch (node->op) { - case ShiftExprNode::OP_SHIFT_LEFT: - result = conversionManager.getShiftLeftInst(node, result, lhsSTy, rhs, rhsSTy, 0); - break; - case ShiftExprNode::OP_SHIFT_RIGHT: - result = conversionManager.getShiftRightInst(node, result, lhsSTy, rhs, rhsSTy, 0); - break; - default: // GCOV_EXCL_LINE - throw CompilerError(UNHANDLED_BRANCH, "ShiftExpr fall-through"); // GCOV_EXCL_LINE + // Retrieve the result, based on the exact operator + switch (opQueue.front().first) { + case ShiftExprNode::ShiftOp::OP_SHIFT_LEFT: + lhs = conversionManager.getShiftLeftInst(node, lhs, lhsSTy, rhs, rhsSTy, operatorIndex); + break; + case ShiftExprNode::ShiftOp::OP_SHIFT_RIGHT: + lhs = conversionManager.getShiftRightInst(node, lhs, lhsSTy, rhs, rhsSTy, operatorIndex); + break; + default: // GCOV_EXCL_LINE + throw CompilerError(UNHANDLED_BRANCH, "AdditiveExpr fall-through"); // GCOV_EXCL_LINE + } + + // Retrieve the new lhs symbol type + lhsSTy = opQueue.front().second; + + opQueue.pop(); } // Return the result - return result; + return lhs; } std::any IRGenerator::visitAdditiveExpr(const AdditiveExprNode *node) { diff --git a/src/typechecker/TypeChecker.cpp b/src/typechecker/TypeChecker.cpp index c93df02db..9b549c5d3 100644 --- a/src/typechecker/TypeChecker.cpp +++ b/src/typechecker/TypeChecker.cpp @@ -1026,6 +1026,7 @@ std::any TypeChecker::visitLogicalOrExpr(LogicalOrExprNode *node) { // Visit leftmost operand auto currentOperand = std::any_cast(visit(node->operands[0])); HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type) + // Loop through all remaining operands for (size_t i = 1; i < node->operands.size(); i++) { auto rhsOperand = std::any_cast(visit(node->operands[i])); @@ -1045,6 +1046,7 @@ std::any TypeChecker::visitLogicalAndExpr(LogicalAndExprNode *node) { // Visit leftmost operand auto currentOperand = std::any_cast(visit(node->operands[0])); HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type) + // Loop through all remaining operands for (size_t i = 1; i < node->operands.size(); i++) { auto rhsOperand = std::any_cast(visit(node->operands[i])); @@ -1064,6 +1066,7 @@ std::any TypeChecker::visitBitwiseOrExpr(BitwiseOrExprNode *node) { // Visit leftmost operand auto currentOperand = std::any_cast(visit(node->operands[0])); HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type) + // Loop through all remaining operands for (size_t i = 1; i < node->operands.size(); i++) { auto rhsOperand = std::any_cast(visit(node->operands[i])); @@ -1083,6 +1086,7 @@ std::any TypeChecker::visitBitwiseXorExpr(BitwiseXorExprNode *node) { // Visit leftmost operand auto currentOperand = std::any_cast(visit(node->operands[0])); HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type) + // Loop through all remaining operands for (size_t i = 1; i < node->operands.size(); i++) { auto rhsOperand = std::any_cast(visit(node->operands[i])); @@ -1102,6 +1106,7 @@ std::any TypeChecker::visitBitwiseAndExpr(BitwiseAndExprNode *node) { // Visit leftmost operand auto currentOperand = std::any_cast(visit(node->operands[0])); HANDLE_UNRESOLVED_TYPE_ER(currentOperand.type) + // Loop through all remaining operands for (size_t i = 1; i < node->operands.size(); i++) { auto rhsOperand = std::any_cast(visit(node->operands[i])); @@ -1119,9 +1124,9 @@ std::any TypeChecker::visitEqualityExpr(EqualityExprNode *node) { return visit(node->operands.front()); // Visit right side first, then left side - auto rhs = std::any_cast(visit(node->operands[1])); + const auto rhs = std::any_cast(visit(node->operands[1])); HANDLE_UNRESOLVED_TYPE_ER(rhs.type) - auto lhs = std::any_cast(visit(node->operands[0])); + const auto lhs = std::any_cast(visit(node->operands[0])); HANDLE_UNRESOLVED_TYPE_ER(lhs.type) // Check if we need the string runtime to perform a string comparison @@ -1147,9 +1152,9 @@ std::any TypeChecker::visitRelationalExpr(RelationalExprNode *node) { return visit(node->operands.front()); // Visit right side first, then left side - auto rhs = std::any_cast(visit(node->operands[1])); + const auto rhs = std::any_cast(visit(node->operands[1])); HANDLE_UNRESOLVED_TYPE_ER(rhs.type) - auto lhs = std::any_cast(visit(node->operands[0])); + const auto lhs = std::any_cast(visit(node->operands[0])); HANDLE_UNRESOLVED_TYPE_ER(lhs.type) // Check operator @@ -1173,20 +1178,28 @@ std::any TypeChecker::visitShiftExpr(ShiftExprNode *node) { if (node->operands.size() == 1) return visit(node->operands.front()); - // Visit right side first, then left - auto rhs = std::any_cast(visit(node->operands[1])); - HANDLE_UNRESOLVED_TYPE_ER(rhs.type) - auto lhs = std::any_cast(visit(node->operands[0])); - HANDLE_UNRESOLVED_TYPE_ER(lhs.type) + // Visit leftmost operand + auto currentResult = std::any_cast(visit(node->operands[0])); + HANDLE_UNRESOLVED_TYPE_ER(currentResult.type) - // Check operator - ExprResult currentResult; - if (node->op == ShiftExprNode::OP_SHIFT_LEFT) // Operator was shl - currentResult = opRuleManager.getShiftLeftResultType(node, lhs, rhs, 0); - else if (node->op == ShiftExprNode::OP_SHIFT_RIGHT) // Operator was shr - currentResult = opRuleManager.getShiftRightResultType(node, lhs, rhs, 0); - else - throw CompilerError(UNHANDLED_BRANCH, "ShiftExpr fall-through"); // GCOV_EXCL_LINE + // Loop through remaining operands + for (size_t i = 0; i < node->opQueue.size(); i++) { + auto operandResult = std::any_cast(visit(node->operands[i + 1])); + HANDLE_UNRESOLVED_TYPE_ER(operandResult.type) + + // Check operator + const ShiftExprNode::ShiftOp &op = node->opQueue.front().first; + if (op == ShiftExprNode::ShiftOp::OP_SHIFT_LEFT) + currentResult = opRuleManager.getShiftLeftResultType(node, currentResult, operandResult, i); + else if (op == ShiftExprNode::ShiftOp::OP_SHIFT_RIGHT) + currentResult = opRuleManager.getShiftRightResultType(node, currentResult, operandResult, i); + else + throw CompilerError(UNHANDLED_BRANCH, "ShiftExpr fall-through"); // GCOV_EXCL_LINE + + // Push the new item and pop the old one on the other side of the queue + node->opQueue.emplace(op, currentResult.type); + node->opQueue.pop(); + } node->setEvaluatedSymbolType(currentResult.type, manIdx); return currentResult; @@ -1203,9 +1216,7 @@ std::any TypeChecker::visitAdditiveExpr(AdditiveExprNode *node) { // Loop through remaining operands for (size_t i = 0; i < node->opQueue.size(); i++) { - // Visit next operand - MultiplicativeExprNode *operand = node->operands[i + 1]; - auto operandResult = std::any_cast(visit(operand)); + auto operandResult = std::any_cast(visit(node->operands[i + 1])); HANDLE_UNRESOLVED_TYPE_ER(operandResult.type) // Check operator @@ -1236,9 +1247,7 @@ std::any TypeChecker::visitMultiplicativeExpr(MultiplicativeExprNode *node) { HANDLE_UNRESOLVED_TYPE_ER(currentResult.type) // Loop through remaining operands for (size_t i = 0; i < node->opQueue.size(); i++) { - // Visit next operand - CastExprNode *operand = node->operands[i + 1]; - auto operandResult = std::any_cast(visit(operand)); + auto operandResult = std::any_cast(visit(node->operands[i + 1])); HANDLE_UNRESOLVED_TYPE_ER(operandResult.type) // Check operator diff --git a/test/test-files/std/test/lifetime-object/source.spice b/test/test-files/std/test/lifetime-object/source.spice index c06e97865..88e5271c2 100644 --- a/test/test-files/std/test/lifetime-object/source.spice +++ b/test/test-files/std/test/lifetime-object/source.spice @@ -21,7 +21,7 @@ f decideOnLORef(bool cond, LifetimeObject& lo1, LifetimeObject& f decideOnLOVal(bool cond, LifetimeObject lo1, LifetimeObject lo2) { cond ? lo1 : lo2; // Shoud do a copy, although the result is ignored LifetimeObject loCopy = cond ? lo1 : lo2; // Shoud do a copy - const LifetimeObject& loRef = cond ? lo1 : lo2; // Should not do a copy + const LifetimeObject& loRef = cond ? lo1 : lo2; // Should do a copy return cond ? lo1 : lo2; // Return statement should do a copy }