#include <onnx/onnx_pb.h> #include <torch/csrc/onnx/back_compat.h> #include <torch/csrc/onnx/init.h> #include <torch/csrc/onnx/onnx.h> #include <torch/version.h> #include <torch/csrc/Exceptions.h> #include <torch/csrc/jit/passes/onnx.h> #include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h> #include <torch/csrc/jit/passes/onnx/constant_fold.h> #include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h> #include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h> #include <torch/csrc/jit/passes/onnx/eval_peephole.h> #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h> #include <torch/csrc/jit/passes/onnx/function_extraction.h> #include <torch/csrc/jit/passes/onnx/function_substitution.h> #include <torch/csrc/jit/passes/onnx/list_model_parameters.h> #include <torch/csrc/jit/passes/onnx/naming.h> #include <torch/csrc/jit/passes/onnx/onnx_log.h> #include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h> #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h> #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h> #include <torch/csrc/jit/passes/onnx/peephole.h> #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h> #include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h> #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h> #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h> #include <torch/csrc/jit/passes/onnx/shape_type_inference.h> #include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h> #include <torch/csrc/jit/serialization/export.h> namespace torch::onnx { using namespace torch::jit; void initONNXBindings(PyObject* module) { auto m = py::handle(module).cast<py::module>(); // ONNX specific passes m.def("_jit_pass_onnx_remove_print", RemovePrintOps) .def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops) .def("_jit_pass_onnx", ToONNX) .def( "_jit_pass_onnx_assign_output_shape", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, const std::vector<at::Tensor>& tensors, const python::IODescriptor& desc, bool onnx_shape_inference, bool is_script, int opset_version) { ONNXAssignOutputShape( graph, tensors, desc, onnx_shape_inference, is_script, opset_version); })) .def( "_jit_pass_onnx_function_substitution", wrap_pybind_function(ONNXFunctionCallSubstitution)) .def( "_jit_pass_onnx_autograd_function_process", wrap_pybind_function(ONNXAutogradFunctionProcess)) .def( "_jit_pass_onnx_peephole", ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph, int opset_version, bool fixed_batch_size) { return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size); })) .def( "_jit_pass_onnx_preprocess", ::torch::wrap_pybind_function(PreprocessForONNX)) .def( "_jit_pass_onnx_eval_peephole", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, std::map<std::string, IValue>& paramsDict) { EvalPeepholeONNX(graph, paramsDict); return paramsDict; }), pybind11::return_value_policy::move) .def( "_jit_pass_onnx_cast_all_constant_to_floating", ::torch::wrap_pybind_function(CastAllConstantToFloating)) .def( "_jit_pass_onnx_constant_fold", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, std::map<std::string, IValue>& paramsDict, int opset_version) { ConstantFoldONNX( graph, paramsDict, opset_version); // overload resolution return paramsDict; }), pybind11::return_value_policy::move) .def( "_jit_pass_onnx_eliminate_unused_items", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, std::map<std::string, IValue>& paramsDict) { EliminateUnusedItemsONNX( graph->block(), paramsDict); // overload resolution return paramsDict; }), pybind11::return_value_policy::move) .def( "_jit_pass_onnx_scalar_type_analysis", ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph, bool lowprecision_cast, int opset_version) { return ScalarTypeAnalysisForONNX( graph, lowprecision_cast, opset_version); }), py::arg("graph"), py::arg("lowprecision_cast") = true, py::arg("opset_version")) .def( "_jit_pass_onnx_remove_inplace_ops_for_onnx", ::torch::wrap_pybind_function(RemoveInplaceOpsForONNX)) .def( "_jit_pass_onnx_node_shape_type_inference", ::torch::wrap_pybind_function( [](Node* n, std::map<std::string, IValue>& params_dict, int opset_version) { ONNXShapeTypeInference(n, params_dict, opset_version); })) .def( "_jit_pass_onnx_graph_shape_type_inference", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, std::map<std::string, IValue>& params_dict, int opset_version) { ONNXShapeTypeInference(graph, params_dict, opset_version); }), py::arg("graph"), py::arg("params_dict"), py::arg("opset_version")) .def( "_jit_pass_onnx_set_dynamic_input_shape", ::torch::wrap_pybind_function(ONNXSetDynamicInputShape)) .def("_jit_pass_onnx_lint", torch::wrap_pybind_function(ONNXLintGraph)) .def( "_jit_pass_onnx_function_extraction", ::torch::wrap_pybind_function( torch::jit::onnx::ONNXFunctionExtraction)) .def("_jit_pass_onnx_block", torch::wrap_pybind_function(BlockToONNX)) .def( "_jit_pass_onnx_unpack_quantized_weights", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, std::map<std::string, IValue>& paramsDict) { UnpackQuantizedWeights(graph, paramsDict); return paramsDict; }), pybind11::return_value_policy::move) .def( "_jit_pass_onnx_quantization_insert_permutes", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, std::map<std::string, IValue>& paramsDict) { insertPermutes(graph, paramsDict); return paramsDict; }), pybind11::return_value_policy::move) .def( "_jit_onnx_list_model_parameters", ::torch::wrap_pybind_function( [](Module& module) { return list_module_parameters(module); })) .def( "_jit_pass_prepare_division_for_onnx", ::torch::wrap_pybind_function(PrepareDivisionForONNX)) .def( "_jit_onnx_convert_pattern_from_subblock", ::torch::wrap_pybind_function(ConvertPatternFromSubblock)) .def( "_jit_pass_fixup_onnx_controlflow_node", ::torch::wrap_pybind_function(FixupONNXControlflowNode)) .def( "_jit_pass_onnx_deduplicate_initializers", ::torch::wrap_pybind_function( [](std::shared_ptr<Graph>& graph, std::map<std::string, IValue> params_dict, bool is_train) { DeduplicateInitializers(graph, params_dict, is_train); return params_dict; }), pybind11::return_value_policy::move) .def( "_jit_pass_onnx_clear_scope_records", &torch::jit::onnx::ONNXClearScopeRecords) .def( "_jit_pass_onnx_track_scope_attributes", &torch::jit::onnx::ONNXTrackScopeAttributes) .def( "_jit_is_onnx_log_enabled", ::torch::jit::onnx::is_log_enabled, "Returns whether ONNX logging is enabled or disabled.") .def( "_jit_set_onnx_log_enabled", ::torch::jit::onnx::set_log_enabled, "Enables or disables ONNX logging.") .def( "_jit_set_onnx_log_output_stream", [](const std::string& stream_name = "stdout") -> void { std::shared_ptr<std::ostream> out; if (stream_name == "stdout") { out = std::shared_ptr<std::ostream>( &std::cout, [](std::ostream*) {}); } else if (stream_name == "stderr") { out = std::shared_ptr<std::ostream>( &std::cerr, [](std::ostream*) {}); } else { std::cerr << "ERROR: only `stdout` and `stderr`" << "are supported as `stream_name`" << std::endl; } ::torch::jit::onnx::set_log_output_stream(out); }, "Set specific file stream for ONNX logging.") .def( "_jit_onnx_log", [](const py::args& args) -> void { if (::torch::jit::onnx::is_log_enabled()) { auto& out = ::torch::jit::onnx::_get_log_output_stream(); for (auto arg : args) { out << ::c10::str(arg); } out << std::endl; } }, "Write `args` to the previously specified ONNX log stream.") .def( "_jit_pass_onnx_assign_scoped_names_for_node_and_value", ::torch::wrap_pybind_function( ::torch::jit::onnx::AssignScopedNamesForNodeAndValue), "Assign informative scoped names for nodes and values.") .def( "_jit_onnx_create_full_scope_name", ::torch::wrap_pybind_function( ::torch::jit::onnx::ONNXScopeName::createFullScopeName), "Create a full scope name from class name and variable name."); m.def( "_check_onnx_proto", ::torch::wrap_pybind_function([](const std::string& proto_string) { check_onnx_proto(proto_string); }), py::arg("proto_string")); auto onnx = m.def_submodule("_onnx"); py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType") .value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) .value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT) .value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8) .value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8) .value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16) .value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16) .value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32) .value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64) .value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING) .value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL) .value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) .value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) .value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32) .value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64) .value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64) .value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128) .value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) .value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN) .value( "FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ) .value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2) .value( "FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ); py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes") .value("ONNX", OperatorExportTypes::ONNX) .value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN) .value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK) .value("ONNX_FALLTHROUGH", OperatorExportTypes::ONNX_FALLTHROUGH); py::enum_<TrainingMode>(onnx, "TrainingMode") .value("EVAL", TrainingMode::EVAL) .value("PRESERVE", TrainingMode::PRESERVE) .value("TRAINING", TrainingMode::TRAINING); onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION); } } // namespace torch::onnx