.. index:: pair: example; cnn_training_bf16.cpp .. _doxid-cnn_training_bf16_8cpp-example: cnn_training_bf16.cpp ===================== This C++ API example demonstrates how to build an AlexNet model training using the bfloat16 data type. Annotated version: :ref:`CNN bf16 training example <doxid-cnn_training_bf16_cpp>` This C++ API example demonstrates how to build an AlexNet model training using the bfloat16 data type. Annotated version: :ref:`CNN bf16 training example <doxid-cnn_training_bf16_cpp>` .. ref-code-block:: cpp /******************************************************************************* * Copyright 2019-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #include <cassert> #include <cmath> #include <iostream> #include <stdexcept> #include "oneapi/dnnl/dnnl.hpp" #include "example_utils.hpp" using namespace :ref:`dnnl <doxid-namespacednnl>`; void simple_net(:ref:`engine::kind <doxid-structdnnl_1_1engine_1a2635da16314dcbdb9bd9ea431316bb1a>` engine_kind) { using :ref:`tag <doxid-structdnnl_1_1memory_1a8e71077ed6a5f7fb7b3e6e1a5a2ecf3f>` = :ref:`memory::format_tag <doxid-structdnnl_1_1memory_1a8e71077ed6a5f7fb7b3e6e1a5a2ecf3f>`; using :ref:`dt <doxid-structdnnl_1_1memory_1a8e83474ec3a50e08e37af76c8c075dce>` = :ref:`memory::data_type <doxid-structdnnl_1_1memory_1a8e83474ec3a50e08e37af76c8c075dce>`; auto eng = :ref:`engine <doxid-structdnnl_1_1engine>`(engine_kind, 0); :ref:`stream <doxid-structdnnl_1_1stream>` s(eng); // Vector of primitives and their execute arguments std::vector<primitive> net_fwd, net_bwd; std::vector<std::unordered_map<int, memory>> net_fwd_args, net_bwd_args; const int batch = 32; // float data type is used for user data std::vector<float> net_src(batch * 3 * 227 * 227); // initializing non-zero values for src for (size_t i = 0; i < net_src.size(); ++i) net_src[i] = sinf((float)i); // AlexNet: conv // {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55} // strides: {4, 4} :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` conv_src_tz = {batch, 3, 227, 227}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` conv_weights_tz = {96, 3, 11, 11}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` conv_bias_tz = {96}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` conv_dst_tz = {batch, 96, 55, 55}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` conv_strides = {4, 4}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` conv_padding = {0, 0}; // float data type is used for user data std::vector<float> conv_weights(product(conv_weights_tz)); std::vector<float> conv_bias(product(conv_bias_tz)); // initializing non-zero values for weights and bias for (size_t i = 0; i < conv_weights.size(); ++i) conv_weights[i] = sinf((float)i); for (size_t i = 0; i < conv_bias.size(); ++i) conv_bias[i] = sinf((float)i); // create memory for user data auto conv_user_src_memory = :ref:`memory <doxid-structdnnl_1_1memory>`({{conv_src_tz}, dt::f32, tag::nchw}, eng); write_to_dnnl_memory(net_src.data(), conv_user_src_memory); auto conv_user_weights_memory = :ref:`memory <doxid-structdnnl_1_1memory>`({{conv_weights_tz}, dt::f32, tag::oihw}, eng); write_to_dnnl_memory(conv_weights.data(), conv_user_weights_memory); auto conv_user_bias_memory = :ref:`memory <doxid-structdnnl_1_1memory>`({{conv_bias_tz}, dt::f32, tag::x}, eng); write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory); // create memory descriptors for bfloat16 convolution data w/ no specified // format tag(`any`) // tag `any` lets a primitive(convolution in this case) // chose the memory format preferred for best performance. auto conv_src_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({conv_src_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); auto conv_weights_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({conv_weights_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); auto conv_dst_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({conv_dst_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); // here bias data type is set to bf16. // additionally, f32 data type is supported for bf16 convolution. auto conv_bias_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({conv_bias_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); // create a convolution primitive descriptor // check if bf16 convolution is supported try { :ref:`convolution_forward::primitive_desc <doxid-structdnnl_1_1convolution__forward_1_1primitive__desc>`(eng, :ref:`prop_kind::forward <doxid-group__dnnl__api__attributes_1ggac7db48f6583aa9903e54c2a39d65438fa965dbaac085fc891bfbbd4f9d145bbc8>`, :ref:`algorithm::convolution_direct <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640a5028ad8f818a45333a8a0eefad35c5c0>`, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_padding, conv_padding); } catch (:ref:`error <doxid-structdnnl_1_1error>` &e) { if (e.status == :ref:`dnnl_unimplemented <doxid-group__dnnl__api__utils_1ggad24f9ded06e34d3ee71e7fc4b408d57aa3a8579e8afc4e23344cd3115b0e81de1>`) throw example_allows_unimplemented { "No bf16 convolution implementation is available for this " "platform.\n" "Please refer to the developer guide for details."}; // on any other error just re-throw throw; } auto conv_pd = :ref:`convolution_forward::primitive_desc <doxid-structdnnl_1_1convolution__forward_1_1primitive__desc>`(eng, :ref:`prop_kind::forward <doxid-group__dnnl__api__attributes_1ggac7db48f6583aa9903e54c2a39d65438fa965dbaac085fc891bfbbd4f9d145bbc8>`, :ref:`algorithm::convolution_direct <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640a5028ad8f818a45333a8a0eefad35c5c0>`, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_padding, conv_padding); // create reorder primitives between user input and conv src if needed auto conv_src_memory = conv_user_src_memory; if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) { conv_src_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(conv_pd.src_desc(), eng); net_fwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`(conv_user_src_memory, conv_src_memory)); net_fwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, conv_user_src_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, conv_src_memory}}); } auto conv_weights_memory = conv_user_weights_memory; if (conv_pd.weights_desc() != conv_user_weights_memory.get_desc()) { conv_weights_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(conv_pd.weights_desc(), eng); net_fwd.push_back( :ref:`reorder <doxid-structdnnl_1_1reorder>`(conv_user_weights_memory, conv_weights_memory)); net_fwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, conv_user_weights_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, conv_weights_memory}}); } // convert bias from f32 to bf16 as convolution descriptor is created with // bias data type as bf16. auto conv_bias_memory = conv_user_bias_memory; if (conv_pd.bias_desc() != conv_user_bias_memory.get_desc()) { conv_bias_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(conv_pd.bias_desc(), eng); net_fwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`(conv_user_bias_memory, conv_bias_memory)); net_fwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, conv_user_bias_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, conv_bias_memory}}); } // create memory for conv dst auto conv_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(conv_pd.dst_desc(), eng); // finally create a convolution primitive net_fwd.push_back(:ref:`convolution_forward <doxid-structdnnl_1_1convolution__forward>`(conv_pd)); net_fwd_args.push_back({{:ref:`DNNL_ARG_SRC <doxid-group__dnnl__api__primitives__common_1gac37ad67b48edeb9e742af0e50b70fe09>`, conv_src_memory}, {:ref:`DNNL_ARG_WEIGHTS <doxid-group__dnnl__api__primitives__common_1gaf279f28c59a807e71a70c719db56c5b3>`, conv_weights_memory}, {:ref:`DNNL_ARG_BIAS <doxid-group__dnnl__api__primitives__common_1gad0cbc09942aba93fbe3c0c2e09166f0d>`, conv_bias_memory}, {:ref:`DNNL_ARG_DST <doxid-group__dnnl__api__primitives__common_1ga3ca217e4a06d42a0ede3c018383c388f>`, conv_dst_memory}}); // AlexNet: relu // {batch, 96, 55, 55} -> {batch, 96, 55, 55} :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` relu_data_tz = {batch, 96, 55, 55}; const float negative_slope = 0.0f; // create relu primitive desc // keep memory format tag of source same as the format tag of convolution // output in order to avoid reorder auto relu_pd = :ref:`eltwise_forward::primitive_desc <doxid-structdnnl_1_1eltwise__forward_1_1primitive__desc>`(eng, :ref:`prop_kind::forward <doxid-group__dnnl__api__attributes_1ggac7db48f6583aa9903e54c2a39d65438fa965dbaac085fc891bfbbd4f9d145bbc8>`, :ref:`algorithm::eltwise_relu <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640aba09bebb742494255b90b43871c01c69>`, conv_pd.dst_desc(), conv_pd.dst_desc(), negative_slope); // create relu dst memory auto relu_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(relu_pd.dst_desc(), eng); // finally create a relu primitive net_fwd.push_back(:ref:`eltwise_forward <doxid-structdnnl_1_1eltwise__forward>`(relu_pd)); net_fwd_args.push_back( {{:ref:`DNNL_ARG_SRC <doxid-group__dnnl__api__primitives__common_1gac37ad67b48edeb9e742af0e50b70fe09>`, conv_dst_memory}, {:ref:`DNNL_ARG_DST <doxid-group__dnnl__api__primitives__common_1ga3ca217e4a06d42a0ede3c018383c388f>`, relu_dst_memory}}); // AlexNet: lrn // {batch, 96, 55, 55} -> {batch, 96, 55, 55} // local size: 5 // alpha: 0.0001 // beta: 0.75 // k: 1.0 :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` lrn_data_tz = {batch, 96, 55, 55}; const uint32_t local_size = 5; const float alpha = 0.0001f; const float beta = 0.75f; const float k = 1.0f; // create a lrn primitive descriptor auto lrn_pd = :ref:`lrn_forward::primitive_desc <doxid-structdnnl_1_1lrn__forward_1_1primitive__desc>`(eng, :ref:`prop_kind::forward <doxid-group__dnnl__api__attributes_1ggac7db48f6583aa9903e54c2a39d65438fa965dbaac085fc891bfbbd4f9d145bbc8>`, :ref:`algorithm::lrn_across_channels <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640ab9e2d858b551792385a4b5b86672b24b>`, relu_pd.dst_desc(), relu_pd.dst_desc(), local_size, alpha, beta, k); // create lrn dst memory auto lrn_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(lrn_pd.dst_desc(), eng); // create workspace only in training and only for forward primitive // query lrn_pd for workspace, this memory will be shared with forward lrn auto lrn_workspace_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(lrn_pd.workspace_desc(), eng); // finally create a lrn primitive net_fwd.push_back(:ref:`lrn_forward <doxid-structdnnl_1_1lrn__forward>`(lrn_pd)); net_fwd_args.push_back( {{:ref:`DNNL_ARG_SRC <doxid-group__dnnl__api__primitives__common_1gac37ad67b48edeb9e742af0e50b70fe09>`, relu_dst_memory}, {:ref:`DNNL_ARG_DST <doxid-group__dnnl__api__primitives__common_1ga3ca217e4a06d42a0ede3c018383c388f>`, lrn_dst_memory}, {:ref:`DNNL_ARG_WORKSPACE <doxid-group__dnnl__api__primitives__common_1ga550c80e1b9ba4f541202a7ac98be117f>`, lrn_workspace_memory}}); // AlexNet: pool // {batch, 96, 55, 55} -> {batch, 96, 27, 27} // kernel: {3, 3} // strides: {2, 2} :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` pool_dst_tz = {batch, 96, 27, 27}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` pool_kernel = {3, 3}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` pool_strides = {2, 2}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` pool_dilation = {0, 0}; :ref:`memory::dims <doxid-structdnnl_1_1memory_1afdd20764d58c0b517d5a31276672aeb8>` pool_padding = {0, 0}; // create memory for pool dst data in user format auto pool_user_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`({{pool_dst_tz}, dt::f32, tag::nchw}, eng); // create pool dst memory descriptor in format any for bfloat16 data type auto pool_dst_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({pool_dst_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); // create a pooling primitive descriptor auto pool_pd = :ref:`pooling_forward::primitive_desc <doxid-structdnnl_1_1pooling__forward_1_1primitive__desc>`(eng, :ref:`prop_kind::forward <doxid-group__dnnl__api__attributes_1ggac7db48f6583aa9903e54c2a39d65438fa965dbaac085fc891bfbbd4f9d145bbc8>`, :ref:`algorithm::pooling_max <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640a8c73d4bb88a0497586a74256bb338e88>`, lrn_dst_memory.get_desc(), pool_dst_md, pool_strides, pool_kernel, pool_dilation, pool_padding, pool_padding); // create pooling workspace memory if training auto pool_workspace_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(pool_pd.workspace_desc(), eng); // create a pooling primitive net_fwd.push_back(:ref:`pooling_forward <doxid-structdnnl_1_1pooling__forward>`(pool_pd)); // leave DST unknown for now (see the next reorder) net_fwd_args.push_back({{:ref:`DNNL_ARG_SRC <doxid-group__dnnl__api__primitives__common_1gac37ad67b48edeb9e742af0e50b70fe09>`, lrn_dst_memory}, // delay putting DST until reorder (if needed) {:ref:`DNNL_ARG_WORKSPACE <doxid-group__dnnl__api__primitives__common_1ga550c80e1b9ba4f541202a7ac98be117f>`, pool_workspace_memory}}); // create reorder primitive between pool dst and user dst format // if needed auto pool_dst_memory = pool_user_dst_memory; if (pool_pd.dst_desc() != pool_user_dst_memory.get_desc()) { pool_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(pool_pd.dst_desc(), eng); net_fwd_args.back().insert({:ref:`DNNL_ARG_DST <doxid-group__dnnl__api__primitives__common_1ga3ca217e4a06d42a0ede3c018383c388f>`, pool_dst_memory}); net_fwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`(pool_dst_memory, pool_user_dst_memory)); net_fwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, pool_dst_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, pool_user_dst_memory}}); } else { net_fwd_args.back().insert({:ref:`DNNL_ARG_DST <doxid-group__dnnl__api__primitives__common_1ga3ca217e4a06d42a0ede3c018383c388f>`, pool_dst_memory}); } //----------------------------------------------------------------------- //----------------- Backward Stream ------------------------------------- // ... user diff_data in float data type ... std::vector<float> net_diff_dst(batch * 96 * 27 * 27); for (size_t i = 0; i < net_diff_dst.size(); ++i) net_diff_dst[i] = sinf((float)i); // create memory for user diff dst data stored in float data type auto pool_user_diff_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`({{pool_dst_tz}, dt::f32, tag::nchw}, eng); write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory); // Backward pooling // create memory descriptors for pooling auto pool_diff_src_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({lrn_data_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); auto pool_diff_dst_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({pool_dst_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); // backward primitive descriptor needs to hint forward descriptor auto pool_bwd_pd = :ref:`pooling_backward::primitive_desc <doxid-structdnnl_1_1pooling__backward_1_1primitive__desc>`(eng, :ref:`algorithm::pooling_max <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640a8c73d4bb88a0497586a74256bb338e88>`, pool_diff_src_md, pool_diff_dst_md, pool_strides, pool_kernel, pool_dilation, pool_padding, pool_padding, pool_pd); // create reorder primitive between user diff dst and pool diff dst // if required auto pool_diff_dst_memory = pool_user_diff_dst_memory; if (pool_dst_memory.get_desc() != pool_user_diff_dst_memory.get_desc()) { pool_diff_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(pool_dst_memory.get_desc(), eng); net_bwd.push_back( :ref:`reorder <doxid-structdnnl_1_1reorder>`(pool_user_diff_dst_memory, pool_diff_dst_memory)); net_bwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, pool_user_diff_dst_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, pool_diff_dst_memory}}); } // create memory for pool diff src auto pool_diff_src_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(pool_bwd_pd.diff_src_desc(), eng); // finally create backward pooling primitive net_bwd.push_back(:ref:`pooling_backward <doxid-structdnnl_1_1pooling__backward>`(pool_bwd_pd)); net_bwd_args.push_back({{:ref:`DNNL_ARG_DIFF_DST <doxid-group__dnnl__api__primitives__common_1gac9302f4cbd2668bf9a98ba99d752b971>`, pool_diff_dst_memory}, {:ref:`DNNL_ARG_DIFF_SRC <doxid-group__dnnl__api__primitives__common_1ga18ee0e360399cfe9d3b58a13dfcb9333>`, pool_diff_src_memory}, {:ref:`DNNL_ARG_WORKSPACE <doxid-group__dnnl__api__primitives__common_1ga550c80e1b9ba4f541202a7ac98be117f>`, pool_workspace_memory}}); // Backward lrn auto lrn_diff_dst_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({lrn_data_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); const auto &lrn_diff_src_md = lrn_diff_dst_md; // create backward lrn primitive descriptor auto lrn_bwd_pd = :ref:`lrn_backward::primitive_desc <doxid-structdnnl_1_1lrn__backward_1_1primitive__desc>`(eng, :ref:`algorithm::lrn_across_channels <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640ab9e2d858b551792385a4b5b86672b24b>`, lrn_diff_src_md, lrn_diff_dst_md, lrn_pd.src_desc(), local_size, alpha, beta, k, lrn_pd); // create reorder primitive between pool diff src and lrn diff dst // if required auto lrn_diff_dst_memory = pool_diff_src_memory; if (lrn_diff_dst_memory.get_desc() != lrn_bwd_pd.diff_dst_desc()) { lrn_diff_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(lrn_bwd_pd.diff_dst_desc(), eng); net_bwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`(pool_diff_src_memory, lrn_diff_dst_memory)); net_bwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, pool_diff_src_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, lrn_diff_dst_memory}}); } // create memory for lrn diff src auto lrn_diff_src_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(lrn_bwd_pd.diff_src_desc(), eng); // finally create a lrn backward primitive // backward lrn needs src: relu dst in this topology net_bwd.push_back(:ref:`lrn_backward <doxid-structdnnl_1_1lrn__backward>`(lrn_bwd_pd)); net_bwd_args.push_back({{:ref:`DNNL_ARG_SRC <doxid-group__dnnl__api__primitives__common_1gac37ad67b48edeb9e742af0e50b70fe09>`, relu_dst_memory}, {:ref:`DNNL_ARG_DIFF_DST <doxid-group__dnnl__api__primitives__common_1gac9302f4cbd2668bf9a98ba99d752b971>`, lrn_diff_dst_memory}, {:ref:`DNNL_ARG_DIFF_SRC <doxid-group__dnnl__api__primitives__common_1ga18ee0e360399cfe9d3b58a13dfcb9333>`, lrn_diff_src_memory}, {:ref:`DNNL_ARG_WORKSPACE <doxid-group__dnnl__api__primitives__common_1ga550c80e1b9ba4f541202a7ac98be117f>`, lrn_workspace_memory}}); // Backward relu auto relu_diff_src_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({relu_data_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); auto relu_diff_dst_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({relu_data_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); auto relu_src_md = conv_pd.dst_desc(); // create backward relu primitive_descriptor auto relu_bwd_pd = :ref:`eltwise_backward::primitive_desc <doxid-structdnnl_1_1eltwise__backward_1_1primitive__desc>`(eng, :ref:`algorithm::eltwise_relu <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640aba09bebb742494255b90b43871c01c69>`, relu_diff_src_md, relu_diff_dst_md, relu_src_md, negative_slope, relu_pd); // create reorder primitive between lrn diff src and relu diff dst // if required auto relu_diff_dst_memory = lrn_diff_src_memory; if (relu_diff_dst_memory.get_desc() != relu_bwd_pd.diff_dst_desc()) { relu_diff_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(relu_bwd_pd.diff_dst_desc(), eng); net_bwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`(lrn_diff_src_memory, relu_diff_dst_memory)); net_bwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, lrn_diff_src_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, relu_diff_dst_memory}}); } // create memory for relu diff src auto relu_diff_src_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(relu_bwd_pd.diff_src_desc(), eng); // finally create a backward relu primitive net_bwd.push_back(:ref:`eltwise_backward <doxid-structdnnl_1_1eltwise__backward>`(relu_bwd_pd)); net_bwd_args.push_back({{:ref:`DNNL_ARG_SRC <doxid-group__dnnl__api__primitives__common_1gac37ad67b48edeb9e742af0e50b70fe09>`, conv_dst_memory}, {:ref:`DNNL_ARG_DIFF_DST <doxid-group__dnnl__api__primitives__common_1gac9302f4cbd2668bf9a98ba99d752b971>`, relu_diff_dst_memory}, {:ref:`DNNL_ARG_DIFF_SRC <doxid-group__dnnl__api__primitives__common_1ga18ee0e360399cfe9d3b58a13dfcb9333>`, relu_diff_src_memory}}); // Backward convolution with respect to weights // create user format diff weights and diff bias memory for float data type auto conv_user_diff_weights_memory = :ref:`memory <doxid-structdnnl_1_1memory>`({{conv_weights_tz}, dt::f32, tag::nchw}, eng); auto conv_diff_bias_memory = :ref:`memory <doxid-structdnnl_1_1memory>`({{conv_bias_tz}, dt::f32, tag::x}, eng); // create memory descriptors for bfloat16 convolution data auto conv_bwd_src_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({conv_src_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); auto conv_diff_weights_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({conv_weights_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); auto conv_diff_dst_md = :ref:`memory::desc <doxid-structdnnl_1_1memory_1_1desc>`({conv_dst_tz}, :ref:`dt::bf16 <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725afe2904d9fb3b0f4a81c92b03dec11424>`, :ref:`tag::any <doxid-group__dnnl__api__fpmath__mode_1gga0ad94cbef13dce222933422bfdcfa725a100b8cad7cf2a56f6df78f171f97a1ec>`); // use diff bias provided by the user auto conv_diff_bias_md = conv_diff_bias_memory.:ref:`get_desc <doxid-structdnnl_1_1memory_1ad8a1ad28ed7acf9c34c69e4b882c6e92>`(); // create backward convolution primitive descriptor auto conv_bwd_weights_pd = :ref:`convolution_backward_weights::primitive_desc <doxid-structdnnl_1_1convolution__backward__weights_1_1primitive__desc>`(eng, :ref:`algorithm::convolution_direct <doxid-group__dnnl__api__attributes_1gga00377dd4982333e42e8ae1d09a309640a5028ad8f818a45333a8a0eefad35c5c0>`, conv_bwd_src_md, conv_diff_weights_md, conv_diff_bias_md, conv_diff_dst_md, conv_strides, conv_padding, conv_padding, conv_pd); // for best performance convolution backward might chose // different memory format for src and diff_dst // than the memory formats preferred by forward convolution // for src and dst respectively // create reorder primitives for src from forward convolution to the // format chosen by backward convolution auto conv_bwd_src_memory = conv_src_memory; if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) { conv_bwd_src_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(conv_bwd_weights_pd.src_desc(), eng); net_bwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`(conv_src_memory, conv_bwd_src_memory)); net_bwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, conv_src_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, conv_bwd_src_memory}}); } // create reorder primitives for diff_dst between diff_src from relu_bwd // and format preferred by conv_diff_weights auto conv_diff_dst_memory = relu_diff_src_memory; if (conv_bwd_weights_pd.diff_dst_desc() != relu_diff_src_memory.get_desc()) { conv_diff_dst_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(conv_bwd_weights_pd.diff_dst_desc(), eng); net_bwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`(relu_diff_src_memory, conv_diff_dst_memory)); net_bwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, relu_diff_src_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, conv_diff_dst_memory}}); } // create backward convolution primitive net_bwd.push_back(:ref:`convolution_backward_weights <doxid-structdnnl_1_1convolution__backward__weights>`(conv_bwd_weights_pd)); net_bwd_args.push_back({{:ref:`DNNL_ARG_SRC <doxid-group__dnnl__api__primitives__common_1gac37ad67b48edeb9e742af0e50b70fe09>`, conv_bwd_src_memory}, {:ref:`DNNL_ARG_DIFF_DST <doxid-group__dnnl__api__primitives__common_1gac9302f4cbd2668bf9a98ba99d752b971>`, conv_diff_dst_memory}, // delay putting DIFF_WEIGHTS until reorder (if needed) {:ref:`DNNL_ARG_DIFF_BIAS <doxid-group__dnnl__api__primitives__common_1ga1cd79979dda6df65ec45eef32a839901>`, conv_diff_bias_memory}}); // create reorder primitives between conv diff weights and user diff weights // if needed auto conv_diff_weights_memory = conv_user_diff_weights_memory; if (conv_bwd_weights_pd.diff_weights_desc() != conv_user_diff_weights_memory.get_desc()) { conv_diff_weights_memory = :ref:`memory <doxid-structdnnl_1_1memory>`(conv_bwd_weights_pd.diff_weights_desc(), eng); net_bwd_args.back().insert( {:ref:`DNNL_ARG_DIFF_WEIGHTS <doxid-group__dnnl__api__primitives__common_1ga3324092ef421f77aebee83b0117cac60>`, conv_diff_weights_memory}); net_bwd.push_back(:ref:`reorder <doxid-structdnnl_1_1reorder>`( conv_diff_weights_memory, conv_user_diff_weights_memory)); net_bwd_args.push_back({{:ref:`DNNL_ARG_FROM <doxid-group__dnnl__api__primitives__common_1ga953b34f004a8222b04e21851487c611a>`, conv_diff_weights_memory}, {:ref:`DNNL_ARG_TO <doxid-group__dnnl__api__primitives__common_1gaf700c3396987b450413c8df5d78bafd9>`, conv_user_diff_weights_memory}}); } else { net_bwd_args.back().insert( {:ref:`DNNL_ARG_DIFF_WEIGHTS <doxid-group__dnnl__api__primitives__common_1ga3324092ef421f77aebee83b0117cac60>`, conv_diff_weights_memory}); } // didn't we forget anything? assert(net_fwd.size() == net_fwd_args.size() && "something is missing"); assert(net_bwd.size() == net_bwd_args.size() && "something is missing"); int n_iter = 1; // number of iterations for training // execute while (n_iter) { // forward for (size_t i = 0; i < net_fwd.size(); ++i) net_fwd.at(i).execute(s, net_fwd_args.at(i)); // update net_diff_dst // auto net_output = pool_user_dst_memory.get_data_handle(); // ..user updates net_diff_dst using net_output... // some user defined func update_diff_dst(net_diff_dst.data(), // net_output) for (size_t i = 0; i < net_bwd.size(); ++i) net_bwd.at(i).execute(s, net_bwd_args.at(i)); // update weights and bias using diff weights and bias // // auto net_diff_weights // = conv_user_diff_weights_memory.get_data_handle(); // auto net_diff_bias = conv_diff_bias_memory.get_data_handle(); // // ...user updates weights and bias using diff weights and bias... // // some user defined func update_weights(conv_weights.data(), // conv_bias.data(), net_diff_weights, net_diff_bias); --n_iter; } s.wait(); } int main(int argc, char **argv) { return handle_example_errors(simple_net, parse_engine_kind(argc, argv)); }