.. index:: pair: page; MatMul Primitive Example .. _doxid-cpu_matmul_csr_cpp: MatMul Primitive Example ======================== .. ref-code-block:: cpp /******************************************************************************* * Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include #include #include #include #include #include "dnnl.hpp" #include "example_utils.hpp" using namespace :ref:`dnnl `; bool check_result(:ref:`dnnl::memory ` dst_mem) { // clang-format off const std::vector expected_result = {8.750000, 11.250000, 2.500000, 6.000000, 2.250000, 3.750000, 19.000000, 15.500000, 5.250000, 4.000000, 7.000000, 3.000000}; // clang-format on std::vector dst_data(expected_result.size()); read_from_dnnl_memory(dst_data.data(), dst_mem); return expected_result == dst_data; } void sparse_matmul() { :ref:`dnnl::engine ` :ref:`engine `(:ref:`engine::kind::cpu `, 0); const :ref:`memory::dim ` M = 4; const :ref:`memory::dim ` N = 3; const :ref:`memory::dim ` K = 6; // A sparse matrix represented in the CSR format. std::vector src_csr_values = {2.5f, 1.5f, 1.5f, 2.5f, 2.0f}; std::vector src_csr_indices = {0, 2, 0, 5, 1}; std::vector src_csr_pointers = {0, 1, 2, 4, 5, 5}; // clang-format off std::vector weights_data = {3.5f, 4.5f, 1.0f, 2.0f, 3.5f, 1.5f, 4.0f, 1.5f, 2.5f, 3.5f, 5.5f, 4.5f, 1.5f, 2.5f, 5.5f, 5.5f, 3.5f, 1.5f}; // clang-format on const int nnz = static_cast(src_csr_values.size()); // Create a memory descriptor for CSR format by providing information // about number of non-zero entries and data types of metadata. const auto src_csr_md = :ref:`memory::desc::csr `({M, K}, :ref:`memory::data_type::f32 `, nnz, :ref:`memory::data_type::s32 `, :ref:`memory::data_type::s32 `); const auto wei_md = :ref:`memory::desc `( {K, N}, :ref:`memory::data_type::f32 `, :ref:`memory::format_tag::oi `); const auto :ref:`dst_md ` = :ref:`memory::desc `( {M, N}, :ref:`memory::data_type::f32 `, :ref:`memory::format_tag::nc `); // This memory is created for the given values and metadata of CSR format. :ref:`memory ` src_csr_mem(src_csr_md, :ref:`engine `, {src_csr_values.data(), src_csr_indices.data(), src_csr_pointers.data()}); :ref:`memory ` wei_mem(wei_md, :ref:`engine `, weights_data.data()); :ref:`memory ` dst_mem(dst_md, :ref:`engine `); :ref:`dnnl::stream ` :ref:`stream `(:ref:`engine `); auto sparse_matmul_pd = :ref:`matmul::primitive_desc `(:ref:`engine `, src_csr_md, wei_md, dst_md); auto sparse_matmul_prim = :ref:`matmul `(sparse_matmul_pd); std::unordered_map sparse_matmul_args; sparse_matmul_args.insert({:ref:`DNNL_ARG_SRC `, src_csr_mem}); sparse_matmul_args.insert({:ref:`DNNL_ARG_WEIGHTS `, wei_mem}); sparse_matmul_args.insert({:ref:`DNNL_ARG_DST `, dst_mem}); sparse_matmul_prim.execute(:ref:`stream `, sparse_matmul_args); :ref:`stream `.:ref:`wait `(); if (!check_result(dst_mem)) throw :ref:`std::runtime_error `("Unexpected output."); } int main(int argc, char **argv) { return handle_example_errors({:ref:`engine::kind::cpu `}, sparse_matmul); }