matmul_f8_quantization.cpp#
Annotated version: Matrix Multiplication with f8 Quantization
/******************************************************************************* * Copyright 2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include <algorithm> #include <cmath> #include <iostream> #include <limits> #include <numeric> #include <stdexcept> #include <string> #include <vector> #include "example_utils.hpp" using namespace dnnl; float decode_f8_e4m3(uint8_t f8_val) { if (f8_val == 0) return 0.0f; // Extract bit components: f8_e4m3 format is S EEEE MMM (bit 7 to 0) const uint8_t sign = (f8_val >> 7) & 0x1; // Bit 7: sign const uint8_t exp = (f8_val >> 3) & 0xF; // Bits 6-3: 4-bit exponent const uint8_t mant = f8_val & 0x7; // Bits 2-0: 3-bit mantissa // Only exp=15, mant=7 is NaN (no infinity) if (exp == 15 && mant == 7) { return std::numeric_limits<float>::quiet_NaN(); } float result; if (exp == 0) { // Denormal: 0.mant * 2^(-6) result = (float)mant / 8.0f * powf(2.0f, -6); } else { // Normal: (1 + mant/2^(3)) * 2^(exp-7) result = (1.0f + (float)mant / 8.0f) * powf(2.0f, (int)exp - 7); } return sign ? -result : result; } float decode_f8_e5m2(uint8_t f8_val) { if (f8_val == 0) return 0.0f; // Extract bit components: f8_e5m2 format is S EEEEE MM (bit 7 to 0) const uint8_t sign = (f8_val >> 7) & 0x1; // Bit 7: sign const uint8_t exp = (f8_val >> 2) & 0x1F; // Bits 6-2: 5-bit exponent const uint8_t mant = f8_val & 0x3; // Bits 1-0: 2-bit mantissa // Handle special cases (infinity and NaN) if (exp == 31) { if (mant == 0) { return (sign ? -1.0f : 1.0f) * INFINITY; // Infinity } else { return std::numeric_limits<float>::quiet_NaN(); // NaN } } float result; if (exp == 0) { // Denormal: 0.mant * 2^(-14) result = (float)mant / 4.0f * powf(2.0f, -14); } else { // Normal: (1 + mant/2^(2)) * 2^(exp-15) result = (1.0f + (float)mant / 4.0f) * powf(2.0f, (int)exp - 15); } return sign ? -result : result; } std::string get_f8_type_name(memory::data_type dt) { switch (dt) { case memory::data_type::f8_e5m2: return "f8_e5m2"; case memory::data_type::f8_e4m3: return "f8_e4m3"; default: return "Unsupported data type"; } } float return_max_value(memory::data_type dt) { switch (dt) { case memory::data_type::f8_e5m2: // f8_e5m2: 1 sign bit + 5 bit exponent (bias=15) + 2 bit mantissa // Per OCP f8 spec: infinity = 11111.00, NaN = 11111.{01, 10, 11} // Max: exponent=30, mantissa=11 (in binary) -> 1.75 × 2^(30-15) = 57344 return 57344.0f; case memory::data_type::f8_e4m3: // f8_e4m3: 1 sign bit + 4 bit exponent (bias=7) + 3 bit mantissa // Per OCP f8 spec: no infinity, NaN = 1111.111 // Max: exponent=15, mantissa=110 (in binary) -> 1.75 × 2^(15-7) = 448 return 448.0f; default: throw std::invalid_argument("Unsupported data type"); } } float compute_naive_quantization(const float *data, size_t size, memory::data_type dst_type, const std::string &label) { if (dst_type != memory::data_type::f8_e5m2 && dst_type != memory::data_type::f8_e4m3) { throw std::invalid_argument("Unsupported data type"); } // Find the maximum absolute value in the data float max_abs = 0.0f; for (size_t i = 0; i < size; ++i) { max_abs = std::max(max_abs, std::abs(data[i])); } // Get theoretical maximum value for the target f8 format float f8_max = return_max_value(dst_type); // Only apply scaling if values exceed the f8 range float scale; if (max_abs <= f8_max) { scale = 1.0f; std::cout << " " << label << " fits in " << get_f8_type_name(dst_type) << " (max=" << max_abs << ", f8_max=" << f8_max << ")" << std::endl; } else { scale = max_abs / f8_max; std::cout << " " << label << " max (" << max_abs << ") > " << get_f8_type_name(dst_type) << " max (" << f8_max << "), scaling: " << scale << std::endl; } return scale; } void perform_matmul_with_f8_quantization(engine::kind engine_kind, memory::data_type f8_type = memory::data_type::f8_e5m2) { if (f8_type != memory::data_type::f8_e5m2 && f8_type != memory::data_type::f8_e4m3) { throw std::invalid_argument("Unsupported data type"); } // Create execution dnnl::engine engine eng(engine_kind, 0); // Create dnnl::stream stream s(eng); // Matrix dimensions for A * B = C const int M = 4, K = 8, N = 4; std::cout << get_f8_type_name(f8_type) << " Quantization Example:" << std::endl; std::cout << " Matrix dimensions: A(" << M << "x" << K << ") * B(" << K << "x" << N << ") = C(" << M << "x" << N << ")" << std::endl; // Initialize input data with float values, and fill matrices with // sample data to demonstrate scaling behavior. // Source: values within f8_e4m3 range (< 448) - should not need scaling for E4M3. // Weights: values exceeding f8_e4m3 range (> 448) - will need scaling for E4M3. std::vector<float> src_f32(M * K); std::vector<float> weights_f32(K * N); std::iota(src_f32.begin(), src_f32.end(), 100.0f); // Each value is 100+ (fits in both formats) std::iota(weights_f32.begin(), weights_f32.end(), 450.0f); // Each value is 450+ (exceeds f8_e4m3 max of 448) // Create memory for inputs and outputs in f32 format auto src_md = memory::desc( {M, K}, memory::data_type::f32, memory::format_tag::ab); auto weights_md = memory::desc( {K, N}, memory::data_type::f32, memory::format_tag::ab); auto dst_md = memory::desc( {M, N}, memory::data_type::f32, memory::format_tag::ab); auto src_mem = memory(src_md, eng); write_to_dnnl_memory(src_f32.data(), src_mem); auto weights_mem = memory(weights_md, eng); write_to_dnnl_memory(weights_f32.data(), weights_mem); auto dst_mem = memory(dst_md, eng); // Create f8 memory descriptors for quantized data auto src_f8_md = memory::desc({M, K}, f8_type, memory::format_tag::ab); auto weights_f8_md = memory::desc({K, N}, f8_type, memory::format_tag::ab); auto src_f8_mem = memory(src_f8_md, eng); auto weights_f8_mem = memory(weights_f8_md, eng); // Step 1: Compute scaling factors for quantization std::cout << "\nStep 1: Computing scaling factors for f32 to " << get_f8_type_name(f8_type) << " quantization" << std::endl; float src_scale = compute_naive_quantization( src_f32.data(), src_f32.size(), f8_type, "Source"); float weights_scale = compute_naive_quantization( weights_f32.data(), weights_f32.size(), f8_type, "Weights"); // Step 2: Quantize f32 to f8 format with scaling std::cout << "\nStep 2: Quantizing f32 data to " << get_f8_type_name(f8_type) << " format with scaling" << std::endl; // Create memory for scales auto src_scale_mem = memory({{1}, memory::data_type::f32, memory::format_tag::x}, eng); write_to_dnnl_memory(&src_scale, src_scale_mem); auto weights_scale_mem = memory({{1}, memory::data_type::f32, memory::format_tag::x}, eng); write_to_dnnl_memory(&weights_scale, weights_scale_mem); // Create reorder primitives with scaling attributes primitive_attr src_attr, weights_attr; src_attr.set_scales_mask(DNNL_ARG_DST, 0); weights_attr.set_scales_mask(DNNL_ARG_DST, 0); // Check if f8 reorders are supported on this platform try { reorder::primitive_desc(eng, src_md, eng, src_f8_md, src_attr); reorder::primitive_desc( eng, weights_md, eng, weights_f8_md, weights_attr); } catch (error &e) { if (e.status == dnnl_unimplemented) throw example_allows_unimplemented { "No f8 reorder implementation is available for this " "platform.\n" "Please refer to the developer guide for details."}; // on any other error just re-throw throw; } auto reorder_src_pd = reorder::primitive_desc(eng, src_md, eng, src_f8_md, src_attr); auto reorder_weights_pd = reorder::primitive_desc( eng, weights_md, eng, weights_f8_md, weights_attr); auto reorder_src = reorder(reorder_src_pd); auto reorder_weights = reorder(reorder_weights_pd); // Execute reorders with scaling reorder_src.execute(s, {{DNNL_ARG_SRC, src_mem}, {DNNL_ARG_DST, src_f8_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, src_scale_mem}}); reorder_weights.execute(s, {{DNNL_ARG_SRC, weights_mem}, {DNNL_ARG_DST, weights_f8_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, weights_scale_mem}}); s.wait(); // Show key quantization results std::cout << " Quantization summary:" << std::endl; std::cout << " Scaling factors: src=" << src_scale << ", weights=" << weights_scale << std::endl; // Read a few f8 values to demonstrate quantization std::vector<uint8_t> weights_f8_data(K * N); read_from_dnnl_memory(weights_f8_data.data(), weights_f8_mem); auto decode_f8 = (f8_type == memory::data_type::f8_e4m3) ? decode_f8_e4m3 : decode_f8_e5m2; std::cout << " Sample: f32=" << weights_f32[0] << " -> f8=" << (int)weights_f8_data[0] << " -> decoded=" << decode_f8(weights_f8_data[0]) << " (f8 as float)" << " -> final=" << decode_f8(weights_f8_data[0]) * weights_scale << " (dequantized)" << std::endl; std::cout << " Successfully quantized inputs to " << get_f8_type_name(f8_type) << " format with scaling" << std::endl; // Step 3: Matrix multiplication with f8 std::cout << "\nStep 3: Performing matrix multiplication with " << get_f8_type_name(f8_type) << " inputs" << std::endl; // Create matmul with dequantization attributes primitive_attr matmul_attr; matmul_attr.set_scales_mask(DNNL_ARG_SRC, 0); matmul_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); // Check if f8 matmul is supported on this platform try { matmul::primitive_desc( eng, src_f8_md, weights_f8_md, dst_md, matmul_attr); } catch (error &e) { if (e.status == dnnl_unimplemented) throw example_allows_unimplemented { "No f8 matmul implementation is available for this " "platform.\n" "Please refer to the developer guide for details."}; // on any other error just re-throw throw; } auto matmul_pd = matmul::primitive_desc( eng, src_f8_md, weights_f8_md, dst_md, matmul_attr); auto matmul_prim = matmul(matmul_pd); // Execute matmul with dequantization matmul_prim.execute(s, {{DNNL_ARG_SRC, src_f8_mem}, {DNNL_ARG_WEIGHTS, weights_f8_mem}, {DNNL_ARG_DST, dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scale_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, weights_scale_mem}}); s.wait(); std::cout << " Matrix multiplication completed successfully" << std::endl; // Read result for validation std::vector<float> dst_result(M * N); read_from_dnnl_memory(dst_result.data(), dst_mem); // Step 4: Validate results std::cout << "\nStep 4: Validating results against f32 reference" << std::endl; // Compute reference result with f32 precision std::vector<float> ref_result(M * N, 0.0f); for (int m = 0; m < M; ++m) { for (int n = 0; n < N; ++n) { for (int k = 0; k < K; ++k) { ref_result[m * N + n] += src_f32[m * K + k] * weights_f32[k * N + n]; } } } // Calculate relative error between f8 and f32 results float max_rel_error = 0.0f; // Use the dst_result vector that we already read instead of direct memory access // This ensures compatibility with GPU where get_data_handle() may not work for (int i = 0; i < M * N; ++i) { if (std::abs(ref_result[i]) > 1e-6f) { float rel_error = std::abs(dst_result[i] - ref_result[i]) / std::abs(ref_result[i]); max_rel_error = std::max(max_rel_error, rel_error); } } // For example purposes set tolerance to 15% const float tolerance = 0.15f; bool validation_passed = max_rel_error < tolerance; std::cout << " Validation " << (validation_passed ? "PASSED" : "FAILED") << " (max relative error: " << max_rel_error * 100.0f << "%, tolerance: " << tolerance * 100.0f << "%)" << std::endl; if (!validation_passed) { throw std::runtime_error( " Validation failed: results exceed expected tolerance"); } } void run_f8_tutorials(engine::kind engine_kind) { // Sample 1: f8_e5m2 std::cout << "Sample 1: f8_e5m2 Format" << std::endl; std::cout << "==========================" << std::endl; perform_matmul_with_f8_quantization( engine_kind, memory::data_type::f8_e5m2); std::cout << "f8_e5m2 tutorial completed successfully" << std::endl << std::endl; // Sample 2: f8_e4m3 std::cout << "Sample 2: f8_e4m3 Format" << std::endl; std::cout << "==========================" << std::endl; perform_matmul_with_f8_quantization( engine_kind, memory::data_type::f8_e4m3); std::cout << "f8_e4m3 tutorial completed successfully" << std::endl << std::endl; } int main(int argc, char **argv) { engine::kind engine_kind = parse_engine_kind(argc, argv); return handle_example_errors(run_f8_tutorials, engine_kind); }