Skip to content

Commit

Permalink
Add support for overloading subscript operator (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer authored Jan 12, 2025
1 parent 1afd140 commit 75d3241
Show file tree
Hide file tree
Showing 17 changed files with 280 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .run/spice run.run.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="spice run" type="CMakeRunConfiguration" factoryName="Application" PROGRAM_PARAMS="run -O2 -d ../../src-bootstrap/main.spice" REDIRECT_INPUT="false" ELEVATE="false" USE_EXTERNAL_CONSOLE="false" EMULATE_TERMINAL="false" PASS_PARENT_ENVS_2="true" PROJECT_NAME="Spice" TARGET_NAME="spice" CONFIG_NAME="Debug" RUN_TARGET_PROJECT_NAME="Spice" RUN_TARGET_NAME="spice">
<configuration default="false" name="spice run" type="CMakeRunConfiguration" factoryName="Application" PROGRAM_PARAMS="run -O2 -d ../../media/test-project/test.spice" REDIRECT_INPUT="false" ELEVATE="false" USE_EXTERNAL_CONSOLE="false" EMULATE_TERMINAL="false" PASS_PARENT_ENVS_2="true" PROJECT_NAME="Spice" TARGET_NAME="spice" CONFIG_NAME="Debug" RUN_TARGET_PROJECT_NAME="Spice" RUN_TARGET_NAME="spice">
<envs>
<env name="LLVM_ADDITIONAL_FLAGS" value="-lole32 -lws2_32" />
<env name="LLVM_BUILD_INCLUDE_DIR" value="$PROJECT_DIR$/../llvm-project-latest/build/include" />
Expand Down
90 changes: 86 additions & 4 deletions media/test-project/test.spice
Original file line number Diff line number Diff line change
@@ -1,6 +1,88 @@
type Size alias long;

type Counter struct {
Size value
}

p Counter.ctor(long initialValue = 0l) {
this.value = initialValue;
}

f<Size> Counter.getValue() {
return this.value;
}

f<Counter> operator+(const Counter c1, const Counter c2) {
return Counter(c1.value + c2.value);
}

f<Counter> operator-(const Counter c1, const Counter c2) {
return Counter(c1.value - c2.value);
}

f<Counter> operator*(const Counter c1, const Counter c2) {
return Counter(c1.value * c2.value);
}

f<Counter> operator/(const Counter c1, const Counter c2) {
return Counter(c1.value / c2.value);
}

f<Counter> operator<<(const Counter c1, const Counter c2) {
return Counter(c1.value << c2.value);
}

f<Counter> operator>>(const Counter c1, const Counter c2) {
return Counter(c1.value >> c2.value);
}

p operator+=(Counter& c1, const Counter c2) {
c1.value += c2.value;
}

p operator-=(Counter& c1, const Counter c2) {
c1.value -= c2.value;
}

p operator*=(Counter& c1, const Counter c2) {
c1.value *= c2.value;
}

p operator/=(Counter& c1, const Counter c2) {
c1.value /= c2.value;
}

f<Size&> operator[](Counter& c, unsigned int summand) {
c.value += summand;
return c.value;
}

f<int> main() {
String test = String("String to be trimmed ");
printf("'%s'\n", test);
String trimmed = test.trim();
printf("'%s'\n", trimmed);
Counter counter1 = Counter(2l);
Counter counter2 = Counter(3l);
printf("Counter1 value: %d\n", counter1.getValue());
printf("Counter2 value: %d\n", counter2.getValue());
Counter counter3 = counter1 + counter2; // Here we call the overloaded operator
printf("Counter3 value: %d\n", counter3.getValue());
Counter counter4 = counter3 - counter2; // Here we call the overloaded operator
printf("Counter4 value: %d\n", counter4.getValue());
Counter counter5 = counter4 * counter2; // Here we call the overloaded operator
printf("Counter5 value: %d\n", counter5.getValue());
Counter counter6 = counter5 / counter2; // Here we call the overloaded operator
printf("Counter6 value: %d\n", counter6.getValue());
Counter counter7 = counter6 << counter2; // Here we call the overloaded operator
printf("Counter7 value: %d\n", counter7.getValue());
Counter counter8 = counter7 >> counter2; // Here we call the overloaded operator
printf("Counter8 value: %d\n", counter8.getValue());
counter8 += counter2; // Here we call the overloaded operator
printf("Counter8 value: %d\n", counter8.getValue());
counter8 -= counter2; // Here we call the overloaded operator
printf("Counter8 value: %d\n", counter8.getValue());
counter8 *= counter2; // Here we call the overloaded operator
printf("Counter8 value: %d\n", counter8.getValue());
counter8 /= counter2; // Here we call the overloaded operator
printf("Counter8 value: %d\n", counter8.getValue());
Size res = counter8[12];
assert res == 14;
printf("Counter8 value: %d\n", counter8.getValue());
}
2 changes: 1 addition & 1 deletion src/Spice.g4
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ functionDataType: (F LESS dataType GREATER | P) LPAREN typeLst? RPAREN;

// Shorthands
assignOp: ASSIGN | PLUS_EQUAL | MINUS_EQUAL | MUL_EQUAL | DIV_EQUAL | REM_EQUAL | SHL_EQUAL | SHR_EQUAL | AND_EQUAL | OR_EQUAL | XOR_EQUAL;
overloadableOp: PLUS | MINUS | MUL | DIV | EQUAL | NOT_EQUAL | LESS LESS | GREATER GREATER | PLUS_EQUAL | MINUS_EQUAL | MUL_EQUAL | DIV_EQUAL | PLUS_PLUS | MINUS_MINUS;
overloadableOp: PLUS | MINUS | MUL | DIV | EQUAL | NOT_EQUAL | LESS LESS | GREATER GREATER | PLUS_EQUAL | MINUS_EQUAL | MUL_EQUAL | DIV_EQUAL | PLUS_PLUS | MINUS_MINUS | LBRACKET RBRACKET;

// Keyword tokens
TYPE_DOUBLE: 'double';
Expand Down
2 changes: 2 additions & 0 deletions src/ast/ASTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,8 @@ std::any ASTBuilder::visitOverloadableOp(SpiceParser::OverloadableOpContext *ctx
fctNameNode->name = OP_FCT_POSTFIX_PLUS_PLUS;
else if (ctx->MINUS_MINUS())
fctNameNode->name = OP_FCT_POSTFIX_MINUS_MINUS;
else if (ctx->LBRACKET())
fctNameNode->name = OP_FCT_SUBSCRIPT;
else
assert_fail("Unsupported overloadable operator"); // GCOV_EXCL_LINE

Expand Down
5 changes: 3 additions & 2 deletions src/ast/ASTNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ constexpr const char *const OP_FCT_MUL_EQUAL = "op.mulequal";
constexpr const char *const OP_FCT_DIV_EQUAL = "op.divequal";
constexpr const char *const OP_FCT_POSTFIX_PLUS_PLUS = "op.plusplus.post";
constexpr const char *const OP_FCT_POSTFIX_MINUS_MINUS = "op.minusminus.post";
constexpr const char *const OP_FCT_SUBSCRIPT = "op.subscript";

/**
* Saves a constant value for an AST node to realize features like array-out-of-bounds checks
Expand Down Expand Up @@ -72,7 +73,7 @@ class ASTNode {
virtual std::any accept(AbstractASTVisitor *visitor) = 0;
virtual std::any accept(ParallelizableASTVisitor *visitor) const = 0;

template <typename... Args> ALWAYS_INLINE std::vector<ASTNode *> collectChildren(Args &&...args) const {
template <typename... Args> [[nodiscard]] ALWAYS_INLINE std::vector<ASTNode *> collectChildren(Args &&...args) const {
std::vector<ASTNode *> children;

// Lambda to handle each argument
Expand All @@ -92,7 +93,7 @@ class ASTNode {
return children;
}

virtual std::vector<ASTNode *> getChildren() const = 0;
[[nodiscard]] virtual std::vector<ASTNode *> getChildren() const = 0;

void resizeToNumberOfManifestations(const size_t manifestationCount) { // NOLINT(misc-no-recursion)
// Resize children
Expand Down
15 changes: 13 additions & 2 deletions src/irgenerator/GenExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ std::any IRGenerator::visitAssignExpr(const AssignExprNode *node) {
lhs.value = result.value;
insertStore(lhs.value, lhs.ptr, lhs.entry && lhs.entry->isVolatile);
}
return LLVMExprResult{.value = lhs.value, .ptr = lhs.ptr, .refPtr = lhs.refPtr, .entry = lhs.entry};
return lhs;
}

// This is a fallthrough case -> throw an error
Expand Down Expand Up @@ -660,10 +660,21 @@ std::any IRGenerator::visitPostfixUnaryExpr(const PostfixUnaryExprNode *node) {

switch (node->op) {
case PostfixUnaryExprNode::OP_SUBSCRIPT: {
const AssignExprNode *indexExpr = node->subscriptIndexExpr;

// Check if we need to generate a call to an overloaded operator function
if (conversionManager.callsOverloadedOpFct(node, 0)) {
ResolverFct lhsV = [&] { return resolveValue(lhsSTy, lhs); };
ResolverFct lhsP = [&] { return resolveAddress(lhs); };
ResolverFct idxV = [&] { return resolveValue(indexExpr); };
ResolverFct idxP = [&] { return nullptr; };
lhs = conversionManager.callOperatorOverloadFct<2>(node, {lhsV, lhsP, idxV, idxP}, 0);
break;
}

lhsSTy = lhsSTy.removeReferenceWrapper();

// Get the index value
const AssignExprNode *indexExpr = node->subscriptIndexExpr;
llvm::Value *indexValue = resolveValue(indexExpr);
// Come up with the address
if (lhsSTy.isArray() && lhsSTy.getArraySize() != ARRAY_SIZE_UNKNOWN) { // Array
Expand Down
15 changes: 8 additions & 7 deletions src/irgenerator/OpRuleConversionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ LLVMExprResult OpRuleConversionManager::getDivInst(const ASTNode *node, LLVMExpr
}

LLVMExprResult OpRuleConversionManager::getRemInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy, LLVMExprResult &rhs,
QualType rhsSTy) {
QualType rhsSTy) const {
ResolverFct lhsV = [&] { return irGenerator->resolveValue(lhsSTy, lhs); };
ResolverFct rhsV = [&] { return irGenerator->resolveValue(rhsSTy, rhs); };
lhsSTy = lhsSTy.removeReferenceWrapper();
Expand Down Expand Up @@ -1464,7 +1464,7 @@ LLVMExprResult OpRuleConversionManager::getRemInst(const ASTNode *node, LLVMExpr
}
}

LLVMExprResult OpRuleConversionManager::getPrefixMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) {
LLVMExprResult OpRuleConversionManager::getPrefixMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const {
ResolverFct lhsV = [&] { return irGenerator->resolveValue(lhsSTy, lhs); };
lhsSTy = lhsSTy.removeReferenceWrapper();

Expand All @@ -1481,7 +1481,7 @@ LLVMExprResult OpRuleConversionManager::getPrefixMinusInst(const ASTNode *node,
throw CompilerError(UNHANDLED_BRANCH, "Operator fallthrough: -"); // GCOV_EXCL_LINE
}

LLVMExprResult OpRuleConversionManager::getPrefixPlusPlusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) {
LLVMExprResult OpRuleConversionManager::getPrefixPlusPlusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const {
ResolverFct lhsV = [&] { return irGenerator->resolveValue(lhsSTy, lhs); };
lhsSTy = lhsSTy.removeReferenceWrapper();

Expand All @@ -1502,7 +1502,7 @@ LLVMExprResult OpRuleConversionManager::getPrefixPlusPlusInst(const ASTNode *nod
throw CompilerError(UNHANDLED_BRANCH, "Operator fallthrough: ++ (prefix)"); // GCOV_EXCL_LINE
}

LLVMExprResult OpRuleConversionManager::getPrefixMinusMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) {
LLVMExprResult OpRuleConversionManager::getPrefixMinusMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const {
ResolverFct lhsV = [&] { return irGenerator->resolveValue(lhsSTy, lhs); };
lhsSTy = lhsSTy.removeReferenceWrapper();

Expand All @@ -1523,7 +1523,7 @@ LLVMExprResult OpRuleConversionManager::getPrefixMinusMinusInst(const ASTNode *n
throw CompilerError(UNHANDLED_BRANCH, "Operator fallthrough: -- (prefix)"); // GCOV_EXCL_LINE
}

LLVMExprResult OpRuleConversionManager::getPrefixNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) {
LLVMExprResult OpRuleConversionManager::getPrefixNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const {
ResolverFct lhsV = [&] { return irGenerator->resolveValue(lhsSTy, lhs); };
lhsSTy = lhsSTy.removeReferenceWrapper();

Expand All @@ -1536,7 +1536,7 @@ LLVMExprResult OpRuleConversionManager::getPrefixNotInst(const ASTNode *node, LL
throw CompilerError(UNHANDLED_BRANCH, "Operator fallthrough: !"); // GCOV_EXCL_LINE
}

LLVMExprResult OpRuleConversionManager::getPrefixBitwiseNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) {
LLVMExprResult OpRuleConversionManager::getPrefixBitwiseNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const {
ResolverFct lhsV = [&] { return irGenerator->resolveValue(lhsSTy, lhs); };
lhsSTy = lhsSTy.removeReferenceWrapper();

Expand Down Expand Up @@ -1605,7 +1605,7 @@ LLVMExprResult OpRuleConversionManager::getPostfixMinusMinusInst(const ASTNode *
throw CompilerError(UNHANDLED_BRANCH, "Operator fallthrough: -- (postfix)"); // GCOV_EXCL_LINE
}

LLVMExprResult OpRuleConversionManager::getCastInst(const ASTNode *node, QualType lhsSTy, LLVMExprResult &rhs, QualType rhsSTy) {
LLVMExprResult OpRuleConversionManager::getCastInst(const ASTNode *node, QualType lhsSTy, LLVMExprResult &rhs, QualType rhsSTy) const {
ResolverFct rhsV = [&] { return irGenerator->resolveValue(rhsSTy, rhs); };
lhsSTy = lhsSTy.removeReferenceWrapper();
rhsSTy = rhsSTy.removeReferenceWrapper();
Expand Down Expand Up @@ -1679,6 +1679,7 @@ bool OpRuleConversionManager::callsOverloadedOpFct(const ASTNode *node, size_t o
template <size_t N>
LLVMExprResult OpRuleConversionManager::callOperatorOverloadFct(const ASTNode *node, const std::array<ResolverFct, N * 2> &opV,
size_t opIdx) {
static_assert(N == 1 || N == 2, "Only unary and binary operators are overloadable");
const size_t manIdx = irGenerator->manIdx;
const std::vector<std::vector<const Function *>> *opFctPointers = node->getOpFctPointers();
assert(!opFctPointers->empty() && opFctPointers->size() > manIdx);
Expand Down
20 changes: 10 additions & 10 deletions src/irgenerator/OpRuleConversionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,20 @@ class OpRuleConversionManager {
size_t opIdx);
LLVMExprResult getDivInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy, LLVMExprResult &rhs, QualType rhsSTy,
size_t opIdx);
LLVMExprResult getRemInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy, LLVMExprResult &rhs, QualType rhsSTy);
LLVMExprResult getPrefixMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy);
LLVMExprResult getPrefixPlusPlusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy);
LLVMExprResult getPrefixMinusMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy);
LLVMExprResult getPrefixNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy);
LLVMExprResult getPrefixBitwiseNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy);
LLVMExprResult getRemInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy, LLVMExprResult &rhs, QualType rhsSTy) const;
LLVMExprResult getPrefixMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const;
LLVMExprResult getPrefixPlusPlusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const;
LLVMExprResult getPrefixMinusMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const;
LLVMExprResult getPrefixNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const;
LLVMExprResult getPrefixBitwiseNotInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy) const;
LLVMExprResult getPostfixPlusPlusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy, size_t opIdx);
LLVMExprResult getPostfixMinusMinusInst(const ASTNode *node, LLVMExprResult &lhs, QualType lhsSTy, size_t opIdx);
LLVMExprResult getCastInst(const ASTNode *node, QualType lhsSTy, LLVMExprResult &rhs, QualType rhsSTy);
LLVMExprResult getCastInst(const ASTNode *node, QualType lhsSTy, LLVMExprResult &rhs, QualType rhsSTy) const;

// Util methods
// Operator overloading
bool callsOverloadedOpFct(const ASTNode *node, size_t opIdx) const;
template <size_t N>
LLVMExprResult callOperatorOverloadFct(const ASTNode *node, const std::array<ResolverFct, N * 2> &opV, size_t opIdx);

private:
// Members
Expand All @@ -94,8 +96,6 @@ class OpRuleConversionManager {
const StdFunctionManager &stdFunctionManager;

// Private methods
template <size_t N>
LLVMExprResult callOperatorOverloadFct(const ASTNode *node, const std::array<ResolverFct, N * 2> &opV, size_t opIdx);
[[nodiscard]] llvm::Value *generateIToFp(const QualType &srcSTy, llvm::Value *srcV, llvm::Type *tgtT) const;
[[nodiscard]] llvm::Value *generateSHR(const QualType &lhsSTy, const QualType &rhsSTy, llvm::Value *lhsV,
llvm::Value *rhsV) const;
Expand Down
1 change: 1 addition & 0 deletions src/typechecker/OpRuleManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ QualType OpRuleManager::getCastResultType(const ASTNode *node, QualType lhsType,
template <size_t N>
ExprResult OpRuleManager::isOperatorOverloadingFctAvailable(ASTNode *node, const char *const fctName,
const std::array<ExprResult, N> &op, size_t opIdx) {
static_assert(N == 1 || N == 2, "Only unary and binary operators are overloadable");
Scope *calleeParentScope = nullptr;
const Function *callee = nullptr;
for (const auto &sourceFile : typeChecker->resourceManager.sourceFiles | std::views::values) {
Expand Down
9 changes: 6 additions & 3 deletions src/typechecker/OpRuleManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,11 @@ class OpRuleManager {
ExprResult getPostfixMinusMinusResultType(ASTNode *node, const ExprResult &lhs, size_t opIdx);
QualType getCastResultType(const ASTNode *node, QualType lhsType, const ExprResult &rhs) const;

// Operator overloading
template <size_t N>
ExprResult isOperatorOverloadingFctAvailable(ASTNode *node, const char *fctName, const std::array<ExprResult, N> &op,
size_t opIdx);

private:
// Members
TypeChecker *typeChecker;
Expand All @@ -645,9 +650,7 @@ class OpRuleManager {
// Private methods
static QualType getAssignResultTypeCommon(const ASTNode *node, const ExprResult &lhs, const ExprResult &rhs, bool isDecl,
bool isReturn);
template <size_t N>
ExprResult isOperatorOverloadingFctAvailable(ASTNode *node, const char *fctName, const std::array<ExprResult, N> &op,
size_t opIdx);

static QualType validateUnaryOperation(const ASTNode *node, const UnaryOpRule opRules[], size_t opRulesSize, const char *name,
const QualType &lhs);
static QualType validateBinaryOperation(const ASTNode *node, const BinaryOpRule opRules[], size_t opRulesSize, const char *name,
Expand Down
25 changes: 16 additions & 9 deletions src/typechecker/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,21 +1341,28 @@ std::any TypeChecker::visitPostfixUnaryExpr(PostfixUnaryExprNode *node) {

switch (node->op) {
case PostfixUnaryExprNode::OP_SUBSCRIPT: {
// Visit index assignment
AssignExprNode *indexAssignExpr = node->subscriptIndexExpr;
const auto index = std::any_cast<ExprResult>(visit(indexAssignExpr));
HANDLE_UNRESOLVED_TYPE_ER(index.type)
// Check if the index is of the right type
if (!index.type.isOneOf({TY_INT, TY_LONG}))
SOFT_ERROR_ER(node, ARRAY_INDEX_NOT_INT_OR_LONG, "Array index must be of type int or long")

// Check is there is an overloaded operator function available, if yes accept it
const auto [type, _] = opRuleManager.isOperatorOverloadingFctAvailable<2>(node, OP_FCT_SUBSCRIPT, {operand, index}, 0);
if (!type.is(TY_INVALID)) {
operandType = type;
break;
}

operandType = operandType.removeReferenceWrapper();

// Check if we can apply the subscript operator on the lhs type
if (!operandType.isOneOf({TY_ARRAY, TY_STRING, TY_PTR}))
SOFT_ERROR_ER(node, OPERATOR_WRONG_DATA_TYPE,
"Can only apply subscript operator on array type, got " + operandType.getName(true))

// Visit index assignment
AssignExprNode *indexAssignExpr = node->subscriptIndexExpr;
QualType indexType = std::any_cast<ExprResult>(visit(indexAssignExpr)).type;
HANDLE_UNRESOLVED_TYPE_ER(indexType)
// Check if the index is of the right type
if (!indexType.isOneOf({TY_INT, TY_LONG}))
SOFT_ERROR_ER(node, ARRAY_INDEX_NOT_INT_OR_LONG, "Array index must be of type int or long")

// Check if we have an unsafe operation
if (operandType.isPtr() && !currentScope->doesAllowUnsafeOperations())
SOFT_ERROR_ER(
Expand Down Expand Up @@ -2541,7 +2548,7 @@ std::any TypeChecker::visitFunctionDataType(FunctionDataTypeNode *node) {
}

/**
* Check if the the capture rules for async lambdas are enforced if the async attribute is set
* Check if the capture rules for async lambdas are enforced if the async attribute is set
*
* Only one capture with pointer type, pass-by-val is allowed, since only then we can store it in the second field of the
* fat pointer and can ensure, that no stack variable is referenced inside the lambda.
Expand Down
Loading

0 comments on commit 75d3241

Please sign in to comment.