gemm_batch#
Computes a group of gemm operations.
Description
The gemm_batch routines are batched versions of gemm, performing
multiple gemm operations in a single call. Each gemm
operation perform a matrix-matrix product with general matrices.
gemm_batch supports the following precisions.
 Ta(A matrix) Tb(B matrix) Tc(C matrix) Ts(alpha/beta)
std::int8_t
std::int8_t
std::int32_t
float
std::int8_t
std::int8_t
float
float
half
half
float
float
half
half
half
half
bfloat16
bfloat16
float
float
bfloat16
bfloat16
bfloat16
float
float
float
float
float
double
double
double
double
std::complex<float>
std::complex<float>
std::complex<float>
std::complex<float>
std::complex<double>
std::complex<double>
std::complex<double>
std::complex<double>
gemm_batch (Buffer Version)#
Description
The buffer version of gemm_batch supports only the strided API.
The strided API operation is defined as:
for i = 0 … batch_size – 1
    A, B and C are matrices at offset i * stridea, i * strideb, i * stridec in a, b and c.
    C := alpha * op(A) * op(B) + beta * C
end for
where:
op(X) is one of op(X) = X, or op(X) = XT, or op(X) = XH,
alpha and beta are scalars,
A, B, and C are matrices,
op(A) is m x k, op(B) is
k x n, and C is m x n.
The a, b and c buffers contain all the input matrices. The stride
between matrices is given by the stride parameter. The total number
of matrices in a, b and c buffers is given by the batch_size parameter.
Strided API
Syntax
namespace oneapi::mkl::blas::column_major {
    void gemm_batch(sycl::queue &queue,
                    oneapi::mkl::transpose transa,
                    oneapi::mkl::transpose transb,
                    std::int64_t m,
                    std::int64_t n,
                    std::int64_t k,
                    Ts alpha,
                    sycl::buffer<Ta,1> &a,
                    std::int64_t lda,
                    std::int64_t stridea,
                    sycl::buffer<Tb,1> &b,
                    std::int64_t ldb,
                    std::int64_t strideb,
                    Ts beta,
                    sycl::buffer<Tc,1> &c,
                    std::int64_t ldc,
                    std::int64_t stridec,
                    std::int64_t batch_size)
}
namespace oneapi::mkl::blas::row_major {
    void gemm_batch(sycl::queue &queue,
                    oneapi::mkl::transpose transa,
                    oneapi::mkl::transpose transb,
                    std::int64_t m,
                    std::int64_t n,
                    std::int64_t k,
                    Ts alpha,
                    sycl::buffer<Ta,1> &a,
                    std::int64_t lda,
                    std::int64_t stridea,
                    sycl::buffer<Tb,1> &b,
                    std::int64_t ldb,
                    std::int64_t strideb,
                    Ts beta,
                    sycl::buffer<Tc,1> &c,
                    std::int64_t ldc,
                    std::int64_t stridec,
                    std::int64_t batch_size)
}
Input Parameters
- queue
 The queue where the routine should be executed.
- transa
 Specifies op(
A) the transposition operation applied to the matricesA. See oneMKL defined datatypes for more details.- transb
 Specifies op(
B) the transposition operation applied to the matricesB. See oneMKL defined datatypes for more details.- m
 Number of rows of op(
A) andC. Must be at least zero.- n
 Number of columns of op(
B) andC. Must be at least zero.- k
 Number of columns of op(
A) and rows of op(B). Must be at least zero.- alpha
 Scaling factor for the matrix-matrix products.
- a
 Buffer holding the input matrices
Awith sizestridea*batch_size.- lda
 The leading dimension of the matrices
A. It must be positive.Anot transposedAtransposedColumn major
ldamust be at leastm.ldamust be at leastk.Row major
ldamust be at leastk.ldamust be at leastm.- stridea
 Stride between different
Amatrices.- b
 Buffer holding the input matrices
Bwith sizestrideb*batch_size.- ldb
 The leading dimension of the matrices``B``. It must be positive.
Bnot transposedBtransposedColumn major
ldbmust be at leastk.ldbmust be at leastn.Row major
ldbmust be at leastn.ldbmust be at leastk.- strideb
 Stride between different
Bmatrices.- beta
 Scaling factor for the matrices
C.- c
 Buffer holding input/output matrices
Cwith sizestridec*batch_size.- ldc
 The leading dimension of the matrices
C. It must be positive and at leastmif column major layout is used to store matrices or at leastnif row major layout is used to store matrices.- stridec
 Stride between different
Cmatrices. Must be at leastldc*n.- batch_size
 Specifies the number of matrix multiply operations to perform.
Output Parameters
- c
 Output buffer, overwritten by
batch_sizematrix multiply operations of the formalpha* op(A)*op(B) +beta*C.
Notes
If beta = 0, matrix C does not need to be initialized before
calling gemm_batch.
Throws
This routine shall throw the following exceptions if the associated condition is detected. An implementation may throw additional implementation-specific exception(s) in case of error conditions not covered here.
gemm_batch (USM Version)#
Description
The USM version of gemm_batch supports the group API and the strided API.
The group API supports pointer and span inputs.
The group API operation is defined as:
idx = 0
for i = 0 … group_count – 1
    for j = 0 … group_size – 1
        A, B, and C are matrices in a[idx], b[idx] and c[idx]
        C := alpha[i] * op(A) * op(B) + beta[i] * C
        idx = idx + 1
    end for
end for
The advantage of using span instead of pointer is that the sizes of the array can vary and the size of the span can be queried at runtime. For each GEMM parameter, except the output matrices, the span can be of size 1, the number of groups or the total batch size. For the output matrices, to ensure all computation are independent, the size of the span must be the total batch size.
Depending on the size of the spans, each parameter for the GEMM computation is used as follows:
If the span has size 1, the parameter is reused for all GEMM computation.
If the span has size group_count, the parameter is reused for all GEMM within a group, but each group will have a different value for this parameter. This is like the gemm_batch group API with pointers.
If the span has size equal to the total batch size, each GEMM computation will use a different value for this parameter.
The strided API operation is defined as
for i = 0 … batch_size – 1
    A, B and C are matrices at offset i * stridea, i * strideb, i * stridec in a, b and c.
    C := alpha * op(A) * op(B) + beta * C
end for
where:
op(X) is one of op(X) = X, or op(X) = XT, or op(X) = XH,
alpha and beta are scalars,
A, B, and C are matrices,
op(A) is m x k, op(B) is k x n, and C is m x n.
For group API, a, b and c arrays contain the pointers for all the input matrices.
The total number of matrices in a, b and c are given by:
For strided API, a, b, c arrays contain all the input matrices. The total number of matrices
in a, b and c are given by the batch_size parameter.
Group API
Syntax
namespace oneapi::mkl::blas::column_major {
    sycl::event gemm_batch(sycl::queue &queue,
                           const oneapi::mkl::transpose *transa,
                           const oneapi::mkl::transpose *transb,
                           const std::int64_t *m,
                           const std::int64_t *n,
                           const std::int64_t *k,
                           const Ts *alpha,
                           const Ta **a,
                           const std::int64_t *lda,
                           const Tb **b,
                           const std::int64_t *ldb,
                           const Ts *beta,
                           Tc **c,
                           const std::int64_t *ldc,
                           std::int64_t group_count,
                           const std::int64_t *group_size,
                           const std::vector<sycl::event> &dependencies = {})
    sycl::event gemm_batch(sycl::queue &queue,
                           const sycl::span<oneapi::mkl::transpose> &transa,
                           const sycl::span<oneapi::mkl::transpose> &transb,
                           const sycl::span<std::int64_t> &m,
                           const sycl::span<std::int64_t> &n,
                           const sycl::span<std::int64_t> &k,
                           const sycl::span<Ts> &alpha,
                           const sycl::span<const Ta*> &a,
                           const sycl::span<std::int64_t> &lda,
                           const sycl::span<const Tb*> &b,
                           const sycl::span<std::int64_t> &ldb,
                           const sycl::span<Ts> &beta,
                           sycl::span<Tc*> &c,
                           const sycl::span<std::int64_t> &ldc,
                           size_t group_count,
                           const sycl::span<size_t> &group_sizes,
                           const std::vector<sycl::event> &dependencies = {})
}
namespace oneapi::mkl::blas::row_major {
    sycl::event gemm_batch(sycl::queue &queue,
                           const oneapi::mkl::transpose *transa,
                           const oneapi::mkl::transpose *transb,
                           const std::int64_t *m,
                           const std::int64_t *n,
                           const std::int64_t *k,
                           const Ts *alpha,
                           const Ta **a,
                           const std::int64_t *lda,
                           const Tb **b,
                           const std::int64_t *ldb,
                           const Ts *beta,
                           Tc **c,
                           const std::int64_t *ldc,
                           std::int64_t group_count,
                           const std::int64_t *group_size,
                           const std::vector<sycl::event> &dependencies = {})
    sycl::event gemm_batch(sycl::queue &queue,
                           const sycl::span<oneapi::mkl::transpose> &transa,
                           const sycl::span<oneapi::mkl::transpose> &transb,
                           const sycl::span<std::int64_t> &m,
                           const sycl::span<std::int64_t> &n,
                           const sycl::span<std::int64_t> &k,
                           const sycl::span<Ts> &alpha,
                           const sycl::span<const Ta*> &a,
                           const sycl::span<std::int64_t> &lda,
                           const sycl::span<const Tb*> &b,
                           const sycl::span<std::int64_t> &ldb,
                           const sycl::span<Ts> &beta,
                           sycl::span<Tc*> &c,
                           const sycl::span<std::int64_t> &ldc,
                           size_t group_count,
                           const sycl::span<size_t> &group_sizes,
                           const std::vector<sycl::event> &dependencies = {})
}
Input Parameters
- queue
 The queue where the routine should be executed.
- transa
 Array or span of
group_countoneapi::mkl::transposevalues.transa[i]specifies the form of op(A) used in the matrix multiplication in groupi. See oneMKL defined datatypes for more details.- transb
 Array or span of
group_countoneapi::mkl::transposevalues.transb[i]specifies the form of op(B) used in the matrix multiplication in groupi. See oneMKL defined datatypes for more details.- m
 Array or span of
group_countintegers.m[i]specifies the number of rows of op(A) andCfor every matrix in groupi. All entries must be at least zero.- n
 Array or span of
group_countintegers.n[i]specifies the number of columns of op(B) andCfor every matrix in groupi. All entries must be at least zero.- k
 Array or span of
group_countintegers.k[i]specifies the number of columns of op(A) and rows of op(B) for every matrix in groupi. All entries must be at least zero.- alpha
 Array or span of
group_countscalar elements.alpha[i]specifies the scaling factor for every matrix-matrix product in groupi.- a
 Array of pointers or span of input matrices
Awith sizetotal_batch_count.See Matrix Storage for more details.
- lda
 Array or span of
group_countintegers.lda[i]specifies the leading dimension ofAfor every matrix in groupi. All entries must be positive.Anot transposedAtransposedColumn major
lda[i]must be at leastm[i].lda[i]must be at leastk[i].Row major
lda[i]must be at leastk[i].lda[i]must be at leastm[i].- b
 Array of pointers or span of input matrices
Bwith sizetotal_batch_count.See Matrix Storage for more details.
- ldb
 Array or span of
group_countintegers.ldb[i]specifies the leading dimension ofBfor every matrix in groupi. All entries must be positive.Bnot transposedBtransposedColumn major
ldb[i]must be at leastk[i].ldb[i]must be at leastn[i].Row major
ldb[i]must be at leastn[i].ldb[i]must be at leastk[i].- beta
 Array or span of
group_countscalar elements.beta[i]specifies the scaling factor for matrixCfor every matrix in groupi.- c
 Array of pointers or span of input/output matrices
Cwith sizetotal_batch_count.See Matrix Storage for more details.
- ldc
 Array or span of
group_countintegers.ldc[i]specifies the leading dimension ofCfor every matrix in groupi. All entries must be positive andldc[i]must be at leastm[i]if column major layout is used to store matrices or at leastn[i]if row major layout is used to store matrices.- group_count
 Specifies the number of groups. Must be at least 0.
- group_size
 Array or span of
group_countintegers.group_size[i]specifies the number of matrix multiply products in groupi. All entries must be at least 0.- dependencies
 List of events to wait for before starting computation, if any. If omitted, defaults to no dependencies.
Output Parameters
- c
 Overwritten by the
m[i]-by-n[i]matrix calculated by (alpha[i]* op(A)*op(B) +beta[i]*C) for groupi.
Notes
If beta = 0, matrix C does not need to be initialized
before calling gemm_batch.
Return Values
Output event to wait on to ensure computation is complete.
Output Parameters
- c
 Overwritten by the
m[i]-by-n[i]matrix calculated by (alpha[i]* op(A)*op(B) +beta[i]*C) for groupi.
Notes
If beta = 0, matrix C does not need to be initialized
before calling gemm_batch.
Return Values
Output event to wait on to ensure computation is complete.
Strided API
Syntax
namespace oneapi::mkl::blas::column_major {
    sycl::event gemm_batch(sycl::queue &queue,
                           oneapi::mkl::transpose transa,
                           oneapi::mkl::transpose transb,
                           std::int64_t m,
                           std::int64_t n,
                           std::int64_t k,
                           value_or_pointer<Ts> alpha,
                           const Ta *a,
                           std::int64_t lda,
                           std::int64_t stridea,
                           const Tb *b,
                           std::int64_t ldb,
                           std::int64_t strideb,
                           value_or_pointer<Ts> beta,
                           Tc *c,
                           std::int64_t ldc,
                           std::int64_t stridec,
                           std::int64_t batch_size,
                           const std::vector<sycl::event> &dependencies = {})
}
namespace oneapi::mkl::blas::row_major {
    sycl::event gemm_batch(sycl::queue &queue,
                           oneapi::mkl::transpose transa,
                           oneapi::mkl::transpose transb,
                           std::int64_t m,
                           std::int64_t n,
                           std::int64_t k,
                           value_or_pointer<Ts> alpha,
                           const Ta *a,
                           std::int64_t lda,
                           std::int64_t stridea,
                           const Tb *b,
                           std::int64_t ldb,
                           std::int64_t strideb,
                           value_or_pointer<Ts> beta,
                           Tc *c,
                           std::int64_t ldc,
                           std::int64_t stridec,
                           std::int64_t batch_size,
                           const std::vector<sycl::event> &dependencies = {})
}
Input Parameters
- queue
 The queue where the routine should be executed.
- transa
 Specifies op(
A) the transposition operation applied to the matricesA. See oneMKL defined datatypes for more details.- transb
 Specifies op(
B) the transposition operation applied to the matricesB. See oneMKL defined datatypes for more details.- m
 Number of rows of op(
A) andC. Must be at least zero.- n
 Number of columns of op(
B) andC. Must be at least zero.- k
 Number of columns of op(
A) and rows of op(B). Must be at least zero.- alpha
 Scaling factor for the matrix-matrix products. See Scalar Arguments in BLAS for more details.
- a
 Pointer to input matrices
Awith sizestridea*batch_size.- lda
 The leading dimension of the matrices
A. It must be positive.Anot transposedAtransposedColumn major
ldamust be at leastm.ldamust be at leastk.Row major
ldamust be at leastk.ldamust be at leastm.- stridea
 Stride between different
Amatrices.- b
 Pointer to input matrices
Bwith sizestrideb*batch_size.- ldb
 The leading dimension of the matrices``B``. It must be positive.
Bnot transposedBtransposedColumn major
ldbmust be at leastk.ldbmust be at leastn.Row major
ldbmust be at leastn.ldbmust be at leastk.- strideb
 Stride between different
Bmatrices.- beta
 Scaling factor for the matrices
C. See Scalar Arguments in BLAS for more details.- c
 Pointer to input/output matrices
Cwith sizestridec*batch_size.- ldc
 The leading dimension of the matrices
C. It must be positive and at leastmif column major layout is used to store matrices or at leastnif row major layout is used to store matrices.- stridec
 Stride between different
Cmatrices.- batch_size
 Specifies the number of matrix multiply operations to perform.
- dependencies
 List of events to wait for before starting computation, if any. If omitted, defaults to no dependencies.
Output Parameters
- c
 Output matrices, overwritten by
batch_sizematrix multiply operations of the formalpha* op(A)*op(B) +beta*C.
Notes
If beta = 0, matrix C does not need to be initialized before
calling gemm_batch.
Return Values
Output event to wait on to ensure computation is complete.
Throws
This routine shall throw the following exceptions if the associated condition is detected. An implementation may throw additional implementation-specific exception(s) in case of error conditions not covered here.
oneapi::mkl::unsupported_device
Parent topic: BLAS-like Extensions