[MLIR][Shape] Simplify shape lowering

Differential Revision: https://reviews.llvm.org/D84161
This commit is contained in:
Frederik Gossen 2020-07-24 08:43:43 +00:00
parent d4e4d5d780
commit a85ca6be2a

View file

@ -172,39 +172,37 @@ LogicalResult
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
ShapeOfOp::Adaptor transformed(operands);
auto tensorVal = transformed.arg();
auto tensorTy = tensorVal.getType();
Value arg = transformed.arg();
Type argTy = arg.getType();
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
// found in the corresponding pass.
if (tensorTy.isa<RankedTensorType>())
if (argTy.isa<RankedTensorType>())
return failure();
// Allocate stack memory.
auto loc = op.getLoc();
auto rankVal = rewriter.create<mlir::RankOp>(loc, tensorVal);
auto i64Ty = rewriter.getI64Type();
auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
Value rank = rewriter.create<mlir::RankOp>(loc, arg);
Type i64Ty = rewriter.getI64Type();
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
// Copy shape extents to stack-allocated memory.
auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
rewriter.create<scf::ForOp>(
loc, zeroVal, rankVal, oneVal, llvm::None,
[&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
auto dimVal = rewriter.create<DimOp>(loc, tensorVal, iVal);
auto dimIntVal = rewriter.create<IndexCastOp>(loc, dimVal, i64Ty);
rewriter.create<StoreOp>(loc, dimIntVal, memVal, ValueRange{iVal});
loc, zero, rank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
Value dim = rewriter.create<DimOp>(loc, arg, iv);
Value dimInt = rewriter.create<IndexCastOp>(loc, dim, i64Ty);
rewriter.create<StoreOp>(loc, dimInt, mem, ValueRange{iv});
rewriter.create<scf::YieldOp>(loc);
});
// Load extents to tensor value.
auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
auto indexTy = rewriter.getIndexType();
auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
shapeTy);
Value extentTensorInt = rewriter.create<TensorLoadOp>(loc, mem);
rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), extentTensorInt,
op.getType());
return success();
}