#include <torch/csrc/Device.h>

#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>

#include <ATen/Device.h>
#include <c10/util/Exception.h>

#include <structmember.h>
#include <limits>
#include <sstream>

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PyObject* THPUpperModuleOfDevice = nullptr;

PyObject* THPDevice_New(const at::Device& device) {
  auto type = (PyTypeObject*)&THPDeviceType;
  auto self = THPObjectPtr{type->tp_alloc(type, 0)};
  if (!self)
    throw python_error();
  auto self_ = reinterpret_cast<THPDevice*>(self.get());
  self_->device = device;
  return self.release();
}

PyObject* THPDevice_repr(THPDevice* self) {
  std::ostringstream oss;
  oss << "device(type=\'" << self->device.type() << "\'";
  if (self->device.has_index()) {
    // `self->device.index()` returns uint8_t which is treated as ascii while
    // printing, hence casting it to uint16_t.
    // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
    oss << ", index=" << static_cast<uint16_t>(self->device.index());
  }
  oss << ")";
  return THPUtils_packString(oss.str().c_str());
}

PyObject* THPDevice_str(THPDevice* self) {
  std::ostringstream oss;
  oss << self->device;
  return THPUtils_packString(oss.str().c_str());
}

PyObject* THPDevice_pynew(
    PyTypeObject* type,
    PyObject* args,
    PyObject* kwargs) {
  HANDLE_TH_ERRORS
  static torch::PythonArgParser parser(
      {"device(Device device)",
       "device(c10::string_view type, int64_t? index=-1)"});
  torch::ParsedArgs<2> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  if (r.has_torch_function()) {
    return handle_torch_function(
        r, nullptr, args, kwargs, THPUpperModuleOfDevice, "torch");
  }
  if (r.idx == 0) {
    auto device = r.device(0);
    return THPDevice_New(device);
  } else if (r.idx == 1) {
    auto as_device = r.device(0); // this works, because device can take strings
    if (as_device.has_index()) {
      auto device_type = r.string(0);
      throw std::runtime_error(
          "type (string) must not include an index because index "
          "was passed explicitly: " +
          device_type);
    }
    int64_t device_index = -1;
    if (!r.isNone(1)) {
      device_index = r.toInt64(1);
      // -1 is allowed in ATen/C++, to mean the default device, but not in
      // Python.
      TORCH_CHECK(device_index >= 0, "Device index must not be negative");
    }
    at::Device device(
        as_device.type(), static_cast<c10::DeviceIndex>(device_index));
    return THPDevice_New(device);
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) {
  HANDLE_TH_ERRORS
  std::ostringstream oss;
  oss << self->device.type();
  return THPUtils_packString(oss.str().c_str());
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) {
  HANDLE_TH_ERRORS
  if (self->device.has_index()) {
    return THPUtils_packInt64(self->device.index());
  } else {
    Py_RETURN_NONE;
  }
  END_HANDLE_TH_ERRORS
}

static Py_ssize_t THPDevice_hash(THPDevice* self) {
  HANDLE_TH_ERRORS
  return static_cast<Py_ssize_t>(
      std::hash<at::Device>{}(self->device) %
      std::numeric_limits<Py_ssize_t>::max());
  END_HANDLE_TH_ERRORS_RET(-1)
}

PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
  HANDLE_TH_ERRORS
  if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
    // Py_RETURN_NOTIMPLEMENTED not in python 2.
    Py_INCREF(Py_NotImplemented);
    return Py_NotImplemented;
  }
  THPDevice* da = reinterpret_cast<THPDevice*>(a);
  THPDevice* db = reinterpret_cast<THPDevice*>(b);

  switch (op) {
    case Py_EQ:
      if (da->device == db->device) {
        Py_RETURN_TRUE;
      } else {
        Py_RETURN_FALSE;
      }
    case Py_NE:
      if (da->device == db->device) {
        Py_RETURN_FALSE;
      } else {
        Py_RETURN_TRUE;
      }
    case Py_LT:
    case Py_LE:
    case Py_GT:
    case Py_GE:
      throw torch::TypeError("comparison not implemented");
    default:
      throw torch::TypeError("unexpected comparison op");
  }
  END_HANDLE_TH_ERRORS
}

PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
  HANDLE_TH_ERRORS
  auto self = (THPDevice*)_self;
  auto ret = THPObjectPtr{PyTuple_New(2)};
  if (!ret)
    throw python_error();

  py::object torch_module = py::module::import("torch");
  py::object torch_device = torch_module.attr("device");
  PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());

  THPObjectPtr args;
  std::ostringstream oss;
  oss << self->device.type();
  if (self->device.has_index()) {
    args = THPObjectPtr{Py_BuildValue(
        "(si)", oss.str().c_str(), static_cast<int>(self->device.index()))};
  } else {
    args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
  }
  if (!args)
    throw python_error();
  PyTuple_SET_ITEM(ret.get(), 1, args.release());

  return ret.release();
  END_HANDLE_TH_ERRORS
}

PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
  HANDLE_TH_ERRORS
  py::object mode = py::module::import("torch.utils._device")
                        .attr("DeviceContext")(py::handle(self));
  at::impl::PythonTorchFunctionTLS::push_onto_stack(
      std::make_shared<c10::SafePyObject>(
          mode.release().ptr(), getPyInterpreter()));
  // So that with torch.device('cuda') as dev: works
  Py_INCREF(self);
  return self;
  END_HANDLE_TH_ERRORS
}

PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
  HANDLE_TH_ERRORS
  at::impl::PythonTorchFunctionTLS::pop_stack();
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
  HANDLE_TH_ERRORS
  py::object deco =
      py::module::import("torch.utils._device").attr("device_decorator");
  return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
      .release()
      .ptr();
  END_HANDLE_TH_ERRORS
}

typedef PyObject* (*getter)(PyObject*, void*);

// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyGetSetDef THPDevice_properties[] = {
    {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
    {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
    {nullptr}};

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static PyMethodDef THPDevice_methods[] = {
    {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
    {"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
    {"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
    {nullptr} /* Sentinel */
};

PyTypeObject THPDeviceType = {
    PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */
    sizeof(THPDevice), /* tp_basicsize */
    0, /* tp_itemsize */
    nullptr, /* tp_dealloc */
    0, /* tp_vectorcall_offset */
    nullptr, /* tp_getattr */
    nullptr, /* tp_setattr */
    nullptr, /* tp_reserved */
    (reprfunc)THPDevice_repr, /* tp_repr */
    nullptr, /* tp_as_number */
    nullptr, /* tp_as_sequence */
    nullptr, /* tp_as_mapping */
    (hashfunc)THPDevice_hash, /* tp_hash  */
    // TODO: We're not sure if this is a good idea or not, because making
    // torch.device callable means that it will start returning true
    // for callable() queries, and that is unexpected.  We can always add
    // this later, so for now, don't actually implement this
    // THPDevice_call, /* tp_call */
    nullptr, /* tp_call */
    (reprfunc)THPDevice_str, /* tp_str */
    nullptr, /* tp_getattro */
    nullptr, /* tp_setattro */
    nullptr, /* tp_as_buffer */
    Py_TPFLAGS_DEFAULT, /* tp_flags */
    nullptr, /* tp_doc */
    nullptr, /* tp_traverse */
    nullptr, /* tp_clear */
    (richcmpfunc)THPDevice_rc, /* tp_richcompare */
    0, /* tp_weaklistoffset */
    nullptr, /* tp_iter */
    nullptr, /* tp_iternext */
    THPDevice_methods, /* tp_methods */
    nullptr, /* tp_members */
    THPDevice_properties, /* tp_getset */
    nullptr, /* tp_base */
    nullptr, /* tp_dict */
    nullptr, /* tp_descr_get */
    nullptr, /* tp_descr_set */
    0, /* tp_dictoffset */
    nullptr, /* tp_init */
    nullptr, /* tp_alloc */
    THPDevice_pynew, /* tp_new */
};

void THPDevice_init(PyObject* module) {
  if (PyType_Ready(&THPDeviceType) < 0) {
    throw python_error();
  }
  Py_INCREF(&THPDeviceType);
  THPUpperModuleOfDevice = module;
  if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
    throw python_error();
  }
}