/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

/// @file
/// C API

#ifndef ONEAPI_DNNL_DNNL_H
#define ONEAPI_DNNL_DNNL_H

#include "oneapi/dnnl/dnnl_common.h"
#include "oneapi/dnnl/dnnl_config.h"
#include "oneapi/dnnl/dnnl_types.h"
#include "oneapi/dnnl/dnnl_version.h"

#ifdef __cplusplus
extern "C" {
#endif

/// @addtogroup dnnl_api
/// @{

/// @addtogroup dnnl_api_primitives
/// @{

/// @addtogroup dnnl_api_primitives_common
/// @{

/// Changes the primitive descriptor to point to the next available
/// implementation.
///
/// @param primitive_desc A primitive descriptor to change.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
/// @returns #dnnl_last_impl_reached if no more implementations available,
/// in which case the primitive descriptor itself is kept unchanged.
dnnl_status_t DNNL_API dnnl_primitive_desc_next_impl(
        dnnl_primitive_desc_t primitive_desc);

/// Clones a primitive descriptor. The resulting primitive descriptor must be
/// destroyed separately.
///
/// @param primitive_desc Output primitive descriptor.
/// @param existing_primitive_desc Primitive descriptor to clone.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
        dnnl_primitive_desc_t *primitive_desc,
        const_dnnl_primitive_desc_t existing_primitive_desc);

/// Returns a constant reference to the attributes of a primitive descriptor.
///
/// @warning
///     It is an error to destroy the resulting @p attr.
///
/// @warning
///     The lifetime of an @p attr is the same as that of a @p
///     primitive_desc, so it is an error to use the @p attr once the @p
///     primitive_desc has been destroyed.
///
/// @param primitive_desc Primitive descriptor.
/// @param attr Output primitive attributes.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
        const_dnnl_primitive_desc_t primitive_desc,
        const_dnnl_primitive_attr_t *attr);

/// Destroys a primitive descriptor.
///
/// @param primitive_desc Primitive descriptor to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(
        dnnl_primitive_desc_t primitive_desc);

/// Queries a primitive descriptor for various pieces of information.
///
/// The most common use case is to query a primitive descriptor, created with
/// source, weights, and destination memory descriptors with format tags set
/// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this
/// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and
/// #dnnl_query_dst_md respectively) so that it is possible to create memory
/// objects and reorder primitives if necessary.
///
/// Another typical use case is to query a primitive descriptor for workspace
/// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this
/// query returns #dnnl_not_required status, then workspace memory is not
/// required.
///
/// @note
///     When querying for a memory descriptor for a scratchpad, a workspace,
///     or an optional parameter, the query will return a pointer to a zero
///     memory descriptor if the parameter is not needed.
///
/// A few other use cases:
///  - query a primitive descriptor for the implementation information string
///    (#dnnl_query_impl_info_str)
///  - query a primitive descriptor for the number of inputs and outputs
///    (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32
///    respectively)
///
/// @sa dnnl_query_t for more options
///
/// @param primitive_desc Primitive descriptor.
/// @param what Parameter to query.
/// @param index Index of the parameter to query for.
/// @param result Output result. The type depends on the query. For example,
///     it must be a @c dnnl_memory_desc_t* if querying for a memory
///     descriptor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_query(
        const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
        int index, void *result);

/// Queries primitive descriptor for a memory descriptor.
///
/// @note
///     This function is a convenience version of
///     #dnnl_primitive_desc_query().
///
/// @param primitive_desc Primitive descriptor.
/// @param what Kind of memory descriptor parameter to query for.
/// @param index Index of the parameter to query.
/// @returns A pointer to the requested memory descriptor.
/// @returns A pointer to a zero memory descriptor if the parameter is not
///          needed.
/// @returns NULL in case of any error.
///
const_dnnl_memory_desc_t DNNL_API dnnl_primitive_desc_query_md(
        const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
        int index);

/// Queries primitive descriptor for a signed 32bit int.
///
/// @note
///     This function is a convenience version of
///     #dnnl_primitive_desc_query().
///
/// @param primitive_desc Primitive descriptor.
/// @param what Kind of the value to query for.
/// @param index Index of the parameter to query.
/// @returns The requested value.
/// @returns 0 in case of any error (in particular if the queried entity is
///     not of type int32_t). Note that 0 may also be the actual returned
///     value.
int DNNL_API dnnl_primitive_desc_query_s32(
        const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
        int index);

/// Creates a primitive.
///
/// @param primitive Output primitive.
/// @param primitive_desc Primitive descriptor used to create the primitive.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive,
        const_dnnl_primitive_desc_t primitive_desc);

/// Creates a primitive from a cache blob.
///
/// @param primitive Output primitive.
/// @param primitive_desc Primitive descriptor used to create the primitive.
/// @param size Size of the cache blob in bytes.
/// @param cache_blob Cache blob of size @p size.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob(
        dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc,
        size_t size, const uint8_t *cache_blob);

/// Executes a primitive.
///
/// @param primitive Primitive to execute.
/// @param stream Stream to use.
/// @param nargs Number of arguments.
/// @param args Array of arguments. Each argument is an
///     <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
///     values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
///     #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
///     descriptor as that returned by
///     #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.

/// @note If any argument in @p args is padded (padded_dims >
/// dims), the primitive execution will assume properly zero-padded
/// input arguments, and produce zero-padded output arguments.
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive,
        dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args);

/// Retrieves a constant reference to the primitive descriptor of a given
/// primitive.
///
/// @warning
///     It is an error to destroy the returned object. It is owned by the
///     primitive. The @c const qualifier of the returned object prevents
///     such attempts.
///
/// @param primitive Primitive to query for the primitive descriptor.
/// @param primitive_desc Output primitive descriptor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
        const_dnnl_primitive_t primitive,
        const_dnnl_primitive_desc_t *primitive_desc);

/// Retrieves a cache blob associated with the given primitive.
///
/// @param primitive Primitive to query for the cache blob.
/// @param size Size of the cache blob in bytes.
/// @param cache_blob Cache blob of size @p size. If the @p cache_blob is
///     nullptr then the size of the cache blob is returned in @p size.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
///
/// @note The cache blob can be empty. It's the user's responsibility to check
///     whether it's empty prior to passing it to
///     #dnnl_primitive_create_from_cache_blob().
dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob(
        const_dnnl_primitive_t primitive, size_t *size, uint8_t *cache_blob);

/// Destroys a primitive.
///
/// @param primitive The primitive to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);

/// @} dnnl_api_primitives_common

/// @addtogroup dnnl_api_attributes
/// @{

/// Creates an empty (default) primitive attributes with all the parameters
/// set to their default values.
///
/// Empty attributes are implied whenever the respective argument is NULL.
///
/// @param attr Output primitive attributes.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr);

/// Clones primitive attributes.
///
/// @param attr Output primitive attributes.
/// @param existing_attr Primitive attributes to clone.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
        dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr);

/// Destroys primitive attributes.
///
/// @param attr Primitive attributes to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);

/// Returns probability for output dropout primitive attribute.
///
/// @param attr Primitive attributes.
/// @param dropout_desc Output dropout memory descriptor
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout(
        const_dnnl_primitive_attr_t attr,
        const_dnnl_memory_desc_t *dropout_desc);

/// Sets probability for output dropout primitive attribute.
///
/// @param attr Primitive attributes.
/// @param dropout_desc Output dropout memory descriptor
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout(
        dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t dropout_desc);

/// Returns the floating-point math mode primitive attribute.
///
/// @param attr Primitive attributes.
/// @param mode Output FP math mode.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode(
        const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode);

/// Sets the floating-point math mode primitive attributes.
///
/// @param attr Primitive attributes.
/// @param mode FP math mode. The possible values are:
///     #dnnl_fpmath_mode_strict (default),
///     #dnnl_fpmath_mode_bf16,
///     #dnnl_fpmath_mode_f16,
///     #dnnl_fpmath_mode_tf32,
///     #dnnl_fpmath_mode_any.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode(
        dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode);

/// Returns the floating-point math mode primitive attribute.
///
/// @param attr Primitive attributes.
/// @param mode Output FP math mode.
/// @param apply_to_int Output use floating-point arithmetic for integer primitives.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode_v2(
        const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode,
        int *apply_to_int);

/// Sets the floating-point math mode primitive attributes.
///
/// @param attr Primitive attributes.
/// @param mode FP math mode. The possible values are:
///     #dnnl_fpmath_mode_strict (default),
///     #dnnl_fpmath_mode_bf16,
///     #dnnl_fpmath_mode_f16,
///     #dnnl_fpmath_mode_tf32,
///     #dnnl_fpmath_mode_any.
/// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode_v2(
        dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode, int apply_to_int);

/// Returns the deterministic primitive attribute value.
///
/// @param attr Primitive attributes.
/// @param value Output deterministic attribute value
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_deterministic(
        const_dnnl_primitive_attr_t attr, int *value);

/// Sets the deterministic primitive attribute value.
///
/// @param attr Primitive attributes.
/// @param value Boolean value to set deterministic attribute.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_deterministic(
        dnnl_primitive_attr_t attr, int value);

/// Returns the accumulation mode primitive attribute.
///
/// @param attr Primitive attributes.
/// @param mode Output accumulation mode.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_accumulation_mode(
        const_dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t *mode);

/// Sets the accumulation mode primitive attribute.
///
/// @param attr Primitive attributes.
/// @param mode Accumulation mode. The possible values are:
///     #dnnl_accumulation_mode_strict (default), which is s32 for quantized primitives, f32/f64 otherwise
///     #dnnl_accumulation_mode_relaxed, which is same as strict but allows intermediate accumulators to be in src/dst datatype
///     #dnnl_accumulation_mode_any, which allows accumulators to be src/dst datatype or any wider type.
///     #dnnl_accumulation_mode_f32,
///     #dnnl_accumulation_mode_s32,
///     #dnnl_accumulation_mode_f16.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_accumulation_mode(
        dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t mode);

/// Returns the primitive attributes scratchpad mode.
///
/// @param attr Primitive attributes.
/// @param mode Output scratchpad mode.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
        const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode);

/// Sets primitive attributes scratchpad mode.
///
/// @param attr Primitive attributes.
/// @param mode Scratchpad mode. The possible values are:
///     #dnnl_scratchpad_mode_library (default) and
///     #dnnl_scratchpad_mode_user.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
        dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode);

/// Sets primitive attributes scaling factors for primitive operations for a
/// given memory argument. The scaling factors must be passed at execution time
/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
///
/// @sa dnnl_primitive_attr_set_scales_mask
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
///     dnnl_primitive_execute() call.
/// @param mask Scaling factors correspondence mask that defines the
///     correspondence between the tensor dimensions and the @p scales array.
///     The set i-th bit indicates that a dedicated scaling factor is used for
///     each index along that dimension. Set the mask to 0 to use a common
///     scaling factor for the whole output tensor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
        dnnl_primitive_attr_t attr, int arg, int mask);

/// Sets primitive attributes scaling factors for primitive operations for a
/// given memory argument. The scaling factors must be passed at execution time
/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
///
/// @sa dnnl_primitive_attr_set_scales
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
///     dnnl_primitive_execute() call.
/// @param mask Scaling factors correspondence mask that defines the
///     correspondence between the tensor dimensions and the @p scales array.
///     The set i-th bit indicates that a dedicated scaling factor is used for
///     each index along that dimension. Set the mask to 0 to use a common
///     scaling factor for the whole output tensor.
/// @param ndims Number of group dimensions.
/// @param group_dims Scaling factors correspondence groups that define the
///     correspondence between the tensor dimensions and the scales array.
///     The group dimensions should only be provided for each logical dimension
///     that has correspondence mask @p mask set.
/// @param data_type Scaling factors data_type.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
        dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
        const dnnl_dims_t group_dims, dnnl_data_type_t data_type);

/// Sets primitive attributes zero points for primitive operations for a given
/// memory argument. The zero points must be passed at execution time
/// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
///
/// @sa dnnl_primitive_attr_set_zero_points_mask
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
///     dnnl_primitive_execute() call.
/// @param mask Zero point correspondence mask that defines the
///     correspondence between the tensor dimensions and the @p
///     zero_points array. The set i-th bit indicates that a dedicated
///     zero point is used for each index along that dimension. Set the
///     mask to 0 to use a common zero point for the whole output tensor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
        dnnl_primitive_attr_t attr, int arg, int mask);

/// Sets primitive attributes zero points for primitive operations for a given
/// memory argument. The zero points must be passed at execution time
/// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
///
/// @sa dnnl_primitive_attr_set_zero_points
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
///     dnnl_primitive_execute() call.
/// @param mask Zero point correspondence mask that defines the
///     correspondence between the tensor dimensions and the @p
///     zero_points array. The set i-th bit indicates that a dedicated
///     zero point is used for each index along that dimension. Set the
///     mask to 0 to use a common zero point for the whole output tensor.
/// @param ndims Number of group dimensions.
/// @param group_dims Zero point factors correspondence groups that define the
///     correspondence between the tensor dimensions and the zero_points array.
///     The group dimensions should be only provided for each logical dimension
///     that has the bit set correspondence mask @p mask set.
/// @param data_type Zero points factors data_type.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
        dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
        const dnnl_dims_t group_dims, dnnl_data_type_t data_type);

/// Sets the rounding mode attribute value for a given argument
///
/// @param attr Primitive attributes.
/// @param arg Argument for which rounding mode should be set.
/// @param mode Rounding mode to apply to the argument.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding(
        dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t mode);

/// Returns the rounding mode attribute value for a given argument
///
/// @param attr Primitive attributes.
/// @param arg Argument for which rounding mode query applies.
/// @param mode Output rounding mode.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rounding(
        dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t *mode);

/// Returns primitive attributes post-ops.
///
/// @warning
///     The output @p post_ops points to the internal @p attr field, so it is
///     an error to modify or destroy them. The lifetime of @p post_ops is
///     the same as that of the @p attr it belongs to, so it is an error to
///     use @p post_ops after @p attr has been destroyed.
///
/// @param attr Primitive attributes.
/// @param post_ops Output post-ops.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
        const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops);

/// Sets primitive attributes post-ops.
///
/// @note
///     There is no way to check whether the post-ops would be supported by
///     the target primitive. Any error will be reported by the
///     dnnl_<primitive name>_[propagation kind]_primitive_desc_create() function call.
///
/// @param attr Primitive attributes.
/// @param post_ops Post-ops to set.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
        dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops);

/// Creates empty post-ops sequence.
///
/// @param post_ops Output post-ops.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops);

/// Clones post-ops primitive attribute.
///
/// @param post_ops Output post-ops primitive attribute.
/// @param existing_post_ops Post-ops primitive attribute to clone.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_clone(
        dnnl_post_ops_t *post_ops, const_dnnl_post_ops_t existing_post_ops);

/// Destroys post-ops.
///
/// @param post_ops Post-ops to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);

/// Returns the length of post-ops.
///
/// @param post_ops Post-ops.
/// @returns The number of post-ops entries.
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);

/// Returns the kind of a post-op entry.
///
/// @param post_ops Post-ops.
/// @param index Post-op entry index.
/// @returns The kind of the post-op with the specified index.
/// @returns #dnnl_undefined_primitive if there is no post-op at the specified
///     index.
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
        const_dnnl_post_ops_t post_ops, int index);

/// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the
/// result, a zero point is subtracted from the previous value and is
/// multiplied by the scale.
///
/// The kind of this post-op is #dnnl_sum.
///
/// This feature may improve performance for cases like dequantize the
/// asymmetrically quantized sum's src1 tensor to f32 domain before performing
/// the sum operation by subtracting the @p zero_point before the scaling.
///
/// In the simplest case where accumulation is the only post-op, the
/// computations will be:
///
///     dst[:] <- scale * (dst[:] - zero_point) + op(...)
///                                             // instead of dst[:] <- op(...)
///
/// If @p data_type is specified, original dst tensor will be reinterpreted
/// as a tensor with provided data type. Since it is reinterpretation,
/// data_type and dst data type should have the same size.
/// As a result, computations will be:
///
///     dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...)
///                                        // instead of dst[:] <- op(...)
/// @note
///     This post-op executes in-place and does not change the
///     destination layout.
///
/// @param post_ops Post-ops.
/// @param scale Accumulation scaling factor.
/// @param zero_point Single scalar int32_t value of zero point.
/// @param data_type Accumulation data_type.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops,
        float scale, int32_t zero_point, dnnl_data_type_t data_type);

/// Returns the parameters of an accumulation (sum) post-op with
/// zero point and data type parameter.
///
/// @param post_ops Post-ops.
/// @param index Index of the sum post-op.
/// @param scale Output accumulation scaling factor.
/// @param zero_point Zero point.
/// @param data_type Data type for accumulation.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
        const_dnnl_post_ops_t post_ops, int index, float *scale,
        int32_t *zero_point, dnnl_data_type_t *data_type);

/// Appends an elementwise post-op.
///
/// The kind of this post operation is #dnnl_eltwise.
///
/// In the simplest case when the elementwise is the only post operation, the
/// computations would be:
///
///     dst[:] <- eltwise_op (op(...)) // instead of dst[:] <- op(...)
///
/// where eltwise_op is configured with the given parameters.
///
/// @param post_ops Post-ops.
/// @param alg_kind Elementwise algorithm for the post-op.
/// @param alpha Alpha parameter for the elementwise algorithm.
/// @param beta Beta parameter for the elementwise algorithm.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops,
        dnnl_alg_kind_t alg_kind, float alpha, float beta);

/// Returns the parameters of an elementwise post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the elementwise post-op.
/// @param alg_kind Output elementwise algorithm kind.
/// @param alpha Output alpha parameter for the elementwise algorithm.
/// @param beta Output beta parameter for the elementwise algorithm.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
/// @returns #dnnl_invalid_arguments if @p index does not refer to an
///     elementwise post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
        const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
        float *alpha, float *beta);

/// Appends a depthwise post-op convolution.
///
/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
/// weights spatial dimensions equal to 1 i.e., kh=kw=1).
///
/// The kind of this post-op is #dnnl_convolution.
///
/// The number of outputs for primitive with fusion is one. The output spatial
/// size can be derived as below:
///
/// output_height = ceil(output_height_1x1_convolution, stride)
/// output_width = ceil(output_width_1x1_convolution, stride)
///
/// See @ref dev_guide_attributes_post_ops_depthwise and
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
///
/// @param post_ops Post-ops.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param kernel_size Size of kernel of depthwise post-op
/// @param stride_size Size of stride of depthwise post-op
/// @param padding_l_size Size of left and top paddings of depthwise post-op
/// @returns #dnnl_success on success and a status describing the error
///     otherwise
dnnl_status_t DNNL_API dnnl_post_ops_append_dw(dnnl_post_ops_t post_ops,
        dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
        dnnl_data_type_t dst_data_type, dnnl_dim_t kernel_size,
        dnnl_dim_t stride_size, dnnl_dim_t padding_l_size);

/// Returns the parameters of an depthwise post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the elementwise post-op.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param kernel_size Size of kernel of depthwise post-op
/// @param stride_size Size of stride of depthwise post-op
/// @param padding_l_size Size of left and top paddings of depthwise post-op
/// @returns #dnnl_success on success and a status describing the error
///     otherwise
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
        const_dnnl_post_ops_t post_ops, int index,
        dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
        dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size,
        dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size);

/// Appends a binary post-op.
///
/// The kind of this post operation is #dnnl_binary.
///
/// In the simplest case when the binary is the only post operation, the
/// computations would be:
///
///     dst[:] <- binary_op (dst[:], another_input[:])
///
/// where binary_op is configured with the given parameters. binary_op supports
/// broadcast semantics for a second operand.
///
/// @param post_ops Post-ops.
/// @param alg_kind Binary algorithm for the post-op.
/// @param src1_desc Memory descriptor of a second operand.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc);

/// Returns the parameters of a binary post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the binary post-op.
/// @param alg_kind Output binary algorithm kind.
/// @param src1_desc Output memory descriptor of a second operand.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
/// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
///     post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
        const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
        const_dnnl_memory_desc_t *src1_desc);

/// Appends a prelu forward post-op.
///
/// The kind of this post-op is #dnnl::primitive::kind::prelu.
///
/// The post-op can be defined as:
///
///      dst[:] <- prelu(dst[:], weights[:])
///      prelu:
///      dst[:] <- dst[:] if dst[:] > 0
///      dst[:] <- dst[:] * weights[:] if dst[:] <= 0
///
///
/// @note
///     The order of dimensions does not depend on how elements are laid
///     out in memory. For example:
///     - for a 2D CNN activations tensor the order is always (n, c)
///     - for a 4D CNN activations tensor the order is always (n, c, h, w)
///     - for a 5D CNN weights tensor the order is always
///        (g, oc, ic, kh, kw)
///
///    Prelu weights tensor is passed in runtime execution phase. Prelu
///    weights tensor data type is implicitly assumed as f32 using plain
///    layout (a, ab, acb, acdb, acdeb)
///
/// @param post_ops Post-ops.
/// @param mask Defines the correspondence between the output tensor
///     dimensions and the prelu weights tensor. The set i-th bit indicates
///     that a dedicated weights value is used for each index along that
///     dimension. Set the mask to 0 to use a common weights value
///     for the whole output tensor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
        dnnl_post_ops_t post_ops, int mask);

/// Returns the parameters of a prelu post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the prelu post-op.
/// @param mask Mask of the prelu post-op.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
        const_dnnl_post_ops_t post_ops, int index, int *mask);

/// @} dnnl_api_attributes

/// @} dnnl_api_primitives

/// @addtogroup dnnl_api_memory
/// @{

/// Destroys a memory descriptor.
///
/// @param memory_desc Memory descriptor to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_destroy(dnnl_memory_desc_t memory_desc);

/// Clones a memory descriptor. The resulting memory descriptor must be
/// destroyed separately.
///
/// @param memory_desc Output memory descriptor.
/// @param existing_memory_desc Memory descriptor to clone.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_clone(dnnl_memory_desc_t *memory_desc,
        const_dnnl_memory_desc_t existing_memory_desc);

/// Retrieves a binary blob associated with the given memory descriptor
///
/// @param blob Output pointer to binary blob.
///     If not nullptr, size bytes of the memory descriptor blob are written.
/// @param size Output pointer to the size of the binary blob in bytes.
///     Size is written if blob is nullptr.
/// @param memory_desc input memory descriptor to serialize
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_get_blob(
        uint8_t *blob, size_t *size, const_dnnl_memory_desc_t memory_desc);

/// Creates a memory descriptor from a memory descriptor binary blob.
///
/// @param memory_desc Output pointer to a newly allocated memory descriptor.
/// @param blob Pointer to a memory descriptor binary blob.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_create_with_blob(
        dnnl_memory_desc_t *memory_desc, const uint8_t *blob);

/// Creates a memory descriptor using dimensions and strides.
///
/// @note
///     As always, the logical order of dimensions corresponds to the `abc...`
///     format tag, and the physical meaning of the dimensions depends on both
///     the primitive that consumes the memory and the context of that
///     consumption.
///
/// @param memory_desc Output memory descriptor.
/// @param ndims Number of dimensions
/// @param dims Array of dimensions.
/// @param data_type Elements data type.
/// @param strides Strides in each dimension.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_create_with_strides(
        dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
        dnnl_data_type_t data_type, const dnnl_dims_t strides);

/// Creates a memory descriptor using dimensions and memory format tag.
///
/// @note
///     As always, the logical order of dimensions corresponds to the `abc...`
///     format tag, and the physical meaning of the dimensions depends on both
///     the primitive that consumes the memory and the context of that
///     consumption.
///
/// @param memory_desc Output memory descriptor.
/// @param ndims Number of dimensions
/// @param dims Array of dimensions.
/// @param data_type Elements data type.
/// @param tag Memory format tag. Can be #dnnl_format_tag_any which would
///     allow a primitive to chose the final memory format. In this case the
///     format_kind field of the memory descriptor would be set to
///     #dnnl_format_kind_any.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_create_with_tag(
        dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
        dnnl_data_type_t data_type, dnnl_format_tag_t tag);

#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Creates a memory descriptor for CSR encoding.
///
/// @param memory_desc Output memory descriptor.
/// @param ndims Number of dimensions
/// @param dims Array of dimensions.
/// @param data_type Elements data type.
/// @param nnz Number of non-zero entries.
/// @param indices_dt Data type of indices.
/// @param pointers_dt Data type of pointers.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_create_with_csr_encoding(
        dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
        dnnl_data_type_t data_type, dnnl_dim_t nnz, dnnl_data_type_t indices_dt,
        dnnl_data_type_t pointers_dt);

/// Creates a memory descriptor for COO encoding.
///
/// The created memory descriptor will describe a memory object that
/// contains n+1 buffers for an n-dimensional tensor.
/// The buffers have the following meaning and assigned numbers (index):
///  - 0: values
///  - 1: indices for dimension 0
///  - 2: indices for dimension 1 ...
///  - n: indices for dimension n-1
///
/// @param memory_desc Output memory descriptor.
/// @param ndims Number of dimensions.
/// @param dims Array of dimensions.
/// @param data_type Elements data type.
/// @param nnz Number of non-zero entries.
/// @param indices_dt Data type of indices.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_create_with_coo_encoding(
        dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
        dnnl_data_type_t data_type, dnnl_dim_t nnz,
        dnnl_data_type_t indices_dt);

/// Creates a memory descriptor for packed sparse encoding.
///
/// The created memory descriptor cannot be used to create a memory
/// object. It can only be used to create a primitive descriptor to
/// query the actual memory descriptor (similar to the format tag
/// `any`).
///
/// @warning
///     The meaning and content of the handles of the memory object that
///     is created using the queried memory descriptor are unspecified
///     therefore using the content is an undefined behavior.
///
/// @param memory_desc Output memory descriptor.
/// @param ndims Number of dimensions
/// @param dims Array of dimensions.
/// @param data_type Elements data type.
/// @param nnz Number of non-zero entries.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding(
        dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
        dnnl_data_type_t data_type, dnnl_dim_t nnz);
#endif

/// Creates a memory descriptor for a region inside an area
/// described by an existing memory descriptor.
///
/// @warning
///     Some combinations of physical memory layout and/or offsets or dims may
///     result in a failure to create a submemory.
//
/// @param memory_desc Output memory descriptor.
/// @param parent_memory_desc An existing memory descriptor.
/// @param dims Sizes of the region.
/// @param offsets Offsets to the region from the encompassing
///     memory object in each dimension
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_create_submemory(
        dnnl_memory_desc_t *memory_desc,
        const_dnnl_memory_desc_t parent_memory_desc, const dnnl_dims_t dims,
        const dnnl_dims_t offsets);

/// Creates a memory descriptor by reshaping an existing one. The new
/// memory descriptor inherits the data type. This operation is valid only for
/// memory descriptors that have format_kind #dnnl_blocked or
/// #dnnl_format_kind_any.
///
/// The resulting memory descriptor must be destroyed separately.
///
/// The operation ensures the transformation of the physical memory format
/// corresponds to the transformation of the logical dimensions. If such
/// transformation is impossible, the function returns #dnnl_invalid_arguments.
///
/// The reshape operation can be described as a combination of the following
/// basic operations:
/// 1. Add a dimension of size `1`. This is always possible.
/// 2. Remove a dimension of size `1`. This is possible only if the dimension
///    has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
/// 3. Split a dimension into multiple ones. This is possible only if the size
///    of the dimension is exactly equal to the product of the split ones and
///    the dimension does not have padding (i.e.
///    `padded_dims[dim] = dims[dim]`).
/// 4. Joining multiple consecutive dimensions into a single one. As in the
///    cases above, this requires that the dimensions do not have padding and
///    that the memory format is such that in physical memory these dimensions
///    are dense and have the same order as their logical counterparts. This
///    also assumes that these dimensions are not blocked.
///    - Here, dense means:
///      `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
///    - And same order means:
///      `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
///
/// @warning
///     Some combinations of physical memory layout and/or offsets or
///     dimensions may result in a failure to make a reshape.
///
/// @param out_memory_desc Output memory descriptor.
/// @param in_memory_desc An existing memory descriptor. Must have format_kind
///     set to #dnnl_blocked or #dnnl_format_kind_any.
/// @param ndims Number of dimensions for the output memory descriptor.
/// @param dims Dimensions for the output memory descriptor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
        dnnl_memory_desc_t *out_memory_desc,
        const_dnnl_memory_desc_t in_memory_desc, int ndims,
        const dnnl_dims_t dims);

/// Creates a memory descriptor by permuting axes in an existing one.
///
/// The physical memory layout representation is adjusted accordingly to
/// maintain the consistency between the logical and physical parts of the
/// memory descriptor.
///
/// The resulting memory descriptor must be destroyed separately.
///
/// The new memory descriptor inherits the data type. This operation is valid
/// only for memory descriptors that have format_kind set to #dnnl_blocked or
/// #dnnl_format_kind_any.
///
/// The logical axes will be permuted in the following manner:
/// ```
/// for (i: 0 .. in_memory_desc->ndims)
///     out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i];
/// ```
///
/// Example:
/// @code
///     dnnl_memory_desc_t in_md, out_md, expect_out_md;
///
///     const int permutation[] = {1, 0}; // swap the first and the second axes
///
///     dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2};
///     dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba;
///
///     dnnl_memory_desc_create_with_tag(
///             &in_md, 2, in_dims, data_type, in_tag);
///     dnnl_memory_desc_create_with_tag(
///             &expect_out_md, 2, out_dims, data_type, out_tag);
///
///     dnnl_memory_desc_permute_axes(&out_md, in_md, permutation);
///     assert(dnnl_memory_desc_equal(out_md, expect_out_md));
///
///     dnnl_memory_desc_destroy(in_md);
///     dnnl_memory_desc_destroy(out_md);
///     dnnl_memory_desc_destroy(expect_out_md);
/// @endcode
///
/// @param out_memory_desc Output memory descriptor.
/// @param in_memory_desc An existing memory descriptor. Must have format_kind
///     set to #dnnl_blocked or #dnnl_format_kind_any.
/// @param permutation Axes permutation (of size `in_memory_desc->ndims`).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
        dnnl_memory_desc_t *out_memory_desc,
        const_dnnl_memory_desc_t in_memory_desc, const int *permutation);

/// Queries a memory descriptor for various pieces of information.
///
/// The following information can be queried:
///  - Number of dimensions (#dnnl_query_ndims_s32)
///  - Dimensions (#dnnl_query_dims) in the following order:
///    - CNN data tensors: mini-batch, channel, spatial
///      (<code>{N, C, [[D,] H,] W}</code>)
///    - CNN weight tensors: group (optional), output channel, input channel,
///      spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
///    - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
///      or layers, directions, states, mini-batch, channels
///      (<code>{L, D, S, N, C}</code>)
///    - RNN weight tensor: layers, directions, input channel, gates, output
///      channels (<code>{L, D, I, G, O}</code>)
///  - Data type of the tensor elements (#dnnl_query_data_type)
///  - Padded dimensions (#dnnl_query_padded_dims) - size of the data including
///    padding in each dimension
///  - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from
///    the padding to actual data, the top-level tensor with offsets applied
///    must lie within the padding area.
///  - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory
///    origin to the current block, non-zero only in a description of a memory
///    sub-block.
///  - Format kind (#dnnl_query_format_kind) - memory format kind
///
/// @note
///    The order of dimensions does not depend on the memory format, so
///    whether the data is laid out in #dnnl_nchw or #dnnl_nhwc
///    the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
///
/// The following queries are applicable only to format kind #dnnl_blocked.
///  - Strides (#dnnl_query_strides) between the outermost blocks or in case
///    of plain (non-blocked) formats the strides between dimensions
///  - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g.
///    `{4, 16, 4}` in case of `OIhw_4i16o4i`
///  - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case
///    of `OIhw_4i16o4i_`
///  - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}`
///    in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim
///
/// @param memory_desc Memory descriptor.
/// @param what Parameter to query.
/// @param result Output result. The type depends on the query. For example,
///     it must be a @c dnnl_dims_t** if querying for a strides.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_query(
        const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, void *result);

#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Queries a memory descriptor for various pieces of information. This version
/// support additional queries #dnnl_query_sparse_encoding, #dnnl_query_nnz_s64
/// #dnnl_query_num_handles_s32 and #dnnl_query_data_type for a particular
/// buffer.
///
/// The following information can be queried:
///  - Number of dimensions (#dnnl_query_ndims_s32)
///  - Dimensions (#dnnl_query_dims) in the following order:
///    - CNN data tensors: mini-batch, channel, spatial
///      (<code>{N, C, [[D,] H,] W}</code>)
///    - CNN weight tensors: group (optional), output channel, input channel,
///      spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
///    - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
///      or layers, directions, states, mini-batch, channels
///      (<code>{L, D, S, N, C}</code>)
///    - RNN weight tensor: layers, directions, input channel, gates, output
///      channels (<code>{L, D, I, G, O}</code>)
///  - Data type of the tensor elements (#dnnl_query_data_type)
///  - Padded dimensions (#dnnl_query_padded_dims) - size of the data including
///    padding in each dimension
///  - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from
///    the padding to actual data, the top-level tensor with offsets applied
///    must lie within the padding area.
///  - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory
///    origin to the current block, non-zero only in a description of a memory
///    sub-block.
///  - Format kind (#dnnl_query_format_kind) - memory format kind
///
/// @note
///    The order of dimensions does not depend on the memory format, so
///    whether the data is laid out in #dnnl_nchw or #dnnl_nhwc
///    the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
///
/// The following queries are applicable only to format kind #dnnl_blocked.
///  - Strides (#dnnl_query_strides) between the outermost blocks or in case
///    of plain (non-blocked) formats the strides between dimensions
///  - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g.
///    `{4, 16, 4}` in case of `OIhw_4i16o4i`
///  - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case
///    of `OIhw_4i16o4i_`
///  - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}`
///    in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim
///
/// @param memory_desc Memory descriptor.
/// @param what Parameter to query.
/// @param index Index of the parameter to query for. It is mostly used with
///     #dnnl_query_data_type to specify which data type is being queried.
///     The main data type (data type of values) has always index 0. For other
///     indices please refer to the API for creating a memory descriptor for
///     sparse encoding.
/// @param result Output result. The type depends on the query. For example,
///     it must be a @c dnnl_dims_t** if querying for a strides.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_query_v2(
        const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, int index,
        void *result);
#endif

/// Compares two memory descriptors.
///
/// Use this function to identify whether a reorder is required between the
/// two memories
///
/// @param lhs Left-hand side of the comparison.
/// @param rhs Right-hand side of the comparison.
/// @returns 1 if the descriptors are the same.
/// @returns 0 if the descriptors are different.
int DNNL_API dnnl_memory_desc_equal(
        const_dnnl_memory_desc_t lhs, const_dnnl_memory_desc_t rhs);

/// Returns the size of a memory descriptor.
///
/// @param memory_desc Memory descriptor.
/// @returns The number of bytes required for memory described by a memory
///     descriptor.
size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc);

#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Returns the size of the data that corresponds to the given index.
///
/// @param memory_desc Memory descriptor.
/// @param index Index of the buffer.
///
/// @returns The number of bytes required for the requested data.
size_t DNNL_API dnnl_memory_desc_get_size_v2(
        const_dnnl_memory_desc_t memory_desc, int index);
#endif

/// Returns the size of data type.
///
/// @param data_type Data type.
/// @returns The number of bytes occupied by data type.
size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);

/// Creates a memory object.
///
/// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
/// object will have the underlying buffer set. In this case, the buffer will
/// be initialized as if dnnl_memory_set_data_handle() had been called.
///
/// @sa dnnl_memory_set_data_handle()
///
/// @param memory Output memory object.
/// @param memory_desc Memory descriptor.
/// @param engine Engine to use.
/// @param handle Handle of the memory buffer to use as an underlying storage.
///     - A pointer to the user-allocated buffer. In this case the library
///       doesn't own the buffer.
///     - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
///       allocate the buffer for the memory object. In this case the library
///       owns the buffer.
///     - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
        const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
        void *handle);

#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Creates a memory object with multiple handles.
///
/// @param memory Output memory object.
/// @param memory_desc Memory descriptor.
/// @param engine Engine to use.
/// @param nhandles Number of handles.
/// @param handles Handles of the memory buffers to use as underlying storages.
///     For each element of the @p handles array the following applies:
///     - A pointer to the user-allocated buffer. In this case the library
///       doesn't own the buffer.
///     - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
///       allocate the buffer for the memory object. In this case the library
///       owns the buffer.
///     - DNNL_MEMORY_NONE Instructs the library to skip allocation of the
///       memory buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_create_v2(dnnl_memory_t *memory,
        const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
        int nhandles, void **handles);
#endif

/// Returns the memory descriptor for a memory object.
///
/// @param memory Memory object.
/// @param memory_desc Output memory descriptor (a copy).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
        const_dnnl_memory_t memory, const_dnnl_memory_desc_t *memory_desc);

/// Returns the engine of a memory object.
///
/// @param memory Memory object.
/// @param engine Output engine on which the memory is located.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_get_engine(
        const_dnnl_memory_t memory, dnnl_engine_t *engine);

/// Maps a memory object and returns a host-side pointer to a memory buffer
/// with a copy of its contents.
///
/// Mapping enables explicit direct access to memory contents for the engines
/// that do not support it implicitly.
///
/// Mapping is an exclusive operation - a memory object cannot be used in
/// other operations until this memory object is unmapped.
///
/// @note
///     Any primitives working with @p memory should be completed before
///     the memory is mapped. Use dnnl_stream_wait to synchronize the
///     corresponding execution stream.
///
/// @note
///     The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
///     mainly provided for debug and testing purposes, and their performance
///     may be suboptimal.
///
/// @param memory Memory object.
/// @param mapped_ptr Output pointer to the mapped buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_map_data(
        const_dnnl_memory_t memory, void **mapped_ptr);

#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Maps a memory object and returns a host-side pointer to a memory buffer
/// with a copy of its contents. The memory buffer corresponds to the given
/// index.
///
/// Mapping enables explicit direct access to memory contents for the engines
/// that do not support it implicitly.
///
/// Mapping is an exclusive operation - a memory object cannot be used in
/// other operations until this memory object is unmapped.
///
/// @note
///     Any primitives working with @p memory should be completed before
///     the memory is mapped. Use dnnl_stream_wait to synchronize the
///     corresponding execution stream.
///
/// @note
///     The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
///     mainly provided for debug and testing purposes, and their performance
///     may be suboptimal.
///
/// @param memory Memory object.
/// @param mapped_ptr Output pointer to the mapped buffer.
/// @param index Index of the buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_map_data_v2(
        const_dnnl_memory_t memory, void **mapped_ptr, int index);
#endif

/// Unmaps a memory object and writes back any changes made to the previously
/// mapped memory buffer. The pointer to the mapped buffer must be obtained
/// via the dnnl_memory_map_data() call.
///
/// @note
///     The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
///     mainly provided for debug and testing purposes, and their performance
///     may be suboptimal.
///
/// @param memory Memory object.
/// @param mapped_ptr Pointer to the mapped buffer that must have been
///     obtained using the dnnl_memory_map_data() function.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_unmap_data(
        const_dnnl_memory_t memory, void *mapped_ptr);

#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Unmaps a memory object and writes back any changes made to the previously
/// mapped memory buffer. The pointer to the mapped buffer must be obtained
/// via the dnnl_memory_map_data() call. The buffer corresponds to the given
/// index.
///
/// @note
///     The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
///     mainly provided for debug and testing purposes, and their performance
///     may be suboptimal.
///
/// @param memory Memory object.
/// @param mapped_ptr Pointer to the mapped buffer that must have been
///     obtained using the dnnl_memory_map_data() function.
/// @param index Index of the buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_unmap_data_v2(
        const_dnnl_memory_t memory, void *mapped_ptr, int index);
#endif

/// Returns memory object's data handle.
///
/// @param memory Memory object.
/// @param handle Output data handle. For the CPU engine, the data handle is a
///     pointer to the actual data. For OpenCL it is a cl_mem.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_get_data_handle(
        const_dnnl_memory_t memory, void **handle);

/// Sets the underlying memory buffer.
///
/// @param memory Memory object.
/// @param handle Data handle. For the CPU engine or when USM is used, the
///     memory buffer is a pointer to the actual data. For OpenCL it is a
///     `cl_mem`.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_set_data_handle(
        dnnl_memory_t memory, void *handle);

#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Returns an underlying memory buffer that corresponds to the given index.
///
/// @param memory Memory object.
/// @param handle Data handle. For the CPU engine or when USM is used, the
///     memory buffer is a pointer to the actual data. For OpenCL it is a
///     `cl_mem`.
/// @param index Index of the buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_get_data_handle_v2(
        const_dnnl_memory_t memory, void **handle, int index);

/// Sets an underlying memory buffer that corresponds to the given index.
///
/// @param memory Memory object.
/// @param handle Data handle. For the CPU engine or when USM is used, the
///     memory buffer is a pointer to the actual data. For OpenCL it is a
///     `cl_mem`.
/// @param index Index of the buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(
        dnnl_memory_t memory, void *handle, int index);
#endif

/// Destroys a memory object.
///
/// @param memory Memory object to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);

/// @} dnnl_api_memory

/// @addtogroup dnnl_api_primitives
/// @{

/// @addtogroup dnnl_api_reorder
/// @{

/// Creates a primitive descriptor for a reorder primitive.
///
/// @param reorder_primitive_desc Output primitive descriptor.
/// @param src_desc Source memory descriptor.
/// @param src_engine Engine on which the source memory object will be
///     located.
/// @param dst_desc Destination memory descriptor.
/// @param dst_engine Engine on which the destination memory object
///     will be located.
/// @param attr Primitive attributes to use (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
        dnnl_primitive_desc_t *reorder_primitive_desc,
        const_dnnl_memory_desc_t src_desc, dnnl_engine_t src_engine,
        const_dnnl_memory_desc_t dst_desc, dnnl_engine_t dst_engine,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_reorder

/// @addtogroup dnnl_api_concat
/// @{

/// Creates a primitive descriptor for an out-of-place concatenation
/// primitive.
///
/// @param concat_primitive_desc Output primitive descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param n Number of source parameters.
/// @param concat_dimension Source tensors will be concatenated over
///     dimension with this index. Note that order of dimensions does
///     not depend on memory format.
/// @param src_descs Array of source memory descriptors with @p n elements.
/// @param attr Primitive attributes to use (can be NULL).
/// @param engine Engine to use.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
        dnnl_primitive_desc_t *concat_primitive_desc, dnnl_engine_t engine,
        const_dnnl_memory_desc_t dst_desc, int n, int concat_dimension,
        const_dnnl_memory_desc_t const *src_descs,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_concat

/// @addtogroup dnnl_api_sum
/// @{

/// Creates a primitive descriptor for an (out-of-place) sum primitive.
///
/// @param sum_primitive_desc Output primitive descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param n Number of source parameters.
/// @param scales Vector of scales to multiply data in each source
///     memory by.
/// @param src_descs Array of source memory descriptors having @p n elements.
/// @param attr Primitive attributes to use (can be NULL).
/// @param engine Engine to use.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
        dnnl_primitive_desc_t *sum_primitive_desc, dnnl_engine_t engine,
        const_dnnl_memory_desc_t dst_desc, int n, const float *scales,
        const_dnnl_memory_desc_t const *src_descs,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_sum

/// @addtogroup dnnl_api_binary
/// @{

/// Creates a primitive descriptor for a binary primitive.
///
/// @note
///     Memory descriptors @p src1_desc and @p dst_desc are allowed to be
///     initialized with #dnnl_format_tag_any or with format_kind set to
///     #dnnl_format_kind_any.
///
/// @note
///     Both memory descriptors must have the same number of dimensions.
///     Element broadcasting is supported for memory descriptor @p src1_desc
///     and are applied to @p src1_desc dimensions that have size equal to 1.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add,
///     #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div,
///     #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le,
///     #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne.
/// @param src0_desc Source 0 memory descriptor.
/// @param src1_desc Source 1 memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
        const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t dst_desc,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a binary primitive with support of
/// ternary operators.
///
/// @note
///     Memory descriptors @p src1_desc, @p src2_desc and @p dst_desc are
///     allowed to be initialized with #dnnl_format_tag_any or with format_kind
///     set to #dnnl_format_kind_any.
///
/// @note
///     All memory descriptors must have the same number of dimensions.
///     Element broadcasting is supported for memory descriptor @p src1_desc
///     and is applied to @p src1_desc dimensions that have a size equal to 1.
///     There is no broadcasting support for @p src2_desc.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Algorithm kind.
/// @param src0_desc Source 0 memory descriptor.
/// @param src1_desc Source 1 memory descriptor.
/// @param src2_desc Source memory descriptor for ternary operations. Might
///     be empty.
/// @param dst_desc Destination memory descriptor.
/// @param attr Primitive attributes.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create_v2(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
        const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t src2_desc,
        const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_binary

/// @addtogroup dnnl_api_convolution
/// @{

/// Creates a primitive descriptor for a convolution forward propagation
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
/// values for spatial dimensions only and hence must have the same number of
/// elements as there are spatial dimensions. The order of values is the same
/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
/// and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param alg_kind Convolution algorithm. Possible values are
///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
///     #dnnl_convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
///     descriptor, or a memory descriptor with format_kind set to
///     #dnnl_format_kind_undef disables the bias term.
/// @param dst_desc Destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param dilates Array of dilations for spatial dimension. A zero value
///     means no dilation in the corresponding dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
        const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
        const dnnl_dims_t strides, const dnnl_dims_t dilates,
        const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a convolution backward propagation
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
/// values for spatial dimensions only and hence must have the same number of
/// elements as there are spatial dimensions. The order of values is the same
/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
/// and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Convolution algorithm. Possible values are
///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
///     #dnnl_convolution_auto.
/// @param diff_src_desc Diff source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param dilates Array of dilations for spatial dimension. A zero value
///     means no dilation in the corresponding dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
        const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
        const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a convolution weights gradient primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
/// values for spatial dimensions only and hence must have the same number of
/// elements as there are spatial dimensions. The order of values is the same
/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
/// and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Convolution algorithm. Possible values are
///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
///     #dnnl_convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
///     memory descriptor, or a memory descriptor with format_kind set to
///     #dnnl_format_kind_undef disables the bias term.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param dilates Array of dilations for spatial dimension. A zero value
///     means no dilation in the corresponding dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t diff_weights_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
        const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
        const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_convolution

/// @addtogroup dnnl_api_deconvolution
/// @{

/// Creates a primitive descriptor for a deconvolution forward propagation
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
/// values for spatial dimensions only and hence must have the same number of
/// elements as there are spatial dimensions. The order of values is the same
/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
/// and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param alg_kind Deconvolution algorithm. Possible values are
///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
///     descriptor, or a memory descriptor with format_kind set to
///     #dnnl_format_kind_undef disables the bias term.
/// @param dst_desc Destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param dilates Array of dilations for spatial dimension. A zero value
///     means no dilation in the corresponding dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
        const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
        const dnnl_dims_t strides, const dnnl_dims_t dilates,
        const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a deconvolution backward propagation
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
/// values for spatial dimensions only and hence must have the same number of
/// elements as there are spatial dimensions. The order of values is the same
/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
/// and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Deconvolution algorithm. Possible values are
///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
/// @param diff_src_desc Diff source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param dilates Array of dilations for spatial dimension. A zero value
///     means no dilation in the corresponding dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
        const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
        const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a deconvolution weights gradient
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
/// values for spatial dimensions only and hence must have the same number of
/// elements as there are spatial dimensions. The order of values is the same
/// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
/// and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Deconvolution algorithm. Possible values are
///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
///     memory descriptor, or a memory descriptor with format_kind set to
///     #dnnl_format_kind_undef disables the bias term.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param dilates Array of dilations for spatial dimension. A zero value
///     means no dilation in the corresponding dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API
dnnl_deconvolution_backward_weights_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t diff_weights_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
        const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
        const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_deconvolution

/// @addtogroup dnnl_api_shuffle
/// @{

/// Creates a primitive descriptor for a shuffle forward propagation primitive
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param axis The axis along which the data is shuffled.
/// @param group_size Shuffle group size.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t dst_desc, int axis, dnnl_dim_t group_size,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a shuffle backward propagation primitive
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param axis The axis along which the data is shuffled.
/// @param group_size Shuffle group size.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc, int axis, dnnl_dim_t group_size,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_shuffle

/// @addtogroup dnnl_api_eltwise
/// @{

/// Creates a primitive descriptor for an eltwise forward propagation primitive.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param alg_kind Elementwise algorithm kind.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param alpha The alpha parameter for the elementwise operation. Specific
///     meaning depends on the algorithm.
/// @param beta The beta parameter for the elementwise operation. Specific
///     meaning depends on the algorithm.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
        const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
        float alpha, float beta, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for an eltwise backward propagation
///     primitive.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Elementwise algorithm kind.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param data_desc Destination memory descriptor if one of the
///     "use_dst_for_bwd" algorithms are used (such as
///     #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor otherwise.
/// @param alpha The alpha parameter for the elementwise operation. Specific
///     meaning depends on the algorithm.
/// @param beta The beta parameter for the elementwise operation. Specific
///     meaning depends on the algorithm.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_memory_desc_t data_desc, float alpha, float beta,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_eltwise

/// @addtogroup dnnl_api_softmax
/// @{

/// Creates a primitive descriptor for a softmax forward propagation primitive.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
///     #dnnl_softmax_log.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param softmax_axis Axis over which softmax is computed.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
        const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
        int softmax_axis, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a softmax backward propagation primitive.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
///     #dnnl_softmax_log.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param softmax_axis Axis over which softmax is computed.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_memory_desc_t dst_desc, int softmax_axis,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_softmax

/// @addtogroup dnnl_api_pooling
/// @{

/// Creates a primitive descriptor for a pooling forward propagation
///     primitive.
///
/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
/// contain values for spatial dimensions only and hence must have the same
/// number of elements as there are spatial dimensions. The order of values
/// is the same as in the tensor: depth (for 3D tensors),
/// height (for 3D and 2D tensors), and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
///     #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param kernel Array of kernel spatial dimensions.
/// @param dilation Array of dilations for spatial dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
        const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
        const dnnl_dims_t strides, const dnnl_dims_t kernel,
        const dnnl_dims_t dilation, const dnnl_dims_t padding_l,
        const dnnl_dims_t padding_r, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a pooling backward propagation
///     primitive.
///
/// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
/// contain values for spatial dimensions only and hence must have the same
/// number of elements as there are spatial dimensions. The order of values
/// is the same as in the tensor: depth (for 3D tensors),
/// height (for 3D and 2D tensors), and width.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
///     #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Array of strides for spatial dimension.
/// @param kernel Array of kernel spatial dimensions.
/// @param dilation Array of dilations for spatial dimension.
/// @param padding_l Array of padding values for low indices for each spatial
///     dimension `([[front,] top,] left)`.
/// @param padding_r Array of padding values for high indices for each spatial
///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
///     padding is considered to be symmetrical.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
        const dnnl_dims_t kernel, const dnnl_dims_t dilation,
        const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_pooling

/// @addtogroup dnnl_api_prelu
/// @{

/// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
///     alpha parameter) forward propagation primitive.
///
/// @note
///     weights descriptor is allowed to be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Alpha parameters memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
///     alpha parameter) backward propagation primitive.
///
/// @note
///     weights descriptor and diff_weights descriptor are allowed
///     to be initialized with #dnnl_format_tag_any or with format_kind
///     set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Alpha parameters memory descriptor.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_weights_desc Diff alpha parameters memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_weights_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_prelu

/// @addtogroup dnnl_api_lrn
/// @{

/// Creates a primitive descriptor for an LRN forward propagation primitive.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
///     #dnnl_lrn_within_channel.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param local_size Regularization local size.
/// @param alpha The alpha regularization parameter.
/// @param beta The beta regularization parameter.
/// @param k The k regularization parameter.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
        const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
        dnnl_dim_t local_size, float alpha, float beta, float k,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for an LRN backward propagation primitive.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
///     #dnnl_lrn_within_channel.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param src_desc Source memory descriptor.
/// @param local_size Regularization local size.
/// @param alpha The alpha regularization parameter.
/// @param beta The beta regularization parameter.
/// @param k The k regularization parameter.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_memory_desc_t src_desc, dnnl_dim_t local_size, float alpha,
        float beta, float k, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_lrn

/// @addtogroup dnnl_api_batch_normalization
/// @{

/// Creates a primitive descriptor for a batch normalization forward propagation
///     primitive.
///
/// @note
///     In-place operation is supported: the dst can refer to the same memory
///     as the src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param epsilon Batch normalization epsilon parameter.
/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t dst_desc, float epsilon, unsigned flags,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a batch normalization backward
///     propagation primitive.
///
/// @note
///     In-place operation is supported: the diff_dst can refer to the same
///     memory as the diff_src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
///     computed in this case).
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param src_desc Source memory descriptor.
/// @param epsilon Batch normalization epsilon parameter.
/// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_memory_desc_t src_desc, float epsilon, unsigned flags,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_batch_normalization

/// @addtogroup dnnl_api_group_normalization
/// @{

/// Creates a primitive descriptor for a group normalization forward propagation
///     primitive.
///
/// @note
///     In-place operation is supported: the dst can refer to the same memory
///     as the src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param groups Group normalization groups parameter.
/// @param epsilon Group normalization epsilon parameter.
/// @param flags Group normalization flags (@ref dnnl_normalization_flags_t).
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_group_normalization_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t dst_desc, dnnl_dim_t groups, float epsilon,
        unsigned flags, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a group normalization backward
///     propagation primitive.
///
/// @note
///     In-place operation is supported: the diff_dst can refer to the same
///     memory as the diff_src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
///     computed in this case).
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param src_desc Source memory descriptor.
/// @param groups Group normalization groups parameter.
/// @param epsilon Group normalization epsilon parameter.
/// @param flags Group normalization flags (@ref dnnl_normalization_flags_t).
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_group_normalization_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_memory_desc_t src_desc, dnnl_dim_t groups, float epsilon,
        unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_group_normalization

/// @addtogroup dnnl_api_layer_normalization
/// @{

/// Creates a primitive descriptor for a layer normalization forward propagation
///     primitive.
///
/// @note
///     In-place operation is supported: the dst can refer to the same memory
///     as the src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param stat_desc Memory descriptor for mean and variance. If this
///     parameter is NULL, a zero memory descriptor, or a memory descriptor
///     with format_kind set to #dnnl_format_kind_undef, then the memory
///     descriptor for stats is derived from @p src_desc by removing the last
///     dimension.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
        float epsilon, unsigned flags, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a layer normalization backward
///     propagation primitive.
///
/// @note
///     In-place operation is supported: the diff_dst can refer to the same
///     memory as the diff_src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
///     computed in this case).
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param src_desc Source memory descriptor.
/// @param stat_desc Memory descriptor for mean and variance. If this
///     parameter is NULL, a zero memory descriptor, or a memory descriptor
///     with format_kind set to #dnnl_format_kind_undef, then the memory
///     descriptor for stats is derived from @p src_desc by removing the last
///     dimension.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
        float epsilon, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a layer normalization forward propagation
///     primitive with a user-provided data type for the scale and shift
///     memory objects.
///
/// @note
///     In-place operation is supported: the dst can refer to the same memory
///     as the src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param stat_desc Memory descriptor for mean and variance. If this
///     parameter is NULL, a zero memory descriptor, or a memory descriptor
///     with format_kind set to #dnnl_format_kind_undef, then the memory
///     descriptor for stats is derived from @p src_desc by removing the last
///     dimension.
/// @param scale_shift_data_type Data type of scale and shift memory. If neither scale
///     nor shift flag are specified the parameter is ignored.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API
dnnl_layer_normalization_forward_primitive_desc_create_v2(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
        dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a layer normalization backward
///     propagation primitive with a user-provided data type for the
///     scale and shift memory objects.
///
/// @note
///     In-place operation is supported: the diff_dst can refer to the same
///     memory as the diff_src.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
///     computed in this case).
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param src_desc Source memory descriptor.
/// @param stat_desc Memory descriptor for mean and variance. If this
///     parameter is NULL, a zero memory descriptor, or a memory descriptor
///     with format_kind set to #dnnl_format_kind_undef, then the memory
///     descriptor for stats is derived from @p src_desc by removing the last
///     dimension.
/// @param diff_scale_shift_data_type Data type of diff scale and shift memory. If neither scale
///     nor shift flag are specified the parameter is ignored.
/// @param scale_shift_data_type Data type of scale and shift memory. If neither scale
///     nor shift flag are specified the parameter is ignored.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API
dnnl_layer_normalization_backward_primitive_desc_create_v2(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
        dnnl_data_type_t diff_scale_shift_data_type,
        dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_layer_normalization

/// @addtogroup dnnl_api_inner_product
/// @{

/// Creates a primitive descriptor for an inner product forward propagation
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
///     descriptor, or a memory descriptor with format_kind set to
///     #dnnl_format_kind_undef disables the bias term.
/// @param dst_desc Destination memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for an inner product backward propagation
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param diff_src_desc Diff source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for an inner product  weights gradient
///     primitive.
///
/// @note
///     Memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive_descriptor.
/// @param engine Engine to use.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
///     memory descriptor, or a memory descriptor with format_kind set to
///     #dnnl_format_kind_undef disables the bias term.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API
dnnl_inner_product_backward_weights_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t diff_weights_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_inner_product

/// @addtogroup dnnl_api_attributes
/// @{

/// Set quantization scale and shift parameters for RNN data tensors.
///
/// For performance reasons, the low-precision configuration of the RNN
/// primitives expects input activations to have the unsigned 8-bit integer
/// data type. The scale and shift parameters are used to quantize
/// floating-point data to unsigned integer and must be passed to the RNN
/// primitive using attributes.
///
/// The quantization formula is `scale * data + shift`.
///
/// @note
///     Quantization scale and shift are common for src_layer, src_iter,
///     dst_iter, and dst_layer.
///
/// Example usage:
/// @code
///     // RNN parameters
///     int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
///     // Activations quantization parameters
///     float scale = 63.f, shift = 64.f;
///
///     dnnl_primitive_attr_t rnn_attr;
///     // Create default attributes
///     dnnl_primitive_attr_create(&rnn_attr);
///
///     // Set scale and shift for int8 quantization of activation
///     dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
///
///     // Create an RNN primitive descriptor.
///     dnnl_primitive_desc_t rnn_pd;
///     dnnl_vanilla_rnn_forward_primitive_desc_create(&rnn_pd,
///             engine, /* arguments */, attr);
/// @endcode
///
/// @param attr Primitive attributes.
/// @param scale The value to scale the data by.
/// @param shift The value to shift the data by.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(
        dnnl_primitive_attr_t attr, const float scale, const float shift);

/// Returns the quantization scale and shift parameters for RNN data tensors.
///
/// @note
///     Quantization scale and shift are common for src_layer, src_iter,
///     dst_iter, and dst_layer.
///
/// @param attr Primitive attributes.
/// @param scale The value to scale the data by.
/// @param shift The value to shift the data by.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(
        const_dnnl_primitive_attr_t attr, float *scale, float *shift);

/// Sets quantization scaling factors for RNN weights tensors. The
/// low-precision configuration of the RNN primitives expects input weights to
/// use the signed 8-bit integer data type. The scaling factors are used to
/// quantize floating-point data to signed integer and must be passed to RNN
/// primitives using attributes.
///
/// @note
///     The dimension order is always native and does not depend on the actual
///     layout used. For example, five-dimensional weights always have (l, d,
///     i, g, o) logical dimension ordering.
///
/// @note
///     Quantization scales are common for weights_layer and weights_iteration
///
/// @param attr Primitive attributes.
/// @param count Number of elements in the @p scales array.
/// @param mask Scaling factors correspondence mask that defines the
///     correspondence between the output tensor dimensions and the @p
///     scales vector. The set i-th bit indicates that a dedicated scaling
///     factor should be used for each index along that dimension. Set the
///     mask to 0 to use a common scaling factor for the whole output
///     tensor.
/// @param scales Array of output scaling factors that must contain @p count
///     values and the following equality must hold:
///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
///     Violations can only be detected when the attributes are used to create
///     a primitive descriptor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(
        dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
        const float *scales);

/// Returns the quantization scaling factors for RNN weights tensors.
///
/// @param attr Primitive attributes.
/// @param count Number of elements in the @p scales array.
/// @param mask Scaling factors correspondence mask that defines the
///     correspondence between the output tensor dimensions and the @p
///     scales vector. The set i-th bit indicates that a dedicated scaling
///     factor should be used for each index along that dimension. Set the
///     mask to 0 to use a common scaling factor for the whole output
///     tensor.
/// @param scales Array of output scaling factors that contain @p count
///     values and the following equality must hold:
///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(
        const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
        const float **scales);

/// Sets quantization scaling factors for RNN projection weights tensors. The
/// low-precision configuration of the RNN primitives expects input weights to
/// use the signed 8-bit integer data type. The scaling factors are used to
/// quantize floating-point data to signed integer and must be passed to RNN
/// primitives using attributes.
///
/// @note
///     The dimension order is always native and does not depend on the actual
///     layout used. For example, five-dimensional weights always have (l, d,
///     i, g, o) logical dimension ordering.
///
/// @param attr Primitive attributes.
/// @param count Number of elements in the @p scales array.
/// @param mask Scaling factors correspondence mask that defines the
///     correspondence between the output tensor dimensions and the @p
///     scales vector. The set i-th bit indicates that a dedicated scaling
///     factor should be used for each index along that dimension. Set the
///     mask to 0 to use a common scaling factor for the whole output
///     tensor.
/// @param scales Array of output scaling factors that must contain @p count
///     values and the following equality must hold:
///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
///     Violations can only be detected when the attributes are used to create
///     a primitive descriptor.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(
        dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
        const float *scales);

/// Returns the quantization scaling factors for RNN projection weights tensors.
///
/// @param attr Primitive attributes.
/// @param count Number of elements in the @p scales array.
/// @param mask Scaling factors correspondence mask that defines the
///     correspondence between the output tensor dimensions and the @p
///     scales vector. The set i-th bit indicates that a dedicated scaling
///     factor should be used for each index along that dimension. Set the
///     mask to 0 to use a common scaling factor for the whole output
///     tensor.
/// @param scales Array of output scaling factors that contain @p count
///     values and the following equality must hold:
///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
        const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
        const float **scales);

/// @} dnnl_api_attributes

/// @addtogroup dnnl_api_rnn
/// @{

/// Creates a primitive descriptor for vanilla RNN forward propagation
///     primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc,
/// - @p bias_desc,
/// - @p dst_iter_desc.
///
/// This would then indicate that the RNN forward propagation primitive should
/// not use them and should default to zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
///     #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param flags Unused.
/// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
/// @param beta Unused.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
        const dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha,
        float beta, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for vanilla RNN backward propagation
///     primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc together with @p diff_src_iter_desc,
/// - @p bias_desc together with @p diff_bias_desc,
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
///
/// This would then indicate that the RNN backward propagation primitive should
/// not use the respective data and should use zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Must be #dnnl_backward.
/// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
///     #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
///     hidden state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
///     applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
///     applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of output
///     vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
///     recurrent hidden state vector.
/// @param flags Unused.
/// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
/// @param beta Unused.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
        const dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc,
        const_dnnl_memory_desc_t diff_src_layer_desc,
        const_dnnl_memory_desc_t diff_src_iter_desc,
        const_dnnl_memory_desc_t diff_weights_layer_desc,
        const_dnnl_memory_desc_t diff_weights_iter_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_layer_desc,
        const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
        float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for an LSTM forward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc together with @p src_iter_c_desc,
/// - @p weights_peephole_desc,
/// - @p bias_desc,
/// - @p dst_iter_desc together with @p dst_iter_c_desc.
///
/// This would then indicate that the LSTM forward propagation primitive should
/// not use them and should default to zero values instead.
///
/// The @p weights_projection_desc could either be @c NULL or point to a zero
/// memory descriptor. This would then indicate that the LSTM doesn't have
/// recurrent projection layer.
///
/// @note
///     All memory descriptors can be initialized with #dnnl_format_tag_any or
///     with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param weights_peephole_desc Memory descriptor for the weights applied to
///     the cell states (according to the Peephole LSTM formula).
/// @param weights_projection_desc Memory descriptor for the weights applied to
///     the hidden states to get the recurrent projection (according to the
///     Projection LSTM formula).
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
///     state vector.
/// @param flags Unused.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t src_iter_c_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t weights_peephole_desc,
        const_dnnl_memory_desc_t weights_projection_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc,
        const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for an LSTM backward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
///   and @p diff_src_iter_c_desc,
/// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
/// - @p bias_desc together with @p diff_bias_desc,
/// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
///   and @p diff_dst_iter_c_desc.
///
/// This would then indicate that the LSTM backward propagation primitive
/// should not use them and should default to zero values instead.
///
/// The @p weights_projection_desc together with @p
/// diff_weights_projection_desc could either be @c NULL or point to a zero
/// memory descriptor. This would then indicate that the LSTM doesn't have
/// recurrent projection layer.
///
/// @note
///     All memory descriptors can be initialized with #dnnl_format_tag_any or
///     with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Must be #dnnl_backward.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent cell
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param weights_peephole_desc Memory descriptor for the weights applied to
///     the cell states (according to the Peephole LSTM formula).
/// @param weights_projection_desc Memory descriptor for the weights applied to
///     the hidden states to get the recurrent projection (according to the
///     Projection LSTM formula).
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
///     state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
///     hidden state vector.
/// @param diff_src_iter_c_desc Memory descriptor for the diff of input
/// recurrent cell state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
///     applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
///     applied to the recurrent input.
/// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
///     applied to the cell states (according to the Peephole LSTM formula).
/// @param diff_weights_projection_desc Memory descriptor for the diff of
///     weights applied to the hidden states to get the recurrent projection
///     (according to the Projection LSTM formula).
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of output
///     vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
///     recurrent hidden state vector.
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
///     recurrent cell state vector.
/// @param flags Unused.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t src_iter_c_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t weights_peephole_desc,
        const_dnnl_memory_desc_t weights_projection_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc,
        const_dnnl_memory_desc_t dst_iter_c_desc,
        const_dnnl_memory_desc_t diff_src_layer_desc,
        const_dnnl_memory_desc_t diff_src_iter_desc,
        const_dnnl_memory_desc_t diff_src_iter_c_desc,
        const_dnnl_memory_desc_t diff_weights_layer_desc,
        const_dnnl_memory_desc_t diff_weights_iter_desc,
        const_dnnl_memory_desc_t diff_weights_peephole_desc,
        const_dnnl_memory_desc_t diff_weights_projection_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_layer_desc,
        const_dnnl_memory_desc_t diff_dst_iter_desc,
        const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for GRU forward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc,
/// - @p bias_desc,
/// - @p dst_iter_desc.
///
/// This would then indicate that the GRU forward propagation primitive should
/// not use them and should default to zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param flags Unused.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for GRU backward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc together with @p diff_src_iter_desc,
/// - @p bias_desc together with @p diff_bias_desc,
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
///
/// This would then indicate that the GRU backward propagation primitive
/// should not use them and should default to zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Must be #dnnl_backward.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
///     hidden state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
///     applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
///     applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of output
///     vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
///     recurrent hidden state vector.
/// @param flags Unused.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc,
        const_dnnl_memory_desc_t diff_src_layer_desc,
        const_dnnl_memory_desc_t diff_src_iter_desc,
        const_dnnl_memory_desc_t diff_weights_layer_desc,
        const_dnnl_memory_desc_t diff_weights_iter_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_layer_desc,
        const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a descriptor for LBR GRU forward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc,
/// - @p bias_desc,
/// - @p dst_iter_desc.
///
/// This would then indicate that the LBR GRU forward propagation primitive
/// should not use them and should default to zero values instead.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param flags Unused.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for LBR GRU backward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc together with @p diff_src_iter_desc,
/// - @p bias_desc together with @p diff_bias_desc,
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
///
/// This would then indicate that the LBR GRU backward propagation primitive
/// should not use them and should default to zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Must be #dnnl_backward.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
///     hidden state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
///     applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
///     applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of output
///     vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
///     recurrent hidden state vector.
/// @param flags Unused.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc,
        const_dnnl_memory_desc_t diff_src_layer_desc,
        const_dnnl_memory_desc_t diff_src_iter_desc,
        const_dnnl_memory_desc_t diff_weights_layer_desc,
        const_dnnl_memory_desc_t diff_weights_iter_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_layer_desc,
        const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for AUGRU forward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc,
/// - @p bias_desc,
/// - @p dst_iter_desc.
///
/// This would then indicate that the AUGRU forward propagation primitive should
/// not use them and should default to zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param attention_desc Memory descriptor for the attention vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param flags Unused.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t attention_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for AUGRU backward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc together with @p diff_src_iter_desc,
/// - @p bias_desc together with @p diff_bias_desc,
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
///
/// This would then indicate that the AUGRU backward propagation primitive
/// should not use them and should default to zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Must be #dnnl_backward.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param attention_desc Memory descriptor for the attention vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
///     hidden state vector.
/// @param diff_attention_desc Memory descriptor for the diff of attention vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
///     applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
///     applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of output
///     vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
///     recurrent hidden state vector.
/// @param flags Unused.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t attention_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc,
        const_dnnl_memory_desc_t diff_src_layer_desc,
        const_dnnl_memory_desc_t diff_src_iter_desc,
        const_dnnl_memory_desc_t diff_attention_desc,
        const_dnnl_memory_desc_t diff_weights_layer_desc,
        const_dnnl_memory_desc_t diff_weights_iter_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_layer_desc,
        const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for LBR AUGRU forward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc,
/// - @p bias_desc,
/// - @p dst_iter_desc.
///
/// This would then indicate that the LBR AUGRU forward propagation primitive
/// should not use them and should default to zero values instead.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param attention_desc Memory descriptor for the attention vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param flags Unused.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t attention_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
        const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for LBR AUGRU backward propagation primitive.
///
/// The following arguments may either be @c NULL or point to a zero memory
/// descriptor:
/// - @p src_iter_desc together with @p diff_src_iter_desc,
/// - @p bias_desc together with @p diff_bias_desc,
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
///
/// This would then indicate that the LBR AUGRU backward propagation primitive
/// should not use them and should default to zero values instead.
///
/// @note
///     All memory descriptors can be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Must be #dnnl_backward.
/// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
///     info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent hidden
///     state vector.
/// @param attention_desc Memory descriptor for the attention vector.
/// @param weights_layer_desc Memory descriptor for the weights applied to the
///     layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied to the
///     recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent hidden
///     state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
///     hidden state vector.
/// @param diff_attention_desc Memory descriptor for the diff of attention vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of weights
///     applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of weights
///     applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of output
///     vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
///     recurrent hidden state vector.
/// @param flags Unused.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
        const_dnnl_memory_desc_t src_layer_desc,
        const_dnnl_memory_desc_t src_iter_desc,
        const_dnnl_memory_desc_t attention_desc,
        const_dnnl_memory_desc_t weights_layer_desc,
        const_dnnl_memory_desc_t weights_iter_desc,
        const_dnnl_memory_desc_t bias_desc,
        const_dnnl_memory_desc_t dst_layer_desc,
        const_dnnl_memory_desc_t dst_iter_desc,
        const_dnnl_memory_desc_t diff_src_layer_desc,
        const_dnnl_memory_desc_t diff_src_iter_desc,
        const_dnnl_memory_desc_t diff_attention_desc,
        const_dnnl_memory_desc_t diff_weights_layer_desc,
        const_dnnl_memory_desc_t diff_weights_iter_desc,
        const_dnnl_memory_desc_t diff_bias_desc,
        const_dnnl_memory_desc_t diff_dst_layer_desc,
        const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_rnn

/// @addtogroup dnnl_api_matmul
/// @{

/// Creates a primitive descriptor for a matrix multiplication primitive.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param src_desc Source memory descriptor (matrix A)
/// @param weights_desc Weights memory descriptor (matrix B)
/// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
///     descriptor, or a memory descriptor with format_kind set to
///     #dnnl_format_kind_undef disables the bias term.
/// @param dst_desc Destination memory descriptor (matrix C).
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t weights_desc,
        const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_matmul

/// @addtogroup dnnl_api_resampling Resampling
/// @{

/// Creates a primitive descriptor for a resampling forward propagation
///     primitive.
///
/// @note
///     Destination memory descriptor is allowed to be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param prop_kind Propagation kind. Possible values are
///     #dnnl_forward_training and #dnnl_forward_inference.
/// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest,
///     or #dnnl_resampling_linear.
/// @param factors Array of scaling factors for spatial dimension.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
        const float *factors, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);

/// Creates a primitive descriptor for a resampling backward propagation
///     primitive.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind resamplinging algorithm kind: either
///     #dnnl_resampling_nearest, or #dnnl_resampling_linear.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param factors Array of scaling factors for spatial dimension.
/// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
///     primitive.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
///
dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const float *factors,
        const_dnnl_memory_desc_t diff_src_desc,
        const_dnnl_memory_desc_t diff_dst_desc,
        const_dnnl_primitive_desc_t hint_fwd_pd,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_resampling

/// @addtogroup dnnl_api_reduction Reduction
/// @{

/// Creates a primitive descriptor for a reduction primitive.
///
/// @note
///     Destination memory descriptor is allowed to be initialized with
///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param alg_kind reduction algorithm kind. Possible values:
///     #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
///     #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max,
///     #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max,
///     #dnnl_reduction_norm_lp_power_p_sum.
/// @param p Algorithm specific parameter.
/// @param eps Algorithm specific parameter.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_reduction_primitive_desc_create(
        dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
        dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
        const_dnnl_memory_desc_t dst_desc, float p, float eps,
        const_dnnl_primitive_attr_t attr);

/// @} dnnl_api_reduction

/// @} dnnl_api_primitives

/// @addtogroup dnnl_api_primitive_cache
/// @{

/// Returns the number of primitives that can be held in the primitive cache
/// at the same time.
///
/// @param capacity Primitive cache capacity to query. Concurrently
/// accessing @p capacity is safe.
/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
///     @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
///     success.
dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity);

/// Sets a number of primitives that can be held in the primitive cache
/// at a time.
///
/// @param capacity Primitive cache capacity to set. If a new @p capacity is
/// less than a number of primitives that the primitive cache already has
/// then the excess entries will be evicted. Setting the @p capacity to 0
/// clears the primitive cache and disables it. Concurrently modifying
/// @p capacity is safe.
/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
///     @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
///     success.
dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity);

/// @} dnnl_api_primitive_cache

/// @addtogroup dnnl_api_service
/// @{

/// Configures dumping of JIT-generated code.
///
/// @note
///     This setting overrides the DNNL_JIT_DUMP environment variable.
///
/// @param enable Flag value. Set to 0 to disable and set to 1 to enable.
/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
///     @p flag value is invalid, and #dnnl_success/#dnnl::status::success on
///     success.
dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable);

/// Sets library profiling flags. The flags define which profilers are
/// supported.
///
/// @note
///     This setting overrides DNNL_JIT_PROFILE environment variable.
///
/// @sa @ref dev_guide_profilers
///
/// @param flags Profiling flags that can contain the following bits:
///     - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Profiler
///         (on by default)
///     - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific
///         jit-pid.dump output (off by default). The location of the output
///         is controlled via JITDUMPDIR environment variable or via
///         dnnl_set_jit_profiling_jitdumpdir() function.
///     - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific
///         perf-pid.map output (off by default). The output is always placed
///         into /tmp.
///
///     Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely.
///
/// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
///     @p flags value is invalid, and #dnnl_success/#dnnl::status::success on
///     success.
dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags);

/// Sets JIT dump output path. Only applicable to Linux and is only
/// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set.
///
/// After the first JIT kernel is generated, the jitdump output will be placed
/// into temporary directory created using the mkdtemp template
/// 'dir/.debug/jit/dnnl.XXXXXX'.
///
/// @sa @ref dev_guide_profilers
///
/// @note
///     This setting overrides JITDUMPDIR environment variable.  If
///     JITDUMPDIR is not set, and this function is never called, the path
///     defaults to HOME. Passing NULL reverts the value to default.
///
/// @note
///     The directory is accessed only when the first JIT kernel is being
///     created. JIT profiling will be disabled in case of any errors
///     accessing or creating this directory.
///
/// @param dir JIT dump output path.
/// @returns #dnnl_success/#dnnl::status::success if the
///     output directory was set correctly and an error status otherwise.
/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows.
dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir);

/// Sets the maximal ISA the library can dispatch to on the CPU. See
/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by
/// the C and C++ API functions respectively.
///
/// This function has effect only once, and returns an error on subsequent
/// calls. It should also be invoked before any other oneDNN API call, otherwise
/// it may return an error.
///
/// This function overrides the DNNL_MAX_CPU_ISA environment variable. The
/// environment variable can be set to the desired maximal ISA name in upper
/// case and with dnnl_cpu_isa prefix removed. For example:
/// `DNNL_MAX_CPU_ISA=AVX2`.
///
/// @note
///     The ISAs are only partially ordered:
///         - SSE41 < AVX < AVX2 < AVX2_VNNI < AVX2_VNNI_2,
///         - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16
///           < AVX10_1_512 < AVX10_1_512_AMX < AVX10_1_512_AMX_FP16,
///         - AVX2_VNNI < AVX10_1_512.
///     Aliases:
///         - AVX512_CORE_FP16 = AVX10_1_512
///         - AVX512_CORE_AMX = AVX10_1_512_AMX
///         - AVX512_CORE_AMX_FP16 = AVX10_1_512_AMX_FP16
///
/// @sa @ref dev_guide_cpu_dispatcher_control for more details
///
/// @param isa Maximal ISA the library should dispatch to. Pass
///     #dnnl_cpu_isa_default/#dnnl::cpu_isa::isa_default to remove ISA restrictions
///     (except for ISAs with initial support in the library).
/// @returns #dnnl_success/#dnnl::status::success on success and a
///     #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa
///     parameter is invalid or the ISA cannot be changed at this time.
/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
///     was disabled at build time (see @ref dev_guide_build_options for more
///     details).
dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa);

/// Gets the maximal ISA the library can dispatch to on the CPU. See
/// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by
/// the C and C++ API functions respectively.
///
/// @sa @ref dev_guide_cpu_dispatcher_control for more details
///
/// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may
///     dispatch to.
dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void);

/// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and
/// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++
/// API functions respectively.
///
/// This function has effect only once, and returns an error on subsequent
/// calls. It should also be invoked before any other oneDNN API call, otherwise
/// it may return an error.
///
/// This function overrides the DNNL_CPU_ISA_HINTS environment variable.
/// @sa @ref dev_guide_cpu_isa_hints for more details
///
/// @param isa_hints CPU ISA hints to be passed over to the implementation.
///     Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use
///     default features i.e. no hints.
/// @returns #dnnl_success/#dnnl::status::success on success and a
///     #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot
///     be specified at the current time.
/// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
///     was disabled at build time (see @ref dev_guide_build_options for more
///     details).
dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints);

/// Gets the ISA specific hints that library can follow. See
/// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values
///  returned by the C and C++ API functions respectively.
///
/// @sa @ref dev_guide_cpu_isa_hints for more details
///
/// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the
/// library can follow.
dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void);

/// @} dnnl_api_service

#ifdef DNNL_EXPERIMENTAL_PROFILING

/// @addtogroup dnnl_api_profiling Profiling
/// @{

/// Resets a profiler's state.
///
/// @param stream Stream associated with the profiler.
///
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_reset_profiling(dnnl_stream_t stream);

/// Queries profiling data. The profiling data accumulates for each primitive
/// execution. The @p num_entries will be equal to the number of executions
/// since the last `dnnl_reset_profiling` call. In order to query the
/// @p num_entries the @p data parameter should be NULL. When @p data is NULL
/// then the @p data_kind parameter is ignored.
///
/// The profiling data can be reset by calling #dnnl_reset_profiling.
///
/// @note
///     It is required to wait for all submitted primitives to complete
///     using #dnnl_stream_wait prior to querying profiling data.
///
/// @param stream Stream that was used for executing a primitive that
/// is being profiled.
/// @param data_kind Profiling data kind to query.
/// @param num_entries Number of profiling data entries.
/// @param data Profiling data.
///
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_query_profiling_data(dnnl_stream_t stream,
        dnnl_profiling_data_kind_t data_kind, int *num_entries, uint64_t *data);

/// @} dnnl_api_profiling
#endif

/// @addtogroup dnnl_api_blas
/// @{

/// Performs single-precision matrix-matrix multiply.
///
/// The operation is defined as:
///
/// `C := alpha * op( A ) * op( B ) + beta * C`
///
/// where
///  - `op( X ) = X` or `op( X ) = X**T`,
///  - `alpha` and `beta` are scalars, and
///  - `A`, `B`, and `C` are matrices:
///     - `op( A )` is an `MxK` matrix,
///     - `op( B )` is an `KxN` matrix,
///     - `C` is an `MxN` matrix.
///
/// The matrices are assumed to be stored in row-major order (the elements in
/// each of the matrix rows are contiguous in memory).
///
/// @note
///     This API does not support XERBLA. Instead, unlike the standard BLAS
///     functions, this one returns a dnnl_status_t value to allow error
///     handling.
///
/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
///     transposed, and 'T' or 't' means that A is transposed.
/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
///     transposed, and 'T' or 't' means that B is transposed.
/// @param M The M dimension.
/// @param N The N dimension.
/// @param K The K dimension.
/// @param alpha The alpha parameter that is used to scale the product of
///     matrices A and B.
/// @param A A pointer to the A matrix data.
/// @param lda The leading dimension for the matrix A.
/// @param B A pointer to the B matrix data.
/// @param ldb The leading dimension for the matrix B.
/// @param beta The beta parameter that is used to scale the matrix C.
/// @param C A pointer to the C matrix data.
/// @param ldc The leading dimension for the matrix C.
/// @returns #dnnl_success/#dnnl::status::success on success and a status
///     describing the error otherwise.
dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M,
        dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
        const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc);

/// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit
/// signed matrix B, and 32-bit signed resulting matrix C.
///
/// The operation is defined as:
///
/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
///
/// where
///  - `op( X ) = X` or `op( X ) = X**T`,
///  - `alpha` and `beta` are scalars, and
///  - `A`, `B`, and `C` are matrices:
///     - `op( A )` is an `MxK` matrix,
///     - `op( B )` is an `KxN` matrix,
///     - `C` is an `MxN` matrix.
///  - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
///  - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
///  - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
///    - if `offsetc = F`: the `len` must be at least `1`,
///    - if `offsetc = C`: the `len` must be at least `max(1, m)`,
///    - if `offsetc = R`: the `len` must be at least `max(1, n)`,
///
/// The matrices are assumed to be stored in row-major order (the elements in
/// each of the matrix rows are contiguous in memory).
///
/// @note
///     This API does not support XERBLA. Instead, unlike the standard BLAS
///     functions, this one returns a dnnl_status_t value to allow error
///     handling.
///
/// @warning
///     On some architectures saturation may happen during intermediate
///     computations, which would lead to unexpected results. For more
///     details, refer to @ref dev_guide_int8_computations.
///
/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
///     transposed, and 'T' or 't' means that A is transposed.
/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
///     transposed, and 'T' or 't' means that B is transposed.
/// @param offsetc Flag specifying how offsets should be applied to matrix C:
///     - 'F' means that the same offset will be applied to each element of
///         the matrix C,
///     - 'C' means that individual offset will be applied to each element
///         within each column,
///     - 'R' means that individual offset will be applied to each element
///         within each row.
/// @param M The M dimension.
/// @param N The N dimension.
/// @param K The K dimension.
/// @param alpha The alpha parameter that is used to scale the product of
///     matrices A and B.
/// @param A A pointer to the A matrix data.
/// @param lda The leading dimension for the matrix A.
/// @param ao The offset value for the matrix A.
/// @param B A pointer to the B matrix data.
/// @param ldb The leading dimension for the matrix B.
/// @param bo The offset value for the matrix B.
/// @param beta The beta parameter that is used to scale the matrix C.
/// @param C A pointer to the C matrix data.
/// @param ldc The leading dimension for the matrix C.
/// @param co An array of offset values for the matrix C. The number of
///     elements in the array depends on the value of @p offsetc.
/// @returns #dnnl_success/#dnnl::status::success on success and a status
///     describing the error otherwise.
dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc,
        dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
        dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
        float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);

/// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit
/// signed matrix B, and 32-bit signed resulting matrix C.
///
/// The operation is defined as:
///
/// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
///
/// where
///  - `op( X ) = X` or `op( X ) = X**T`,
///  - `alpha` and `beta` are scalars, and
///  - `A`, `B`, and `C` are matrices:
///     - `op( A )` is an `MxK` matrix,
///     - `op( B )` is an `KxN` matrix,
///     - `C` is an `MxN` matrix.
///  - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
///  - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
///  - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
///    - if `offsetc = F`: the `len` must be at least `1`,
///    - if `offsetc = C`: the `len` must be at least `max(1, m)`,
///    - if `offsetc = R`: the `len` must be at least `max(1, n)`,
///
/// The matrices are assumed to be stored in row-major order (the elements in
/// each of the matrix rows are contiguous in memory).
///
/// @note
///     This API does not support XERBLA. Instead, unlike the standard BLAS
///     functions, this one returns a dnnl_status_t value to allow error
///     handling.
///
/// @warning
///     On some architectures saturation may happen during intermediate
///     computations, which would lead to unexpected results. For more
///     details, refer to @ref dev_guide_int8_computations.
///
/// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
///     transposed, and 'T' or 't' means that A is transposed.
/// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
///     transposed, and 'T' or 't' means that B is transposed.
/// @param offsetc Flag specifying how offsets should be applied to matrix C:
///     - 'F' means that the same offset will be applied to each element of
///         the matrix C,
///     - 'C' means that individual offset will be applied to each element
///         within each column,
///     - 'R' means that individual offset will be applied to each element
///         within each row.
/// @param M The M dimension.
/// @param N The N dimension.
/// @param K The K dimension.
/// @param alpha The alpha parameter that is used to scale the product of
///     matrices A and B.
/// @param A A pointer to the A matrix data.
/// @param lda The leading dimension for the matrix A.
/// @param ao The offset value for the matrix A.
/// @param B A pointer to the B matrix data.
/// @param ldb The leading dimension for the matrix B.
/// @param bo The offset value for the matrix B.
/// @param beta The beta parameter that is used to scale the matrix C.
/// @param C A pointer to the C matrix data.
/// @param ldc The leading dimension for the matrix C.
/// @param co An array of offset values for the matrix C. The number of
///     elements in the array depends on the value of @p offsetc.
/// @returns #dnnl_success/#dnnl::status::success on success and a status
///     describing the error otherwise.
dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc,
        dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
        dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
        float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);

/// @} dnnl_api_blas

/// @} dnnl_api

#ifdef __cplusplus
}
#endif

#endif /* ONEAPI_DNNL_DNNL_H */
