Add custom builder for AffineIfOp
Closes tensorflow/mlir#109 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/109 from nmostafa:nmostafa/AffineIfOp 7dbf2115f0092ffab26381ea8704aa05a0253971 PiperOrigin-RevId: 267633077
This commit is contained in:
parent
1b8eff8fcd
commit
8154370b49
|
@ -212,9 +212,10 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
|
|||
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState *result, "
|
||||
"Value *cond, bool withElseRegion">
|
||||
"IntegerSet set, ArrayRef<Value *> args, bool withElseRegion">
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
|
|
@ -198,10 +198,11 @@ template <typename Op> struct OperationBuilder : public OperationHandle {
|
|||
OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
|
||||
};
|
||||
|
||||
using alloc = ValueBuilder<AllocOp>;
|
||||
using affine_apply = ValueBuilder<AffineApplyOp>;
|
||||
using affine_if = OperationBuilder<AffineIfOp>;
|
||||
using affine_load = ValueBuilder<AffineLoadOp>;
|
||||
using affine_store = OperationBuilder<AffineStoreOp>;
|
||||
using alloc = ValueBuilder<AllocOp>;
|
||||
using call = OperationBuilder<mlir::CallOp>;
|
||||
using constant_float = ValueBuilder<ConstantFloatOp>;
|
||||
using constant_index = ValueBuilder<ConstantIndexOp>;
|
||||
|
|
|
@ -1662,6 +1662,17 @@ void AffineIfOp::setConditional(IntegerSet set, ArrayRef<Value *> operands) {
|
|||
getOperation()->setOperands(operands);
|
||||
}
|
||||
|
||||
void AffineIfOp::build(Builder *builder, OperationState *result, IntegerSet set,
|
||||
ArrayRef<Value *> args, bool withElseRegion) {
|
||||
result->addOperands(args);
|
||||
result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
|
||||
Region *thenRegion = result->addRegion();
|
||||
Region *elseRegion = result->addRegion();
|
||||
AffineIfOp::ensureTerminator(*thenRegion, *builder, result->location);
|
||||
if (withElseRegion)
|
||||
AffineIfOp::ensureTerminator(*elseRegion, *builder, result->location);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This is a pattern to canonicalize an affine if op's conditional (integer
|
||||
// set + operands).
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
@ -746,6 +747,41 @@ TEST_FUNC(empty_map_load_store) {
|
|||
f.erase();
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @affine_if_op
|
||||
// CHECK: affine.if ([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}}
|
||||
// CHECK-NOT: else
|
||||
// CHECK: affine.if ([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}}
|
||||
// CHECK-NEXT: } else {
|
||||
TEST_FUNC(affine_if_op) {
|
||||
using namespace edsc;
|
||||
using namespace edsc::intrinsics;
|
||||
using namespace edsc::op;
|
||||
auto f32Type = FloatType::getF32(&globalContext());
|
||||
auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
|
||||
auto f = makeFunction("affine_if_op", {}, {memrefType});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
ValueHandle zero = constant_index(0), ten = constant_index(10);
|
||||
|
||||
SmallVector<bool, 4> isEq = {false, false, false, false};
|
||||
SmallVector<AffineExpr, 4> affineExprs = {
|
||||
builder.getAffineDimExpr(0), // d0 >= 0
|
||||
builder.getAffineDimExpr(1), // d1 >= 0
|
||||
builder.getAffineSymbolExpr(0), // s0 >= 0
|
||||
builder.getAffineSymbolExpr(1) // s1 >= 0
|
||||
};
|
||||
auto intSet = builder.getIntegerSet(2, 2, affineExprs, isEq);
|
||||
|
||||
SmallVector<Value *, 4> affineIfArgs = {zero, zero, ten, ten};
|
||||
intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/false);
|
||||
intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/true);
|
||||
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
int main() {
|
||||
RUN_TESTS();
|
||||
return 0;
|
||||
|
|
Loading…
Reference in a new issue