MatMul Tutorial: MXFP8 Inference#
C++ API example demonstrating how one can use MatMul with MXFP8 datatype in inference.
Concepts:
Dynamic quantization compliant with MX specification
Create primitive once, use multiple times
/******************************************************************************* * Copyright 2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include <cassert> #include <cctype> #include <cmath> #include <cstdio> #include <iostream> #include <random> #include <stdexcept> #include <vector> #include "oneapi/dnnl/dnnl.hpp" #include "example_utils.hpp" using namespace dnnl; namespace { void init_vector(std::vector<float> &v) { std::mt19937 gen; std::uniform_real_distribution<float> u(0, 1); for (auto &e : v) e = u(gen); } uint8_t f32_to_e8m0(const float &a) { // Note: memcpy can be replaced with bit_cast in C++20 uint32_t a_s32; std::memcpy(&a_s32, &a, sizeof(float)); uint8_t a_e8m0 = (a_s32 >> 23) & 0xff; return a_e8m0; } float e8m0_to_f32(const uint8_t a) { float r_f32; uint32_t r_s32; if (a == 0xff) return std::numeric_limits<float>::quiet_NaN(); // Note: memcpy can be replaced with bit_cast in C++20 if (a == 0x00) r_s32 = uint32_t(0x00400000); // 2^-127 encoding in float else r_s32 = uint32_t(a) << 23; std::memcpy(&r_f32, &r_s32, sizeof(float)); return r_f32; } } // namespace const memory::dim mx_block_size = 32; int number_of_runs = 1; // Create a MatMul primitive descriptor with: // // - Matrices A and C are non-transposed, B is transposed // - All matrices uses MXFP8 format with e4m3 elements, // e8m0 scales, and blocks of size 32. // - The scales values are precomputed for A and B as they are already // quantized // - The scales values for C will be computed according to MX spec by // the MatMul primitive and written to DNNL_ARG_ATTR_SCALES | // DNNL_ARG_DST memory argument during execution matmul::primitive_desc matmul_pd_create( int64_t M, int64_t N, int64_t K, const engine &eng) { memory::desc a_md( {M, K}, memory::data_type::f8_e4m3, {K, 1}); // M x K layout memory::desc b_md({K, N}, memory::data_type::f8_e4m3, {1, K}); memory::desc c_md( {M, N}, memory::data_type::f8_e4m3, {N, 1}); // M x N layout // Create scales attributes to indicate scales datatype, group // shapes, and how they are computed: // - user-provided for DNNL_ARG_SRC and DNNL_ARG_WEIGHTS memory arguments // - library-computed according to MX spec for DNNL_ARG_DST memory argument int mask = (1 << 1) | (1 << 0); primitive_attr attr; attr.set_scales(DNNL_ARG_SRC, mask, {1, 32}, memory::data_type::e8m0); attr.set_scales(DNNL_ARG_WEIGHTS, mask, {32, 1}, memory::data_type::e8m0); // Specifying the dynamic_mx quantization mode signals the compute // primitive to effectively compute MX compliant scales, and write // them to DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST buffer. attr.set_scales(DNNL_ARG_DST, mask, {1, 32}, memory::data_type::e8m0, false, quantization_mode::dynamic_mx); // Create a MatMul primitive descriptor try { return matmul::primitive_desc(eng, a_md, b_md, c_md, attr); } catch (error &e) { if (e.status == dnnl_unimplemented) throw example_allows_unimplemented { "No mxfp8 matmul implementation is available for this " "platform.\n" "Please refer to the developer guide for details."}; // on any other error just re-throw throw; } } void prepare_input(memory &in_mem) { auto dims = in_mem.get_desc().get_dims(); int64_t rows = dims[dims.size() - 2]; int64_t cols = dims[dims.size() - 1]; std::vector<float> buff(rows * cols); init_vector(buff); write_to_dnnl_memory(buff.data(), in_mem); } // Takes A_mem and B_mem as inputs, and returns quantized version and scales // Matrix is assumed row major void quantize_input(memory &in_e4m3, memory &in_scales_e8m0, memory &in_f32) { // This is the conversion routine defined by OCP MX spec v1 const auto dims = in_f32.get_desc().get_dims(); const auto nelems = product(dims); float *ptr = static_cast<float *>(in_f32.get_data_handle()); uint8_t *ptr_scales = static_cast<uint8_t *>(in_scales_e8m0.get_data_handle()); assert((dims[dims.size() - 1] % mx_block_size) == 0); // We compute the e8m0 scaling factors of each mx block in the following loop memory::dim nblocks = nelems / mx_block_size; for (memory::dim i = 0; i < nblocks; ++i) { // We first compute the scale value for the block float block_amax = 0.0f; for (memory::dim j = 0; j < mx_block_size; ++j) block_amax = std::max( block_amax, std::abs(ptr[i * mx_block_size + j])); const float max_e4m3 = 448.f; uint8_t e8m0_scale = f32_to_e8m0(block_amax) - f32_to_e8m0(max_e4m3); ptr_scales[i] = e8m0_scale; // We then apply that scale inside the block. We do that // inplace as the f32 buffer is not reused. float f32_scale = e8m0_to_f32(e8m0_scale); for (memory::dim j = 0; j < mx_block_size; ++j) ptr[i * mx_block_size + j] *= f32_scale; } // we now downconvert to e4m3 with reorder reorder(in_f32, in_e4m3); } void mxfp_matmul(engine::kind engine_kind) { engine eng(engine_kind, 0); const int64_t K = 128; const int64_t N = 64; const int64_t M = 96; auto matmul_pd = matmul_pd_create(M, N, K, eng); matmul matmul_p(matmul_pd); // The following code initializes the inputs that are typically // provided: // - activations are quantized by previous layer // - weights can be quantized offline ahead of time. auto a_desc = matmul_pd.src_desc(); memory A_e4m3_elems_mem(a_desc, eng); memory A_e8m0_scales_mem( {{M * (K / mx_block_size)}, memory::data_type::e8m0, {1}}, eng); { memory A_f32({a_desc.get_dims(), memory::data_type::f32, a_desc.get_strides()}, eng); prepare_input(A_f32); quantize_input(A_e4m3_elems_mem, A_e8m0_scales_mem, A_f32); } auto b_desc = matmul_pd.weights_desc(); memory B_e4m3_elems_mem(b_desc, eng); memory B_e8m0_scales_mem( {{(K / mx_block_size) * N}, memory::data_type::e8m0, {1}}, eng); { memory B_f32({b_desc.get_dims(), memory::data_type::f32, b_desc.get_strides()}, eng); prepare_input(B_f32); quantize_input(B_e4m3_elems_mem, B_e8m0_scales_mem, B_f32); } // For C, we only allocate as those will be populated by the // matmul execute call. memory C_e4m3_elems_mem(matmul_pd.dst_desc(), eng); memory C_e8m0_scales_mem( {{M * (N / mx_block_size)}, memory::data_type::e8m0, {1}}, eng); // Now MatMul primitive is run on a stream. For SRC, WEIGHTS and // DST, we provide both elements and associated scales as separate // buffers. stream s(eng); for (int run = 0; run < number_of_runs; ++run) matmul_p.execute(s, {{DNNL_ARG_SRC, A_e4m3_elems_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, A_e8m0_scales_mem}, {DNNL_ARG_WEIGHTS, B_e4m3_elems_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, B_e8m0_scales_mem}, {DNNL_ARG_DST, C_e4m3_elems_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, C_e8m0_scales_mem}}); s.wait(); } int main(int argc, char **argv) { engine::kind engine_kind = parse_engine_kind(argc, argv); return handle_example_errors(mxfp_matmul, engine_kind); }