Simple CPU runner

This implements a simple CPU runner based on LLVM Orc JIT.  The base
functionality is provided by the ExecutionEngine class that compiles and links
the module, and provides an interface for obtaining function pointers to the
JIT-compiled MLIR functions and for invoking those functions directly.  Since
function pointers need to be casted to the correct pointer type, the
ExecutionEngine wraps LLVM IR functions obtained from MLIR into a helper
function with the common signature `void (void **)` where the single argument
is interpreted as a list of pointers to the actual arguments passed to the
function, eventually followed by a pointer to the result of the function.
Additionally, the ExecutionEngine is set up to resolve library functions to
those available in the current process, enabling support for, e.g., simple C
library calls.

For integration purposes, this also provides a simplistic runtime for memref
descriptors as expected by the LLVM IR code produced by MLIR translation.  In
particular, memrefs are transformed into LLVM structs (can be mapped to C
structs) with a pointer to the data, followed by dynamic sizes.  This
implementation only supports statically-shaped memrefs of type float, but can
be extened if necessary.

Provide a binary for the runner and a test that exercises it.

PiperOrigin-RevId: 230876363
This commit is contained in:
Alex Zinenko 2019-01-25 03:16:06 -08:00 committed by jpienaar
parent 5c5739d42b
commit 5a4403787f
6 changed files with 744 additions and 0 deletions

View file

@ -0,0 +1,92 @@
//===- ExecutionEngine.h - MLIR Execution engine and utils -----*- C++ -*--===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file provides a JIT-backed execution engine for MLIR modules.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_
#define MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_
#include "mlir/Support/LLVM.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Error.h"
#include <memory>
namespace llvm {
template <typename T> class Expected;
}
namespace mlir {
class Module;
namespace impl {
class OrcJIT;
} // end namespace impl
/// JIT-backed execution engine for MLIR modules. Assumes the module can be
/// converted to LLVM IR. For each function, creates a wrapper function with
/// the fixed interface
///
/// void _mlir_funcName(void **)
///
/// where the only argument is interpreted as a list of pointers to the actual
/// arguments of the function, followed by a pointer to the result. This allows
/// the engine to provide the caller with a generic function pointer that can
/// be used to invoke the JIT-compiled function.
class ExecutionEngine {
public:
~ExecutionEngine();
/// Creates an execution engine for the given module.
static llvm::Expected<std::unique_ptr<ExecutionEngine>> create(Module *m);
/// Looks up a packed-argument function with the given name and returns a
/// pointer to it. Propagates errors in case of failure.
llvm::Expected<void (*)(void **)> lookup(StringRef name) const;
/// Invokes the function with the given name passing it the list of arguments.
/// The arguments are accepted by lvalue-reference since the packed function
/// interface expects a list of non-null pointers.
template <typename... Args>
llvm::Error invoke(StringRef name, Args &... args);
private:
// FIXME: we may want a `unique_ptr` here if impl::OrcJIT decides to provide
// a default constructor.
impl::OrcJIT *jit;
llvm::LLVMContext llvmContext;
};
template <typename... Args>
llvm::Error ExecutionEngine::invoke(StringRef name, Args &... args) {
auto expectedFPtr = lookup(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
auto fptr = *expectedFPtr;
llvm::SmallVector<void *, 8> packedArgs{static_cast<void *>(&args)...};
(*fptr)(packedArgs.data());
return llvm::Error::success();
}
} // end namespace mlir
#endif // MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_

View file

@ -0,0 +1,55 @@
//===- MemRefUtils.h - MLIR runtime utilities for memrefs -------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is a set of utilities to working with objects of memref type in an JIT
// context using the MLIR execution engine.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
#include "mlir/Support/LLVM.h"
namespace llvm {
template <typename T> class Expected;
}
namespace mlir {
class Function;
/// Simple memref descriptor class compatible with the ABI of functions emitted
/// by MLIR to LLVM IR conversion for statically-shaped memrefs of float type.
struct StaticFloatMemRef {
float *data;
};
/// Given an MLIR function that takes only statically-shaped memrefs with
/// element type f32, allocate the memref descriptor and the data storage for
/// each of the arguments, initialize the storage with `initialValue`, and
/// return a list of type-erased descriptor pointers.
llvm::Expected<SmallVector<void *, 8>>
allocateMemRefArguments(const Function *func, float initialValue = 0.0);
/// Free a list of type-erased descriptors to statically-shaped memrefs with
/// element type f32.
void freeMemRefArguments(ArrayRef<void *> args);
} // namespace mlir
#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_

View file

@ -0,0 +1,299 @@
//===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the execution engine for MLIR modules based on LLVM Orc
// JIT engine.
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/TargetRegistry.h"
using namespace mlir;
using llvm::Error;
using llvm::Expected;
namespace {
// Memory manager for the JIT's objectLayer. Its main goal is to fallback to
// resolving functions in the current process if they cannot be resolved in the
// JIT-compiled modules.
class MemoryManager : public llvm::SectionMemoryManager {
public:
MemoryManager(llvm::orc::ExecutionSession &execSession)
: session(execSession) {}
// Resolve the named symbol. First, try looking it up in the main library of
// the execution session. If there is no such symbol, try looking it up in
// the current process (for example, if it is a standard library function).
// Return `nullptr` if lookup fails.
llvm::JITSymbol findSymbol(const std::string &name) override {
auto mainLibSymbol = session.lookup({&session.getMainJITDylib()}, name);
if (mainLibSymbol)
return mainLibSymbol.get();
auto address = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
if (!address) {
llvm::errs() << "Could not look up: " << name << '\n';
return nullptr;
}
return llvm::JITSymbol(address, llvm::JITSymbolFlags::Exported);
}
private:
llvm::orc::ExecutionSession &session;
};
} // end anonymous namespace
namespace mlir {
namespace impl {
// Simple layered Orc JIT compilation engine.
class OrcJIT {
public:
// Construct a JIT engine for the target host defined by `machineBuilder`,
// using the data layout provided as `dataLayout`.
// Setup the object layer to use our custom memory manager in order to resolve
// calls to library functions present in the process.
OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
llvm::DataLayout layout)
: objectLayer(
session,
[this]() { return llvm::make_unique<MemoryManager>(session); }),
compileLayer(
session, objectLayer,
llvm::orc::ConcurrentIRCompiler(std::move(machineBuilder))),
dataLayout(layout), mangler(session, this->dataLayout),
threadSafeCtx(llvm::make_unique<llvm::LLVMContext>()) {
session.getMainJITDylib().setGenerator(
cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
layout)));
}
// Create a JIT engine for the current host.
static Expected<std::unique_ptr<OrcJIT>> createDefault() {
auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!machineBuilder)
return machineBuilder.takeError();
auto dataLayout = machineBuilder->getDefaultDataLayoutForTarget();
if (!dataLayout)
return dataLayout.takeError();
return llvm::make_unique<OrcJIT>(std::move(*machineBuilder),
std::move(*dataLayout));
}
// Add an LLVM module to the main library managed by the JIT engine.
Error addModule(std::unique_ptr<llvm::Module> M) {
return compileLayer.add(
session.getMainJITDylib(),
llvm::orc::ThreadSafeModule(std::move(M), threadSafeCtx));
}
// Lookup a symbol in the main library managed by the JIT engine.
Expected<llvm::JITEvaluatedSymbol> lookup(StringRef Name) {
return session.lookup({&session.getMainJITDylib()}, mangler(Name.str()));
}
private:
llvm::orc::ExecutionSession session;
llvm::orc::RTDyldObjectLinkingLayer objectLayer;
llvm::orc::IRCompileLayer compileLayer;
llvm::DataLayout dataLayout;
llvm::orc::MangleAndInterner mangler;
llvm::orc::ThreadSafeContext threadSafeCtx;
};
} // end namespace impl
} // namespace mlir
// Wrap a string into an llvm::StringError.
static inline Error make_string_error(const llvm::Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
// Given a list of PassInfo coming from a higher level, creates the passes to
// run as an owning vector and appends the extra required passes to lower to
// LLVMIR. Currently, these extra passes are:
// - constant folding
// - CSE
// - canonicalization
// - affine lowering
static std::vector<std::unique_ptr<mlir::Pass>>
getDefaultPasses(const std::vector<const mlir::PassInfo *> &mlirPassInfoList) {
std::vector<std::unique_ptr<mlir::Pass>> passList;
passList.reserve(mlirPassInfoList.size() + 4);
// Run each of the passes that were selected.
for (const auto *passInfo : mlirPassInfoList) {
passList.emplace_back(passInfo->createPass());
}
// Append the extra passes for lowering to MLIR.
passList.emplace_back(mlir::createConstantFoldPass());
passList.emplace_back(mlir::createCSEPass());
passList.emplace_back(mlir::createCanonicalizerPass());
passList.emplace_back(mlir::createLowerAffinePass());
return passList;
}
// Run the passes sequentially on the given module.
// Return `nullptr` immediately if any of the passes fails.
static bool runPasses(const std::vector<std::unique_ptr<mlir::Pass>> &passes,
Module *module) {
for (const auto &pass : passes) {
mlir::PassResult result = pass->runOnModule(module);
if (result == mlir::PassResult::Failure || module->verify()) {
llvm::errs() << "Pass failed\n";
return true;
}
}
return false;
}
// Setup LLVM target triple from the current machine.
static bool setupTargetTriple(llvm::Module *llvmModule) {
// Setup the machine properties from the current architecture.
auto targetTriple = llvm::sys::getDefaultTargetTriple();
std::string errorMessage;
auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
if (!target) {
llvm::errs() << "NO target: " << errorMessage << "\n";
return true;
}
auto machine =
target->createTargetMachine(targetTriple, "generic", "", {}, {});
llvmModule->setDataLayout(machine->createDataLayout());
llvmModule->setTargetTriple(targetTriple);
return false;
}
static std::string makePackedFunctionName(StringRef name) {
return "_mlir_" + name.str();
}
// For each function in the LLVM module, define an interface function that wraps
// all the arguments of the original function and all its results into an i8**
// pointer to provide a unified invocation interface.
void packFunctionArguments(llvm::Module *module) {
auto &ctx = module->getContext();
llvm::IRBuilder<> builder(ctx);
llvm::DenseSet<llvm::Function *> interfaceFunctions;
for (auto &func : module->getFunctionList()) {
if (func.isDeclaration()) {
continue;
}
if (interfaceFunctions.count(&func)) {
continue;
}
// Given a function `foo(<...>)`, define the interface function
// `mlir_foo(i8**)`.
auto newType = llvm::FunctionType::get(
builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
/*isVarArg=*/false);
auto newName = makePackedFunctionName(func.getName());
llvm::Constant *funcCst = module->getOrInsertFunction(newName, newType);
llvm::Function *interfaceFunc = llvm::cast<llvm::Function>(funcCst);
interfaceFunctions.insert(interfaceFunc);
// Extract the arguments from the type-erased argument list and cast them to
// the proper types.
auto bb = llvm::BasicBlock::Create(ctx);
bb->insertInto(interfaceFunc);
builder.SetInsertPoint(bb);
llvm::Value *argList = interfaceFunc->arg_begin();
llvm::SmallVector<llvm::Value *, 8> args;
args.reserve(llvm::size(func.args()));
for (auto &indexedArg : llvm::enumerate(func.args())) {
llvm::Value *argIndex = llvm::Constant::getIntegerValue(
builder.getInt64Ty(), llvm::APInt(64, indexedArg.index()));
llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex);
llvm::Value *argPtr = builder.CreateLoad(argPtrPtr);
argPtr = builder.CreateBitCast(
argPtr, indexedArg.value().getType()->getPointerTo());
llvm::Value *arg = builder.CreateLoad(argPtr);
args.push_back(arg);
}
// Call the implementation function with the extracted arguments.
llvm::Value *result = builder.CreateCall(&func, args);
// Assuming the result is one value, potentially of type `void`.
if (!result->getType()->isVoidTy()) {
llvm::Value *retIndex = llvm::Constant::getIntegerValue(
builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args())));
llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex);
llvm::Value *retPtr = builder.CreateLoad(retPtrPtr);
retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
builder.CreateStore(result, retPtr);
}
// The interface function returns void.
builder.CreateRetVoid();
}
}
ExecutionEngine::~ExecutionEngine() {
if (jit)
delete jit;
}
Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(Module *m) {
auto engine = llvm::make_unique<ExecutionEngine>();
auto expectedJIT = impl::OrcJIT::createDefault();
if (!expectedJIT)
return expectedJIT.takeError();
if (runPasses(getDefaultPasses({}), m))
return make_string_error("passes failed");
auto llvmModule = convertModuleToLLVMIR(*m, engine->llvmContext);
if (!llvmModule)
return make_string_error("could not convert to LLVM IR");
// FIXME: the triple should be passed to the translation or dialect conversion
// instead of this. Currently, the LLVM module created above has no triple
// associated with it.
setupTargetTriple(llvmModule.get());
packFunctionArguments(llvmModule.get());
engine->jit = std::move(*expectedJIT).release();
if (auto err = engine->jit->addModule(std::move(llvmModule)))
return std::move(err);
return engine;
}
Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
auto expectedSymbol = jit->lookup(makePackedFunctionName(name));
if (!expectedSymbol)
return expectedSymbol.takeError();
auto rawFPtr = expectedSymbol->getAddress();
auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr);
if (!fptr)
return make_string_error("looked up function is null");
return fptr;
}

View file

@ -0,0 +1,106 @@
//===- MemRefUtils.cpp - MLIR runtime utilities for memrefs ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is a set of utilities to working with objects of memref type in an JIT
// context using the MLIR execution engine.
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Error.h"
using namespace mlir;
static inline llvm::Error make_string_error(const llvm::Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
static llvm::Expected<StaticFloatMemRef *>
allocMemRefDescriptor(Type type, bool allocateData = true,
float initialValue = 0.0) {
auto memRefType = type.dyn_cast<MemRefType>();
if (!memRefType)
return make_string_error("non-memref argument not supported");
if (memRefType.getNumDynamicDims() != 0)
return make_string_error("memref with dynamic shapes not supported");
auto elementType = memRefType.getElementType();
if (!elementType.isF32())
return make_string_error(
"memref with element other than f32 not supported");
auto *descriptor =
reinterpret_cast<StaticFloatMemRef *>(malloc(sizeof(StaticFloatMemRef)));
if (!allocateData) {
descriptor->data = nullptr;
return descriptor;
}
auto shape = memRefType.getShape();
int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
descriptor->data = reinterpret_cast<float *>(malloc(sizeof(float) * size));
for (int64_t i = 0; i < size; ++i) {
descriptor->data[i] = initialValue;
}
return descriptor;
}
llvm::Expected<SmallVector<void *, 8>>
mlir::allocateMemRefArguments(const Function *func, float initialValue) {
SmallVector<void *, 8> args;
args.reserve(func->getNumArguments());
for (const auto &arg : func->getArguments()) {
auto descriptor =
allocMemRefDescriptor(arg->getType(),
/*allocateData=*/true, initialValue);
if (!descriptor)
return descriptor.takeError();
args.push_back(*descriptor);
}
if (func->getType().getNumResults() > 1)
return make_string_error("functions with more than 1 result not supported");
for (Type resType : func->getType().getResults()) {
auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false);
if (!descriptor)
return descriptor.takeError();
args.push_back(*descriptor);
}
return args;
}
// Because the function can return the same descriptor as passed in arguments,
// we check that we don't attempt to free the underlying data twice.
void mlir::freeMemRefArguments(ArrayRef<void *> args) {
llvm::DenseSet<void *> dataPointers;
for (void *arg : args) {
float *dataPtr = reinterpret_cast<StaticFloatMemRef *>(arg)->data;
if (dataPointers.count(dataPtr) == 0) {
free(dataPtr);
dataPointers.insert(dataPtr);
}
free(arg);
}
}

View file

@ -0,0 +1,30 @@
// RUN: mlir-cpu-runner %s | FileCheck %s
// RUN: mlir-cpu-runner -e foo -init-value 1000 %s | FileCheck -check-prefix=NOMAIN %s
func @fabsf(f32) -> f32
func @main(%a : memref<2xf32>, %b : memref<1xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = constant -420.0 : f32
%1 = load %a[%c0] : memref<2xf32>
%2 = load %a[%c1] : memref<2xf32>
%3 = addf %0, %1 : f32
%4 = addf %3, %2 : f32
%5 = call @fabsf(%4) : (f32) -> f32
store %5, %b[%c0] : memref<1xf32>
return
}
// CHECK: 0.000000e+00 0.000000e+00
// CHECK-NEXT: 4.200000e+02
func @foo(%a : memref<1x1xf32>) -> memref<1x1xf32> {
%c0 = constant 0 : index
%0 = constant 1234.0 : f32
%1 = load %a[%c0, %c0] : memref<1x1xf32>
%2 = addf %1, %0 : f32
store %2, %a[%c0, %c0] : memref<1x1xf32>
return %a : memref<1x1xf32>
}
// NOMAIN: 2.234000e+03
// NOMAIN-NEXT: 2.234000e+03

View file

@ -0,0 +1,162 @@
//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is a command line utility that executes an MLIR file on the CPU by
// translating MLIR to LLVM IR before JIT-compiling and executing the latter.
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
using llvm::Error;
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string>
initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
static llvm::cl::opt<std::string>
mainFuncName("e", llvm::cl::desc("The function to be called"),
llvm::cl::value_desc("<function name>"),
llvm::cl::init("main"));
static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
MLIRContext *context) {
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
}
// Initialize the relevant subsystems of LLVM.
static void initializeLLVM() {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
}
static inline Error make_string_error(const llvm::Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
static void printOneMemRef(Type t, void *val) {
auto memRefType = t.cast<MemRefType>();
auto shape = memRefType.getShape();
int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
for (int64_t i = 0; i < size; ++i) {
llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
}
llvm::outs() << '\n';
}
static void printMemRefArguments(const Function *func, ArrayRef<void *> args) {
auto properArgs = args.take_front(func->getNumArguments());
for (const auto &kvp : llvm::zip(func->getArguments(), properArgs)) {
auto arg = std::get<0>(kvp);
auto val = std::get<1>(kvp);
printOneMemRef(arg->getType(), val);
}
auto results = args.drop_front(func->getNumArguments());
for (const auto &kvp : llvm::zip(func->getType().getResults(), results)) {
auto type = std::get<0>(kvp);
auto val = std::get<1>(kvp);
printOneMemRef(type, val);
}
}
static Error compileAndExecute(Module *module, StringRef entryPoint) {
Function *mainFunction = module->getNamedFunction(entryPoint);
if (!mainFunction || mainFunction->getBlocks().empty()) {
return make_string_error("entry point not found");
}
float init = std::stof(initValue.getValue());
auto expectedArguments = allocateMemRefArguments(mainFunction, init);
if (!expectedArguments)
return expectedArguments.takeError();
auto expectedEngine = mlir::ExecutionEngine::create(module);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(expectedArguments->data());
printMemRefArguments(mainFunction, *expectedArguments);
freeMemRefArguments(*expectedArguments);
return Error::success();
}
int main(int argc, char **argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
initializeLLVM();
MLIRContext context;
auto m = parseMLIRInput(inputFilename, &context);
if (!m) {
llvm::errs() << "could not parse the input IR\n";
return 1;
}
auto error = compileAndExecute(m.get(), mainFuncName.getValue());
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {
llvm::errs() << "Error: ";
info.log(llvm::errs());
llvm::errs() << '\n';
exitCode = EXIT_FAILURE;
});
return exitCode;
}