[mlir][core] Add IndexElementsAttr helpers.
Summary: In a follow-up, I'll update the Shape dialect to use this instead of I64ElementsAttr. Differential Revision: https://reviews.llvm.org/D80601
This commit is contained in:
parent
98ef93eabd
commit
9546d8b108
|
@ -128,6 +128,7 @@ public:
|
|||
/// as attributes.
|
||||
DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values);
|
||||
DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values);
|
||||
DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values);
|
||||
|
||||
ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
|
||||
ArrayAttr getBoolArrayAttr(ArrayRef<bool> values);
|
||||
|
|
|
@ -1218,6 +1218,13 @@ class IntElementsAttrBase<Pred condition, string description> :
|
|||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
def IndexElementsAttr
|
||||
: IntElementsAttrBase<CPred<[{$_self.cast<DenseIntElementsAttr>()
|
||||
.getType()
|
||||
.getElementType()
|
||||
.isIndex()}]>,
|
||||
"index elements attribute">;
|
||||
|
||||
class AnyIntElementsAttr<int width> : IntElementsAttrBase<
|
||||
CPred<"$_self.cast<DenseIntElementsAttr>().getType()."
|
||||
"getElementType().isInteger(" # width # ")">,
|
||||
|
|
|
@ -624,6 +624,8 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
|
|||
owner.getContext());
|
||||
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
|
||||
}
|
||||
if (eltTy.isa<IndexType>())
|
||||
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
|
||||
if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
|
||||
IntElementIterator intIt(owner, index);
|
||||
FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
|
||||
|
|
|
@ -130,6 +130,13 @@ DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
|
|||
values);
|
||||
}
|
||||
|
||||
DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
|
||||
return DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(static_cast<int64_t>(values.size()),
|
||||
getIndexType()),
|
||||
values);
|
||||
}
|
||||
|
||||
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
|
||||
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
|
||||
}
|
||||
|
|
|
@ -454,6 +454,10 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
|
|||
let arguments = (ins I32ElementsAttr:$attr);
|
||||
}
|
||||
|
||||
def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> {
|
||||
let arguments = (ins IndexElementsAttr:$attr);
|
||||
}
|
||||
|
||||
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let arguments = (ins AnyTensor, AnyTensor);
|
||||
|
|
|
@ -489,3 +489,18 @@ func @elements_attr_i32(%arg0: tensor<1x2xi32>) {
|
|||
"test.i32ElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @elements_attr_index() {
|
||||
"test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xindex>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @elements_attr_not_index() {
|
||||
// expected-error@+1 {{index elements attribute}}
|
||||
"test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue