//===-- runtime/matmul.cpp ------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // Implements all forms of MATMUL (Fortran 2018 16.9.124) // // There are two main entry points; one establishes a descriptor for the // result and allocates it, and the other expects a result descriptor that // points to existing storage. // // This implementation must handle all combinations of numeric types and // kinds (100 - 165 cases depending on the target), plus all combinations // of logical kinds (16). A single template undergoes many instantiations // to cover all of the valid possibilities. // // Places where BLAS routines could be called are marked as TODO items. #include "flang/Runtime/matmul.h" #include "terminator.h" #include "tools.h" #include "flang/Runtime/c-or-cpp.h" #include "flang/Runtime/cpp-type.h" #include "flang/Runtime/descriptor.h" #include namespace Fortran::runtime { // General accumulator for any type and stride; this is not used for // contiguous numeric cases. template class Accumulator { public: using Result = AccumulationType; Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) { if constexpr (RCAT == TypeCategory::Logical) { sum_ = sum_ || (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); } else { sum_ += static_cast(*x_.Element(xAt)) * static_cast(*y_.Element(yAt)); } } Result GetResult() const { return sum_; } private: const Descriptor &x_, &y_; Result sum_{}; }; // Contiguous numeric matrix*matrix multiplication // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) // Straightforward algorithm: // DO 1 I = 1, NROWS // DO 1 J = 1, NCOLS // RES(I,J) = 0 // DO 1 K = 1, N // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) // With loop distribution and transposition to avoid the inner sum // reduction and to avoid non-unit strides: // DO 1 I = 1, NROWS // DO 1 J = 1, NCOLS // 1 RES(I,J) = 0 // DO 2 K = 1, N // DO 2 J = 1, NCOLS // DO 2 I = 1, NROWS // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term template inline void MatrixTimesMatrix(CppTypeFor *RESTRICT product, SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y, SubscriptValue n) { using ResultType = CppTypeFor; std::memset(product, 0, rows * cols * sizeof *product); const XT *RESTRICT xp0{x}; for (SubscriptValue k{0}; k < n; ++k) { ResultType *RESTRICT p{product}; for (SubscriptValue j{0}; j < cols; ++j) { const XT *RESTRICT xp{xp0}; auto yv{static_cast(y[k + j * n])}; for (SubscriptValue i{0}; i < rows; ++i) { *p++ += static_cast(*xp++) * yv; } } xp0 += rows; } } // Contiguous numeric matrix*vector multiplication // matrix(rows,n) * column vector(n) -> column vector(rows) // Straightforward algorithm: // DO 1 J = 1, NROWS // RES(J) = 0 // DO 1 K = 1, N // 1 RES(J) = RES(J) + X(J,K)*Y(K) // With loop distribution and transposition to avoid the inner // sum reduction and to avoid non-unit strides: // DO 1 J = 1, NROWS // 1 RES(J) = 0 // DO 2 K = 1, N // DO 2 J = 1, NROWS // 2 RES(J) = RES(J) + X(J,K)*Y(K) template inline void MatrixTimesVector(CppTypeFor *RESTRICT product, SubscriptValue rows, SubscriptValue n, const XT *RESTRICT x, const YT *RESTRICT y) { using ResultType = CppTypeFor; std::memset(product, 0, rows * sizeof *product); for (SubscriptValue k{0}; k < n; ++k) { ResultType *RESTRICT p{product}; auto yv{static_cast(*y++)}; for (SubscriptValue j{0}; j < rows; ++j) { *p++ += static_cast(*x++) * yv; } } } // Contiguous numeric vector*matrix multiplication // row vector(n) * matrix(n,cols) -> row vector(cols) // Straightforward algorithm: // DO 1 J = 1, NCOLS // RES(J) = 0 // DO 1 K = 1, N // 1 RES(J) = RES(J) + X(K)*Y(K,J) // With loop distribution and transposition to avoid the inner // sum reduction and one non-unit stride (the other remains): // DO 1 J = 1, NCOLS // 1 RES(J) = 0 // DO 2 K = 1, N // DO 2 J = 1, NCOLS // 2 RES(J) = RES(J) + X(K)*Y(K,J) template inline void VectorTimesMatrix(CppTypeFor *RESTRICT product, SubscriptValue n, SubscriptValue cols, const XT *RESTRICT x, const YT *RESTRICT y) { using ResultType = CppTypeFor; std::memset(product, 0, cols * sizeof *product); for (SubscriptValue k{0}; k < n; ++k) { ResultType *RESTRICT p{product}; auto xv{static_cast(*x++)}; const YT *RESTRICT yp{&y[k]}; for (SubscriptValue j{0}; j < cols; ++j) { *p++ += xv * static_cast(*yp); yp += n; } } } // Implements an instance of MATMUL for given argument types. template static inline void DoMatmul( std::conditional_t &result, const Descriptor &x, const Descriptor &y, Terminator &terminator) { int xRank{x.rank()}; int yRank{y.rank()}; int resRank{xRank + yRank - 2}; if (xRank * yRank != 2 * resRank) { terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); } SubscriptValue extent[2]{ xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), resRank == 2 ? y.GetDimension(1).Extent() : 0}; if constexpr (IS_ALLOCATING) { result.Establish( RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable); for (int j{0}; j < resRank; ++j) { result.GetDimension(j).SetBounds(1, extent[j]); } if (int stat{result.Allocate()}) { terminator.Crash( "MATMUL: could not allocate memory for result; STAT=%d", stat); } } else { RUNTIME_CHECK(terminator, resRank == result.rank()); RUNTIME_CHECK( terminator, result.ElementBytes() == static_cast(RKIND)); RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0]); RUNTIME_CHECK(terminator, resRank == 1 || result.GetDimension(1).Extent() == extent[1]); } SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; if (n != y.GetDimension(0).Extent()) { terminator.Crash("MATMUL: arrays do not conform (%jd != %jd)", static_cast(n), static_cast(y.GetDimension(0).Extent())); } using WriteResult = CppTypeFor; if constexpr (RCAT != TypeCategory::Logical) { if (x.IsContiguous() && y.IsContiguous() && (IS_ALLOCATING || result.IsContiguous())) { // Contiguous numeric matrices if (resRank == 2) { // M*M -> M if (std::is_same_v) { if constexpr (std::is_same_v) { // TODO: call BLAS-3 SGEMM } else if constexpr (std::is_same_v) { // TODO: call BLAS-3 DGEMM } else if constexpr (std::is_same_v>) { // TODO: call BLAS-3 CGEMM } else if constexpr (std::is_same_v>) { // TODO: call BLAS-3 ZGEMM } } MatrixTimesMatrix( result.template OffsetElement(), extent[0], extent[1], x.OffsetElement(), y.OffsetElement(), n); return; } else if (xRank == 2) { // M*V -> V if (std::is_same_v) { if constexpr (std::is_same_v) { // TODO: call BLAS-2 SGEMV(x,y) } else if constexpr (std::is_same_v) { // TODO: call BLAS-2 DGEMV(x,y) } else if constexpr (std::is_same_v>) { // TODO: call BLAS-2 CGEMV(x,y) } else if constexpr (std::is_same_v>) { // TODO: call BLAS-2 ZGEMV(x,y) } } MatrixTimesVector( result.template OffsetElement(), extent[0], n, x.OffsetElement(), y.OffsetElement()); return; } else { // V*M -> V if (std::is_same_v) { if constexpr (std::is_same_v) { // TODO: call BLAS-2 SGEMV(y,x) } else if constexpr (std::is_same_v) { // TODO: call BLAS-2 DGEMV(y,x) } else if constexpr (std::is_same_v>) { // TODO: call BLAS-2 CGEMV(y,x) } else if constexpr (std::is_same_v>) { // TODO: call BLAS-2 ZGEMV(y,x) } } VectorTimesMatrix( result.template OffsetElement(), n, extent[0], x.OffsetElement(), y.OffsetElement()); return; } } } // General algorithms for LOGICAL and noncontiguity SubscriptValue xAt[2], yAt[2], resAt[2]; x.GetLowerBounds(xAt); y.GetLowerBounds(yAt); result.GetLowerBounds(resAt); if (resRank == 2) { // M*M -> M SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; for (SubscriptValue i{0}; i < extent[0]; ++i) { for (SubscriptValue j{0}; j < extent[1]; ++j) { Accumulator accumulator{x, y}; yAt[1] = y1 + j; for (SubscriptValue k{0}; k < n; ++k) { xAt[1] = x1 + k; yAt[0] = y0 + k; accumulator.Accumulate(xAt, yAt); } resAt[1] = res1 + j; *result.template Element(resAt) = accumulator.GetResult(); } ++resAt[0]; ++xAt[0]; } } else if (xRank == 2) { // M*V -> V SubscriptValue x1{xAt[1]}, y0{yAt[0]}; for (SubscriptValue j{0}; j < extent[0]; ++j) { Accumulator accumulator{x, y}; for (SubscriptValue k{0}; k < n; ++k) { xAt[1] = x1 + k; yAt[0] = y0 + k; accumulator.Accumulate(xAt, yAt); } *result.template Element(resAt) = accumulator.GetResult(); ++resAt[0]; ++xAt[0]; } } else { // V*M -> V SubscriptValue x0{xAt[0]}, y0{yAt[0]}; for (SubscriptValue j{0}; j < extent[0]; ++j) { Accumulator accumulator{x, y}; for (SubscriptValue k{0}; k < n; ++k) { xAt[0] = x0 + k; yAt[0] = y0 + k; accumulator.Accumulate(xAt, yAt); } *result.template Element(resAt) = accumulator.GetResult(); ++resAt[0]; ++yAt[1]; } } } // Maps the dynamic type information from the arguments' descriptors // to the right instantiation of DoMatmul() for valid combinations of // types. template struct Matmul { using ResultDescriptor = std::conditional_t; template struct MM1 { template struct MM2 { void operator()(ResultDescriptor &result, const Descriptor &x, const Descriptor &y, Terminator &terminator) const { if constexpr (constexpr auto resultType{ GetResultType(XCAT, XKIND, YCAT, YKIND)}) { if constexpr (common::IsNumericTypeCategory(resultType->first) || resultType->first == TypeCategory::Logical) { return DoMatmulfirst, resultType->second, CppTypeFor, CppTypeFor>(result, x, y, terminator); } } terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", static_cast(XCAT), XKIND, static_cast(YCAT), YKIND); } }; void operator()(ResultDescriptor &result, const Descriptor &x, const Descriptor &y, Terminator &terminator, TypeCategory yCat, int yKind) const { ApplyType(yCat, yKind, terminator, result, x, y, terminator); } }; void operator()(ResultDescriptor &result, const Descriptor &x, const Descriptor &y, const char *sourceFile, int line) const { Terminator terminator{sourceFile, line}; auto xCatKind{x.type().GetCategoryAndKind()}; auto yCatKind{y.type().GetCategoryAndKind()}; RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); ApplyType(xCatKind->first, xCatKind->second, terminator, result, x, y, terminator, yCatKind->first, yCatKind->second); } }; extern "C" { void RTNAME(Matmul)(Descriptor &result, const Descriptor &x, const Descriptor &y, const char *sourceFile, int line) { Matmul{}(result, x, y, sourceFile, line); } void RTNAME(MatmulDirect)(const Descriptor &result, const Descriptor &x, const Descriptor &y, const char *sourceFile, int line) { Matmul{}(result, x, y, sourceFile, line); } } // extern "C" } // namespace Fortran::runtime