Start TFLite legalizer pass

Start of TFLite legalizer pass. Currently focussed on macro expanding ops, limited to what is registered directly in a separate pass (this should instead be a general pass), no querying of what gets produced, the matching is string based instead of using the ops proper (the matching TF ops should be defined) etc. This is a step to enable prototyping. In addition to the above shortcomings, the legalizer is very verbose in this form and should instead be driven by autogenerated patterns (same is true for the op builders too). But this starts from the explicit form and extracting out commonality in follow up.

Add definition for tfl.relu for basic selection of fused relu add.

PiperOrigin-RevId: 220287087
This commit is contained in:
Jacques Pienaar 2018-11-06 08:33:10 -08:00 committed by jpienaar
parent 4269a01863
commit 5e01000d46

View file

@ -32,6 +32,7 @@
#include "mlir/Pass.h"
#include "mlir/TensorFlow/ControlFlowOps.h"
#include "mlir/TensorFlow/Passes.h"
#include "mlir/TensorFlowLite/Passes.h"
#include "mlir/Transforms/CFGFunctionViewGraph.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/XLA/Passes.h"
@ -71,11 +72,12 @@ enum Passes {
ComposeAffineMaps,
ConstantFold,
ConvertToCFG,
MemRefBoundCheck,
MemRefDependenceCheck,
TFLiteLegaize,
LoopFusion,
LoopUnroll,
LoopUnrollAndJam,
MemRefBoundCheck,
MemRefDependenceCheck,
PipelineDataTransfer,
PrintCFGGraph,
SimplifyAffineStructures,
@ -94,13 +96,13 @@ static cl::list<Passes> passList(
"Constant fold operations in functions"),
clEnumValN(ConvertToCFG, "convert-to-cfg",
"Convert all ML functions in the module to CFG ones"),
clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"),
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
clEnumValN(MemRefBoundCheck, "memref-bound-check",
"Convert all ML functions in the module to CFG ones"),
clEnumValN(MemRefDependenceCheck, "memref-dependence-check",
"Checks dependences between all pairs of memref accesses."),
clEnumValN(LoopFusion, "loop-fusion", "Fuse loop nests"),
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam", "Unroll and jam loops"),
clEnumValN(PipelineDataTransfer, "pipeline-data-transfer",
"Pipeline non-blocking data transfers between"
"explicitly managed levels of the memory hierarchy"),
@ -108,6 +110,8 @@ static cl::list<Passes> passList(
"Print CFG graph per function"),
clEnumValN(SimplifyAffineStructures, "simplify-affine-structures",
"Simplify affine expressions"),
clEnumValN(TFLiteLegaize, "tfl-legalize",
"Legalize operations to TensorFlow Lite dialect"),
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
"Dynamic TensorFlow Switch/Match nodes to a CFG"),
clEnumValN(Vectorize, "vectorize",
@ -200,12 +204,6 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
case ConvertToCFG:
pass = createConvertToCFGPass();
break;
case MemRefBoundCheck:
pass = createMemRefBoundCheckPass();
break;
case MemRefDependenceCheck:
pass = createMemRefDependenceCheckPass();
break;
case LoopFusion:
pass = createLoopFusionPass();
break;
@ -215,6 +213,12 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
case LoopUnrollAndJam:
pass = createLoopUnrollAndJamPass();
break;
case MemRefBoundCheck:
pass = createMemRefBoundCheckPass();
break;
case MemRefDependenceCheck:
pass = createMemRefDependenceCheckPass();
break;
case PipelineDataTransfer:
pass = createPipelineDataTransferPass();
break;
@ -224,6 +228,9 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
case SimplifyAffineStructures:
pass = createSimplifyAffineStructuresPass();
break;
case TFLiteLegaize:
pass = tfl::createLegalizer();
break;
case TFRaiseControlFlow:
pass = createRaiseTFControlFlowPass();
break;