#pragma once #include <torch/csrc/jit/ir/irparser.h> #include <torch/csrc/jit/runtime/autodiff.h> #include <torch/csrc/jit/runtime/interpreter.h> #include <torch/csrc/jit/testing/file_check.h> namespace { static inline void trim(std::string& s) { s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); s.erase( std::find_if( s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }) .base(), s.end()); for (size_t i = 0; i < s.size(); ++i) { while (i < s.size() && s[i] == '\n') { s.erase(i, 1); } } for (size_t i = 0; i < s.size(); ++i) { if (s[i] == ' ') { while (i + 1 < s.size() && s[i + 1] == ' ') { s.erase(i + 1, 1); } } } } } // namespace #define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \ try { \ (void)statement; \ FAIL(); \ } catch (const std::exception& e) { \ std::string substring_s(substring); \ trim(substring_s); \ auto exception_string = std::string(e.what()); \ trim(exception_string); \ ASSERT_NE(exception_string.find(substring_s), std::string::npos) \ << " Error was: \n" \ << exception_string; \ } namespace torch { namespace jit { using tensor_list = std::vector<at::Tensor>; using namespace torch::autograd; // work around the fact that variable_tensor_list doesn't duplicate all // of std::vector's constructors. // most constructors are never used in the implementation, just in our tests. Stack createStack(std::vector<at::Tensor>&& list); void assertAllClose(const tensor_list& a, const tensor_list& b); std::vector<at::Tensor> run( InterpreterState& interp, const std::vector<at::Tensor>& inputs); std::pair<tensor_list, tensor_list> runGradient( Gradient& grad_spec, tensor_list& tensors_in, tensor_list& tensor_grads_in); std::shared_ptr<Graph> build_lstm(); std::shared_ptr<Graph> build_mobile_export_analysis_graph(); std::shared_ptr<Graph> build_mobile_export_with_out(); std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg(); std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested(); std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const(); at::Tensor t_use(at::Tensor x); at::Tensor t_def(at::Tensor x); // given the difference of output vs expected tensor, check whether the // difference is within a relative tolerance range. This is a standard way of // matching tensor values up to certain precision bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs); bool almostEqual(const at::Tensor& a, const at::Tensor& b); bool exactlyEqual(const at::Tensor& a, const at::Tensor& b); bool exactlyEqual( const std::vector<at::Tensor>& a, const std::vector<at::Tensor>& b); std::vector<at::Tensor> runGraph( std::shared_ptr<Graph> graph, const std::vector<at::Tensor>& inputs); std::pair<at::Tensor, at::Tensor> lstm( at::Tensor input, at::Tensor hx, at::Tensor cx, at::Tensor w_ih, at::Tensor w_hh); } // namespace jit } // namespace torch