Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 86 additions & 8 deletions src/passes/GlobalEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,87 @@ struct FuncInfo {
std::unordered_set<HeapType> indirectCalledTypes;
};

// Only funcs that are referenced may be the target of an indirect call. A
// function is referenced if:
// - It appears in a ref.func expression (this includes `elem` statements
// because of how our IR is represented).
// - It's exported, because it may flow back to us as a reference.
//
// If a function doesn't meet any of these criteria, it can't be the target of
// an indirect call and we don't need to include its effects in indirect calls.
std::unordered_set<Function*> getReferencedFuncs(Module& module,
PassRunner& passRunner) {
struct AddressedFuncsWalker : WalkerPass<PostWalker<AddressedFuncsWalker>> {
// For each function, which functions are referenced in its body.
// The key for `nullptr` contains references that are not in a function
// (e.g. `elem` segments).
std::unordered_map<Function*, std::unordered_set<Function*>>&
allReferencedFuncs;
// Points to `allReferencedFuncs`.
std::unordered_set<Function*>* referencedFuncs = nullptr;

AddressedFuncsWalker(
std::unordered_map<Function*, std::unordered_set<Function*>>&
allReferencedFuncs)
: allReferencedFuncs(allReferencedFuncs),
referencedFuncs(&allReferencedFuncs[nullptr]) {}

std::unique_ptr<Pass> create() override {
return std::make_unique<AddressedFuncsWalker>(allReferencedFuncs);
}

bool isFunctionParallel() override { return true; }

bool modifiesBinaryenIR() override { return false; }

void doWalkFunction(Function* func) {
referencedFuncs = &allReferencedFuncs.at(func);
walk(func->body);
}

void visitRefFunc(RefFunc* refFunc) {
referencedFuncs->insert(getModule()->getFunction(refFunc->func));
}
};

std::unordered_map<Function*, std::unordered_set<Function*>>
allReferencedFuncs;
for (auto& func : module.functions) {
allReferencedFuncs[func.get()];
}

AddressedFuncsWalker walker(allReferencedFuncs);
walker.run(&passRunner, &module);
walker.runOnModuleCode(&passRunner, &module);

std::unordered_set<Function*> mergedReferencedFuncs;
for (auto& [_, referencedFuncs] : allReferencedFuncs) {
mergedReferencedFuncs.merge(referencedFuncs);
}

for (const auto& export_ : module.exports) {
if (export_->kind != ExternalKind::Function) {
continue;
}

mergedReferencedFuncs.insert(
module.getFunction(*export_->getInternalName()));
}

return mergedReferencedFuncs;
}

std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
const PassOptions& passOptions) {
ModuleUtils::ParallelFunctionAnalysis<FuncInfo> analysis(
module, [&](Function* func, FuncInfo& funcInfo) {
if (func->imported()) {
// Imports can do anything, so we need to assume the worst anyhow,
// which is the same as not specifying any effects for them in the
// map (which we do by not setting funcInfo.effects).
// Imports can do almost anything, so we need to assume the worst
// anyhow, which is the same as not specifying any effects for them in
// the map (which we do by not setting funcInfo.effects).
//
// TODO: We can be more precise here since imports can't mutate
// globals/tables/memories that aren't imported or exported.
return;
}

Expand Down Expand Up @@ -144,6 +217,7 @@ using CallGraph =

CallGraph buildCallGraph(const Module& module,
const std::map<Function*, FuncInfo>& funcInfos,
const std::unordered_set<Function*>& referencedFuncs,
WorldMode worldMode) {
CallGraph callGraph;
if (worldMode == WorldMode::Open) {
Expand Down Expand Up @@ -179,16 +253,18 @@ CallGraph buildCallGraph(const Module& module,
}

// Type -> Function
callGraph[caller->type.getHeapType()].insert(caller);
if (referencedFuncs.contains(caller)) {
callGraph[caller->type.getHeapType()].insert(caller);
}
}

// Type -> Type
// Do a DFS up the type heirarchy for all function implementations.
// Do a DFS up the type hierarchy for all function implementations.
// We are essentially walking up each supertype chain and adding edges from
// super -> subtype, but doing it via DFS to avoid repeated work.
Graph superTypeGraph(allFunctionTypes.begin(),
allFunctionTypes.end(),
[&callGraph](auto&& push, HeapType t) {
[&callGraph](const auto& push, HeapType t) {
// Not needed except that during lookup we expect the
// key to exist.
callGraph[t];
Expand Down Expand Up @@ -350,8 +426,10 @@ struct GenerateGlobalEffects : public Pass {
std::map<Function*, FuncInfo> funcInfos =
analyzeFuncs(*module, getPassOptions());

auto callGraph =
buildCallGraph(*module, funcInfos, getPassOptions().worldMode);
auto referencedFuncs = getReferencedFuncs(*module, *getPassRunner());

auto callGraph = buildCallGraph(
*module, funcInfos, referencedFuncs, getPassOptions().worldMode);

propagateEffects(*module,
getPassOptions(),
Expand Down
Loading
Loading