/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <stdexcept>
#include "fbgemm/Types.h"
#include "fbgemm/Utils.h"

namespace fbgemm {

/**
 * @ Transform all entries in a matrix from fp32 to bfloat16: reference
 * implementation.
 *
 */
FBGEMM_API void
FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size);

/**
 * @ Transform all entries in a matrix from bfloat16 to fp32: reference
 * implementation.
 *
 */
FBGEMM_API void
Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size);

/**
 * @ Transform all entries in a matrix from fp32 to bfloat16: simd
 * implementation.
 *
 */
FBGEMM_API void
FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size);

/**
 * @ Transform all entries in a matrix from bfloat16 to fp32: simd
 * implementation.
 *
 */
FBGEMM_API void
Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size);

/**
 * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
 *
 */
FBGEMM_API void
FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size);

/**
 * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
 *
 */
FBGEMM_API void
FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size);

/**
 * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
 *
 */
FBGEMM_API void
Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size);

/**
 * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
 *
 */
FBGEMM_API void
Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size);

/**
 * @ Transform all entries in a matrix from fp32 to float16: reference
 * implementation.
 *
 * @param do_clip if true we saturate to fp16 min and max instead of generating
 *                infinities.
 */
FBGEMM_API void FloatToFloat16_ref(
    const float* src,
    float16* dst,
    size_t size,
    bool do_clip = false);

/**
 * @ Transform all entries in a matrix from float16 to fp32: reference
 * implementation.
 *
 */
FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size);

/**
 * @ Transform all entries in a matrix from fp32 to float16: simd
 * implementation.
 *
 * @param do_clip if true we saturate to fp16 min and max instead of generating
 *                infinities.
 */
FBGEMM_API void FloatToFloat16_simd(
    const float* src,
    float16* dst,
    size_t size,
    bool do_clip = false);

/**
 * @ Transform all entries in a matrix from float16 to fp32: simd
 * implementation.
 *
 */
FBGEMM_API void
Float16ToFloat_simd(const float16* src, float* dst, size_t size);

/**
 * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
 *
 */
FBGEMM_API void FloatToFloat16_avx2(
    const float* src,
    float16* dst,
    size_t size,
    bool do_clip = false);

/**
 * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers.
 *
 */
FBGEMM_API void FloatToFloat16_avx512(
    const float* src,
    float16* dst,
    size_t size,
    bool do_clip = false);

/**
 * @brief SVE2 implementation to convert fp32 numbers to fp16 numbers.
 *
 */
FBGEMM_API void FloatToFloat16_sve2(
    const float* src,
    float16* dst,
    size_t size,
    bool do_clip = false);

/**
 * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
 *
 */
FBGEMM_API void
Float16ToFloat_avx2(const float16* src, float* dst, size_t size);

/**
 * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
 *
 */
FBGEMM_API void
Float16ToFloat_avx512(const float16* src, float* dst, size_t size);

/**
 * @brief Transform all entries in a matrix from fp32 to float16 and back to
 * fp32.
 */
FBGEMM_API void RoundToFloat16(
    const float* input,
    float* output,
    size_t size,
    bool clamp = false,
    bool clamp_denorms = false);

/**
 * @brief Quantize float32 to float8. The code is a copy of float_to_hfp8() in
 * fbgemm_gpu/quantize_ops_utils.h
 */
FBGEMM_API void FloatToFloat8_ref(
    const float input,
    uint8_t* output,
    int exponent_bits,
    int exponent_bias);

/**
 * @brief Dequantize float8 to float32. The code is a copy of hf8_to_float() in
 * fbgemm_gpu/quantize_ops_utils.h
 */
FBGEMM_API void Float8ToFloat_ref(
    const uint8_t input,
    float* output,
    int exponent_bits,
    int exponent_bias);

} // namespace fbgemm
