.. index:: pair: example; cpu_matmul_weights_compression.cpp .. _doxid-cpu_matmul_weights_compression_8cpp-example: cpu_matmul_weights_compression.cpp ================================== Annotated version: :ref:`MatMul Primitive Example ` Annotated version: :ref:`MatMul Primitive Example ` This C++ API example demonstrates how to create and execute a :ref:`MatMul ` primitive that uses a weights tensor encoded with the packed sparse encoding. .. ref-code-block:: cpp /******************************************************************************* * Copyright 2023-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include #include #include #include #include #include #include "example_utils.hpp" #include "oneapi/dnnl/dnnl.hpp" using namespace :ref:`dnnl `; void matmul_example(:ref:`dnnl::engine::kind ` engine_kind) { // Create execution dnnl::engine. :ref:`dnnl::engine ` :ref:`engine `(engine_kind, 0); // Create dnnl::stream. :ref:`dnnl::stream ` engine_stream(:ref:`engine `); // Tensor dimensions. const :ref:`memory::dim ` M = 512, K = 512, N = 512; // Source (src), weights, and destination (dst) tensors dimensions. :ref:`memory::dims ` src_dims = {M, K}; :ref:`memory::dims ` weights_dims = {K, N}; :ref:`memory::dims ` dst_dims = {M, N}; // Allocate buffers. std::vector src_data(product(src_dims)); std::vector weights_data(product(weights_dims)); std::vector dst_data(product(dst_dims)); // Initialize src, weights. std::generate(src_data.begin(), src_data.end(), []() { static int i = 0; return std::cos(i++ / 10.f); }); std::generate(weights_data.begin(), weights_data.end(), [&]() { static const float density = 0.1f; static std::default_random_engine def_gen; static std::bernoulli_distribution b_dist(density); const auto is_one = b_dist(def_gen); static int i = 1; return std::sin(i++ * 2.f) * is_one; }); const :ref:`memory::dim ` nnz = std::count_if(weights_data.begin(), weights_data.end(), [](float v) { return v != 0.0f; }); auto :ref:`src_md ` = :ref:`memory::desc `( src_dims, :ref:`memory::data_type::f32 `, :ref:`memory::format_tag::ab `); auto :ref:`dst_md ` = :ref:`memory::desc `( dst_dims, :ref:`memory::data_type::f32 `, :ref:`memory::format_tag::ab `); auto src_mem = :ref:`memory `(src_md, :ref:`engine `); auto dst_mem = :ref:`memory `(dst_md, :ref:`engine `); auto user_src_mem = :ref:`memory `( {src_dims, :ref:`memory::data_type::f32 `, :ref:`memory::format_tag::ab `}, :ref:`engine `); auto user_weights_mem = :ref:`memory `( {weights_dims, :ref:`memory::data_type::f32 `, :ref:`memory::format_tag::ab `}, :ref:`engine `); auto user_dst_mem = :ref:`memory `( {dst_dims, :ref:`memory::data_type::f32 `, :ref:`memory::format_tag::ab `}, :ref:`engine `); write_to_dnnl_memory(src_data.data(), src_mem); write_to_dnnl_memory(weights_data.data(), user_weights_mem); auto matmul_src_md = :ref:`memory::desc `( src_dims, :ref:`memory::data_type::u8 `, :ref:`memory::format_tag::any `); auto matmul_weights_md = :ref:`memory::desc::packed `(weights_dims, :ref:`memory::data_type::s8 `, nnz); auto matmul_dst_md = :ref:`memory::desc `( dst_dims, :ref:`memory::data_type::u8 `, :ref:`memory::format_tag::any `); :ref:`matmul::primitive_desc ` matmul_pd; try { matmul_pd = :ref:`matmul::primitive_desc `( :ref:`engine `, matmul_src_md, matmul_weights_md, matmul_dst_md); } catch (:ref:`error ` &e) { if (e.status == :ref:`dnnl_unimplemented `) throw example_allows_unimplemented { "No matmul implementation with packed encoding support is " "available for this platform.\nPlease refer to the " "developer guide for details."}; // on any other error just re-throw throw; } auto matmul_src_mem = user_src_mem; auto matmul_weights_mem = user_weights_mem; auto matmul_dst_mem = user_dst_mem; auto matmul_prim = :ref:`matmul `(matmul_pd); if (matmul_pd.:ref:`src_desc `() != user_src_mem.get_desc()) { matmul_src_mem = :ref:`memory `(matmul_pd.:ref:`src_desc `(), :ref:`engine `); :ref:`reorder `(user_src_mem, matmul_src_mem) .:ref:`execute `(engine_stream, user_src_mem, matmul_src_mem); } // Use reorder to pack the weights. auto wei_packed_md = matmul_pd.:ref:`weights_desc `(); const int nhandles = wei_packed_md.:ref:`get_num_handles `(); std::vector wei_handles(nhandles); std::vector> wei_buffers(nhandles); for (int h = 0; h < nhandles; h++) { const size_t buf_sz = wei_packed_md.get_size(h); wei_buffers[h].resize(buf_sz); wei_handles[h] = wei_buffers[h].data(); } if (wei_packed_md != user_weights_mem.:ref:`get_desc `()) { matmul_weights_mem = :ref:`memory `(wei_packed_md, :ref:`engine `, std::move(wei_handles)); :ref:`reorder `(user_weights_mem, matmul_weights_mem) .:ref:`execute `(engine_stream, user_weights_mem, matmul_weights_mem); } if (matmul_pd.:ref:`dst_desc `() != user_dst_mem.:ref:`get_desc `()) { matmul_dst_mem = :ref:`memory `(matmul_pd.:ref:`dst_desc `(), :ref:`engine `); :ref:`reorder `(user_dst_mem, matmul_dst_mem) .:ref:`execute `(engine_stream, user_dst_mem, matmul_dst_mem); } // Primitive arguments. std::unordered_map matmul_args; matmul_args.insert({:ref:`DNNL_ARG_SRC `, matmul_src_mem}); matmul_args.insert({:ref:`DNNL_ARG_WEIGHTS `, matmul_weights_mem}); matmul_args.insert({:ref:`DNNL_ARG_DST `, matmul_dst_mem}); // Primitive execution: matrix multiplication with ReLU. matmul_prim.execute(engine_stream, matmul_args); // Wait for the computation to finalize. engine_stream.wait(); // Read data from memory object's handle. read_from_dnnl_memory(dst_data.data(), dst_mem); } int main(int argc, char **argv) { return handle_example_errors(matmul_example, parse_engine_kind(argc, argv)); }