#pragma once

#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>

namespace {
static inline void trim(std::string& s) {
  s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
            return !std::isspace(ch);
          }));
  s.erase(
      std::find_if(
          s.rbegin(),
          s.rend(),
          [](unsigned char ch) { return !std::isspace(ch); })
          .base(),
      s.end());
  for (size_t i = 0; i < s.size(); ++i) {
    while (i < s.size() && s[i] == '\n') {
      s.erase(i, 1);
    }
  }
  for (size_t i = 0; i < s.size(); ++i) {
    if (s[i] == ' ') {
      while (i + 1 < s.size() && s[i + 1] == ' ') {
        s.erase(i + 1, 1);
      }
    }
  }
}
} // namespace

#define ASSERT_THROWS_WITH_MESSAGE(statement, substring)             \
  try {                                                              \
    (void)statement;                                                 \
    FAIL();                                                          \
  } catch (const std::exception& e) {                                \
    std::string substring_s(substring);                              \
    trim(substring_s);                                               \
    auto exception_string = std::string(e.what());                   \
    trim(exception_string);                                          \
    ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
        << " Error was: \n"                                          \
        << exception_string;                                         \
  }

namespace torch {
namespace jit {

using tensor_list = std::vector<at::Tensor>;
using namespace torch::autograd;

// work around the fact that variable_tensor_list doesn't duplicate all
// of std::vector's constructors.
// most constructors are never used in the implementation, just in our tests.
Stack createStack(std::vector<at::Tensor>&& list);

void assertAllClose(const tensor_list& a, const tensor_list& b);

std::vector<at::Tensor> run(
    InterpreterState& interp,
    const std::vector<at::Tensor>& inputs);

std::pair<tensor_list, tensor_list> runGradient(
    Gradient& grad_spec,
    tensor_list& tensors_in,
    tensor_list& tensor_grads_in);

std::shared_ptr<Graph> build_lstm();
std::shared_ptr<Graph> build_mobile_export_analysis_graph();
std::shared_ptr<Graph> build_mobile_export_with_out();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested();
std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const();

at::Tensor t_use(at::Tensor x);
at::Tensor t_def(at::Tensor x);

// given the difference of output vs expected tensor, check whether the
// difference is within a relative tolerance range. This is a standard way of
// matching tensor values up to certain precision
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
bool almostEqual(const at::Tensor& a, const at::Tensor& b);

bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
bool exactlyEqual(
    const std::vector<at::Tensor>& a,
    const std::vector<at::Tensor>& b);

std::vector<at::Tensor> runGraph(
    std::shared_ptr<Graph> graph,
    const std::vector<at::Tensor>& inputs);

std::pair<at::Tensor, at::Tensor> lstm(
    at::Tensor input,
    at::Tensor hx,
    at::Tensor cx,
    at::Tensor w_ih,
    at::Tensor w_hh);

} // namespace jit
} // namespace torch