1 #pragma once 2 // This file is temporary until native_functions.yaml and derivatives.yaml are 3 // merged. Ideally this should all go into native_functions.yaml 4 5 #include <c10/util/StringUtil.h> 6 #include <torch/csrc/jit/api/module.h> 7 #include <optional> 8 9 namespace torch::jit { 10 struct GradientPair { 11 std::shared_ptr<Graph> forward; 12 std::shared_ptr<Graph> backward; 13 }; 14 15 TORCH_API std::optional<GradientPair> gradientInfoForSchema( 16 const FunctionSchema& schema); 17 TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema); 18 } // namespace torch::jit 19