Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AST child node fetching into vector #717

Merged
merged 1 commit into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 40 additions & 61 deletions src/ast/ASTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ std::any ASTBuilder::visitInterfaceDef(SpiceParser::InterfaceDefContext *ctx) {

// Check if a custom type id was set
if (interfaceDefNode->attrs && interfaceDefNode->attrs->attrLst->hasAttr(ATTR_CORE_COMPILER_FIXED_TYPE_ID))
interfaceDefNode->typeId =
interfaceDefNode->attrs->attrLst->getAttrValueByName(ATTR_CORE_COMPILER_FIXED_TYPE_ID)->intValue;
interfaceDefNode->typeId = interfaceDefNode->attrs->attrLst->getAttrValueByName(ATTR_CORE_COMPILER_FIXED_TYPE_ID)->intValue;
}
if (ctx->specifierLst())
interfaceDefNode->specifierLst = std::any_cast<SpecifierLstNode *>(visit(ctx->specifierLst()));
Expand Down Expand Up @@ -429,8 +428,7 @@ std::any ASTBuilder::visitSwitchStmt(SpiceParser::SwitchStmtContext *ctx) {

// Visit children
switchStmtNode->assignExpr = std::any_cast<AssignExprNode *>(visit(ctx->assignExpr()));
for (SpiceParser::CaseBranchContext *caseBranch : ctx->caseBranch())
switchStmtNode->caseBranches.push_back(std::any_cast<CaseBranchNode *>(visit(caseBranch)));
fetchChildrenIntoVector(switchStmtNode->caseBranches, ctx->caseBranch());
if (ctx->defaultBranch()) {
switchStmtNode->hasDefaultBranch = true;
switchStmtNode->defaultBranch = std::any_cast<DefaultBranchNode *>(visit(ctx->defaultBranch()));
Expand All @@ -443,8 +441,7 @@ std::any ASTBuilder::visitCaseBranch(SpiceParser::CaseBranchContext *ctx) {
const auto caseBranchNode = createNode<CaseBranchNode>(ctx);

// Visit children
for (SpiceParser::CaseConstantContext *caseConstant : ctx->caseConstant())
caseBranchNode->caseConstants.push_back(std::any_cast<CaseConstantNode *>(visit(caseConstant)));
fetchChildrenIntoVector(caseBranchNode->caseConstants, ctx->caseConstant());
caseBranchNode->body = std::any_cast<StmtLstNode *>(visit(ctx->stmtLst()));

return concludeNode(caseBranchNode);
Expand Down Expand Up @@ -507,8 +504,7 @@ std::any ASTBuilder::visitTypeLst(SpiceParser::TypeLstContext *ctx) {
const auto typeLstNode = createNode<TypeLstNode>(ctx);

// Visit children
for (SpiceParser::DataTypeContext *dataType : ctx->dataType())
typeLstNode->dataTypes.push_back(std::any_cast<DataTypeNode *>(visit(dataType)));
fetchChildrenIntoVector(typeLstNode->dataTypes, ctx->dataType());

return concludeNode(typeLstNode);
}
Expand All @@ -517,8 +513,7 @@ std::any ASTBuilder::visitTypeAltsLst(SpiceParser::TypeAltsLstContext *ctx) {
const auto typeAltsLstNode = createNode<TypeAltsLstNode>(ctx);

// Visit children
for (SpiceParser::DataTypeContext *dataType : ctx->dataType())
typeAltsLstNode->dataTypes.push_back(std::any_cast<DataTypeNode *>(visit(dataType)));
fetchChildrenIntoVector(typeAltsLstNode->dataTypes, ctx->dataType());

return concludeNode(typeAltsLstNode);
}
Expand All @@ -527,11 +522,12 @@ std::any ASTBuilder::visitParamLst(SpiceParser::ParamLstContext *ctx) {
const auto paramLstNode = createNode<ParamLstNode>(ctx);

// Visit children
for (SpiceParser::DeclStmtContext *declStmt : ctx->declStmt()) {
auto param = std::any_cast<DeclStmtNode *>(visit(declStmt));
param->isFctParam = true;
param->dataType->isParamType = true;
paramLstNode->params.push_back(param);
fetchChildrenIntoVector(paramLstNode->params, ctx->declStmt());

// Set some flags to later detect that the decl statements are parameters
for (DeclStmtNode *declStmt : paramLstNode->params) {
declStmt->isFctParam = true;
declStmt->dataType->isParamType = true;
}

return concludeNode(paramLstNode);
Expand All @@ -541,8 +537,7 @@ std::any ASTBuilder::visitArgLst(SpiceParser::ArgLstContext *ctx) {
const auto argLstNode = createNode<ArgLstNode>(ctx);

// Visit children
for (SpiceParser::AssignExprContext *assignExpr : ctx->assignExpr())
argLstNode->args.push_back(std::any_cast<AssignExprNode *>(visit(assignExpr)));
fetchChildrenIntoVector(argLstNode->args, ctx->assignExpr());
argLstNode->argInfos.reserve(argLstNode->args.size());

return concludeNode(argLstNode);
Expand All @@ -552,8 +547,7 @@ std::any ASTBuilder::visitEnumItemLst(SpiceParser::EnumItemLstContext *ctx) {
const auto enumItemLstNode = createNode<EnumItemLstNode>(ctx);

// Visit children
for (SpiceParser::EnumItemContext *enumItem : ctx->enumItem())
enumItemLstNode->items.push_back(std::any_cast<EnumItemNode *>(visit(enumItem)));
fetchChildrenIntoVector(enumItemLstNode->items, ctx->enumItem());

return concludeNode(enumItemLstNode);
}
Expand Down Expand Up @@ -663,11 +657,11 @@ std::any ASTBuilder::visitSpecifierLst(SpiceParser::SpecifierLstContext *ctx) {
const auto specifierLstNode = createNode<SpecifierLstNode>(ctx);

// Visit children
bool seenSignedOrUnsigned = false;
for (SpiceParser::SpecifierContext *specifierCtx : ctx->specifier()) {
auto specifier = std::any_cast<SpecifierNode *>(visit(specifierCtx));
specifierLstNode->specifiers.push_back(specifier);
fetchChildrenIntoVector(specifierLstNode->specifiers, ctx->specifier());

// Check if specifier combination is invalid
bool seenSignedOrUnsigned = false;
for (const SpecifierNode *specifier : specifierLstNode->specifiers) {
// Check if we have both, signed and unsigned specifier
if (specifier->type != SpecifierNode::TY_SIGNED && specifier->type != SpecifierNode::TY_UNSIGNED)
continue;
Expand Down Expand Up @@ -745,8 +739,7 @@ std::any ASTBuilder::visitAttrLst(SpiceParser::AttrLstContext *ctx) {
const auto attrLstNode = createNode<AttrLstNode>(ctx);

// Visit children
for (SpiceParser::AttrContext *attr : ctx->attr())
attrLstNode->attributes.push_back(std::any_cast<AttrNode *>(visit(attr)));
fetchChildrenIntoVector(attrLstNode->attributes, ctx->attr());

return concludeNode(attrLstNode);
}
Expand Down Expand Up @@ -906,8 +899,7 @@ std::any ASTBuilder::visitPrintfCall(SpiceParser::PrintfCallContext *ctx) {
printfCallNode->templatedString = templatedString;

// Visit children
for (SpiceParser::AssignExprContext *assignExprContext : ctx->assignExpr())
printfCallNode->args.push_back(std::any_cast<AssignExprNode *>(visit(assignExprContext)));
fetchChildrenIntoVector(printfCallNode->args, ctx->assignExpr());

return concludeNode(printfCallNode);
}
Expand Down Expand Up @@ -962,8 +954,7 @@ std::any ASTBuilder::visitSysCall(SpiceParser::SysCallContext *ctx) {
const auto sysCallNode = createNode<SysCallNode>(ctx);

// Visit children
for (SpiceParser::AssignExprContext *assignExprContext : ctx->assignExpr())
sysCallNode->args.push_back(std::any_cast<AssignExprNode *>(visit(assignExprContext)));
fetchChildrenIntoVector(sysCallNode->args, ctx->assignExpr());

return concludeNode(sysCallNode);
}
Expand Down Expand Up @@ -1004,8 +995,7 @@ std::any ASTBuilder::visitLogicalOrExpr(SpiceParser::LogicalOrExprContext *ctx)
const auto logicalOrExprNode = createNode<LogicalOrExprNode>(ctx);

// Visit children
for (SpiceParser::LogicalAndExprContext *logicalAndExpr : ctx->logicalAndExpr())
logicalOrExprNode->operands.push_back(std::any_cast<LogicalAndExprNode *>(visit(logicalAndExpr)));
fetchChildrenIntoVector(logicalOrExprNode->operands, ctx->logicalAndExpr());

return concludeNode(logicalOrExprNode);
}
Expand All @@ -1014,8 +1004,7 @@ std::any ASTBuilder::visitLogicalAndExpr(SpiceParser::LogicalAndExprContext *ctx
const auto logicalAndExprNode = createNode<LogicalAndExprNode>(ctx);

// Visit children
for (SpiceParser::BitwiseOrExprContext *bitwiseOrExpr : ctx->bitwiseOrExpr())
logicalAndExprNode->operands.push_back(std::any_cast<BitwiseOrExprNode *>(visit(bitwiseOrExpr)));
fetchChildrenIntoVector(logicalAndExprNode->operands, ctx->bitwiseOrExpr());

return concludeNode(logicalAndExprNode);
}
Expand All @@ -1024,8 +1013,7 @@ std::any ASTBuilder::visitBitwiseOrExpr(SpiceParser::BitwiseOrExprContext *ctx)
const auto bitwiseOrExprNode = createNode<BitwiseOrExprNode>(ctx);

// Visit children
for (SpiceParser::BitwiseXorExprContext *bitwiseXorExpr : ctx->bitwiseXorExpr())
bitwiseOrExprNode->operands.push_back(std::any_cast<BitwiseXorExprNode *>(visit(bitwiseXorExpr)));
fetchChildrenIntoVector(bitwiseOrExprNode->operands, ctx->bitwiseXorExpr());

return concludeNode(bitwiseOrExprNode);
}
Expand All @@ -1034,8 +1022,7 @@ std::any ASTBuilder::visitBitwiseXorExpr(SpiceParser::BitwiseXorExprContext *ctx
const auto bitwiseXorExprNode = createNode<BitwiseXorExprNode>(ctx);

// Visit children
for (SpiceParser::BitwiseAndExprContext *bitwiseAndExpr : ctx->bitwiseAndExpr())
bitwiseXorExprNode->operands.push_back(std::any_cast<BitwiseAndExprNode *>(visit(bitwiseAndExpr)));
fetchChildrenIntoVector(bitwiseXorExprNode->operands, ctx->bitwiseAndExpr());

return concludeNode(bitwiseXorExprNode);
}
Expand All @@ -1044,8 +1031,7 @@ std::any ASTBuilder::visitBitwiseAndExpr(SpiceParser::BitwiseAndExprContext *ctx
const auto bitwiseAndExprNode = createNode<BitwiseAndExprNode>(ctx);

// Visit children
for (SpiceParser::EqualityExprContext *equalityExpr : ctx->equalityExpr())
bitwiseAndExprNode->operands.push_back(std::any_cast<EqualityExprNode *>(visit(equalityExpr)));
fetchChildrenIntoVector(bitwiseAndExprNode->operands, ctx->equalityExpr());

return concludeNode(bitwiseAndExprNode);
}
Expand All @@ -1054,8 +1040,7 @@ std::any ASTBuilder::visitEqualityExpr(SpiceParser::EqualityExprContext *ctx) {
const auto equalityExprNode = createNode<EqualityExprNode>(ctx);

// Visit children
for (SpiceParser::RelationalExprContext *relationalExpr : ctx->relationalExpr())
equalityExprNode->operands.push_back(std::any_cast<RelationalExprNode *>(visit(relationalExpr)));
fetchChildrenIntoVector(equalityExprNode->operands, ctx->relationalExpr());

// Extract operator
if (ctx->EQUAL())
Expand All @@ -1070,8 +1055,7 @@ std::any ASTBuilder::visitRelationalExpr(SpiceParser::RelationalExprContext *ctx
const auto relationalExprNode = createNode<RelationalExprNode>(ctx);

// Visit children
for (SpiceParser::ShiftExprContext *shiftExpr : ctx->shiftExpr())
relationalExprNode->operands.push_back(std::any_cast<ShiftExprNode *>(visit(shiftExpr)));
fetchChildrenIntoVector(relationalExprNode->operands, ctx->shiftExpr());

// Extract operator
if (ctx->LESS())
Expand All @@ -1090,8 +1074,7 @@ std::any ASTBuilder::visitShiftExpr(SpiceParser::ShiftExprContext *ctx) {
const auto shiftExprNode = createNode<ShiftExprNode>(ctx);

// Visit children
for (SpiceParser::AdditiveExprContext *additiveExpr : ctx->additiveExpr())
shiftExprNode->operands.push_back(std::any_cast<AdditiveExprNode *>(visit(additiveExpr)));
fetchChildrenIntoVector(shiftExprNode->operands, ctx->additiveExpr());

// Extract operator
if (!ctx->LESS().empty())
Expand All @@ -1106,8 +1089,7 @@ std::any ASTBuilder::visitAdditiveExpr(SpiceParser::AdditiveExprContext *ctx) {
const auto additiveExprNode = createNode<AdditiveExprNode>(ctx);

// Visit children
for (SpiceParser::MultiplicativeExprContext *multiplicativeExpr : ctx->multiplicativeExpr())
additiveExprNode->operands.push_back(std::any_cast<MultiplicativeExprNode *>(visit(multiplicativeExpr)));
fetchChildrenIntoVector(additiveExprNode->operands, ctx->multiplicativeExpr());

for (ParserRuleContext::ParseTree *subTree : ctx->children) {
const auto terminal = dynamic_cast<TerminalNode *>(subTree);
Expand All @@ -1129,8 +1111,7 @@ std::any ASTBuilder::visitMultiplicativeExpr(SpiceParser::MultiplicativeExprCont
const auto multiplicativeExprNode = createNode<MultiplicativeExprNode>(ctx);

// Visit children
for (SpiceParser::CastExprContext *castExpr : ctx->castExpr())
multiplicativeExprNode->operands.push_back(std::any_cast<CastExprNode *>(visit(castExpr)));
fetchChildrenIntoVector(multiplicativeExprNode->operands, ctx->castExpr());

for (ParserRuleContext::ParseTree *subTree : ctx->children) {
const auto terminal = dynamic_cast<TerminalNode *>(subTree);
Expand Down Expand Up @@ -1668,19 +1649,17 @@ int16_t ASTBuilder::parseShort(TerminalNode *terminal) {

int64_t ASTBuilder::parseLong(TerminalNode *terminal) {
const NumericParserCallback<int64_t> cb = [](const std::string &substr, short base, bool isSigned) -> int64_t {
if (isSigned)
return std::stoll(substr, nullptr, base);
else
return static_cast<int64_t>(std::stoull(substr, nullptr, base));
return isSigned ? std::stoll(substr, nullptr, base) : std::stoull(substr, nullptr, base);
};
return parseNumeric(terminal, cb);
}

int8_t ASTBuilder::parseChar(TerminalNode *terminal) const {
const std::string input = terminal->toString();
if (input.length() == 3) { // Normal char literals
if (input.length() == 3) // Normal char literals
return input[1];
} else if (input.length() == 4 && input[1] == '\\') { // Char literals with escape sequence

if (input.length() == 4 && input[1] == '\\') { // Char literals with escape sequence
switch (input[2]) {
case '\'':
return '\'';
Expand All @@ -1706,10 +1685,10 @@ int8_t ASTBuilder::parseChar(TerminalNode *terminal) const {
const CodeLoc codeLoc(terminal->getSymbol(), sourceFile);
throw ParserError(codeLoc, INVALID_CHAR_LITERAL, "Invalid escape sequence " + input);
}
} else {
const CodeLoc codeLoc(terminal->getSymbol(), sourceFile);
throw ParserError(codeLoc, INVALID_CHAR_LITERAL, "Invalid char literal " + input);
}

const CodeLoc codeLoc(terminal->getSymbol(), sourceFile);
throw ParserError(codeLoc, INVALID_CHAR_LITERAL, "Invalid char literal " + input);
}

std::string ASTBuilder::parseString(std::string input) {
Expand Down Expand Up @@ -1756,18 +1735,18 @@ template <typename T> T ASTBuilder::parseNumeric(TerminalNode *terminal, const N
}

void ASTBuilder::replaceEscapeChars(std::string &input) {
std::unordered_map<char, char> escapeMap = {
const std::unordered_map<char, char> escapeMap = {
{'a', '\a'}, {'b', '\b'}, {'f', '\f'}, {'n', '\n'}, {'r', '\r'}, {'t', '\t'},
{'v', '\v'}, {'\\', '\\'}, {'?', '\?'}, {'\'', '\''}, {'"', '\"'},
};

size_t writeIndex = 0; // Index where the next character should be written
for (size_t readIndex = 0; readIndex < input.length(); ++readIndex, ++writeIndex) {
if (input[readIndex] == '\\' && readIndex + 1 < input.length()) {
char nextChar = input[readIndex + 1];
const char nextChar = input[readIndex + 1];
if (escapeMap.contains(nextChar)) {
// If the next character forms a valid escape sequence, replace it
input[writeIndex] = escapeMap[nextChar];
input[writeIndex] = escapeMap.at(nextChar);
readIndex++; // Skip the next character as it's part of the escape sequence
} else {
// If it's not a valid escape sequence, just copy the backslash
Expand Down
26 changes: 23 additions & 3 deletions src/ast/ASTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,19 @@ class ASTBuilder final : CompilerPass, public SpiceVisitor {
std::stack<ASTNode *> parentStack;

// Private methods
template <typename T> ALWAYS_INLINE T *createNode(const ParserRuleContext *ctx) {
template <typename SrcTy, typename TgtTy>
void fetchChildrenIntoVector(std::vector<TgtTy> &tgt, const std::vector<SrcTy> &src)
requires(std::is_pointer_v<SrcTy> && std::is_pointer_v<TgtTy>)
{
tgt.reserve(src.size());
for (SrcTy shiftExpr : src)
tgt.push_back(std::any_cast<TgtTy>(visit(shiftExpr)));
}

template <typename T>
T *createNode(const ParserRuleContext *ctx)
requires std::is_base_of_v<ASTNode, T>
{
// Create the new node
T *node = resourceManager.astNodeAlloc.allocate<T>(getCodeLoc(ctx));
if constexpr (!std::is_same_v<T, EntryNode>)
Expand All @@ -143,9 +155,17 @@ class ASTBuilder final : CompilerPass, public SpiceVisitor {
return node;
}

template <typename T> ALWAYS_INLINE T *resumeForExpansion() const { return spice_pointer_cast<T *>(parentStack.top()); }
template <typename T>
ALWAYS_INLINE T *resumeForExpansion() const
requires std::is_base_of_v<ASTNode, T>
{
return spice_pointer_cast<T *>(parentStack.top());
}

template <typename T> ALWAYS_INLINE T *concludeNode(T *node) {
template <typename T>
ALWAYS_INLINE T *concludeNode(T *node)
requires std::is_base_of_v<ASTNode, T>
{
// This node is no longer the parent for its children
assert(parentStack.top() == node);
parentStack.pop();
Expand Down
Loading