[mlir][Vector] Simplify code a bit. NFCI.

This commit is contained in:
Benjamin Kramer 2020-08-01 14:48:42 +02:00
parent 04b99a4d18
commit eb41f9edde
2 changed files with 19 additions and 25 deletions

View file

@ -184,9 +184,9 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
auto lhsType = types[0].cast<VectorType>();
auto rhsType = types[1].cast<VectorType>();
auto maskElementType = parser.getBuilder().getI1Type();
SmallVector<Type, 2> maskTypes;
maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
std::array<Type, 2> maskTypes = {
VectorType::get(lhsType.getShape(), maskElementType),
VectorType::get(rhsType.getShape(), maskElementType)};
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
return failure();
return success();
@ -462,12 +462,10 @@ std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
}
SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
SmallVector<AffineMap, 4> res;
auto mapAttrs = indexing_maps().getValue();
res.reserve(mapAttrs.size());
for (auto mapAttr : mapAttrs)
res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
return res;
return llvm::to_vector<4>(
llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
return mapAttr.cast<AffineMapAttr>().getValue();
}));
}
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
@ -1854,8 +1852,7 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
}
Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
auto s = getVectorType().getShape();
return SmallVector<int64_t, 4>{s.begin(), s.end()};
return llvm::to_vector<4>(getVectorType().getShape());
}
//===----------------------------------------------------------------------===//
@ -2014,11 +2011,8 @@ static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
memRefType.getShape().end());
if (vectorType) {
res.reserve(memRefType.getRank() + vectorType.getRank());
for (auto s : vectorType.getShape())
res.push_back(s);
}
if (vectorType)
res.append(vectorType.getShape().begin(), vectorType.getShape().end());
return res;
}

View file

@ -1707,7 +1707,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
SmallVector<int64_t, 2> perm{1, 0};
static constexpr std::array<int64_t, 2> perm = {1, 0};
auto iteratorTypes = op.iterator_types().getValue();
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
if (isParallelIterator(iteratorTypes[0]) &&
@ -1911,10 +1911,10 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
assert(lookup.hasValue() && "parallel index not listed in reduction");
int64_t resIndex = lookup.getValue();
// Construct new iterator types and affine map array attribute.
SmallVector<AffineMap, 4> lowIndexingMaps;
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
adjustMap(iMap[1], iterIndex, rewriter),
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
@ -1962,10 +1962,10 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
op.acc());
}
// Construct new iterator types and affine map array attribute.
SmallVector<AffineMap, 4> lowIndexingMaps;
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
adjustMap(iMap[1], iterIndex, rewriter),
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));