#include <gtest/gtest.h>

#include <ATen/ATen.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <c10/util/irange.h>

#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/lower_grad_of.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/symbolic_script.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/testing/file_check.h>

#include <onnx/onnx_pb.h>

#include <c10/util/Exception.h>

#include <algorithm>
#include <cstddef>
#include <functional>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_set>
#include <utility>
#include <vector>

namespace torch {
namespace jit {

class FuserTest : public ::testing::Test {
  void SetUp() override {
    old_nvfuser_value_ = fuser::cuda::setEnabled(false);
  }
  void TearDown() override {
    fuser::cuda::setEnabled(old_nvfuser_value_);
  }

 private:
  bool old_nvfuser_value_;
};

TEST_F(FuserTest, TestSimple_CUDA) {
#if defined(FBCODE_CAFFE2)
  return;
#endif
  const auto graph_string = R"IR(
      graph(%0 : Tensor,
            %1 : Tensor):
        %2 : Tensor = aten::mul(%0, %1)
        return (%2))IR";
  Graph graph;
  torch::jit::parseIR(graph_string, &graph);

  auto a = at::rand({3, 4}, at::kCUDA);
  auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
  auto o = at::zeros({3, 4}, at::kCUDA);
  auto outputs = debugLaunchGraph(graph, {a, b});
  ASSERT_EQ(outputs.size(), 1);
  auto o2 = a * b;
  float max_diff = (o2 - outputs[0]).abs().max().item<double>();
  // std::cout << "max diff: " << max_diff << "\n";
  ASSERT_EQ(max_diff, 0);
}

TEST_F(FuserTest, TestOne_CUDA) {
#if defined(FBCODE_CAFFE2)
  return;
#endif
  auto testOne = [&](int ti, int tj) {
    const auto graph_string = R"IR(
      graph(%0 : Tensor,
            %1 : Tensor,
            %2 : Tensor,
            %3 : Tensor,
            %4 : Tensor):
        %5 : Tensor = aten::sigmoid(%4)
        %6 : Tensor = aten::sigmoid(%3)
        %7 : Tensor = aten::tanh(%2)
        %8 : Tensor = aten::sigmoid(%1)
        %9 : Tensor = aten::mul(%6, %0)
        %10 : Tensor = aten::mul(%5, %7)
        %11 : int = prim::Constant[value=1]()
        %12 : Tensor = aten::add(%9, %10, %11)
        %13 : Tensor = aten::tanh(%12)
        %14 : Tensor = aten::mul(%8, %13)
        return (%14, %12))IR";
    Graph graph;
    torch::jit::parseIR(graph_string, &graph);

    graph.lint();

    std::vector<at::Tensor> inputs;
    // We want to generate input/output tensors with dimension 128x128x32, but
    // with different internal strides.  To do this, we generate a tensor
    // with the "wrong" dimensions, and then use transpose to get an
    // appropriately sized view.
    std::generate_n(
        std::back_inserter(inputs), graph.inputs().size(), [ti, tj] {
          std::array<int64_t, 3> dims = {128, 128, 32};
          std::swap(dims[ti], dims[tj]);
          return at::rand(dims, at::kCUDA).transpose(ti, tj);
        });

    auto t22 = inputs[4].sigmoid();
    auto t20 = inputs[3].sigmoid();
    auto t18 = inputs[2].tanh();
    auto t16 = inputs[1].sigmoid();
    auto t14 = t20 * inputs[0];
    auto t11 = t22 * t18;
    auto out1 = t14 + t11;
    auto t5 = out1.tanh();
    auto out0 = t16 * t5;

    auto outputs = debugLaunchGraph(graph, inputs);
    ASSERT_EQ(outputs.size(), graph.outputs().size());
    ASSERT_TRUE(out0.is_same_size(outputs.front()));
    float max_diff = (outputs.front() - out0).abs().max().item<double>();
    ASSERT_TRUE(max_diff < 1e-6);
  };
  testOne(0, 0);
  testOne(0, 1);
  testOne(1, 2);
  testOne(0, 2);
}

TEST_F(FuserTest, FusedConcat_CUDA) {
#if defined(FBCODE_CAFFE2)
  return;
#endif
  const auto graph_string0 = R"IR(
    graph(%0 : Tensor,
          %1 : Tensor):
      %2 : Tensor = aten::mul(%0, %1)
      %3 : Tensor = prim::FusedConcat[dim=0](%0, %2)
      return (%2, %3))IR";
  const auto graph_string1 = R"IR(
    graph(%0 : Tensor,
          %1 : Tensor):
      %2 : Tensor = aten::mul(%0, %1)
      %3 : Tensor = prim::FusedConcat[dim=1](%0, %2)
      return (%2, %3))IR";
  const auto graph_string2 = R"IR(
    graph(%0 : Tensor,
          %1 : Tensor):
      %2 : Tensor = aten::mul(%0, %1)
      %3 : Tensor = prim::FusedConcat[dim=2](%0, %2)
      return (%2, %3))IR";

  auto a = at::rand({3, 4, 5}, at::kCUDA);
  auto b = at::rand({4, 3, 5}, at::kCUDA).transpose(0, 1);
  const auto o_r = a * b;

  std::vector<std::string> graph_strings{
      graph_string0, graph_string1, graph_string2};
  for (const auto i : c10::irange(graph_strings.size())) {
    Graph g;
    torch::jit::parseIR(graph_strings[i], &g);

    auto outputs = debugLaunchGraph(g, {a, b});
    ASSERT_EQ(outputs.size(), 2);

    float max_diff = (o_r - outputs[0]).abs().max().item<double>();
    ASSERT_EQ(max_diff, 0);

    const auto o2_r = at::cat({a, o_r}, i);
    float max_diff2 = (o2_r - outputs[1]).abs().max().item<double>();
    ASSERT_EQ(max_diff2, 0);
  };
}

TEST_F(FuserTest, FusionAliasing) {
#if defined(FBCODE_CAFFE2)
  return;
#endif
  const auto graph_string = R"IR(
    graph(%0 : Tensor,
          %1 : Tensor):
      %12 : int = prim::Constant[value=1]()
      %2.1 : Tensor = aten::mul(%0, %1)
      %2 : Tensor = aten::mul(%2.1, %1)
      %3 : Tensor = aten::add_(%2, %1, %12)
      %4 : Tensor = aten::mul(%2, %1)
      %5 : Tensor = aten::add(%2, %4, %12)
      return (%5))IR";
  auto g = std::make_shared<Graph>();
  torch::jit::parseIR(graph_string, g.get());

  g->lint();
  FuseGraph(g);

  // We should not be able to fuse across the in-place operation here.
  testing::FileCheck()
      .check("prim::FusionGroup_0")
      ->check("aten::add_")
      ->check("prim::FusionGroup_1")
      ->run(*g);
}

TEST_F(FuserTest, KernelCaching) {
#if defined(FBCODE_CAFFE2)
  return;
#endif

  // Constructs two functionally equivalent graphs
  const auto graph0_string = R"IR(
    graph(%0 : Float(2, 3, 4),
          %1 : Float(2, 3, 4)):
      %c0 : Float(2, 3, 4) = aten::mul(%0, %1)
      %d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
      return (%d0))IR";
  auto g0 = std::make_shared<Graph>();
  torch::jit::parseIR(graph0_string, g0.get());

  const auto graph1_string = R"IR(
    graph(%0 : Float(2, 3, 4),
          %1 : Float(2, 3, 4)):
      %c1 : Float(2, 3, 4) = aten::mul(%0, %1)
      %d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
      return (%d1))IR";
  auto g1 = std::make_shared<Graph>();
  torch::jit::parseIR(graph1_string, g1.get());

  auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
    const auto& nodes = graph->nodes();
    auto maybe_fusion_group =
        std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
          return node->kind() == prim::FusionGroup;
        });
    TORCH_CHECK(
        maybe_fusion_group != nodes.end(),
        "testRegisterFusionCachesKernel: could not create FusionGroup");
    return *maybe_fusion_group;
  };

  // Creates two alpha-equivalent fusion groups
  torch::jit::overrideCanFuseOnCPU(true);
  FuseGraph(g0);
  FuseGraph(g1);
  torch::jit::overrideCanFuseOnCPU(false);
  auto fg0 = getFusionGroup(g0);
  auto fg1 = getFusionGroup(g1);

  // Registers both with the fusion compiler.
  auto expected_key = registerFusion(fg0);
  auto second_key = registerFusion(fg1);

  // Because the graphs are alpha-equivalent, they should return the same key
  // and therefore share a KernelSpec to share kernels for specializations
  ASSERT_EQ(second_key, expected_key);
}
} // namespace jit
} // namespace torch