#pragma once

#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <c10/util/python_stub.h>
#include <optional>
#include <stack>
#include <string>

#include <utility>

namespace at {

namespace impl {

struct TORCH_API SavedTensorDefaultHooksTLS {
  // PyObject is defined in c10/util/python_stub.h
  std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;

  // See NOTE: [Disabling SavedTensorDefaultHooks] for context
  // NOTE: [disabled_error_message invariant]
  // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
  // We did this for efficiency (so we didn't have to keep a separate bool
  // around)
  std::optional<std::string> disabled_error_message;

  // See NOTE: [Deferring tensor pack/unpack hooks until runtime]
  bool is_tracing = false;
};

} // namespace impl

struct TORCH_API SavedTensorDefaultHooks {
  static void push_hooks(
      c10::SafePyObject pack_hook,
      c10::SafePyObject unpack_hook);
  static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
  static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
  get_hooks(bool ignore_is_tracing = false);
  static void lazy_initialize();

  static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
  static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);

  // NOTE: [Disabling SavedTensorDefaultHooks]
  // A developer of a PyTorch feature may choose to disable SavedTensorDefault
  // hooks, especially if their feature does not work with it. If they are
  // disabled, then the following will raise an error:
  // - Attempting to push_hooks
  // - calling disable(message) with a non-zero stack (hooks) size
  static void disable(
      const std::string& error_message,
      const bool fail_if_non_empty = true);
  static void enable();
  static bool is_enabled();
  static const std::optional<std::string>& get_disabled_error_message();

  // NOTE: [Deferring tensor pack/unpack hooks until runtime]
  // To preserve eager semantics of pack/unpack hooks firing only once per saved
  // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
  // disable() would loud error at trace time, and pushing a no-op hook would
  // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
  // To do so, we disable these hooks during tracing. See
  // https://github.com/pytorch/pytorch/issues/113263.
  static bool set_tracing(bool is_tracing);
};

} // namespace at
