[MLIR][Shape] Simplify shape lowering

Differential Revision: https://reviews.llvm.org/D84161
This commit is contained in:
Frederik Gossen 2020-07-24 08:43:43 +00:00
parent d4e4d5d780
commit a85ca6be2a

View file

@ -172,39 +172,37 @@ LogicalResult
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
ShapeOfOp::Adaptor transformed(operands); ShapeOfOp::Adaptor transformed(operands);
auto tensorVal = transformed.arg(); Value arg = transformed.arg();
auto tensorTy = tensorVal.getType(); Type argTy = arg.getType();
// For ranked tensors `shape_of` lowers to `std` and the pattern can be // For ranked tensors `shape_of` lowers to `std` and the pattern can be
// found in the corresponding pass. // found in the corresponding pass.
if (tensorTy.isa<RankedTensorType>()) if (argTy.isa<RankedTensorType>())
return failure(); return failure();
// Allocate stack memory. // Allocate stack memory.
auto loc = op.getLoc(); auto loc = op.getLoc();
auto rankVal = rewriter.create<mlir::RankOp>(loc, tensorVal); Value rank = rewriter.create<mlir::RankOp>(loc, arg);
auto i64Ty = rewriter.getI64Type(); Type i64Ty = rewriter.getI64Type();
auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty); Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal})); Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
// Copy shape extents to stack-allocated memory. // Copy shape extents to stack-allocated memory.
auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0); Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1); Value one = rewriter.create<ConstantIndexOp>(loc, 1);
rewriter.create<scf::ForOp>( rewriter.create<scf::ForOp>(
loc, zeroVal, rankVal, oneVal, llvm::None, loc, zero, rank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iVal, ValueRange args) { [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
auto dimVal = rewriter.create<DimOp>(loc, tensorVal, iVal); Value dim = rewriter.create<DimOp>(loc, arg, iv);
auto dimIntVal = rewriter.create<IndexCastOp>(loc, dimVal, i64Ty); Value dimInt = rewriter.create<IndexCastOp>(loc, dim, i64Ty);
rewriter.create<StoreOp>(loc, dimIntVal, memVal, ValueRange{iVal}); rewriter.create<StoreOp>(loc, dimInt, mem, ValueRange{iv});
rewriter.create<scf::YieldOp>(loc); rewriter.create<scf::YieldOp>(loc);
}); });
// Load extents to tensor value. // Load extents to tensor value.
auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal); Value extentTensorInt = rewriter.create<TensorLoadOp>(loc, mem);
auto indexTy = rewriter.getIndexType(); rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), extentTensorInt,
auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); op.getType());
rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
shapeTy);
return success(); return success();
} }