Skip to content

Commit

Permalink
[SandboxIR] Implement GlobalIFunc (#108622)
Browse files Browse the repository at this point in the history
This patch implements sandboxir::GlobalIFunc mirroring
llvm::GlobalIFunc.
  • Loading branch information
vporpo authored Sep 13, 2024
1 parent aca226c commit ae3e825
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 17 deletions.
90 changes: 88 additions & 2 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class DSOLocalEquivalent;
class ConstantTokenNone;
class GlobalValue;
class GlobalObject;
class GlobalIFunc;
class Context;
class Function;
class Instruction;
Expand Down Expand Up @@ -332,6 +333,7 @@ class Value {
friend class GlobalValue; // For `Val`.
friend class DSOLocalEquivalent; // For `Val`.
friend class GlobalObject; // For `Val`.
friend class GlobalIFunc; // For `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -1128,6 +1130,7 @@ class GlobalValue : public Constant {
friend class Context; // For constructor.

public:
using LinkageTypes = llvm::GlobalValue::LinkageTypes;
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
switch (From->getSubclassID()) {
Expand Down Expand Up @@ -1285,6 +1288,88 @@ class GlobalObject : public GlobalValue {
}
};

/// Provides API functions, like getIterator() and getReverseIterator() to
/// GlobalIFunc, Function, GlobalVariable and GlobalAlias. In LLVM IR these are
/// provided by ilist_node.
template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
typename LLVMParentT>
class GlobalWithNodeAPI : public ParentT {
/// Helper for mapped_iterator.
struct LLVMGVToGV {
Context &Ctx;
LLVMGVToGV(Context &Ctx) : Ctx(Ctx) {}
GlobalT &operator()(LLVMGlobalT &LLVMGV) const;
};

public:
GlobalWithNodeAPI(Value::ClassID ID, LLVMParentT *C, Context &Ctx)
: ParentT(ID, C, Ctx) {}

// TODO: Missing getParent(). Should be added once Module is available.

using iterator = mapped_iterator<
decltype(static_cast<LLVMGlobalT *>(nullptr)->getIterator()), LLVMGVToGV>;
using reverse_iterator = mapped_iterator<
decltype(static_cast<LLVMGlobalT *>(nullptr)->getReverseIterator()),
LLVMGVToGV>;
iterator getIterator() const {
auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
LLVMGVToGV ToGV(this->Ctx);
return map_iterator(LLVMGV->getIterator(), ToGV);
}
reverse_iterator getReverseIterator() const {
auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
LLVMGVToGV ToGV(this->Ctx);
return map_iterator(LLVMGV->getReverseIterator(), ToGV);
}
};

class GlobalIFunc final
: public GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
llvm::GlobalObject> {
GlobalIFunc(llvm::GlobalObject *C, Context &Ctx)
: GlobalWithNodeAPI(ClassID::GlobalIFunc, C, Ctx) {}
friend class Context; // For constructor.

public:
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
return From->getSubclassID() == ClassID::GlobalIFunc;
}

// TODO: Missing create() because we don't have a sandboxir::Module yet.

// TODO: Missing functions: copyAttributesFrom(), removeFromParent(),
// eraseFromParent()

void setResolver(Constant *Resolver);

Constant *getResolver() const;

// Return the resolver function after peeling off potential ConstantExpr
// indirection.
Function *getResolverFunction();
const Function *getResolverFunction() const {
return const_cast<GlobalIFunc *>(this)->getResolverFunction();
}

static bool isValidLinkage(LinkageTypes L) {
return llvm::GlobalIFunc::isValidLinkage(L);
}

// TODO: Missing applyAlongResolverPath().

#ifndef NDEBUG
void verify() const override {
assert(isa<llvm::GlobalIFunc>(Val) && "Expected a GlobalIFunc!");
}
void dumpOS(raw_ostream &OS) const override {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}
#endif
};

class BlockAddress final : public Constant {
BlockAddress(llvm::BlockAddress *C, Context &Ctx)
: Constant(ClassID::BlockAddress, C, Ctx) {}
Expand Down Expand Up @@ -4219,7 +4304,8 @@ class Context {
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
};

class Function : public GlobalObject {
class Function : public GlobalWithNodeAPI<Function, llvm::Function,
GlobalObject, llvm::GlobalObject> {
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
Expand All @@ -4230,7 +4316,7 @@ class Function : public GlobalObject {
};
/// Use Context::createFunction() instead.
Function(llvm::Function *F, sandboxir::Context &Ctx)
: GlobalObject(ClassID::Function, F, Ctx) {}
: GlobalWithNodeAPI(ClassID::Function, F, Ctx) {}
friend class Context; // For constructor.

public:
Expand Down
37 changes: 37 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,39 @@ void GlobalObject::setSection(StringRef S) {
cast<llvm::GlobalObject>(Val)->setSection(S);
}

template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
typename LLVMParentT>
GlobalT &GlobalWithNodeAPI<GlobalT, LLVMGlobalT, ParentT, LLVMParentT>::
LLVMGVToGV::operator()(LLVMGlobalT &LLVMGV) const {
return cast<GlobalT>(*Ctx.getValue(&LLVMGV));
}

namespace llvm::sandboxir {
// Explicit instantiations.
template class GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
llvm::GlobalObject>;
template class GlobalWithNodeAPI<Function, llvm::Function, GlobalObject,
llvm::GlobalObject>;
} // namespace llvm::sandboxir

void GlobalIFunc::setResolver(Constant *Resolver) {
Ctx.getTracker()
.emplaceIfTracking<
GenericSetter<&GlobalIFunc::getResolver, &GlobalIFunc::setResolver>>(
this);
cast<llvm::GlobalIFunc>(Val)->setResolver(
cast<llvm::Constant>(Resolver->Val));
}

Constant *GlobalIFunc::getResolver() const {
return Ctx.getOrCreateConstant(cast<llvm::GlobalIFunc>(Val)->getResolver());
}

Function *GlobalIFunc::getResolverFunction() {
return cast<Function>(Ctx.getOrCreateConstant(
cast<llvm::GlobalIFunc>(Val)->getResolverFunction()));
}

void GlobalValue::setUnnamedAddr(UnnamedAddr V) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&GlobalValue::getUnnamedAddr,
Expand Down Expand Up @@ -2727,6 +2760,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<Function>(
new Function(cast<llvm::Function>(C), *this));
break;
case llvm::Value::GlobalIFuncVal:
It->second = std::unique_ptr<GlobalIFunc>(
new GlobalIFunc(cast<llvm::GlobalIFunc>(C), *this));
break;
default:
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
break;
Expand Down
137 changes: 122 additions & 15 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,84 @@ define void @foo() {
EXPECT_EQ(GO->canIncreaseAlignment(), LLVMGO->canIncreaseAlignment());
}

TEST_F(SandboxIRTest, GlobalIFunc) {
parseIR(C, R"IR(
declare external void @bar()
@ifunc0 = ifunc void(), ptr @foo
@ifunc1 = ifunc void(), ptr @foo
define void @foo() {
call void @ifunc0()
call void @ifunc1()
call void @bar()
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
auto *LLVMBB = &*LLVMF.begin();
auto LLVMIt = LLVMBB->begin();
auto *LLVMCall0 = cast<llvm::CallInst>(&*LLVMIt++);
auto *LLVMIFunc0 = cast<llvm::GlobalIFunc>(LLVMCall0->getCalledOperand());

sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Call0 = cast<sandboxir::CallInst>(&*It++);
auto *Call1 = cast<sandboxir::CallInst>(&*It++);
auto *CallBar = cast<sandboxir::CallInst>(&*It++);
// Check classof(), creation.
auto *IFunc0 = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
auto *IFunc1 = cast<sandboxir::GlobalIFunc>(Call1->getCalledOperand());
auto *Bar = cast<sandboxir::Function>(CallBar->getCalledOperand());

// Check getIterator().
{
auto It0 = IFunc0->getIterator();
auto It1 = IFunc1->getIterator();
EXPECT_EQ(&*It0, IFunc0);
EXPECT_EQ(&*It1, IFunc1);
EXPECT_EQ(std::next(It0), It1);
EXPECT_EQ(std::prev(It1), It0);
EXPECT_EQ(&*std::next(It0), IFunc1);
EXPECT_EQ(&*std::prev(It1), IFunc0);
}
// Check getReverseIterator().
{
auto RevIt0 = IFunc0->getReverseIterator();
auto RevIt1 = IFunc1->getReverseIterator();
EXPECT_EQ(&*RevIt0, IFunc0);
EXPECT_EQ(&*RevIt1, IFunc1);
EXPECT_EQ(std::prev(RevIt0), RevIt1);
EXPECT_EQ(std::next(RevIt1), RevIt0);
EXPECT_EQ(&*std::prev(RevIt0), IFunc1);
EXPECT_EQ(&*std::next(RevIt1), IFunc0);
}

// Check setResolver(), getResolver().
EXPECT_EQ(IFunc0->getResolver(), Ctx.getValue(LLVMIFunc0->getResolver()));
auto *OrigResolver = IFunc0->getResolver();
auto *NewResolver = Bar;
EXPECT_NE(NewResolver, OrigResolver);
IFunc0->setResolver(NewResolver);
EXPECT_EQ(IFunc0->getResolver(), NewResolver);
IFunc0->setResolver(OrigResolver);
EXPECT_EQ(IFunc0->getResolver(), OrigResolver);
// Check getResolverFunction().
EXPECT_EQ(IFunc0->getResolverFunction(),
Ctx.getValue(LLVMIFunc0->getResolverFunction()));
// Check isValidLinkage().
for (auto L :
{GlobalValue::ExternalLinkage, GlobalValue::AvailableExternallyLinkage,
GlobalValue::LinkOnceAnyLinkage, GlobalValue::LinkOnceODRLinkage,
GlobalValue::WeakAnyLinkage, GlobalValue::WeakODRLinkage,
GlobalValue::AppendingLinkage, GlobalValue::InternalLinkage,
GlobalValue::PrivateLinkage, GlobalValue::ExternalWeakLinkage,
GlobalValue::CommonLinkage}) {
EXPECT_EQ(IFunc0->isValidLinkage(L), LLVMIFunc0->isValidLinkage(L));
}
}

TEST_F(SandboxIRTest, BlockAddress) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
Expand Down Expand Up @@ -1200,29 +1278,58 @@ define void @foo(i8 %v) {

TEST_F(SandboxIRTest, Function) {
parseIR(C, R"IR(
define void @foo(i32 %arg0, i32 %arg1) {
define void @foo0(i32 %arg0, i32 %arg1) {
bb0:
br label %bb1
bb1:
ret void
}
define void @foo1() {
ret void
}

)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
llvm::Argument *LLVMArg0 = LLVMF->getArg(0);
llvm::Argument *LLVMArg1 = LLVMF->getArg(1);
llvm::Function *LLVMF0 = &*M->getFunction("foo0");
llvm::Function *LLVMF1 = &*M->getFunction("foo1");
llvm::Argument *LLVMArg0 = LLVMF0->getArg(0);
llvm::Argument *LLVMArg1 = LLVMF0->getArg(1);

sandboxir::Context Ctx(C);
sandboxir::Function *F = Ctx.createFunction(LLVMF);
sandboxir::Function *F0 = Ctx.createFunction(LLVMF0);
sandboxir::Function *F1 = Ctx.createFunction(LLVMF1);

// Check getIterator().
{
auto It0 = F0->getIterator();
auto It1 = F1->getIterator();
EXPECT_EQ(&*It0, F0);
EXPECT_EQ(&*It1, F1);
EXPECT_EQ(std::next(It0), It1);
EXPECT_EQ(std::prev(It1), It0);
EXPECT_EQ(&*std::next(It0), F1);
EXPECT_EQ(&*std::prev(It1), F0);
}
// Check getReverseIterator().
{
auto RevIt0 = F0->getReverseIterator();
auto RevIt1 = F1->getReverseIterator();
EXPECT_EQ(&*RevIt0, F0);
EXPECT_EQ(&*RevIt1, F1);
EXPECT_EQ(std::prev(RevIt0), RevIt1);
EXPECT_EQ(std::next(RevIt1), RevIt0);
EXPECT_EQ(&*std::prev(RevIt0), F1);
EXPECT_EQ(&*std::next(RevIt1), F0);
}

// Check F arguments
EXPECT_EQ(F->arg_size(), 2u);
EXPECT_FALSE(F->arg_empty());
EXPECT_EQ(F->getArg(0), Ctx.getValue(LLVMArg0));
EXPECT_EQ(F->getArg(1), Ctx.getValue(LLVMArg1));
EXPECT_EQ(F0->arg_size(), 2u);
EXPECT_FALSE(F0->arg_empty());
EXPECT_EQ(F0->getArg(0), Ctx.getValue(LLVMArg0));
EXPECT_EQ(F0->getArg(1), Ctx.getValue(LLVMArg1));

// Check F.begin(), F.end(), Function::iterator
llvm::BasicBlock *LLVMBB = &*LLVMF->begin();
for (sandboxir::BasicBlock &BB : *F) {
llvm::BasicBlock *LLVMBB = &*LLVMF0->begin();
for (sandboxir::BasicBlock &BB : *F0) {
EXPECT_EQ(&BB, Ctx.getValue(LLVMBB));
LLVMBB = LLVMBB->getNextNode();
}
Expand All @@ -1232,17 +1339,17 @@ define void @foo(i32 %arg0, i32 %arg1) {
// Check F.dumpNameAndArgs()
std::string Buff;
raw_string_ostream BS(Buff);
F->dumpNameAndArgs(BS);
EXPECT_EQ(Buff, "void @foo(i32 %arg0, i32 %arg1)");
F0->dumpNameAndArgs(BS);
EXPECT_EQ(Buff, "void @foo0(i32 %arg0, i32 %arg1)");
}
{
// Check F.dump()
std::string Buff;
raw_string_ostream BS(Buff);
BS << "\n";
F->dumpOS(BS);
F0->dumpOS(BS);
EXPECT_EQ(Buff, R"IR(
void @foo(i32 %arg0, i32 %arg1) {
void @foo0(i32 %arg0, i32 %arg1) {
bb0:
br label %bb1 ; SB4. (Br)

Expand Down
32 changes: 32 additions & 0 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,38 @@ define void @foo() {
EXPECT_EQ(GV->getVisibility(), OrigVisibility);
}

TEST_F(TrackerTest, GlobalIFuncSetters) {
parseIR(C, R"IR(
declare external void @bar()
@ifunc = ifunc void(), ptr @foo
define void @foo() {
call void @ifunc()
call void @bar()
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Call0 = cast<sandboxir::CallInst>(&*It++);
auto *Call1 = cast<sandboxir::CallInst>(&*It++);
// Check classof(), creation.
auto *IFunc = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
auto *Bar = cast<sandboxir::Function>(Call1->getCalledOperand());
// Check setResolver().
auto *OrigResolver = IFunc->getResolver();
auto *NewResolver = Bar;
EXPECT_NE(NewResolver, OrigResolver);
Ctx.save();
IFunc->setResolver(NewResolver);
EXPECT_EQ(IFunc->getResolver(), NewResolver);
Ctx.revert();
EXPECT_EQ(IFunc->getResolver(), OrigResolver);
}

TEST_F(TrackerTest, SetVolatile) {
parseIR(C, R"IR(
define void @foo(ptr %arg0, i8 %val) {
Expand Down

0 comments on commit ae3e825

Please sign in to comment.