xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/insert_observers.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/passes/quantization/insert_observers.h>
3 
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <torch/csrc/jit/ir/subgraph_matcher.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/constant_pooling.h>
8 #include <torch/csrc/jit/passes/constant_propagation.h>
9 #include <torch/csrc/jit/passes/fuse_linear.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/inline_fork_wait.h>
12 #include <torch/csrc/jit/passes/quantization/helper.h>
13 #include <torch/csrc/jit/passes/remove_mutation.h>
14 
15 #include <memory>
16 #include <stack>
17 #include <string>
18 #include <utility>
19 
20 namespace torch {
21 namespace jit {
22 
23 using ModuleQConfigMap = std::unordered_map<ModulePtr, std::optional<QConfig>>;
24 
25 namespace {
26 
27 struct OptionalQConfigHash {
operator ()torch::jit::__anond581fb520111::OptionalQConfigHash28   inline size_t operator()(const std::optional<QConfig>& qconfig_opt) const {
29     if (qconfig_opt.has_value()) {
30       const auto& m1 = std::get<0>(*qconfig_opt);
31       const auto& m2 = std::get<1>(*qconfig_opt);
32       constexpr int CONST = 7;
33       return std::hash<Module>()(m1) + CONST * std::hash<Module>()(m2);
34     }
35     return 0;
36   }
37 };
38 using QConfigTypePtrMap =
39     std::unordered_map<std::optional<QConfig>, TypePtr, OptionalQConfigHash>;
40 using NameModuleVector = std::vector<std::pair<std::string, Module>>;
41 using OptionalModuleVector = std::vector<std::optional<Module>>;
42 using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
43 using graph_rewrite_helper::PatternInfo;
44 using graph_rewrite_helper::replaceConvolutionWithAtenConv;
45 
46 // helper functions
fillQConfigMap(const Module & module,const QConfigDict & qconfig_dict,ModuleQConfigMap & map,const std::string & key="",const std::optional<QConfig> & parent_qconfig=std::nullopt)47 void fillQConfigMap(
48     const Module& module,
49     const QConfigDict& qconfig_dict,
50     ModuleQConfigMap& map,
51     const std::string& key = "",
52     const std::optional<QConfig>& parent_qconfig = std::nullopt) {
53   std::optional<QConfig> qconfig;
54   if (qconfig_dict.find(key) != qconfig_dict.end()) {
55     GRAPH_DEBUG("Got module config for key:", key);
56     qconfig = qconfig_dict.at(key);
57   } else {
58     GRAPH_DEBUG("Inheriting qconfig from parent module:", key);
59     qconfig = parent_qconfig;
60   }
61   map[module._ivalue()] = qconfig;
62 
63   for (const NameModule& s : module.named_children()) {
64     std::string child_key;
65     if (key.empty()) {
66       child_key = s.name;
67     } else {
68       child_key = key + "." + s.name;
69     }
70     fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
71   }
72 }
73 
getObserverModuleFor(Value * v,const QConfig & qconfig)74 Module getObserverModuleFor(Value* v, const QConfig& qconfig) {
75   return isWeight(v) ? std::get<1>(qconfig) : std::get<0>(qconfig);
76 }
77 
78 // helper classes
79 class ModuleCloneHelper {
80  public:
81   /** Clone according to module qconfig map, this is for handling the case
82    *  where we have two module instances sharing the same ClassType
83    *  but configured with different QConfig
84    *  code is copied and modified from
85    * https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp
86    * inplace option means if the copy of the Tensor is deepcopy or not
87    * if inplace is true, the cloned module will share the tensors with
88    * original model instead of deepcopy them
89    */
clone(const Module & module,const ModuleQConfigMap & module_qconfig_map,bool inplace=false)90   Module clone(
91       const Module& module,
92       const ModuleQConfigMap& module_qconfig_map,
93       bool inplace = false) {
94     std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
95     IValue::HashIdentityIValueMap memo;
96     return clone_impl(
97         module, module_qconfig_map, type_remap, inplace, std::move(memo));
98   }
99 
100  private:
clone_impl(const Module & module,const ModuleQConfigMap & module_qconfig_map,std::unordered_map<TypePtr,QConfigTypePtrMap> & type_remap,bool inplace,IValue::HashIdentityIValueMap memo)101   Module clone_impl(
102       const Module& module,
103       const ModuleQConfigMap& module_qconfig_map,
104       std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap,
105       bool inplace,
106       IValue::HashIdentityIValueMap memo) {
107     auto qconfig = module_qconfig_map.at(module._ivalue());
108     auto type = module.type();
109     // Create a new _ivalue in the same compilation unit.
110     // Since now we have shared ClassType, we need to preserve the shared
111     // ClassType during cloning, so we first use type and qconfig to check if
112     // the type is already cloned, if so, we'll create a new module with the
113     // cloned ClassType, if not, we'll create a new module and a new ClassType.
114     bool type_already_cloned = type_remap.find(type) != type_remap.end() &&
115         type_remap.at(type).find(qconfig) != type_remap.at(type).end();
116     Module r;
117     if (type_already_cloned) {
118       // if we cloned the class type before, we'll reuse it
119       Module new_module(
120           module._ivalue()->compilation_unit(),
121           type_remap.at(type).at(qconfig)->cast<ClassType>());
122       r = new_module;
123     } else {
124       Module new_module(
125           *type->name(), module._ivalue()->compilation_unit(), true);
126       r = new_module;
127       type_remap[type][module_qconfig_map.at(module._ivalue())] = r.type();
128     }
129     // Copy slots. If a slot is a module - recursively clone it.
130     size_t N = type->numAttributes();
131     for (const auto i : c10::irange(N)) {
132       IValue s = module._ivalue()->getSlot(i);
133       std::string attr_name = type->getAttributeName(i);
134       TypePtr attr_type = type->getAttribute(i);
135       if (attr_type->is_module()) {
136         const Module& orig = Module(s.toObject());
137         Module cloned =
138             clone_impl(orig, module_qconfig_map, type_remap, inplace, memo);
139 
140         // NOTE: why do we need to manually setattr on object instead of using
141         // register_module here? because the attr can be a module interface
142         // type and hold a Module object still. register_module will not let us
143         // correctly set up the type for this attr, so we had to do this
144         // manually. In the case it's an interface type, the type will be shared
145         // by the new cloned instance in the same compilation unit bc it only
146         // contains a list of functionSchema
147         r.type()->addOrCheckAttribute(
148             attr_name,
149             attr_type->cast<ClassType>() ? cloned.type() : attr_type);
150         r._ivalue()->setAttr(attr_name, cloned._ivalue());
151       } else {
152         // we'll deepcopy the IValue in non inplace option
153         r.register_attribute(
154             type->getAttributeName(i),
155             type->getAttribute(i),
156             inplace ? s : s.deepcopy(memo),
157             type->is_parameter(i),
158             type->is_buffer(i));
159       }
160     }
161 
162     // only clone the methods and constants if the ClassType is not cloned
163     // before
164     if (!type_already_cloned) {
165       for (size_t i = 0; i < type->numConstants(); ++i) {
166         r.type()->addConstant(type->getConstantName(i), type->getConstant(i));
167       }
168       // Clone methods remapping the types to the cloned ones.
169       for (auto& fn : type->methods()) {
170         clone_method(module, r, *fn, module_qconfig_map, type_remap);
171       }
172       // Execute __setstate__(__getstate__()) to initialize custom class
173       // members.
174       if (auto setstate_method = r.find_method("__setstate__")) {
175         auto getstate_method = r.find_method("__getstate__");
176         TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
177         auto state = (*getstate_method)(Stack{});
178         (*setstate_method)(Stack{std::move(state)});
179       }
180     }
181     return r;
182   }
183 
remapTypes(Block * block,Value * self,const Module & source,Module & target,const ModuleQConfigMap & module_qconfig_map,const std::function<TypePtr (TypePtr,std::optional<QConfig>)> & type_remap_fn)184   void remapTypes(
185       Block* block,
186       Value* self,
187       const Module& source,
188       Module& target,
189       const ModuleQConfigMap& module_qconfig_map,
190       const std::function<TypePtr(TypePtr, std::optional<QConfig>)>&
191           type_remap_fn) {
192     // remap of %self will be done outside of the function
193     // and we don't support the case when people pass in
194     // module as argument of the method because in that case
195     // we need to do more comprehensive analysis to decide the
196     // QConfig for the module
197     for (size_t i = 1; i < block->inputs().size(); ++i) {
198       TORCH_CHECK(
199           !block->inputs()[i]->type()->cast<ClassType>(),
200           "We don't support quantizing methods that has Object as arguments");
201     }
202     for (Node* node : block->nodes()) {
203       // remapping type for module instance
204       if (node->kind() == prim::CallMethod || node->kind() == prim::GetAttr) {
205         Value* instance = node->inputs()[0];
206         auto child_opt = getInvokedModuleOpt(source, node, self);
207         if (child_opt.has_value()) {
208           auto qconfig = module_qconfig_map.at(child_opt->_ivalue());
209           instance->setType(type_remap_fn(instance->type(), qconfig));
210         }
211       }
212       // We don't remap output and the remapping of module type
213       // will be done in CallMethod, we don't support type remapping
214       // for modules returned from methods or functions
215       for (Block* sub_block : node->blocks()) {
216         remapTypes(
217             sub_block, self, source, target, module_qconfig_map, type_remap_fn);
218       }
219       for (Symbol name : node->attributeNames()) {
220         if (node->kindOf(name) == AttributeKind::g) {
221           remapTypes(
222               node->g(name).get(),
223               source,
224               target,
225               module_qconfig_map,
226               type_remap_fn);
227         } else if (node->kindOf(name) == AttributeKind::gs) {
228           for (const auto& g : node->gs(name)) {
229             remapTypes(
230                 g.get(), source, target, module_qconfig_map, type_remap_fn);
231           }
232         }
233       }
234     }
235   }
236 
remapTypes(Graph * graph,const Module & source,Module & target,const ModuleQConfigMap & module_qconfig_map,const std::function<TypePtr (TypePtr,std::optional<QConfig>)> & type_remap_fn)237   void remapTypes(
238       Graph* graph,
239       const Module& source,
240       Module& target,
241       const ModuleQConfigMap& module_qconfig_map,
242       const std::function<TypePtr(TypePtr, std::optional<QConfig>)>&
243           type_remap_fn) {
244     remapTypes(
245         graph->block(),
246         graph->inputs()[0],
247         source,
248         target,
249         module_qconfig_map,
250         type_remap_fn);
251   }
252 
clone_method(const Module & source,Module & target,const Function & method,const ModuleQConfigMap & module_qconfig_map,const std::unordered_map<TypePtr,QConfigTypePtrMap> & type_remap)253   void clone_method(
254       const Module& source,
255       Module& target,
256       const Function& method,
257       const ModuleQConfigMap& module_qconfig_map,
258       const std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap) {
259     auto type_remap_fn = [&](TypePtr type_ptr,
260                              const std::optional<QConfig>& qconfig) {
261       if (type_remap.find(type_ptr) != type_remap.end()) {
262         const auto& qconfig_map = type_remap.at(type_ptr);
263         if (qconfig_map.find(qconfig) != qconfig_map.end()) {
264           return qconfig_map.at(qconfig);
265         }
266       }
267       return type_ptr;
268     };
269     auto graph = toGraphFunction(method).graph()->copy();
270     remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn);
271     // remap self
272     graph->inputs()[0]->setType(target.type());
273     // we only support %self being Module in the arguments of function
274     auto schema_type_remap_fn = [&](TypePtr type_ptr) {
275       return type_remap_fn(
276           std::move(type_ptr), module_qconfig_map.at(source._ivalue()));
277     };
278     auto schema =
279         method.getSchema().cloneWithRemappedTypes(schema_type_remap_fn);
280     const auto this_method_name =
281         c10::QualifiedName(*target.type()->name(), method.name());
282     auto copied = target._ivalue()->compilation_unit()->create_function(
283         this_method_name, std::move(graph));
284     target.type()->addMethod(copied);
285     copied->setSchema(std::move(schema));
286   }
287 };
288 
289 class InsertObserversHelper {
290  public:
InsertObserversHelper(const ModuleQConfigMap & map,QuantType quant_type)291   explicit InsertObserversHelper(
292       const ModuleQConfigMap& map,
293       QuantType quant_type)
294       : module_qconfig_map_(map), quant_type_(quant_type) {}
295 
296   // TODO: replace (module, method_name) with graph?
297   // preprocess to clean up the graph from tracing
298   void preprocess(Module& module, const std::string& method_name);
299 
300   // Fill the map between the caller input/output to input/output
301   // of called graph, this is used to navigate through the graph
302   // to find the observer for a given value
303   void fillBoundaryValueMap(Module& module, const std::string& method_name);
304 
305   // analyze the graph and record necessary information that can
306   // be used in insert observers
307   void analyze(Module& module, const std::string& method_name);
308 
309   void removeActivationObservers();
310 
311   /**
312    * Recursively insert observers for the method, also we'll process
313    * the nodes in the graph in the order of execution of these nodes
314    * since we need the context information to decide whether we want to
315    * observe/quantize a value a not, we don't want to observe a value multiple
316    * times.
317    *
318    * argument: is_entry_point means whether the current method is the forward
319    * method of the top level module.
320    *
321    * Since we want to insert observers in the call site instead of in the called
322    * graph, we'll postpone inserting observer to caller as much as possible, if
323    * we know the current method is the outer most method, then
324    * we will insert all observers in the graph instead of postpone this to the
325    * parent, note that this assumes we don't have recursive method
326    * calls
327    *
328    * returns a tuple of vectors of observer modules for input and output, these
329    * are used for inserting observers for the input/output values
330    * since we need to insert these values at call site.
331    * And a vector of indexes of outputs that indicates whether the output value
332    * is already observed or not, this is used for propagating the observed
333    * property of a value through CallMethods, because we should skip inserting
334    * observers for ops that don't require observation
335    */
336   std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
337   insertObservers(
338       Module& module,
339       const std::string& method_name,
340       bool is_entry_point = false,
341       std::unordered_set<Value*> graph_observed_values =
342           std::unordered_set<Value*>());
343 
setInsertResetObserverMethod(bool insert_reset_observer_method,const std::string & method_name)344   void setInsertResetObserverMethod(
345       bool insert_reset_observer_method,
346       const std::string& method_name) {
347     insert_reset_observer_method_ = insert_reset_observer_method;
348     reset_observer_method_name_ = "reset_observers_" + method_name;
349   }
350 
351  private:
352   std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
353   insertObserversFor(
354       Block* block,
355       script::Module& module,
356       // this is a reference because when we insert observer for a value
357       // in one block it is also observed in another block, we don't want to
358       // insert multiple observers for the same value
359       std::unordered_set<Value*>& block_observed_values,
360       bool is_entry_point = false,
361       bool is_user_defined_function = false);
362 
363   // Record v as "ready for observation" by storing it in values_to_observe.
364   // If v is a part of a delayed observation pattern, record v's descendant
365   // (per delay rules) instead. The observers are inserted at a later stage
366   // by reading the state created by this function.
367   void recordObserved(
368       Value* v,
369       const Module& observer_module,
370       std::unordered_map<Value*, Module>& values_to_observe,
371       std::unordered_set<Value*>& block_observed_values);
372 
373   ModuleMethodVector getInvokedMethods(
374       Module& module,
375       const std::string& method_name);
376 
377   bool valueNeedsToBeQuantized(Value* v, const QConfig& qconfig);
378 
isObserved(Value * v,const std::unordered_set<Value * > & block_observed_values)379   bool isObserved(
380       Value* v,
381       const std::unordered_set<Value*>& block_observed_values) {
382     return block_observed_values.count(v) || observed_values_.count(v);
383   }
384 
385   // Fill the map from value to the corresponding observer module
386   // this map is used in insertObservers to actually insert
387   // observers to the module
388   void fillValueObserverMap(Module& module, const std::string& method_name);
389 
390   // Clone observer module and add it to the original module,
391   // and insert a call to observer forward function
392   void insertObserverFor(
393       Value* v,
394       Module& module,
395       const Module& observer_module,
396       NameModuleVector& observer_name_and_modules);
397 
398   void insertObserverResetMinMax(
399       Module& module,
400       const NameModuleVector& observer_name_and_modules);
401 
402   // Uses the state created by fillBoundaryValueMap and fillValueObserverMap
403   // to return an observer configured for a value, if it is needed.
404   std::optional<Module> getObserverFor(Value* v);
405 
406   // Uses the state created by fillPassThroughValueMap to propagage observed
407   // property which should pass through from inputs to outputs.
408   void propagateObservedProperty(
409       Value* output,
410       std::unordered_set<Value*>& block_observed_values);
411 
412   // for cat/add/mul we will only observe their output if their input
413   // are observed
shouldObserve(Node * n,const std::unordered_set<Value * > & block_observed_values,QuantType quant_type)414   bool shouldObserve(
415       Node* n,
416       const std::unordered_set<Value*>& block_observed_values,
417       QuantType quant_type) {
418     // Check whether node output uses can be quantized, eg cat followed by
419     // linear op
420     for (Value* v : n->outputs()) {
421       for (const auto& use : v->uses()) {
422         if (useQuantizable(use, quant_type)) {
423           return true;
424         }
425       }
426     }
427     if (isPropagateQuantSingleInputOp(n)) {
428       return isObserved(n->input(0), block_observed_values);
429     } else if (isPropagateQuantBinaryOp(n)) {
430       // This checks both of the input should be tensor and observed.
431       // There is one check that we didn't do here, which is
432       // !isScalar(isObserved(n->input(1), block_observed_values)
433       // to make sure input 1 is not a scalar, because scalar tensor input
434       // for add/mul won't be observed with current rule, we can omit
435       // this check here
436       return isObserved(n->input(0), block_observed_values) &&
437           isObserved(n->input(1), block_observed_values);
438     }
439     return true;
440   }
441 
442   void delayObservingValuesInPattern(Graph& graph, const PatternInfo& pattern);
443 
444   // Find and mark known patterns such as conv-relu (and others) where
445   // we should not insert observers in the middle of the pattern.
446   void addValuesToDelayObservation(
447       const Module& module,
448       const std::string& method_name);
449 
450   // Fill the map from values to the list of values that can pass the observed
451   // property to it
452   void fillPassThroughValueMap(const std::shared_ptr<Graph>& graph);
453 
insertResetObserverMethod()454   bool insertResetObserverMethod() {
455     return insert_reset_observer_method_;
456   }
457 
458   const ModuleQConfigMap& module_qconfig_map_;
459 
460   // Values we want to delay observation, used to delay the observation for
461   // values in the middle of the ops that are supposed to be fused, e.g.
462   // the output value of conv in the conv - relu pattern
463   // the key is the intermediate output, e.g. output of conv
464   // the value is the value we want to observe, e.g. output of relu
465   //
466   // example, assuming we want to delay conv-relu:
467   //   %x1 = conv(%x0)
468   //   %x2 = relu(%x1)
469   //
470   // delay_observation_map_ = {
471   //   %x1: %x2,
472   // }
473   std::unordered_map<Value*, Value*> delay_observation_map_;
474 
475   std::unordered_set<Graph*> visited_graph_of_observer_map_;
476 
477   // Map of value to observer module configured for that value.
478   std::unordered_map<Value*, Module> observer_for_value_;
479 
480   // Map from values from callsite into the values in the CallMethod graph
481   // key of the map is the value from caller graph, and the value of the map
482   // is the list of values in the callee graph (the graph
483   // corresponding to the called method),
484   // the reason it is a set is that a value in the caller graph
485   // can both correspond to the output of one callee graph and input of another
486   // callee graph.
487   //
488   // example:
489   //   // top level module
490   //   %x1 = conv(%x0)
491   //   %x2 = prim::CallFunction(%foo, %x1)
492   //
493   //   // graph of %foo
494   //   %y2 = conv(%y1)
495   //   return %y2
496   //
497   // boundary_value_map = {
498   //   // current module's output values to corresponding return values from
499   //   subgraph %x2: %y2,
500   //   // current module's input values to corresponding input value to subgraph
501   //   %x1: %y1,
502   // }
503   std::unordered_map<Value*, std::unordered_set<Value*>> boundary_value_map_;
504 
505   std::unordered_set<Value*> observed_values_;
506 
507   // This is used for the observed values to pass through the ops like flatten,
508   // so that output value of flatten does not need to be observed
509   // key is the output of the op, value is a vector of values that need
510   // to be observed in order to pass the observed property to the output
511   //
512   // example:
513   //   %x1 = flatten(%x0) // pass_through
514   //   %x2 = conv(%x1) // not pass_through
515   //
516   // pass_through_value_map_ = {
517   //   %x1: [%x0],
518   // }
519   std::unordered_map<Value*, std::vector<Value*>> pass_through_value_map_;
520 
521   // Unique id generator for observer module, used for generating
522   // unique observer names when we insert observer module, we
523   // record the current unique id used to avoid incrementing from 0
524   // every time to find a unique id.
525   int uid_ = 0;
526   // Set of observer forward call nodes
527   std::unordered_set<Node*> observer_nodes_;
528   // Map from block to a vector of observer name and observer modules we
529   // want to add to the module instance that has the block
530   std::unordered_map<Block*, NameModuleVector> block_observer_map_;
531 
532   // Type of quantization for this pass.
533   QuantType quant_type_ = QuantType::STATIC;
534   // These are the IR patterns we match to skip inserting observers.
535   // They are compiled once on construction and used repeatedly within
536   // the pass.
537 
538   // nn.Linear + nn.ReLU
539   const PatternInfo nn_linear_nn_relu = PatternInfo::parse_from_str(
540       R"(
541 graph(%input, %linear, %relu):
542     %first_output = prim::CallMethod[name="forward"](%linear, %input)
543     %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
544     return (%second_output) )",
545       {is_linear_module, is_relu_module});
546 
547   // nn.Linear + F.relu
548   const PatternInfo nn_linear_f_relu = PatternInfo::parse_from_str(
549       R"(
550 graph(%input, %linear, %relu, %inplace):
551     %first_output = prim::CallMethod[name="forward"](%linear, %input)
552     %second_output = prim::CallFunction(%relu, %first_output, %inplace)
553     return (%second_output) )",
554       {is_linear_module, is_functional_relu});
555 
556   // nn.Linear + aten::relu
557   const PatternInfo nn_linear_aten_relu = PatternInfo::parse_from_str(
558       R"(
559 graph(%input, %linear, %relu):
560     %first_output = prim::CallMethod[name="forward"](%linear, %input)
561     %second_output = aten::relu(%first_output)
562     return (%second_output) )",
563       {is_linear_module});
564 
565   // nn.Linear + aten::relu_
566   const PatternInfo nn_linear_aten_relu_ = PatternInfo::parse_from_str(
567       R"(
568 graph(%input, %linear, %relu):
569     %first_output = prim::CallMethod[name="forward"](%linear, %input)
570     %second_output = aten::relu_(%first_output)
571     return (%second_output) )",
572       {is_linear_module});
573 
574   // aten::linear + nn.ReLU
575   const PatternInfo aten_linear_nn_relu = PatternInfo::parse_from_str(
576       R"(
577 graph(%input, %weight, %bias, %relu):
578     %first_output = aten::linear(%input, %weight, %bias)
579     %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
580     return (%second_output) )",
581       {is_relu_module});
582 
583   // aten::linear + F.relu
584   const PatternInfo aten_linear_f_relu = PatternInfo::parse_from_str(
585       R"(
586 graph(%input, %weight, %bias, %relu, %inplace):
587     %first_output = aten::linear(%input, %weight, %bias)
588     %second_output = prim::CallFunction(%relu, %first_output, %inplace)
589     return (%second_output) )",
590       {is_functional_relu});
591 
592   // aten::linear + aten::relu
593   const PatternInfo aten_linear_aten_relu = PatternInfo::parse_from_str(
594       R"(
595 graph(%input, %weight, %bias):
596     %first_output = aten::linear(%input, %weight, %bias)
597     %second_output = aten::relu(%first_output)
598     return (%second_output) )");
599 
600   // aten::linear + aten::relu_
601   const PatternInfo aten_linear_aten_relu_ = PatternInfo::parse_from_str(
602       R"(
603 graph(%input, %weight, %bias):
604     %first_output = aten::linear(%input, %weight, %bias)
605     %second_output = aten::relu_(%first_output)
606     return (%second_output) )");
607 
608   const PatternInfo nn_conv1d_f_relu = PatternInfo::parse_from_str(
609       R"(
610 graph(%self, %input, %conv, %relu, %inplace):
611     %first_output = prim::CallMethod[name="forward"](%conv, %input)
612     %second_output = prim::CallFunction(%relu, %first_output, %inplace)
613     return (%second_output) )",
614       {is_conv1d_module, is_functional_relu});
615 
616   const PatternInfo nn_conv1d_nn_relu = PatternInfo::parse_from_str(
617       R"(
618 graph(%self, %input, %conv, %relu):
619     %first_output = prim::CallMethod[name="forward"](%conv, %input)
620     %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
621     return (%second_output) )",
622       {is_conv1d_module, is_relu_module});
623 
624   const PatternInfo nn_conv1d_aten_relu = PatternInfo::parse_from_str(
625       R"(
626 graph(%self, %input, %conv):
627     %first_output = prim::CallMethod[name="forward"](%conv, %input)
628     %second_output = aten::relu(%first_output)
629     return (%second_output) )",
630       {is_conv1d_module});
631 
632   const PatternInfo nn_conv1d_aten_relu_ = PatternInfo::parse_from_str(
633       R"(
634 graph(%self, %input, %conv):
635     %first_output = prim::CallMethod[name="forward"](%conv, %input)
636     %second_output = aten::relu_(%first_output)
637     return (%second_output) )",
638       {is_conv1d_module});
639 
640   const PatternInfo nn_conv2d_f_relu = PatternInfo::parse_from_str(
641       R"(
642 graph(%self, %input, %conv, %relu, %inplace):
643     %first_output = prim::CallMethod[name="forward"](%conv, %input)
644     %second_output = prim::CallFunction(%relu, %first_output, %inplace)
645     return (%second_output) )",
646       {is_conv2d_module, is_functional_relu});
647 
648   const PatternInfo nn_conv2d_nn_relu = PatternInfo::parse_from_str(
649       R"(
650 graph(%self, %input, %conv, %relu):
651     %first_output = prim::CallMethod[name="forward"](%conv, %input)
652     %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
653     return (%second_output) )",
654       {is_conv2d_module, is_relu_module});
655 
656   const PatternInfo nn_conv2d_aten_relu = PatternInfo::parse_from_str(
657       R"(
658 graph(%self, %input, %conv):
659     %first_output = prim::CallMethod[name="forward"](%conv, %input)
660     %second_output = aten::relu(%first_output)
661     return (%second_output) )",
662       {is_conv2d_module});
663 
664   const PatternInfo nn_conv2d_aten_relu_ = PatternInfo::parse_from_str(
665       R"(
666 graph(%self, %input, %conv):
667     %first_output = prim::CallMethod[name="forward"](%conv, %input)
668     %second_output = aten::relu_(%first_output)
669     return (%second_output) )",
670       {is_conv2d_module});
671 
672   const PatternInfo nn_conv3d_f_relu = PatternInfo::parse_from_str(
673       R"(
674 graph(%self, %input, %conv, %relu, %inplace):
675     %first_output = prim::CallMethod[name="forward"](%conv, %input)
676     %second_output = prim::CallFunction(%relu, %first_output, %inplace)
677     return (%second_output) )",
678       {is_conv3d_module, is_functional_relu});
679 
680   const PatternInfo nn_conv3d_nn_relu = PatternInfo::parse_from_str(
681       R"(
682 graph(%self, %input, %conv, %relu):
683     %first_output = prim::CallMethod[name="forward"](%conv, %input)
684     %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
685     return (%second_output) )",
686       {is_conv3d_module, is_relu_module});
687 
688   const PatternInfo nn_conv3d_aten_relu = PatternInfo::parse_from_str(
689       R"(
690 graph(%self, %conv, %input):
691     %first_output = prim::CallMethod[name="forward"](%conv, %input)
692     %second_output = aten::relu(%first_output)
693     return (%second_output) )",
694       {is_conv3d_module});
695 
696   const PatternInfo nn_conv3d_aten_relu_ = PatternInfo::parse_from_str(
697       R"(
698 graph(%self, %input, %conv):
699     %first_output = prim::CallMethod[name="forward"](%conv, %input)
700     %second_output = aten::relu_(%first_output)
701     return (%second_output) )",
702       {is_conv3d_module});
703 
704   const PatternInfo add_nn_relu = PatternInfo::parse_from_str(
705       R"(
706 graph(%self, %a, %b, %alpha, %relu):
707      %first_output = aten::add(%a, %b, %alpha)
708      %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
709      return (%second_output) )",
710       {aten_add_alpha_is_one, is_relu_module});
711 
712   const PatternInfo add_f_relu = PatternInfo::parse_from_str(
713       R"(
714 graph(%self, %a, %b, %alpha, %relu, %inplace):
715      %first_output = aten::add(%a, %b, %alpha)
716      %second_output = prim::CallFunction(%relu, %first_output, %inplace)
717      return (%second_output) )",
718       {aten_add_alpha_is_one, is_functional_relu});
719 
720   const PatternInfo inplace_add_nn_relu = PatternInfo::parse_from_str(
721       R"(
722 graph(%self, %a, %b, %alpha, %relu):
723      %first_output = aten::add_(%a, %b, %alpha)
724      %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
725      return (%second_output) )",
726       {aten_add_alpha_is_one, is_relu_module});
727 
728   const PatternInfo inplace_add_f_relu = PatternInfo::parse_from_str(
729       R"(
730 graph(%self, %a, %b, %alpha, %relu, %inplace):
731      %first_output = aten::add_(%a, %b, %alpha)
732      %second_output = prim::CallFunction(%relu, %first_output, %inplace)
733      return (%second_output) )",
734       {aten_add_alpha_is_one, is_functional_relu});
735 
736   const PatternInfo add_aten_relu = PatternInfo::parse_from_str(R"(
737 graph(%self, %a, %b, %alpha):
738      %first_output = aten::add(%a, %b, %alpha)
739      %second_output = aten::relu(%first_output)
740      return (%second_output) )");
741 
742   const PatternInfo add_aten_relu_ = PatternInfo::parse_from_str(R"(
743 graph(%self, %a, %b, %alpha):
744      %first_output = aten::add(%a, %b, %alpha)
745      %second_output = aten::relu_(%first_output)
746      return (%second_output) )");
747 
748   const PatternInfo inplace_add_aten_relu = PatternInfo::parse_from_str(R"(
749 graph(%self, %a, %b, %alpha):
750      %first_output = aten::add_(%a, %b, %alpha)
751      %second_output = aten::relu(%first_output)
752      return (%second_output) )");
753 
754   const PatternInfo inplace_add_aten_relu_ = PatternInfo::parse_from_str(R"(
755 graph(%self, %a, %b, %alpha):
756      %first_output = aten::add_(%a, %b, %alpha)
757      %second_output = aten::relu_(%first_output)
758      return (%second_output) )");
759 
760   const PatternInfo nn_bn2d_nn_relu = PatternInfo::parse_from_str(
761       R"(
762 graph(%self, %input, %batchnorm, %relu):
763     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
764     %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
765     return (%second_output) )",
766       {is_batchnorm2d_module, is_relu_module});
767 
768   const PatternInfo nn_bn2d_f_relu = PatternInfo::parse_from_str(
769       R"(
770 graph(%self, %input, %batchnorm, %relu, %inplace):
771     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
772     %second_output = prim::CallFunction(%relu, %first_output, %inplace)
773     return (%second_output) )",
774       {is_batchnorm2d_module, is_functional_relu});
775 
776   const PatternInfo nn_bn2d_aten_relu = PatternInfo::parse_from_str(
777       R"(
778 graph(%self, %input, %batchnorm):
779     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
780     %second_output = aten::relu(%first_output)
781     return (%second_output) )",
782       {is_batchnorm2d_module});
783 
784   const PatternInfo nn_bn2d_aten_relu_ = PatternInfo::parse_from_str(
785       R"(
786 graph(%self, %input, %batchnorm):
787     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
788     %second_output = aten::relu_(%first_output)
789     return (%second_output) )",
790       {is_batchnorm2d_module});
791 
792   const PatternInfo nn_bn3d_nn_relu = PatternInfo::parse_from_str(
793       R"(
794 graph(%self, %input, %batchnorm, %relu):
795     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
796     %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
797     return (%second_output) )",
798       {is_batchnorm3d_module, is_relu_module});
799 
800   const PatternInfo nn_bn3d_f_relu = PatternInfo::parse_from_str(
801       R"(
802 graph(%self, %input, %batchnorm, %relu, %inplace):
803     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
804     %second_output = prim::CallFunction(%relu, %first_output, %inplace)
805     return (%second_output) )",
806       {is_batchnorm3d_module, is_functional_relu});
807 
808   const PatternInfo nn_bn3d_aten_relu = PatternInfo::parse_from_str(
809       R"(
810 graph(%self, %input, %batchnorm):
811     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
812     %second_output = aten::relu(%first_output)
813     return (%second_output) )",
814       {is_batchnorm3d_module});
815 
816   const PatternInfo nn_bn3d_aten_relu_ = PatternInfo::parse_from_str(
817       R"(
818 graph(%self, %input, %batchnorm):
819     %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
820     %second_output = aten::relu_(%first_output)
821     return (%second_output) )",
822       {is_batchnorm3d_module});
823 
824   const PatternInfo mul_nn_relu = PatternInfo::parse_from_str(
825       R"(
826 graph(%self, %a, %b, %relu):
827      %first_output = aten::mul(%a, %b)
828      %second_output = prim::CallMethod[name="forward"](%relu, %first_output)
829      return (%second_output) )",
830       {is_relu_module});
831 
832   const PatternInfo mul_f_relu = PatternInfo::parse_from_str(
833       R"(
834 graph(%self, %a, %b, %relu, %inplace):
835      %first_output = aten::mul(%a, %b)
836      %second_output = prim::CallFunction(%relu, %first_output, %inplace)
837      return (%second_output) )",
838       {is_functional_relu});
839 
840   const PatternInfo inplace_mul_nn_relu = PatternInfo::parse_from_str(
841       R"(
842 graph(%self, %a, %b, %relu):
843      %first_output = aten::mul_(%a, %b)
844      %second_output = prim::CallMethod[name="forward"](%relu, %first_output)
845      return (%second_output) )",
846       {is_relu_module});
847 
848   const PatternInfo inplace_mul_f_relu = PatternInfo::parse_from_str(
849       R"(
850 graph(%self, %a, %b, %relu, %inplace):
851      %first_output = aten::mul_(%a, %b)
852      %second_output = prim::CallFunction(%relu, %first_output, %inplace)
853      return (%second_output) )",
854       {is_functional_relu});
855 
856   const PatternInfo mul_aten_relu = PatternInfo::parse_from_str(R"(
857 graph(%self, %a, %b):
858      %first_output = aten::mul(%a, %b)
859      %second_output = aten::relu(%first_output)
860      return (%second_output) )");
861 
862   const PatternInfo mul_aten_relu_ = PatternInfo::parse_from_str(R"(
863 graph(%self, %a, %b):
864      %first_output = aten::mul(%a, %b)
865      %second_output = aten::relu_(%first_output)
866      return (%second_output) )");
867 
868   const PatternInfo inplace_mul_aten_relu = PatternInfo::parse_from_str(R"(
869 graph(%self, %a, %b):
870      %first_output = aten::mul_(%a, %b)
871      %second_output = aten::relu(%first_output)
872      return (%second_output) )");
873 
874   const PatternInfo inplace_mul_aten_relu_ = PatternInfo::parse_from_str(R"(
875 graph(%self, %a, %b):
876      %first_output = aten::mul_(%a, %b)
877      %second_output = aten::relu_(%first_output)
878      return (%second_output) )");
879 
880   const std::vector<std::reference_wrapper<const PatternInfo>> delay_patterns =
881       {
882           nn_linear_f_relu,      nn_linear_nn_relu,
883           nn_linear_aten_relu,   nn_linear_aten_relu_,
884           aten_linear_f_relu,    aten_linear_nn_relu,
885           aten_linear_aten_relu, aten_linear_aten_relu_,
886 
887           nn_conv1d_f_relu,      nn_conv1d_nn_relu,
888           nn_conv1d_aten_relu,   nn_conv1d_aten_relu_,
889           nn_conv2d_f_relu,      nn_conv2d_nn_relu,
890           nn_conv2d_aten_relu,   nn_conv2d_aten_relu_,
891           nn_conv3d_f_relu,      nn_conv3d_nn_relu,
892           nn_conv3d_aten_relu,   nn_conv3d_aten_relu_,
893 
894           add_nn_relu,           add_f_relu,
895           inplace_add_nn_relu,   inplace_add_f_relu,
896           add_aten_relu,         add_aten_relu_,
897           inplace_add_aten_relu, inplace_add_aten_relu_,
898 
899           nn_bn2d_nn_relu,       nn_bn2d_f_relu,
900           nn_bn2d_aten_relu,     nn_bn2d_aten_relu_,
901           nn_bn3d_nn_relu,       nn_bn3d_f_relu,
902           nn_bn3d_aten_relu,     nn_bn3d_aten_relu_,
903 
904           mul_nn_relu,           mul_f_relu,
905           inplace_mul_nn_relu,   inplace_mul_f_relu,
906           mul_aten_relu,         mul_aten_relu_,
907           inplace_mul_aten_relu, inplace_mul_aten_relu_,
908   };
909 
910   bool insert_reset_observer_method_{false};
911   std::string reset_observer_method_name_;
912 };
913 
getInvokedMethods(Module & module,const std::string & method_name)914 ModuleMethodVector InsertObserversHelper::getInvokedMethods(
915     Module& module,
916     const std::string& method_name) {
917   ModuleMethodVector invoked_methods;
918   Method method = module.get_method(method_name);
919   auto graph = method.graph();
920 
921   std::stack<Block*> blocks_to_visit;
922   blocks_to_visit.push(graph->block());
923   while (!blocks_to_visit.empty()) {
924     Block* b = blocks_to_visit.top();
925     blocks_to_visit.pop();
926     for (Node* n : b->nodes()) {
927       // Skip observer nodes
928       if (observer_nodes_.count(n)) {
929         continue;
930       }
931       if (n->kind() == prim::CallMethod) {
932         auto m_opt = getInvokedModuleOpt(module, n, graph->inputs()[0]);
933         if (m_opt.has_value()) {
934           invoked_methods.emplace_back(*m_opt, n->s(attr::name));
935         }
936       }
937 
938       for (Block* subblock : n->blocks()) {
939         blocks_to_visit.push(subblock);
940       }
941     }
942   }
943   return invoked_methods;
944 }
945 
insertObserverFor(Value * v,Module & module,const Module & observer_module,NameModuleVector & observer_name_and_modules)946 void InsertObserversHelper::insertObserverFor(
947     Value* v,
948     Module& module,
949     const Module& observer_module,
950     NameModuleVector& observer_name_and_modules) {
951   if (observed_values_.count(v)) {
952     return;
953   }
954   GRAPH_DEBUG("Inserting observer for:", v->debugName());
955   Module observer = observer_module.deepcopy();
956   std::string observer_name = "_observer_" + std::to_string(uid_++);
957   while (module.hasattr(observer_name)) {
958     observer_name = "_observer_" + std::to_string(uid_++);
959   }
960   module.register_module(observer_name, observer);
961   observer_name_and_modules.emplace_back(observer_name, observer);
962 
963   auto* g = v->owningGraph();
964   // Get handle of observer module
965   Node* observer_instance =
966       g->createGetAttr(g->inputs()[0], observer_name)->insertAfter(v->node());
967   observer_instance->output()->setDebugName(observer_name);
968 
969   {
970     WithInsertPoint guard(observer_instance->next());
971     // Match arguments to types of observer's arguments
972     MatchedSchema forward_matched_schema = matchSchema(
973         observer.get_method("forward").function().getSchema(),
974         v->node()->sourceRange(),
975         *g,
976         {observer_instance->output(), v},
977         {});
978     // Insert call to observer's forward
979     Node* call = g->insertMethodCall("forward", forward_matched_schema)->node();
980     call->output()->copyMetadata(v);
981 
982     // Replace v with the output of observer
983     v->replaceAllUsesWith(call->output());
984     // The above also replaced the input to `call`, so switch it back to
985     // the correct value
986     call->replaceInput(1, v);
987     observer_nodes_.emplace(call);
988     observed_values_.insert(call->output());
989   }
990 }
991 
insertObserverResetMinMax(Module & module,const NameModuleVector & observer_name_and_modules)992 void InsertObserversHelper::insertObserverResetMinMax(
993     Module& module,
994     const NameModuleVector& observer_name_and_modules) {
995   if (observer_name_and_modules.empty()) {
996     return;
997   }
998   auto reset_min_max_opt = module.find_method(reset_observer_method_name_);
999   if (!reset_min_max_opt.has_value()) {
1000     std::shared_ptr<Graph> reset_observer_graph = std::make_shared<Graph>();
1001     Value* module_value = reset_observer_graph->addInput("self");
1002     Node* output_node = reset_observer_graph->createNone();
1003     reset_observer_graph->insertNode(output_node);
1004     reset_observer_graph->registerOutput(output_node->output());
1005     module_value->setType(module._ivalue()->type());
1006     const auto method_name = c10::QualifiedName(
1007         *(module.type()->name()), reset_observer_method_name_);
1008     auto reset_observer_fn =
1009         module._ivalue()->compilation_unit()->create_function(
1010             method_name, std::move(reset_observer_graph));
1011     auto self_arg = c10::Argument("self", module.type());
1012     auto output_arg = c10::Argument("none", output_node->output()->type());
1013     auto schema = c10::FunctionSchema(
1014         reset_observer_method_name_,
1015         "",
1016         {std::move(self_arg)},
1017         {std::move(output_arg)});
1018     reset_observer_fn->setSchema(std::move(schema));
1019     module.type()->addMethod(reset_observer_fn);
1020   }
1021   auto reset_min_max_graph =
1022       module.get_method(reset_observer_method_name_).graph();
1023   Value* self = reset_min_max_graph->inputs()[0];
1024 
1025   for (const auto& pair : observer_name_and_modules) {
1026     const auto& observer_name = pair.first;
1027     const auto& observer = pair.second;
1028     Value* observer_value =
1029         reset_min_max_graph->insertGetAttr(self, observer_name);
1030     MatchedSchema reset_minmax_schema = matchSchema(
1031         observer.get_method("reset_min_max_vals").function().getSchema(),
1032         observer_value->node()->sourceRange(),
1033         *reset_min_max_graph,
1034         {observer_value},
1035         {});
1036     reset_min_max_graph->insertMethodCall(
1037         "reset_min_max_vals", reset_minmax_schema);
1038   }
1039 }
1040 
delayObservingValuesInPattern(Graph & graph,const PatternInfo & pattern)1041 void InsertObserversHelper::delayObservingValuesInPattern(
1042     Graph& graph,
1043     const PatternInfo& pattern) {
1044   const Graph& pattern_graph = *pattern.pattern_graph;
1045   const std::unordered_map<std::string, Value*>& vmap = pattern.vmap;
1046 
1047   const auto& matches = findPatternMatches(pattern_graph, graph);
1048   for (const auto& match : matches) {
1049     if (!std::all_of(
1050             pattern.filters.begin(),
1051             pattern.filters.end(),
1052             [&](const MatchFilter& f) { return f(match, vmap); })) {
1053       continue;
1054     }
1055     auto first_output = match.values_map.at(vmap.at("first_output"));
1056     auto second_output = match.values_map.at(vmap.at("second_output"));
1057     GRAPH_DEBUG(
1058         "Delay observation for value in function pattern:",
1059         first_output->debugName(),
1060         " to ",
1061         second_output->debugName());
1062     delay_observation_map_[first_output] = second_output;
1063   }
1064 }
1065 
addValuesToDelayObservation(const Module & module,const std::string & method_name)1066 void InsertObserversHelper::addValuesToDelayObservation(
1067     const Module& module,
1068     const std::string& method_name) {
1069   Method method = module.get_method(method_name);
1070   auto graph = method.graph();
1071 
1072   for (const auto& pattern : delay_patterns) {
1073     delayObservingValuesInPattern(*graph, pattern);
1074   }
1075 }
1076 
fillPassThroughValueMap(const std::shared_ptr<Graph> & graph)1077 void InsertObserversHelper::fillPassThroughValueMap(
1078     const std::shared_ptr<Graph>& graph) {
1079   std::stack<Block*> blocks_to_visit;
1080   blocks_to_visit.push(graph->block());
1081   while (!blocks_to_visit.empty()) {
1082     Block* b = blocks_to_visit.top();
1083     blocks_to_visit.pop();
1084     for (Node* n : b->nodes()) {
1085       if (userDefinedCallFunction(n)) {
1086         auto g = getCallFunctionGraph(n);
1087         blocks_to_visit.push(g->block());
1088       }
1089       for (auto* output : n->outputs()) {
1090         for (auto* input : getPassThroughInputs(output)) {
1091           pass_through_value_map_[output].push_back(input);
1092         }
1093       }
1094       for (Block* subblock : n->blocks()) {
1095         blocks_to_visit.push(subblock);
1096       }
1097     }
1098   }
1099 }
1100 
fillBoundaryValueMap(Module & module,const std::string & method_name)1101 void InsertObserversHelper::fillBoundaryValueMap(
1102     Module& module,
1103     const std::string& method_name) {
1104   for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1105     auto& invoked_module = std::get<0>(invoked_method);
1106     const auto& invoked_method_name = std::get<1>(invoked_method);
1107     fillBoundaryValueMap(invoked_module, invoked_method_name);
1108   }
1109 
1110   auto graph = module.get_method(method_name).graph();
1111   std::stack<Block*> blocks_to_visit;
1112   blocks_to_visit.push(graph->block());
1113   auto* self = graph->inputs()[0];
1114   while (!blocks_to_visit.empty()) {
1115     Block* b = blocks_to_visit.top();
1116     blocks_to_visit.pop();
1117     for (Node* n : b->nodes()) {
1118       if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
1119         std::shared_ptr<Graph> g;
1120         // offset of input for the caller node, since the first
1121         // input of CallFunction is the function node and the graph
1122         // for CallFunction start with actual input
1123         size_t input_offset = 0;
1124         if (n->kind() == prim::CallMethod) {
1125           auto m_opt = getInvokedModuleOpt(module, n, self);
1126           if (!m_opt.has_value()) {
1127             continue;
1128           }
1129           auto m = *m_opt;
1130           g = m.get_method(n->s(attr::name)).graph();
1131           input_offset = 0;
1132         } else {
1133           g = getCallFunctionGraph(n);
1134           input_offset = 1;
1135         }
1136         // add mapping from callsite value to value in called graph
1137         for (auto i = 0U; i < g->outputs().size(); ++i) {
1138           auto* return_val = g->outputs()[i];
1139           GRAPH_DEBUG(
1140               "Boundary Map[return]:",
1141               n->output(i)->debugName(),
1142               " -> ",
1143               return_val->debugName());
1144           boundary_value_map_[n->output(i)].insert(return_val);
1145         }
1146         for (auto i = 0U; i < g->inputs().size(); ++i) {
1147           auto caller_input_index = i + input_offset;
1148           auto* caller_input = n->input(caller_input_index);
1149           auto* input_val = g->inputs()[i];
1150           GRAPH_DEBUG(
1151               "Boundary Map[input]:",
1152               caller_input->debugName(),
1153               " -> ",
1154               input_val->debugName());
1155           boundary_value_map_[caller_input].insert(input_val);
1156         }
1157       } else if (n->kind() == prim::If) {
1158         for (Block* subblock : n->blocks()) {
1159           blocks_to_visit.push(subblock);
1160           for (Value* v : n->outputs()) {
1161             Value* subblock_output = subblock->outputs()[v->offset()];
1162             GRAPH_DEBUG(
1163                 "Boundary Map[if_output]:",
1164                 v->debugName(),
1165                 " -> ",
1166                 subblock_output->debugName());
1167             boundary_value_map_[v].insert(subblock_output);
1168           }
1169         }
1170       } else {
1171         for (Block* subblock : n->blocks()) {
1172           blocks_to_visit.push(subblock);
1173         }
1174       }
1175     }
1176   }
1177 }
1178 
preprocess(Module & module,const std::string & method_name)1179 void InsertObserversHelper::preprocess(
1180     Module& module,
1181     const std::string& method_name) {
1182   // run preprocess for child module before parent, since preprocess
1183   // mutates the graph and it might affect passes like fillBoundaryValueMap
1184   for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1185     auto& invoked_module = std::get<0>(invoked_method);
1186     const auto& invoked_method_name = std::get<1>(invoked_method);
1187     preprocess(invoked_module, invoked_method_name);
1188   }
1189 
1190   Method method = module.get_method(method_name);
1191   auto graph = method.graph();
1192   // Inline fork-wait calls
1193   InlineForkWait(graph);
1194   // fuse decomposed linear into aten::linear
1195   FuseLinear(graph);
1196   replaceConvolutionWithAtenConv(graph);
1197   RemoveListMutation(graph);
1198 }
1199 
analyze(Module & module,const std::string & method_name)1200 void InsertObserversHelper::analyze(
1201     Module& module,
1202     const std::string& method_name) {
1203   for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1204     auto& invoked_module = std::get<0>(invoked_method);
1205     const auto& invoked_method_name = std::get<1>(invoked_method);
1206     analyze(invoked_module, invoked_method_name);
1207   }
1208 
1209   // fill out various internal state which will be later used in
1210   // insertObservers to insert the correct observer
1211   addValuesToDelayObservation(module, method_name);
1212   fillValueObserverMap(module, method_name);
1213   Method method = module.get_method(method_name);
1214   auto graph = method.graph();
1215   fillPassThroughValueMap(graph);
1216 }
1217 
valueNeedsToBeQuantized(Value * v,const QConfig & qconfig)1218 bool InsertObserversHelper::valueNeedsToBeQuantized(
1219     Value* v,
1220     const QConfig& qconfig) {
1221   if (isBiasOfConvOrLinear(v) ||
1222       !(v->type()->isSubtypeOf(*TensorType::get()) ||
1223         v->type()->isSubtypeOf(*ListType::ofTensors())) ||
1224       isEmbeddingBagNonInput(v)) {
1225     return false;
1226   }
1227   // For dynamic quantization we only insert observers at the input
1228   // of the quantizable function.
1229   if (quant_type_ == QuantType::STATIC) {
1230     // Check whether producer is quantizable
1231     if (!isWeightOnlyStaticQuantOp(v->node()) &&
1232         (nodeQuantizable(v->node()) || isPropagateQuantOp(v->node()))) {
1233       return true;
1234     }
1235   }
1236   if (quant_type_ == QuantType::DYNAMIC) {
1237     // Check the dtype of the observer module.
1238     Module observer_module = getObserverModuleFor(v, qconfig);
1239     auto scalar_type = observer_module.attr("dtype");
1240     // For inputs with Fp16 type that are not-weights we don't observer them for
1241     // dynamic quantization.
1242     if (scalar_type == at::ScalarType::Half && !isWeight(v)) {
1243       return false;
1244     }
1245   }
1246   // Check whether node input value is quantizable
1247   for (const auto& use : v->uses()) {
1248     if (useQuantizable(use, quant_type_)) {
1249       return true;
1250     }
1251   }
1252   return false;
1253 }
1254 
removeActivationObservers()1255 void InsertObserversHelper::removeActivationObservers() {
1256   std::vector<std::unordered_map<Value*, Module>::iterator>
1257       values_to_be_removed;
1258   for (auto it = observer_for_value_.begin(); it != observer_for_value_.end();
1259        it++) {
1260     if (!isWeight(it->first)) {
1261       values_to_be_removed.push_back(it);
1262     }
1263   }
1264   for (auto it : values_to_be_removed) {
1265     observer_for_value_.erase(it);
1266   }
1267 }
1268 
fillValueObserverMap(Module & module,const std::string & method_name)1269 void InsertObserversHelper::fillValueObserverMap(
1270     Module& module,
1271     const std::string& method_name) {
1272   Method method = module.get_method(method_name);
1273   auto graph = method.graph();
1274 
1275   if (visited_graph_of_observer_map_.count(graph.get())) {
1276     return;
1277   }
1278   visited_graph_of_observer_map_.insert(graph.get());
1279 
1280   std::stack<Block*> blocks_to_visit;
1281   auto qconfig_opt = module_qconfig_map_.at(module._ivalue());
1282   if (!qconfig_opt) {
1283     return;
1284   }
1285   auto qconfig = *qconfig_opt;
1286   for (auto* v : graph->inputs()) {
1287     if (valueNeedsToBeQuantized(v, qconfig)) {
1288       GRAPH_DEBUG("Recording observer for ", v->debugName());
1289       GRAPH_DUMP("In graph:", v->owningGraph());
1290       observer_for_value_[v] = getObserverModuleFor(v, qconfig);
1291     }
1292   }
1293 
1294   blocks_to_visit.push(graph->block());
1295   while (!blocks_to_visit.empty()) {
1296     Block* b = blocks_to_visit.top();
1297     blocks_to_visit.pop();
1298     for (Node* n : b->nodes()) {
1299       for (Value* v : n->outputs()) {
1300         if (valueNeedsToBeQuantized(v, qconfig)) {
1301           GRAPH_DEBUG("Recording observer for ", v->debugName());
1302           GRAPH_DUMP("In graph:", v->owningGraph());
1303           observer_for_value_[v] = getObserverModuleFor(v, qconfig);
1304         }
1305       }
1306 
1307       for (Block* subblock : n->blocks()) {
1308         blocks_to_visit.push(subblock);
1309       }
1310     }
1311   }
1312 }
1313 
getObserverFor(Value * v)1314 std::optional<Module> InsertObserversHelper::getObserverFor(Value* v) {
1315   if (observer_for_value_.count(v)) {
1316     auto observer = observer_for_value_.at(v);
1317     GRAPH_DEBUG("Got observer module config for:", v->debugName());
1318     return observer;
1319   }
1320   std::optional<Module> result;
1321   if (boundary_value_map_.count(v)) {
1322     for (Value* next : boundary_value_map_.at(v)) {
1323       GRAPH_DEBUG(
1324           "Going through boundary map:",
1325           v->debugName(),
1326           " --> ",
1327           next->debugName());
1328       GRAPH_DUMP("From graph:", v->owningGraph());
1329       GRAPH_DUMP("To graph:", next->owningGraph());
1330       auto observer_opt = getObserverFor(next);
1331       if (observer_opt) {
1332         // Need to make sure all values are
1333         // configured with same observer
1334         if (result) {
1335           TORCH_CHECK(
1336               *observer_opt == *result,
1337               "Expecting all values in the graph only configured with one observer");
1338         } else {
1339           result = observer_opt;
1340         }
1341       }
1342     }
1343   }
1344   GRAPH_DEBUG(
1345       "Observer module config for ", v->debugName(), ":", result.has_value());
1346   return result;
1347 }
1348 
1349 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObservers(Module & module,const std::string & method_name,bool is_entry_point,std::unordered_set<Value * > graph_observed_values)1350 InsertObserversHelper::insertObservers(
1351     Module& module,
1352     const std::string& method_name,
1353     bool is_entry_point,
1354     std::unordered_set<Value*> graph_observed_values) {
1355   auto graph = module.get_method(method_name).graph();
1356   return insertObserversFor(
1357       graph->block(), module, graph_observed_values, is_entry_point);
1358 }
1359 
recordObserved(Value * v,const Module & observer_module,std::unordered_map<Value *,Module> & values_to_observe,std::unordered_set<Value * > & block_observed_values)1360 void InsertObserversHelper::recordObserved(
1361     Value* v,
1362     const Module& observer_module,
1363     std::unordered_map<Value*, Module>& values_to_observe,
1364     std::unordered_set<Value*>& block_observed_values) {
1365   Value* to_observe = v;
1366   if (delay_observation_map_.count(v)) {
1367     to_observe = delay_observation_map_.at(v);
1368   }
1369   values_to_observe[to_observe] = observer_module;
1370   block_observed_values.insert(to_observe);
1371 }
1372 
1373 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObserversFor(Block * block,script::Module & module,std::unordered_set<Value * > & block_observed_values,bool is_entry_point,bool is_user_defined_function)1374 InsertObserversHelper::insertObserversFor(
1375     Block* block,
1376     script::Module& module,
1377     std::unordered_set<Value*>& block_observed_values,
1378     bool is_entry_point,
1379     bool is_user_defined_function) {
1380   // input/output values, used to skip inserting observers
1381   // for input and output of the block and the owning graph,
1382   // we have to insert the observers at call site because
1383   // the graph itself can be shared
1384   std::unordered_set<Value*> inputs_outputs;
1385   // list of observer modules for input values
1386   std::vector<std::optional<Module>> block_input_observers;
1387   // list of observer modules for output values
1388   std::vector<std::optional<Module>> block_output_observers;
1389 
1390   // if the current block is the block for entry point graph(the forward graph
1391   // of the top level module), we can insert observers in the block directly
1392   if (!is_entry_point) {
1393     auto* graph = block->owningGraph();
1394     // graph inputs/outputs
1395     for (auto list : {graph->inputs(), graph->outputs()}) {
1396       for (auto* v : list) {
1397         inputs_outputs.insert(v);
1398       }
1399     }
1400     // block outputs
1401     for (auto* v : block->outputs()) {
1402       inputs_outputs.insert(v);
1403     }
1404 
1405     for (auto* v : block->inputs()) {
1406       block_input_observers.emplace_back(getObserverFor(v));
1407     }
1408 
1409     for (auto* v : block->outputs()) {
1410       // we need explicitly skip the values that are already observed
1411       // this might happen in subblocks for `if` since
1412       // these subblock has access to all values before the `if` node
1413       if (!isObserved(v, block_observed_values)) {
1414         block_output_observers.emplace_back(getObserverFor(v));
1415       } else {
1416         block_output_observers.emplace_back(std::nullopt);
1417       }
1418     }
1419   }
1420 
1421   // This means the block is been processed before, we just
1422   // need to attach observer modules and construct the information
1423   // needed by call site here
1424   bool visited = block_observer_map_.count(block);
1425   if (visited) {
1426     // instance clone of observer module and setAttr
1427     for (const auto& observer_attrs : block_observer_map_.at(block)) {
1428       const auto& name = std::get<0>(observer_attrs);
1429       const auto& observer = std::get<1>(observer_attrs);
1430       module._ivalue()->setAttr(name, observer.deepcopy()._ivalue());
1431     }
1432   }
1433   // NB: Why do we need to process the graph even if it's visited?
1434   // Reason is `block_observed_values` can
1435   // change depending on where the method is called, and
1436   // outputs that's been observed(third item of the returned result)
1437   // can change depending on that, so for each graph we'll need to go through
1438   // the whole process of inserting observers, the observers inserted in this
1439   // block won't change, but the information we return to the caller will change
1440   // based on `block_observed_values`
1441 
1442   std::stack<Block*> blocks_to_visit;
1443   blocks_to_visit.push(block);
1444   auto* self = block->owningGraph()->inputs()[0];
1445   // We first construct a map from value to the module, then
1446   // insert observers for them later, this is to avoid interference
1447   // of the inserted observers with the analysis to decide where
1448   // to insert observers, also we only insert observers for
1449   // "intermediate values" that is not the input/output of the
1450   // graph
1451   std::unordered_map<Value*, Module> values_to_observe;
1452 
1453   for (auto* v : block->inputs()) {
1454     if (!inputs_outputs.count(v) && !values_to_observe.count(v)) {
1455       if (auto observer_opt = getObserverFor(v)) {
1456         recordObserved(
1457             v, *observer_opt, values_to_observe, block_observed_values);
1458       }
1459     }
1460   }
1461   while (!blocks_to_visit.empty()) {
1462     Block* b = blocks_to_visit.top();
1463     blocks_to_visit.pop();
1464     for (Node* n : b->nodes()) {
1465       if (observer_nodes_.count(n)) {
1466         continue;
1467       }
1468       if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
1469         script::Module m;
1470         std::shared_ptr<Graph> g;
1471         size_t input_offset = 0;
1472         bool is_udf_for_subblock = is_user_defined_function;
1473         if (n->kind() == prim::CallMethod) {
1474           auto m_opt = getInvokedModuleOpt(module, n, self);
1475           if (!m_opt.has_value()) {
1476             continue;
1477           }
1478           m = *m_opt;
1479           g = m.get_method(n->s(attr::name)).graph();
1480           input_offset = 0;
1481         } else { // CallFunction
1482           m = module;
1483           g = getCallFunctionGraph(n);
1484           input_offset = 1;
1485           is_udf_for_subblock = true;
1486         }
1487 
1488         std::unordered_set<Value*> callee_observed_inputs;
1489         for (auto i = 0U; i < g->inputs().size(); ++i) {
1490           auto* node_input = n->input(i + input_offset);
1491           if (isObserved(node_input, block_observed_values)) {
1492             callee_observed_inputs.insert(g->inputs()[i]);
1493           }
1494         }
1495         auto* subblock = g->block();
1496         auto info_from_callee = insertObserversFor(
1497             subblock, m, callee_observed_inputs, false, is_udf_for_subblock);
1498         auto input_observers = std::get<0>(info_from_callee);
1499         auto output_observers = std::get<1>(info_from_callee);
1500         auto callee_observed_outputs = std::get<2>(info_from_callee);
1501         for (auto idx : callee_observed_outputs) {
1502           block_observed_values.insert(n->outputs()[idx]);
1503         }
1504         for (auto i = 0U; i < g->inputs().size(); ++i) {
1505           auto* node_input = n->input(i + input_offset);
1506           if (input_observers[i] && !inputs_outputs.count(node_input) &&
1507               !isObserved(node_input, block_observed_values)) {
1508             recordObserved(
1509                 node_input,
1510                 *input_observers[i],
1511                 values_to_observe,
1512                 block_observed_values);
1513           }
1514         }
1515         for (auto i = 0U; i < n->outputs().size(); ++i) {
1516           if (output_observers[i] && !inputs_outputs.count(n->output(i)) &&
1517               !isObserved(n->output(i), block_observed_values)) {
1518             recordObserved(
1519                 n->output(i),
1520                 *output_observers[i],
1521                 values_to_observe,
1522                 block_observed_values);
1523           }
1524         }
1525       } else if (n->kind() == prim::If) {
1526         // a vector recoding whether each output is observed or not
1527         std::vector<bool> aggregated_output_observe_state;
1528         for (Block* subblock : n->blocks()) {
1529           if (alwaysRaisesException(subblock)) {
1530             continue;
1531           }
1532           // subblock has access to all the values in the scope of prim::If,
1533           // so subblock_observed_values == block_observed_values
1534           auto info_from_subblock =
1535               insertObserversFor(subblock, module, block_observed_values);
1536           // subblock for prim::If doesn't have inputs
1537           auto output_observers = std::get<1>(info_from_subblock);
1538           auto subblock_observed_outputs = std::get<2>(info_from_subblock);
1539 
1540           // We'll insert output observer for each subblock, and in the end
1541           // we will check if output of subblocks are quantized consistently
1542           for (size_t i = 0; i < subblock->outputs().size(); ++i) {
1543             Value* output = subblock->outputs()[i];
1544             if (output_observers[i] && !inputs_outputs.count(output) &&
1545                 !isObserved(output, block_observed_values)) {
1546               recordObserved(
1547                   output,
1548                   *output_observers[i],
1549                   values_to_observe,
1550                   block_observed_values);
1551             }
1552           }
1553           for (auto idx : subblock_observed_outputs) {
1554             block_observed_values.insert(subblock->outputs()[idx]);
1555           }
1556           std::vector<bool> subblock_output_observe_state;
1557           for (size_t i = 0; i < subblock->outputs().size(); ++i) {
1558             Value* output = subblock->outputs()[i];
1559             subblock_output_observe_state.push_back(
1560                 isObserved(output, block_observed_values));
1561           }
1562           if (!aggregated_output_observe_state.empty()) {
1563             TORCH_CHECK(
1564                 aggregated_output_observe_state ==
1565                     subblock_output_observe_state,
1566                 "branches for `if` should return values that are observed "
1567                 "consistently, if node:",
1568                 *n);
1569           } else {
1570             aggregated_output_observe_state = subblock_output_observe_state;
1571           }
1572         }
1573         // mark the output of if as observed
1574         for (size_t i = 0; i < n->outputs().size(); ++i) {
1575           if (aggregated_output_observe_state[i]) {
1576             block_observed_values.insert(n->output(i));
1577           }
1578         }
1579       } else if (n->kind() == prim::Loop) {
1580         TORCH_WARN_ONCE(
1581             "prim::Loop is not yet supported in quantization, "
1582             "please make sure nothing needs to be quantized in the "
1583             "loop");
1584       }
1585       for (Value* v : n->outputs()) {
1586         propagateObservedProperty(v, block_observed_values);
1587         if (!inputs_outputs.count(v) && !isObserved(v, block_observed_values)) {
1588           auto observer_opt = getObserverFor(v);
1589           // If the node is one of the propagate quant node, e.g.
1590           // aten::cat, we should observe its output only
1591           // if the input of the node is observed
1592           if (observer_opt &&
1593               shouldObserve(n, block_observed_values, quant_type_)) {
1594             recordObserved(
1595                 v, *observer_opt, values_to_observe, block_observed_values);
1596           }
1597         }
1598       }
1599     }
1600   }
1601   std::vector<size_t> output_idxs;
1602   for (auto i = 0U; i < block->outputs().size(); ++i) {
1603     if (isObserved(block->outputs()[i], block_observed_values)) {
1604       output_idxs.push_back(i);
1605     }
1606   }
1607   if (!visited) {
1608     NameModuleVector observer_name_and_modules;
1609     for (const auto& item : values_to_observe) {
1610       auto* v = item.first;
1611       auto observer = item.second;
1612       TORCH_CHECK(
1613           !is_user_defined_function,
1614           "Inserting observers for user defined functions is not "
1615           "supported right now");
1616       insertObserverFor(v, module, observer, observer_name_and_modules);
1617     }
1618     if (insertResetObserverMethod()) {
1619       insertObserverResetMinMax(module, observer_name_and_modules);
1620     }
1621     block_observer_map_[block] = observer_name_and_modules;
1622   }
1623   return std::make_tuple(
1624       block_input_observers, block_output_observers, output_idxs);
1625 }
1626 
propagateObservedProperty(Value * output,std::unordered_set<Value * > & block_observed_values)1627 void InsertObserversHelper::propagateObservedProperty(
1628     Value* output,
1629     std::unordered_set<Value*>& block_observed_values) {
1630   if (pass_through_value_map_.count(output)) {
1631     // since the vector is always non-empty, we will
1632     // not return the initial value
1633     bool all_observed = true;
1634     for (Value* v : pass_through_value_map_.at(output)) {
1635       all_observed &=
1636           observed_values_.count(v) || block_observed_values.count(v);
1637     }
1638     if (all_observed) {
1639       GRAPH_DEBUG("Pass through observed property in node:", *output->node());
1640       // This is to propagate observed property through
1641       // all ops that doesn't require observation
1642       block_observed_values.insert(output);
1643     }
1644   }
1645 }
1646 
1647 } // namespace
1648 
InsertObservers(Module & input_module,const std::string & method_name,const QConfigDict & qconfig_dict,bool inplace,QuantType quant_type)1649 Module InsertObservers(
1650     Module& input_module,
1651     const std::string& method_name,
1652     const QConfigDict& qconfig_dict,
1653     bool inplace,
1654     QuantType quant_type) {
1655   ModuleQConfigMap map_before_clone;
1656   fillQConfigMap(input_module, qconfig_dict, map_before_clone);
1657   ModuleCloneHelper mh;
1658   Module module = mh.clone(input_module, map_before_clone, inplace);
1659   SwapFunctionalLinear(module);
1660   ModuleQConfigMap module_qconfig_map;
1661   // Since the types are changed after clone, we need to fill
1662   // the qconfig map again
1663   fillQConfigMap(module, qconfig_dict, module_qconfig_map);
1664   GRAPH_DEBUG("Quant type:", quant_type);
1665   InsertObserversHelper helper(module_qconfig_map, quant_type);
1666   helper.preprocess(module, method_name);
1667   helper.fillBoundaryValueMap(module, method_name);
1668   // analyze needs to run after fillBoundaryValueMap
1669   // since we need to know the boundary value mapping to trace
1670   // through the calls
1671   helper.analyze(module, method_name);
1672   helper.insertObservers(module, method_name, /* is_entry_point */ true);
1673   return module;
1674 }
1675 
InsertObserversForOnDevicePTQ(Module & input_module,const std::string & method_name,const QConfigDict & qconfig_dict,bool inplace,QuantType quant_type)1676 Module InsertObserversForOnDevicePTQ(
1677     Module& input_module,
1678     const std::string& method_name,
1679     const QConfigDict& qconfig_dict,
1680     bool inplace,
1681     QuantType quant_type) {
1682   ModuleQConfigMap map_before_clone;
1683   fillQConfigMap(input_module, qconfig_dict, map_before_clone);
1684   ModuleCloneHelper mh;
1685   Module cloned_module = mh.clone(input_module, map_before_clone, inplace);
1686   std::shared_ptr<Graph> g = cloned_module.get_method(method_name).graph();
1687   SwapFunctionalLinear(g);
1688   std::string observer_method_name = "observe_" + method_name;
1689   cloneMethod(cloned_module, method_name, observer_method_name);
1690   ModuleQConfigMap module_qconfig_map;
1691   // Since the types are changed after clone, we need to fill
1692   // the qconfig map again
1693   fillQConfigMap(cloned_module, qconfig_dict, module_qconfig_map);
1694   GRAPH_DEBUG("Quant type:", quant_type);
1695   InsertObserversHelper helper(module_qconfig_map, quant_type);
1696   // Removes list mutation part is not clear. Is it needed
1697   helper.preprocess(cloned_module, observer_method_name);
1698   // Since we expect the graph to be inlined this should not have any use
1699   // However, this function does handle if blocks
1700   // Although as far as I understood If blocks are not really handled
1701   // in JIT quantization. Should we just protect against this. That is if we
1702   // find observable value inside If block? Also side effect of inlining is that
1703   // you will have multiple getattrs for the same attribute and thus potentially
1704   // multiple observers observing the same value. This will also lead to
1705   // increased size of the packed param struct. I dont expect this to be a
1706   // common pattern but something to be aware fo Note that current quant
1707   // workflow does not prevent this anyway since during inset quant dequant
1708   // things are inlined anyway
1709   helper.fillBoundaryValueMap(cloned_module, observer_method_name);
1710   // analyze needs to run after fillBoundaryValueMap
1711   // since we need to know the boundary value mapping to trace
1712   // through the calls
1713   helper.analyze(cloned_module, observer_method_name);
1714   // Remove activation observer if quant_type is dynamic
1715   if (quant_type == QuantType::DYNAMIC) {
1716     helper.removeActivationObservers();
1717   }
1718   helper.setInsertResetObserverMethod(true, method_name);
1719   helper.insertObservers(
1720       cloned_module, observer_method_name, /* is_entry_point */ true);
1721   return cloned_module;
1722 }
1723 } // namespace jit
1724 } // namespace torch
1725