xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/test_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  // Copyright (c) Meta Platforms, Inc. and affiliates.
2  //
3  // This source code is licensed under the BSD-style license found in the
4  // LICENSE file in the root directory of this source tree.
5  
6  #pragma once
7  
8  #include <string>
9  #include <vector>
10  
11  #include <torch/csrc/jit/ir/ir.h>
12  #include <torch/csrc/jit/runtime/static/impl.h>
13  
14  namespace c10 {
15  struct IValue;
16  }
17  
18  namespace torch {
19  namespace jit {
20  
21  struct Node;
22  class StaticModule;
23  
24  namespace test {
25  
26  // Given a model/function in jit or IR script, run the model/function
27  // with the jit interpreter and static runtime, and compare the results
28  void testStaticRuntime(
29      const std::string& source,
30      const std::vector<c10::IValue>& args,
31      const std::vector<c10::IValue>& args2 = {},
32      const bool use_allclose = false,
33      const bool use_equalnan = false,
34      const bool check_resize = true);
35  
36  std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script);
37  
38  std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
39  
40  bool hasProcessedNodeWithName(
41      torch::jit::StaticModule& smodule,
42      const char* name);
43  
44  at::Tensor getTensor(const at::IValue& ival);
45  
46  Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind);
47  Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind);
48  
49  bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind);
50  bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind);
51  
52  void compareResultsWithJIT(
53      StaticRuntime& runtime,
54      const std::shared_ptr<Graph>& graph,
55      const std::vector<c10::IValue>& args,
56      const bool use_allclose = false,
57      const bool use_equalnan = false);
58  
59  void compareResults(
60      const IValue& expect,
61      const IValue& actual,
62      const bool use_allclose = false,
63      const bool use_equalnan = false);
64  
65  } // namespace test
66  } // namespace jit
67  } // namespace torch
68