.. index:: pair: page; Group Normalization .. _doxid-dev_guide_group_normalization: Group Normalization =================== :ref:`API Reference ` General ~~~~~~~ The group normalization primitive performs a forward or backward group normalization operation on tensors with numbers of dimensions equal to 3 or more. Forward ------- The group normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard :ref:`Naming Conventions `. .. math:: \dst(n, g \cdot C_G + c_g, h, w) = \gamma(g \cdot C_G + c_g) \cdot \frac{\src(n, g \cdot C_G + c_g, h, w) - \mu(n, g)} {\sqrt{\sigma^2(n, g) + \varepsilon}} + \beta(g \cdot C_G + c_g), where * :math:`C_G = \frac{C}{G}`, * :math:`c_g \in [0, C_G).`, * :math:`\gamma(c), \beta(c)` are optional scale and shift for a channel (see :ref:`dnnl_use_scale ` and :ref:`dnnl_use_shift ` flags), * :math:`\mu(n, g), \sigma^2(n, g)` are mean and variance for a group of channels in a batch (see :ref:`dnnl_use_global_stats ` flag), and * :math:`\varepsilon` is a constant to improve numerical stability. Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used: * :math:`\mu(n, g) = \frac{1}{(C/G)HW} \sum\limits_{c_ghw} \src(n, g \cdot C_G + c_g, h, w)_{}`, * :math:`\sigma^2(n, g) = \frac{1}{(C/G)HW} \sum\limits_{c_ghw} {}_{} (\src(n, g \cdot C_G + c_g, h, w) - \mu(n, g))^2`. The :math:`\gamma(c)` and :math:`\beta(c)` tensors are considered learnable. .. note:: * The group normalization primitive computes population mean and variance and not the sample or unbiased versions that are typically used to compute running mean and variance. * Using the mean and variance computed by the group normalization primitive, running mean and variance :math:`\hat\mu` and :math:`\hat\sigma^2` can be computed as .. math:: \hat\mu := \alpha \cdot \hat\mu + (1 - \alpha) \cdot \mu, \\ \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1 - \alpha) \cdot \sigma^2. Difference Between Forward Training and Forward Inference +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ * If mean and variance are computed at runtime (i.e., :ref:`dnnl_use_global_stats ` is not set), they become outputs for the propagation kind :ref:`dnnl_forward_training ` (because they would be required during the backward propagation) and are not exposed for the propagation kind :ref:`dnnl_forward_inference `. Backward -------- The backward propagation computes :math:`\diffsrc(n, c, h, w)`, :math:`\diffgamma(c)^*`, and :math:`\diffbeta(c)^*` based on :math:`\diffdst(n, c, h, w)`, :math:`\src(n, c, h, w)`, :math:`\mu(n, g)`, :math:`\sigma^2(n, g)`, :math:`\gamma(c) ^*`, and :math:`\beta(c) ^*`. The tensors marked with an asterisk are used only when the primitive is configured to use :math:`\gamma(c)` and :math:`\beta(c)` (i.e., :ref:`dnnl_use_scale ` or :ref:`dnnl_use_shift ` are set). Execution Arguments ~~~~~~~~~~~~~~~~~~~ Depending on the :ref:`flags ` and :ref:`propagation kind `, the group normalization primitive requires different inputs and outputs. For clarity, a summary is shown below. ====================================================================================================================================================================================================================================================================================================================================================================================================================================== ================================================================================================================================================= ================================================================================================================================================ =================================================================================================================================================================================== ==================================================================================================================================================== Flags :ref:`dnnl_forward_inference ` :ref:`dnnl_forward_training ` :ref:`dnnl_backward ` :ref:`dnnl_backward_data ` ====================================================================================================================================================================================================================================================================================================================================================================================================================================== ================================================================================================================================================= ================================================================================================================================================ =================================================================================================================================================================================== ==================================================================================================================================================== :ref:`dnnl_normalization_flags_none ` *Inputs* : :math:`\src` *Outputs* : :math:`\dst` *Inputs* : :math:`\src` *Outputs* : :math:`\dst` , :math:`\mu` , :math:`\sigma^2` *Inputs* : :math:`\diffdst` , :math:`\src` , :math:`\mu` , :math:`\sigma^2` *Outputs* : :math:`\diffsrc` Same as for :ref:`dnnl_backward ` :ref:`dnnl_use_global_stats ` *Inputs* : :math:`\src` , :math:`\mu` , :math:`\sigma^2` *Outputs* : :math:`\dst` *Inputs* : :math:`\src` , :math:`\mu` , :math:`\sigma^2` *Outputs* : :math:`\dst` *Inputs* : :math:`\diffdst` , :math:`\src` , :math:`\mu` , :math:`\sigma^2` *Outputs* : :math:`\diffsrc` Same as for :ref:`dnnl_backward ` :ref:`dnnl_use_scale ` *Inputs* : :math:`\src` , :math:`\gamma` *Outputs* : :math:`\dst` *Inputs* : :math:`\src` , :math:`\gamma` *Outputs* : :math:`\dst` , :math:`\mu` , :math:`\sigma^2` *Inputs* : :math:`\diffdst` , :math:`\src` , :math:`\mu` , :math:`\sigma^2` , :math:`\gamma` *Outputs* : :math:`\diffsrc` , :math:`\diffgamma` Not supported :ref:`dnnl_use_shift ` *Inputs* : :math:`\src` , :math:`\beta` *Outputs* : :math:`\dst` *Inputs* : :math:`\src` , :math:`\beta` *Outputs* : :math:`\dst` , :math:`\mu` , :math:`\sigma^2` *Inputs* : :math:`\diffdst` , :math:`\src` , :math:`\mu` , :math:`\sigma^2` , :math:`\beta` *Outputs* : :math:`\diffsrc` , :math:`\diffbeta` Not supported :ref:`dnnl_use_global_stats ` | :ref:`dnnl_use_scale ` | :ref:`dnnl_use_shift ` *Inputs* : :math:`\src` , :math:`\mu` , :math:`\sigma^2` , :math:`\gamma` , :math:`\beta` *Outputs* : :math:`\dst` *Inputs* : :math:`\src` , :math:`\mu` , :math:`\sigma^2` , :math:`\gamma` , :math:`\beta` *Outputs* : :math:`\dst` *Inputs* : :math:`\diffdst` , :math:`\src` , :math:`\mu` , :math:`\sigma^2` , :math:`\gamma` , :math:`\beta` *Outputs* : :math:`\diffsrc` , :math:`\diffgamma` , :math:`\diffbeta` Not supported ====================================================================================================================================================================================================================================================================================================================================================================================================================================== ================================================================================================================================================= ================================================================================================================================================ =================================================================================================================================================================================== ==================================================================================================================================================== When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table. ============================== ================================================================================================================================================================= Primitive Input/Output Execution Argument Index ============================== ================================================================================================================================================================= :math:`\src` DNNL_ARG_SRC :math:`\gamma` DNNL_ARG_SCALE :math:`\beta` DNNL_ARG_SHIFT mean ( :math:`\mu` ) DNNL_ARG_MEAN variance ( :math:`\sigma^2` ) DNNL_ARG_VARIANCE :math:`\dst` DNNL_ARG_DST :math:`\diffdst` DNNL_ARG_DIFF_DST :math:`\diffsrc` DNNL_ARG_DIFF_SRC :math:`\diffgamma` DNNL_ARG_DIFF_SCALE :math:`\diffbeta` DNNL_ARG_DIFF_SHIFT :math:`\text{binary post-op}` :ref:`DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) ` | DNNL_ARG_SRC_1 ============================== ================================================================================================================================================================= Implementation Details ~~~~~~~~~~~~~~~~~~~~~~ General Notes ------------- #. The different flavors of the primitive are partially controlled by the ``flags`` parameter that is passed to the primitive descriptor creation function (e.g., :ref:`dnnl::group_normalization_forward::primitive_desc() `). Multiple flags can be set using the bitwise OR operator (``|``). #. For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the :ref:`dnnl_use_global_stats ` flag. For the backward propagation, the mean and variance are always input parameters. #. Both forward and backward propagation support in-place operations, meaning that :math:`\src` can be used as input and output for forward propagation, and :math:`\diffdst` can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. Note, however, that backward propagation requires the original :math:`\src`, hence the corresponding forward propagation should not be performed in-place. Data Type Support ----------------- The operation supports the following combinations of data types: =================== ===================== ============================= Propagation Source / Destination Mean / Variance / ScaleShift =================== ===================== ============================= forward / backward f32, bf16, f16 f32 forward s8 f32 =================== ===================== ============================= .. warning:: There might be hardware- or implementation-specific restrictions. Check the :ref:`Implementation Limitations ` section below. Data Representation ------------------- Mean and Variance +++++++++++++++++ The mean (:math:`\mu`) and variance (:math:`\sigma^2`) are separate 2D tensors of size :math:`N \times G`. The format of the corresponding memory object must be :ref:`dnnl_nc ` (:ref:`dnnl_ab `). Scale and Shift +++++++++++++++ If :ref:`dnnl_use_scale ` or :ref:`dnnl_use_shift ` are used, the scale (:math:`\gamma`) and shift (:math:`\beta`) are separate 1D tensors of shape :math:`C`. The format of the corresponding memory object must be :ref:`dnnl_x ` (:ref:`dnnl_a `). Source, Destination, and Their Gradients ++++++++++++++++++++++++++++++++++++++++ The group normalization primitive expects data to be :math:`N \times C \times SP_n \times \cdots \times SP_0` tensor. The group normalization primitive is optimized for the following memory formats: ======== =============== ============================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================= Spatial Logical tensor Implementations optimized for memory formats ======== =============== ============================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================= 1D NCW :ref:`dnnl_ncw ` ( :ref:`dnnl_abc ` ), :ref:`dnnl_nwc ` ( :ref:`dnnl_acb ` ) 2D NCHW :ref:`dnnl_nchw ` ( :ref:`dnnl_abcd ` ), :ref:`dnnl_nhwc ` ( :ref:`dnnl_acdb ` ) 3D NCDHW :ref:`dnnl_ncdhw ` ( :ref:`dnnl_abcde ` ), :ref:`dnnl_ndhwc ` ( :ref:`dnnl_acdeb ` ) ======== =============== ============================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================= Post-Ops and Attributes ----------------------- Attributes enable you to modify the behavior of the group normalization primitive. The following attributes are supported by the group normalization primitive: ============ ========== ======================================================================================= ===================================================================================== =================================================================================== Propagation Type Operation Description Restrictions ============ ========== ======================================================================================= ===================================================================================== =================================================================================== forward attribute :ref:`Scales ` Scales the corresponding tensor by the given scale factor(s). Supported only for int8 group normalization and one scale per tensor is supported. forward Post-op :ref:`Binary ` Applies a :ref:`Binary ` operation to the result General binary post-op restrictions forward Post-op :ref:`Eltwise ` Applies an :ref:`Eltwise ` operation to the result. ============ ========== ======================================================================================= ===================================================================================== =================================================================================== :target:`doxid-dev_guide_group_normalization_1dg_gnorm_impl_limits` Implementation Limitations ~~~~~~~~~~~~~~~~~~~~~~~~~~ #. Refer to :ref:`Data Types ` for limitations related to data types support. Performance Tips ~~~~~~~~~~~~~~~~ #. Mixing different formats for inputs and outputs is functionally supported but leads to highly suboptimal performance. #. Use in-place operations whenever possible (see caveats in General Notes). Examples ~~~~~~~~ :ref:`Group Normalization Primitive Example ` This C++ API example demonstrates how to create and execute a :ref:`Group Normalization ` primitive in forward training propagation mode. Key optimizations included in this example: * In-place primitive execution; * Source memory format for an optimized primitive implementation;