Skip to content

Commit

Permalink
Fix ternary value copying (#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer authored Jan 17, 2025
1 parent b459e27 commit c739d58
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 39 deletions.
33 changes: 28 additions & 5 deletions media/test-project/test.spice
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ f<LifetimeObject> spawnLO() {

f<int> main() {
// Ignored return value of ctor
printf("Ignored return value of ctor");
printf("Ignored return value of ctor:\n");
{
LifetimeObject();
}
Expand Down Expand Up @@ -76,12 +76,35 @@ f<int> main() {
const LifetimeObject& loConstRef = spawnLO(); // Assigned to loConstRef
}

printf("Ternary\n");
printf("Ternary (true temporary, false temporary)\n");
{
bool cond = false;
bool cond = true;
LifetimeObject loCopy1 = cond ? LifetimeObject() : LifetimeObject(); // ctor calls in both branches
LifetimeObject loCopy2 = !cond ? LifetimeObject() : LifetimeObject(); // ctor calls in both branches
}

printf("Ternary (true temporary, false not temporary)\n");
{
bool cond = true;
LifetimeObject lo = LifetimeObject();
LifetimeObject loCopy1 = cond ? LifetimeObject() : lo; // ctor in true branch, copy ctor in false branch
LifetimeObject loCopy2 = !cond ? LifetimeObject() : lo; // ctor in true branch, copy ctor in false branch
}

printf("Ternary (true not temporary, false temporary)\n");
{
bool cond = true;
LifetimeObject lo = LifetimeObject();
LifetimeObject loCopy1 = cond ? lo : LifetimeObject(); // copy ctor in true branch, ctor in false branch
LifetimeObject loCopy2 = !cond ? lo : LifetimeObject(); // copy ctor in true branch, ctor in false branch
}

printf("Ternary (true not temporary, false not temporary)\n");
{
bool cond = true;
LifetimeObject lo1 = LifetimeObject();
LifetimeObject lo2 = LifetimeObject();
LifetimeObject lo3 = cond ? lo1 : lo2;
// ToDo: The dtor of lo2 is called twice
LifetimeObject loCopy1 = cond ? lo1 : lo2; // copy ctor call in both branches
LifetimeObject loCopy2 = !cond ? lo1 : lo2; // copy ctor call in both branches
}
}
3 changes: 3 additions & 0 deletions src/ast/ASTNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,9 @@ class TernaryExprNode final : public ExprNode {
LogicalOrExprNode *condition = nullptr;
LogicalOrExprNode *trueExpr = nullptr;
LogicalOrExprNode *falseExpr = nullptr;
Function *calledCopyCtor = nullptr;
bool trueSideCallsCopyCtor = false;
bool falseSideCallsCopyCtor = false;
bool isShortened = false;
};

Expand Down
65 changes: 50 additions & 15 deletions src/irgenerator/GenExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ std::any IRGenerator::visitTernaryExpr(const TernaryExprNode *node) {
const LogicalOrExprNode *trueNode = node->isShortened ? node->condition : node->trueExpr;
const LogicalOrExprNode *falseNode = node->falseExpr;

llvm::Value* resultValue;
llvm::Value* resultValue = nullptr;
llvm::Value* resultPtr = nullptr;
SymbolTableEntry *anonymousSymbol = nullptr;
if (trueNode->hasCompileTimeValue() && falseNode->hasCompileTimeValue()) {
// If both are constants, we can simply emit a selection instruction
llvm::Value *trueValue = resolveValue(trueNode);
Expand All @@ -116,29 +118,62 @@ std::any IRGenerator::visitTernaryExpr(const TernaryExprNode *node) {

// Fill true block
switchToBlock(condTrue);
llvm::Value *trueValue = resolveValue(trueNode);
llvm::Value *trueValue = nullptr;
llvm::Value* truePtr = nullptr;
if (node->falseSideCallsCopyCtor) { // both sides or only the false side needs copy ctor call
truePtr = resolveAddress(trueNode);
} else if (node->trueSideCallsCopyCtor) { // only true side needs copy ctor call
llvm::Value* originalPtr = resolveAddress(trueNode);
truePtr = insertAlloca(trueNode->getEvaluatedSymbolType(manIdx).toLLVMType(sourceFile));
generateCtorOrDtorCall(truePtr, node->calledCopyCtor, {originalPtr});
} else { // neither true nor false side need copy ctor call
trueValue = resolveValue(trueNode);
}
insertJump(condExit);

// Fill false block
switchToBlock(condFalse);
llvm::Value *falseValue = resolveValue(falseNode);
llvm::Value *falseValue = nullptr;
llvm::Value *falsePtr = nullptr;
if (node->trueSideCallsCopyCtor) { // both sides or only the true side needs copy ctor call
falsePtr = resolveAddress(falseNode);
} else if (node->falseSideCallsCopyCtor) { // only false side needs copy ctor call
llvm::Value* originalPtr = resolveAddress(falseNode);
falsePtr = insertAlloca(falseNode->getEvaluatedSymbolType(manIdx).toLLVMType(sourceFile));
generateCtorOrDtorCall(falsePtr, node->calledCopyCtor, {originalPtr});
} else { // neither true nor false side need copy ctor call
falseValue = resolveValue(falseNode);
}
insertJump(condExit);

// Fill the exit block
switchToBlock(condExit);
llvm::PHINode* phiInst = builder.CreatePHI(trueValue->getType(), 2, "cond.result");
phiInst->addIncoming(trueValue, condTrue);
phiInst->addIncoming(falseValue, condFalse);
resultValue = phiInst;
}
llvm::Type *resultType = node->getEvaluatedSymbolType(manIdx).toLLVMType(sourceFile);
if (node->trueSideCallsCopyCtor || node->falseSideCallsCopyCtor) { // at least one side needs copy ctor call
llvm::PHINode* phiInst = builder.CreatePHI(builder.getPtrTy(), 2, "cond.result");
phiInst->addIncoming(truePtr, condTrue);
phiInst->addIncoming(falsePtr, condFalse);
if (node->trueSideCallsCopyCtor && node->falseSideCallsCopyCtor) { // both sides need copy ctor call
resultPtr = insertAlloca(resultType);
generateCtorOrDtorCall(resultPtr, node->calledCopyCtor, {phiInst});
} else {
resultPtr = phiInst;
}
} else { // neither true nor false side need copy ctor call
assert(trueValue != nullptr);
llvm::PHINode* phiInst = builder.CreatePHI(resultType, 2, "cond.result");
phiInst->addIncoming(trueValue, condTrue);
phiInst->addIncoming(falseValue, condFalse);
resultValue = phiInst;
}

// If we have an anonymous symbol for this ternary expr, make sure that it has an address to reference.
SymbolTableEntry *anonymousSymbol = currentScope->symbolTable.lookupAnonymous(node->codeLoc);
llvm::Value *resultPtr = nullptr;
if (anonymousSymbol != nullptr) {
resultPtr = insertAlloca(anonymousSymbol->getQualType().toLLVMType(sourceFile));
insertStore(resultValue, resultPtr);
anonymousSymbol->updateAddress(resultPtr);
// If we have an anonymous symbol for this ternary expr, make sure that it has an address to reference.
anonymousSymbol = currentScope->symbolTable.lookupAnonymous(node->codeLoc);
if (anonymousSymbol != nullptr) {
resultPtr = insertAlloca(anonymousSymbol->getQualType().toLLVMType(sourceFile));
insertStore(resultValue, resultPtr);
anonymousSymbol->updateAddress(resultPtr);
}
}

return LLVMExprResult{.value = resultValue, .ptr = resultPtr, .entry = anonymousSymbol};
Expand Down
2 changes: 1 addition & 1 deletion src/irgenerator/GenImplicit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void IRGenerator::generateCtorOrDtorCall(const SymbolTableEntry *entry, const Fu
void IRGenerator::generateCtorOrDtorCall(llvm::Value *structAddr, const Function *ctorOrDtor,
const std::vector<llvm::Value *> &args) const {
// Build parameter list
std::vector<llvm::Value *> argValues = {structAddr};
std::vector argValues = {structAddr};
argValues.insert(argValues.end(), args.begin(), args.end());

// Generate function call
Expand Down
3 changes: 2 additions & 1 deletion src/symboltablebuilder/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ SymbolTableEntry *SymbolTable::insert(const std::string &name, ASTNode *declNode

// Check if shadowed
if (parent != nullptr && parent->lookup(name) != nullptr && !declNode->isParam()) {
CompilerWarning warning(declNode->codeLoc, SHADOWED_VARIABLE, "Variable '" + name + "' shadows a variable in a parent scope");
const std::string warningMsg = "Variable '" + name + "' shadows a variable in a parent scope";
const CompilerWarning warning(declNode->codeLoc, SHADOWED_VARIABLE, warningMsg);
scope->sourceFile->compilerOutput.warnings.push_back(warning);
}

Expand Down
2 changes: 1 addition & 1 deletion src/symboltablebuilder/SymbolTableEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SymbolTableEntry final {
void updateType(const QualType &newType, bool overwriteExistingType);
void updateState(const LifecycleState &newState, const ASTNode *node, bool force = false);
[[nodiscard]] const CodeLoc &getDeclCodeLoc() const;
[[nodiscard]] virtual llvm::Value *getAddress() const;
[[nodiscard]] llvm::Value *getAddress() const;
void updateAddress(llvm::Value *address);
void pushAddress(llvm::Value *address);
void popAddress();
Expand Down
47 changes: 31 additions & 16 deletions src/typechecker/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,14 +564,8 @@ std::any TypeChecker::visitDeclStmt(DeclStmtNode *node) {
localVarType = opRuleManager.getAssignResultType(node, lhsResult, rhs, true);

// Call copy ctor if required
if (localVarType.is(TY_STRUCT) && !node->isFctParam && !rhs.isTemporary()) {
Scope *matchScope = localVarType.getBodyScope();
assert(matchScope != nullptr);
// Check if we have a no-args ctor to call
const QualType &thisType = localVarType;
const ArgList args = {{thisType.toConstRef(node), false}};
node->calledCopyCtor = FunctionManager::match(this, matchScope, CTOR_FUNCTION_NAME, thisType, args, {}, true, node);
}
if (localVarType.is(TY_STRUCT) && !node->isFctParam && !rhs.isTemporary())
node->calledCopyCtor = matchCopyCtor(localVarType, node);

// If this is a struct type, check if the type is known. If not, error out
if (localVarType.isBase(TY_STRUCT) && !sourceFile->getNameRegistryEntry(localVarType.getBase().getSubType())) {
Expand Down Expand Up @@ -977,9 +971,11 @@ std::any TypeChecker::visitTernaryExpr(TernaryExprNode *node) {
// Visit condition
const auto condition = std::any_cast<ExprResult>(visit(node->condition));
HANDLE_UNRESOLVED_TYPE_ER(condition.type)
const auto [trueType, trueEntry] = node->isShortened ? condition : std::any_cast<ExprResult>(visit(node->trueExpr));
const auto trueExpr = node->isShortened ? condition : std::any_cast<ExprResult>(visit(node->trueExpr));
const auto [trueType, trueEntry] = trueExpr;
HANDLE_UNRESOLVED_TYPE_ER(trueType)
const auto [falseType, falseEntry] = std::any_cast<ExprResult>(visit(node->falseExpr));
const auto falseExpr = std::any_cast<ExprResult>(visit(node->falseExpr));
const auto [falseType, falseEntry] = falseExpr;
HANDLE_UNRESOLVED_TYPE_ER(falseType)

// Check if the condition evaluates to bool
Expand All @@ -997,18 +993,28 @@ std::any TypeChecker::visitTernaryExpr(TernaryExprNode *node) {
// If there is an anonymous symbol attached to left or right, remove it,
// since the result takes over the ownership of any destructible object.
const bool removeAnonymousSymbolTrueSide = trueEntry && trueEntry->anonymous;
if (removeAnonymousSymbolTrueSide)
if (removeAnonymousSymbolTrueSide) {
currentScope->symbolTable.deleteAnonymous(trueEntry->name);
} else if (trueEntry && !trueEntry->anonymous && !trueTypeModified.isTriviallyCopyable(node)) {
node->trueSideCallsCopyCtor = true;
}
const bool removeAnonymousSymbolFalseSide = falseEntry && falseEntry->anonymous;
if (removeAnonymousSymbolFalseSide)
if (removeAnonymousSymbolFalseSide) {
currentScope->symbolTable.deleteAnonymous(falseEntry->name);
} else if (falseEntry && !falseEntry->anonymous && !falseTypeModified.isTriviallyCopyable(node)) {
node->falseSideCallsCopyCtor = true;
}

// Create a new anonymous symbol for the result if required
const QualType& returnType = trueType;
const QualType &returnType = trueType;
SymbolTableEntry *anonymousSymbol = nullptr;
if (removeAnonymousSymbolTrueSide || removeAnonymousSymbolFalseSide)
anonymousSymbol = currentScope->symbolTable.insertAnonymous(returnType, node);

// Lookup copy ctor if at least one side needs it
if (node->trueSideCallsCopyCtor || node->falseSideCallsCopyCtor)
node->calledCopyCtor = matchCopyCtor(trueTypeModified, node);

return ExprResult{node->setEvaluatedSymbolType(trueType, manIdx), anonymousSymbol};
}

Expand Down Expand Up @@ -1951,9 +1957,11 @@ bool TypeChecker::visitMethodCall(FctCallNode *node, Scope *structScope, QualTyp
// Retrieve field entry
SymbolTableEntry *fieldEntry = structScope->lookupStrict(identifier);
if (!fieldEntry) {
const std::string name = thisType.getBase().getName(false, true);
SOFT_ERROR_BOOL(node, ACCESS_TO_NON_EXISTING_MEMBER,
"The type '" + name + "' does not have a member with the name '" + identifier + "'")
std::stringstream errorMsg;
errorMsg << "The type '";
errorMsg << thisType.getBase().getName(false, true);
errorMsg << "' does not have a member with the name '" << identifier << "'";
SOFT_ERROR_BOOL(node, ACCESS_TO_NON_EXISTING_MEMBER, errorMsg.str())
}
if (!fieldEntry->getQualType().getBase().isOneOf({TY_STRUCT, TY_INTERFACE}))
SOFT_ERROR_BOOL(node, INVALID_MEMBER_ACCESS,
Expand Down Expand Up @@ -2592,6 +2600,13 @@ bool TypeChecker::checkAsyncLambdaCaptureRules(const LambdaBaseNode *node, const
return false; // Violated
}

Function *TypeChecker::matchCopyCtor(const QualType &thisType, const ASTNode* node) {
Scope *matchScope = thisType.getBodyScope();
assert(matchScope != nullptr);
const ArgList args = {{thisType.toConstRef(node), false}};
return FunctionManager::match(this, matchScope, CTOR_FUNCTION_NAME, thisType, args, {}, true, node);
}

QualType TypeChecker::mapLocalTypeToImportedScopeType(const Scope *targetScope, const QualType &symbolType) const {
// Skip all types, except structs
if (!symbolType.isBase(TY_STRUCT))
Expand Down
1 change: 1 addition & 0 deletions src/typechecker/TypeChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class TypeChecker final : CompilerPass, public ASTVisitor {
bool visitFctPtrCall(const FctCallNode *node, const QualType &functionType) const;
bool visitMethodCall(FctCallNode *node, Scope *structScope, QualTypeList &templateTypes);
bool checkAsyncLambdaCaptureRules(const LambdaBaseNode *node, const LambdaAttrNode *attrs) const;
[[nodiscard]] Function *matchCopyCtor(const QualType& thisType, const ASTNode* node);
[[nodiscard]] QualType mapLocalTypeToImportedScopeType(const Scope *targetScope, const QualType &symbolType) const;
[[nodiscard]] QualType mapImportedScopeTypeToLocalType(const Scope *sourceScope, const QualType &symbolType) const;
static void autoDeReference(QualType &symbolType);
Expand Down
28 changes: 28 additions & 0 deletions test/test-files/std/test/lifetime-object/cout.out
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,31 @@ Return from function as value:
-- LifetimeObject 11 (dtor)
-- LifetimeObject 10 (dtor)
-- LifetimeObject 9 (dtor)
Ternary (true temporary, false temporary)
-- LifetimeObject 12 was created (ctor)
-- LifetimeObject 13 was created (ctor)
-- LifetimeObject 13 (dtor)
-- LifetimeObject 12 (dtor)
Ternary (true temporary, false not temporary)
-- LifetimeObject 14 was created (ctor)
-- LifetimeObject 15 was created (ctor)
-- LifetimeObject 14 was copied to LifetimeObject 16 (copy ctor)
-- LifetimeObject 16 (dtor)
-- LifetimeObject 15 (dtor)
-- LifetimeObject 14 (dtor)
Ternary (true not temporary, false temporary)
-- LifetimeObject 17 was created (ctor)
-- LifetimeObject 17 was copied to LifetimeObject 18 (copy ctor)
-- LifetimeObject 19 was created (ctor)
-- LifetimeObject 19 (dtor)
-- LifetimeObject 18 (dtor)
-- LifetimeObject 17 (dtor)
Ternary (true not temporary, false not temporary)
-- LifetimeObject 20 was created (ctor)
-- LifetimeObject 21 was created (ctor)
-- LifetimeObject 20 was copied to LifetimeObject 22 (copy ctor)
-- LifetimeObject 21 was copied to LifetimeObject 23 (copy ctor)
-- LifetimeObject 23 (dtor)
-- LifetimeObject 22 (dtor)
-- LifetimeObject 21 (dtor)
-- LifetimeObject 20 (dtor)
32 changes: 32 additions & 0 deletions test/test-files/std/test/lifetime-object/source.spice
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,36 @@ f<int> main() {
// Return value receiver - const ref
const LifetimeObject& loConstRef = spawnLO(); // Assigned to loConstRef
}

printf("Ternary (true temporary, false temporary)\n");
{
bool cond = true;
LifetimeObject loCopy1 = cond ? LifetimeObject() : LifetimeObject(); // ctor calls in both branches
LifetimeObject loCopy2 = !cond ? LifetimeObject() : LifetimeObject(); // ctor calls in both branches
}

printf("Ternary (true temporary, false not temporary)\n");
{
bool cond = true;
LifetimeObject lo = LifetimeObject();
LifetimeObject loCopy1 = cond ? LifetimeObject() : lo; // ctor in true branch, copy ctor in false branch
LifetimeObject loCopy2 = !cond ? LifetimeObject() : lo; // ctor in true branch, copy ctor in false branch
}

printf("Ternary (true not temporary, false temporary)\n");
{
bool cond = true;
LifetimeObject lo = LifetimeObject();
LifetimeObject loCopy1 = cond ? lo : LifetimeObject(); // copy ctor in true branch, ctor in false branch
LifetimeObject loCopy2 = !cond ? lo : LifetimeObject(); // copy ctor in true branch, ctor in false branch
}

printf("Ternary (true not temporary, false not temporary)\n");
{
bool cond = true;
LifetimeObject lo1 = LifetimeObject();
LifetimeObject lo2 = LifetimeObject();
LifetimeObject loCopy1 = cond ? lo1 : lo2; // copy ctor call in both branches
LifetimeObject loCopy2 = !cond ? lo1 : lo2; // copy ctor call in both branches
}
}

0 comments on commit c739d58

Please sign in to comment.