//===- Inliner.cpp - Pass to inline function calls ------------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements a basic inlining algorithm that operates bottom up over // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more // incremental propagation of inlining decisions from the leafs to the roots of // the callgraph. // //===----------------------------------------------------------------------===// #include "mlir/Analysis/CallGraph.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Module.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SCCIterator.h" using namespace mlir; //===----------------------------------------------------------------------===// // CallGraph traversal //===----------------------------------------------------------------------===// /// Run a given transformation over the SCCs of the callgraph in a bottom up /// traversal. static void runTransformOnCGSCCs( const CallGraph &cg, function_ref)> sccTransformer) { for (auto cgi = llvm::scc_begin(&cg); !cgi.isAtEnd(); ++cgi) sccTransformer(*cgi); } namespace { /// This struct represents a resolved call to a given callgraph node. Given that /// the call does not actually contain a direct reference to the /// Region(CallGraphNode) that it is dispatching to, we need to resolve them /// explicitly. struct ResolvedCall { ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) : call(call), targetNode(targetNode) {} CallOpInterface call; CallGraphNode *targetNode; }; } // end anonymous namespace /// Collect all of the callable operations within the given range of blocks. If /// `traverseNestedCGNodes` is true, this will also collect call operations /// inside of nested callgraph nodes. static void collectCallOps(llvm::iterator_range blocks, CallGraph &cg, SmallVectorImpl &calls, bool traverseNestedCGNodes) { SmallVector worklist; auto addToWorklist = [&](llvm::iterator_range blocks) { for (Block &block : blocks) worklist.push_back(&block); }; addToWorklist(blocks); while (!worklist.empty()) { for (Operation &op : *worklist.pop_back_val()) { if (auto call = dyn_cast(op)) { CallGraphNode *node = cg.resolveCallable(call.getCallableForCallee(), &op); if (!node->isExternal()) calls.emplace_back(call, node); continue; } // If this is not a call, traverse the nested regions. If // `traverseNestedCGNodes` is false, then don't traverse nested call graph // regions. for (auto &nestedRegion : op.getRegions()) if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) addToWorklist(nestedRegion); } } } //===----------------------------------------------------------------------===// // Inliner //===----------------------------------------------------------------------===// namespace { /// This class provides a specialization of the main inlining interface. struct Inliner : public InlinerInterface { Inliner(MLIRContext *context, CallGraph &cg) : InlinerInterface(context), cg(cg) {} /// Process a set of blocks that have been inlined. This callback is invoked /// *before* inlined terminator operations have been processed. void processInlinedBlocks( llvm::iterator_range inlinedBlocks) final { collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); } /// The current set of call instructions to consider for inlining. SmallVector calls; /// The callgraph being operated on. CallGraph &cg; }; } // namespace /// Returns true if the given call should be inlined. static bool shouldInline(ResolvedCall &resolvedCall) { // Don't allow inlining terminator calls. We currently don't support this // case. if (resolvedCall.call.getOperation()->isKnownTerminator()) return false; // Don't allow inlining if the target is an ancestor of the call. This // prevents inlining recursively. if (resolvedCall.targetNode->getCallableRegion()->isAncestor( resolvedCall.call.getParentRegion())) return false; // Otherwise, inline. return true; } /// Attempt to inline calls within the given scc. static void inlineCallsInSCC(Inliner &inliner, ArrayRef currentSCC) { CallGraph &cg = inliner.cg; auto &calls = inliner.calls; // Collect all of the direct calls within the nodes of the current SCC. We // don't traverse nested callgraph nodes, because they are handled separately // likely within a different SCC. for (auto *node : currentSCC) { if (!node->isExternal()) collectCallOps(*node->getCallableRegion(), cg, calls, /*traverseNestedCGNodes=*/false); } if (calls.empty()) return; // Try to inline each of the call operations. Don't cache the end iterator // here as more calls may be added during inlining. for (unsigned i = 0; i != calls.size(); ++i) { ResolvedCall &it = calls[i]; if (!shouldInline(it)) continue; CallOpInterface call = it.call; Region *targetRegion = it.targetNode->getCallableRegion(); LogicalResult inlineResult = inlineCall( inliner, call, cast(targetRegion->getParentOp()), targetRegion); if (failed(inlineResult)) continue; // If the inlining was successful, then erase the call. call.erase(); } calls.clear(); } //===----------------------------------------------------------------------===// // InlinerPass //===----------------------------------------------------------------------===// // TODO(riverriddle) This pass should currently only be used for basic testing // of inlining functionality. namespace { struct InlinerPass : public OperationPass { void runOnOperation() override { CallGraph &cg = getAnalysis(); Inliner inliner(&getContext(), cg); // Run the inline transform in post-order over the SCCs in the callgraph. runTransformOnCGSCCs(cg, [&](ArrayRef scc) { inlineCallsInSCC(inliner, scc); }); } }; } // end anonymous namespace static PassRegistration pass("inline", "Inline function calls");