1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/passes/quantization/insert_observers.h>
3
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <torch/csrc/jit/ir/subgraph_matcher.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/constant_pooling.h>
8 #include <torch/csrc/jit/passes/constant_propagation.h>
9 #include <torch/csrc/jit/passes/fuse_linear.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/inline_fork_wait.h>
12 #include <torch/csrc/jit/passes/quantization/helper.h>
13 #include <torch/csrc/jit/passes/remove_mutation.h>
14
15 #include <memory>
16 #include <stack>
17 #include <string>
18 #include <utility>
19
20 namespace torch {
21 namespace jit {
22
23 using ModuleQConfigMap = std::unordered_map<ModulePtr, std::optional<QConfig>>;
24
25 namespace {
26
27 struct OptionalQConfigHash {
operator ()torch::jit::__anond581fb520111::OptionalQConfigHash28 inline size_t operator()(const std::optional<QConfig>& qconfig_opt) const {
29 if (qconfig_opt.has_value()) {
30 const auto& m1 = std::get<0>(*qconfig_opt);
31 const auto& m2 = std::get<1>(*qconfig_opt);
32 constexpr int CONST = 7;
33 return std::hash<Module>()(m1) + CONST * std::hash<Module>()(m2);
34 }
35 return 0;
36 }
37 };
38 using QConfigTypePtrMap =
39 std::unordered_map<std::optional<QConfig>, TypePtr, OptionalQConfigHash>;
40 using NameModuleVector = std::vector<std::pair<std::string, Module>>;
41 using OptionalModuleVector = std::vector<std::optional<Module>>;
42 using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
43 using graph_rewrite_helper::PatternInfo;
44 using graph_rewrite_helper::replaceConvolutionWithAtenConv;
45
46 // helper functions
fillQConfigMap(const Module & module,const QConfigDict & qconfig_dict,ModuleQConfigMap & map,const std::string & key="",const std::optional<QConfig> & parent_qconfig=std::nullopt)47 void fillQConfigMap(
48 const Module& module,
49 const QConfigDict& qconfig_dict,
50 ModuleQConfigMap& map,
51 const std::string& key = "",
52 const std::optional<QConfig>& parent_qconfig = std::nullopt) {
53 std::optional<QConfig> qconfig;
54 if (qconfig_dict.find(key) != qconfig_dict.end()) {
55 GRAPH_DEBUG("Got module config for key:", key);
56 qconfig = qconfig_dict.at(key);
57 } else {
58 GRAPH_DEBUG("Inheriting qconfig from parent module:", key);
59 qconfig = parent_qconfig;
60 }
61 map[module._ivalue()] = qconfig;
62
63 for (const NameModule& s : module.named_children()) {
64 std::string child_key;
65 if (key.empty()) {
66 child_key = s.name;
67 } else {
68 child_key = key + "." + s.name;
69 }
70 fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
71 }
72 }
73
getObserverModuleFor(Value * v,const QConfig & qconfig)74 Module getObserverModuleFor(Value* v, const QConfig& qconfig) {
75 return isWeight(v) ? std::get<1>(qconfig) : std::get<0>(qconfig);
76 }
77
78 // helper classes
79 class ModuleCloneHelper {
80 public:
81 /** Clone according to module qconfig map, this is for handling the case
82 * where we have two module instances sharing the same ClassType
83 * but configured with different QConfig
84 * code is copied and modified from
85 * https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp
86 * inplace option means if the copy of the Tensor is deepcopy or not
87 * if inplace is true, the cloned module will share the tensors with
88 * original model instead of deepcopy them
89 */
clone(const Module & module,const ModuleQConfigMap & module_qconfig_map,bool inplace=false)90 Module clone(
91 const Module& module,
92 const ModuleQConfigMap& module_qconfig_map,
93 bool inplace = false) {
94 std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
95 IValue::HashIdentityIValueMap memo;
96 return clone_impl(
97 module, module_qconfig_map, type_remap, inplace, std::move(memo));
98 }
99
100 private:
clone_impl(const Module & module,const ModuleQConfigMap & module_qconfig_map,std::unordered_map<TypePtr,QConfigTypePtrMap> & type_remap,bool inplace,IValue::HashIdentityIValueMap memo)101 Module clone_impl(
102 const Module& module,
103 const ModuleQConfigMap& module_qconfig_map,
104 std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap,
105 bool inplace,
106 IValue::HashIdentityIValueMap memo) {
107 auto qconfig = module_qconfig_map.at(module._ivalue());
108 auto type = module.type();
109 // Create a new _ivalue in the same compilation unit.
110 // Since now we have shared ClassType, we need to preserve the shared
111 // ClassType during cloning, so we first use type and qconfig to check if
112 // the type is already cloned, if so, we'll create a new module with the
113 // cloned ClassType, if not, we'll create a new module and a new ClassType.
114 bool type_already_cloned = type_remap.find(type) != type_remap.end() &&
115 type_remap.at(type).find(qconfig) != type_remap.at(type).end();
116 Module r;
117 if (type_already_cloned) {
118 // if we cloned the class type before, we'll reuse it
119 Module new_module(
120 module._ivalue()->compilation_unit(),
121 type_remap.at(type).at(qconfig)->cast<ClassType>());
122 r = new_module;
123 } else {
124 Module new_module(
125 *type->name(), module._ivalue()->compilation_unit(), true);
126 r = new_module;
127 type_remap[type][module_qconfig_map.at(module._ivalue())] = r.type();
128 }
129 // Copy slots. If a slot is a module - recursively clone it.
130 size_t N = type->numAttributes();
131 for (const auto i : c10::irange(N)) {
132 IValue s = module._ivalue()->getSlot(i);
133 std::string attr_name = type->getAttributeName(i);
134 TypePtr attr_type = type->getAttribute(i);
135 if (attr_type->is_module()) {
136 const Module& orig = Module(s.toObject());
137 Module cloned =
138 clone_impl(orig, module_qconfig_map, type_remap, inplace, memo);
139
140 // NOTE: why do we need to manually setattr on object instead of using
141 // register_module here? because the attr can be a module interface
142 // type and hold a Module object still. register_module will not let us
143 // correctly set up the type for this attr, so we had to do this
144 // manually. In the case it's an interface type, the type will be shared
145 // by the new cloned instance in the same compilation unit bc it only
146 // contains a list of functionSchema
147 r.type()->addOrCheckAttribute(
148 attr_name,
149 attr_type->cast<ClassType>() ? cloned.type() : attr_type);
150 r._ivalue()->setAttr(attr_name, cloned._ivalue());
151 } else {
152 // we'll deepcopy the IValue in non inplace option
153 r.register_attribute(
154 type->getAttributeName(i),
155 type->getAttribute(i),
156 inplace ? s : s.deepcopy(memo),
157 type->is_parameter(i),
158 type->is_buffer(i));
159 }
160 }
161
162 // only clone the methods and constants if the ClassType is not cloned
163 // before
164 if (!type_already_cloned) {
165 for (size_t i = 0; i < type->numConstants(); ++i) {
166 r.type()->addConstant(type->getConstantName(i), type->getConstant(i));
167 }
168 // Clone methods remapping the types to the cloned ones.
169 for (auto& fn : type->methods()) {
170 clone_method(module, r, *fn, module_qconfig_map, type_remap);
171 }
172 // Execute __setstate__(__getstate__()) to initialize custom class
173 // members.
174 if (auto setstate_method = r.find_method("__setstate__")) {
175 auto getstate_method = r.find_method("__getstate__");
176 TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
177 auto state = (*getstate_method)(Stack{});
178 (*setstate_method)(Stack{std::move(state)});
179 }
180 }
181 return r;
182 }
183
remapTypes(Block * block,Value * self,const Module & source,Module & target,const ModuleQConfigMap & module_qconfig_map,const std::function<TypePtr (TypePtr,std::optional<QConfig>)> & type_remap_fn)184 void remapTypes(
185 Block* block,
186 Value* self,
187 const Module& source,
188 Module& target,
189 const ModuleQConfigMap& module_qconfig_map,
190 const std::function<TypePtr(TypePtr, std::optional<QConfig>)>&
191 type_remap_fn) {
192 // remap of %self will be done outside of the function
193 // and we don't support the case when people pass in
194 // module as argument of the method because in that case
195 // we need to do more comprehensive analysis to decide the
196 // QConfig for the module
197 for (size_t i = 1; i < block->inputs().size(); ++i) {
198 TORCH_CHECK(
199 !block->inputs()[i]->type()->cast<ClassType>(),
200 "We don't support quantizing methods that has Object as arguments");
201 }
202 for (Node* node : block->nodes()) {
203 // remapping type for module instance
204 if (node->kind() == prim::CallMethod || node->kind() == prim::GetAttr) {
205 Value* instance = node->inputs()[0];
206 auto child_opt = getInvokedModuleOpt(source, node, self);
207 if (child_opt.has_value()) {
208 auto qconfig = module_qconfig_map.at(child_opt->_ivalue());
209 instance->setType(type_remap_fn(instance->type(), qconfig));
210 }
211 }
212 // We don't remap output and the remapping of module type
213 // will be done in CallMethod, we don't support type remapping
214 // for modules returned from methods or functions
215 for (Block* sub_block : node->blocks()) {
216 remapTypes(
217 sub_block, self, source, target, module_qconfig_map, type_remap_fn);
218 }
219 for (Symbol name : node->attributeNames()) {
220 if (node->kindOf(name) == AttributeKind::g) {
221 remapTypes(
222 node->g(name).get(),
223 source,
224 target,
225 module_qconfig_map,
226 type_remap_fn);
227 } else if (node->kindOf(name) == AttributeKind::gs) {
228 for (const auto& g : node->gs(name)) {
229 remapTypes(
230 g.get(), source, target, module_qconfig_map, type_remap_fn);
231 }
232 }
233 }
234 }
235 }
236
remapTypes(Graph * graph,const Module & source,Module & target,const ModuleQConfigMap & module_qconfig_map,const std::function<TypePtr (TypePtr,std::optional<QConfig>)> & type_remap_fn)237 void remapTypes(
238 Graph* graph,
239 const Module& source,
240 Module& target,
241 const ModuleQConfigMap& module_qconfig_map,
242 const std::function<TypePtr(TypePtr, std::optional<QConfig>)>&
243 type_remap_fn) {
244 remapTypes(
245 graph->block(),
246 graph->inputs()[0],
247 source,
248 target,
249 module_qconfig_map,
250 type_remap_fn);
251 }
252
clone_method(const Module & source,Module & target,const Function & method,const ModuleQConfigMap & module_qconfig_map,const std::unordered_map<TypePtr,QConfigTypePtrMap> & type_remap)253 void clone_method(
254 const Module& source,
255 Module& target,
256 const Function& method,
257 const ModuleQConfigMap& module_qconfig_map,
258 const std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap) {
259 auto type_remap_fn = [&](TypePtr type_ptr,
260 const std::optional<QConfig>& qconfig) {
261 if (type_remap.find(type_ptr) != type_remap.end()) {
262 const auto& qconfig_map = type_remap.at(type_ptr);
263 if (qconfig_map.find(qconfig) != qconfig_map.end()) {
264 return qconfig_map.at(qconfig);
265 }
266 }
267 return type_ptr;
268 };
269 auto graph = toGraphFunction(method).graph()->copy();
270 remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn);
271 // remap self
272 graph->inputs()[0]->setType(target.type());
273 // we only support %self being Module in the arguments of function
274 auto schema_type_remap_fn = [&](TypePtr type_ptr) {
275 return type_remap_fn(
276 std::move(type_ptr), module_qconfig_map.at(source._ivalue()));
277 };
278 auto schema =
279 method.getSchema().cloneWithRemappedTypes(schema_type_remap_fn);
280 const auto this_method_name =
281 c10::QualifiedName(*target.type()->name(), method.name());
282 auto copied = target._ivalue()->compilation_unit()->create_function(
283 this_method_name, std::move(graph));
284 target.type()->addMethod(copied);
285 copied->setSchema(std::move(schema));
286 }
287 };
288
289 class InsertObserversHelper {
290 public:
InsertObserversHelper(const ModuleQConfigMap & map,QuantType quant_type)291 explicit InsertObserversHelper(
292 const ModuleQConfigMap& map,
293 QuantType quant_type)
294 : module_qconfig_map_(map), quant_type_(quant_type) {}
295
296 // TODO: replace (module, method_name) with graph?
297 // preprocess to clean up the graph from tracing
298 void preprocess(Module& module, const std::string& method_name);
299
300 // Fill the map between the caller input/output to input/output
301 // of called graph, this is used to navigate through the graph
302 // to find the observer for a given value
303 void fillBoundaryValueMap(Module& module, const std::string& method_name);
304
305 // analyze the graph and record necessary information that can
306 // be used in insert observers
307 void analyze(Module& module, const std::string& method_name);
308
309 void removeActivationObservers();
310
311 /**
312 * Recursively insert observers for the method, also we'll process
313 * the nodes in the graph in the order of execution of these nodes
314 * since we need the context information to decide whether we want to
315 * observe/quantize a value a not, we don't want to observe a value multiple
316 * times.
317 *
318 * argument: is_entry_point means whether the current method is the forward
319 * method of the top level module.
320 *
321 * Since we want to insert observers in the call site instead of in the called
322 * graph, we'll postpone inserting observer to caller as much as possible, if
323 * we know the current method is the outer most method, then
324 * we will insert all observers in the graph instead of postpone this to the
325 * parent, note that this assumes we don't have recursive method
326 * calls
327 *
328 * returns a tuple of vectors of observer modules for input and output, these
329 * are used for inserting observers for the input/output values
330 * since we need to insert these values at call site.
331 * And a vector of indexes of outputs that indicates whether the output value
332 * is already observed or not, this is used for propagating the observed
333 * property of a value through CallMethods, because we should skip inserting
334 * observers for ops that don't require observation
335 */
336 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
337 insertObservers(
338 Module& module,
339 const std::string& method_name,
340 bool is_entry_point = false,
341 std::unordered_set<Value*> graph_observed_values =
342 std::unordered_set<Value*>());
343
setInsertResetObserverMethod(bool insert_reset_observer_method,const std::string & method_name)344 void setInsertResetObserverMethod(
345 bool insert_reset_observer_method,
346 const std::string& method_name) {
347 insert_reset_observer_method_ = insert_reset_observer_method;
348 reset_observer_method_name_ = "reset_observers_" + method_name;
349 }
350
351 private:
352 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
353 insertObserversFor(
354 Block* block,
355 script::Module& module,
356 // this is a reference because when we insert observer for a value
357 // in one block it is also observed in another block, we don't want to
358 // insert multiple observers for the same value
359 std::unordered_set<Value*>& block_observed_values,
360 bool is_entry_point = false,
361 bool is_user_defined_function = false);
362
363 // Record v as "ready for observation" by storing it in values_to_observe.
364 // If v is a part of a delayed observation pattern, record v's descendant
365 // (per delay rules) instead. The observers are inserted at a later stage
366 // by reading the state created by this function.
367 void recordObserved(
368 Value* v,
369 const Module& observer_module,
370 std::unordered_map<Value*, Module>& values_to_observe,
371 std::unordered_set<Value*>& block_observed_values);
372
373 ModuleMethodVector getInvokedMethods(
374 Module& module,
375 const std::string& method_name);
376
377 bool valueNeedsToBeQuantized(Value* v, const QConfig& qconfig);
378
isObserved(Value * v,const std::unordered_set<Value * > & block_observed_values)379 bool isObserved(
380 Value* v,
381 const std::unordered_set<Value*>& block_observed_values) {
382 return block_observed_values.count(v) || observed_values_.count(v);
383 }
384
385 // Fill the map from value to the corresponding observer module
386 // this map is used in insertObservers to actually insert
387 // observers to the module
388 void fillValueObserverMap(Module& module, const std::string& method_name);
389
390 // Clone observer module and add it to the original module,
391 // and insert a call to observer forward function
392 void insertObserverFor(
393 Value* v,
394 Module& module,
395 const Module& observer_module,
396 NameModuleVector& observer_name_and_modules);
397
398 void insertObserverResetMinMax(
399 Module& module,
400 const NameModuleVector& observer_name_and_modules);
401
402 // Uses the state created by fillBoundaryValueMap and fillValueObserverMap
403 // to return an observer configured for a value, if it is needed.
404 std::optional<Module> getObserverFor(Value* v);
405
406 // Uses the state created by fillPassThroughValueMap to propagage observed
407 // property which should pass through from inputs to outputs.
408 void propagateObservedProperty(
409 Value* output,
410 std::unordered_set<Value*>& block_observed_values);
411
412 // for cat/add/mul we will only observe their output if their input
413 // are observed
shouldObserve(Node * n,const std::unordered_set<Value * > & block_observed_values,QuantType quant_type)414 bool shouldObserve(
415 Node* n,
416 const std::unordered_set<Value*>& block_observed_values,
417 QuantType quant_type) {
418 // Check whether node output uses can be quantized, eg cat followed by
419 // linear op
420 for (Value* v : n->outputs()) {
421 for (const auto& use : v->uses()) {
422 if (useQuantizable(use, quant_type)) {
423 return true;
424 }
425 }
426 }
427 if (isPropagateQuantSingleInputOp(n)) {
428 return isObserved(n->input(0), block_observed_values);
429 } else if (isPropagateQuantBinaryOp(n)) {
430 // This checks both of the input should be tensor and observed.
431 // There is one check that we didn't do here, which is
432 // !isScalar(isObserved(n->input(1), block_observed_values)
433 // to make sure input 1 is not a scalar, because scalar tensor input
434 // for add/mul won't be observed with current rule, we can omit
435 // this check here
436 return isObserved(n->input(0), block_observed_values) &&
437 isObserved(n->input(1), block_observed_values);
438 }
439 return true;
440 }
441
442 void delayObservingValuesInPattern(Graph& graph, const PatternInfo& pattern);
443
444 // Find and mark known patterns such as conv-relu (and others) where
445 // we should not insert observers in the middle of the pattern.
446 void addValuesToDelayObservation(
447 const Module& module,
448 const std::string& method_name);
449
450 // Fill the map from values to the list of values that can pass the observed
451 // property to it
452 void fillPassThroughValueMap(const std::shared_ptr<Graph>& graph);
453
insertResetObserverMethod()454 bool insertResetObserverMethod() {
455 return insert_reset_observer_method_;
456 }
457
458 const ModuleQConfigMap& module_qconfig_map_;
459
460 // Values we want to delay observation, used to delay the observation for
461 // values in the middle of the ops that are supposed to be fused, e.g.
462 // the output value of conv in the conv - relu pattern
463 // the key is the intermediate output, e.g. output of conv
464 // the value is the value we want to observe, e.g. output of relu
465 //
466 // example, assuming we want to delay conv-relu:
467 // %x1 = conv(%x0)
468 // %x2 = relu(%x1)
469 //
470 // delay_observation_map_ = {
471 // %x1: %x2,
472 // }
473 std::unordered_map<Value*, Value*> delay_observation_map_;
474
475 std::unordered_set<Graph*> visited_graph_of_observer_map_;
476
477 // Map of value to observer module configured for that value.
478 std::unordered_map<Value*, Module> observer_for_value_;
479
480 // Map from values from callsite into the values in the CallMethod graph
481 // key of the map is the value from caller graph, and the value of the map
482 // is the list of values in the callee graph (the graph
483 // corresponding to the called method),
484 // the reason it is a set is that a value in the caller graph
485 // can both correspond to the output of one callee graph and input of another
486 // callee graph.
487 //
488 // example:
489 // // top level module
490 // %x1 = conv(%x0)
491 // %x2 = prim::CallFunction(%foo, %x1)
492 //
493 // // graph of %foo
494 // %y2 = conv(%y1)
495 // return %y2
496 //
497 // boundary_value_map = {
498 // // current module's output values to corresponding return values from
499 // subgraph %x2: %y2,
500 // // current module's input values to corresponding input value to subgraph
501 // %x1: %y1,
502 // }
503 std::unordered_map<Value*, std::unordered_set<Value*>> boundary_value_map_;
504
505 std::unordered_set<Value*> observed_values_;
506
507 // This is used for the observed values to pass through the ops like flatten,
508 // so that output value of flatten does not need to be observed
509 // key is the output of the op, value is a vector of values that need
510 // to be observed in order to pass the observed property to the output
511 //
512 // example:
513 // %x1 = flatten(%x0) // pass_through
514 // %x2 = conv(%x1) // not pass_through
515 //
516 // pass_through_value_map_ = {
517 // %x1: [%x0],
518 // }
519 std::unordered_map<Value*, std::vector<Value*>> pass_through_value_map_;
520
521 // Unique id generator for observer module, used for generating
522 // unique observer names when we insert observer module, we
523 // record the current unique id used to avoid incrementing from 0
524 // every time to find a unique id.
525 int uid_ = 0;
526 // Set of observer forward call nodes
527 std::unordered_set<Node*> observer_nodes_;
528 // Map from block to a vector of observer name and observer modules we
529 // want to add to the module instance that has the block
530 std::unordered_map<Block*, NameModuleVector> block_observer_map_;
531
532 // Type of quantization for this pass.
533 QuantType quant_type_ = QuantType::STATIC;
534 // These are the IR patterns we match to skip inserting observers.
535 // They are compiled once on construction and used repeatedly within
536 // the pass.
537
538 // nn.Linear + nn.ReLU
539 const PatternInfo nn_linear_nn_relu = PatternInfo::parse_from_str(
540 R"(
541 graph(%input, %linear, %relu):
542 %first_output = prim::CallMethod[name="forward"](%linear, %input)
543 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
544 return (%second_output) )",
545 {is_linear_module, is_relu_module});
546
547 // nn.Linear + F.relu
548 const PatternInfo nn_linear_f_relu = PatternInfo::parse_from_str(
549 R"(
550 graph(%input, %linear, %relu, %inplace):
551 %first_output = prim::CallMethod[name="forward"](%linear, %input)
552 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
553 return (%second_output) )",
554 {is_linear_module, is_functional_relu});
555
556 // nn.Linear + aten::relu
557 const PatternInfo nn_linear_aten_relu = PatternInfo::parse_from_str(
558 R"(
559 graph(%input, %linear, %relu):
560 %first_output = prim::CallMethod[name="forward"](%linear, %input)
561 %second_output = aten::relu(%first_output)
562 return (%second_output) )",
563 {is_linear_module});
564
565 // nn.Linear + aten::relu_
566 const PatternInfo nn_linear_aten_relu_ = PatternInfo::parse_from_str(
567 R"(
568 graph(%input, %linear, %relu):
569 %first_output = prim::CallMethod[name="forward"](%linear, %input)
570 %second_output = aten::relu_(%first_output)
571 return (%second_output) )",
572 {is_linear_module});
573
574 // aten::linear + nn.ReLU
575 const PatternInfo aten_linear_nn_relu = PatternInfo::parse_from_str(
576 R"(
577 graph(%input, %weight, %bias, %relu):
578 %first_output = aten::linear(%input, %weight, %bias)
579 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
580 return (%second_output) )",
581 {is_relu_module});
582
583 // aten::linear + F.relu
584 const PatternInfo aten_linear_f_relu = PatternInfo::parse_from_str(
585 R"(
586 graph(%input, %weight, %bias, %relu, %inplace):
587 %first_output = aten::linear(%input, %weight, %bias)
588 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
589 return (%second_output) )",
590 {is_functional_relu});
591
592 // aten::linear + aten::relu
593 const PatternInfo aten_linear_aten_relu = PatternInfo::parse_from_str(
594 R"(
595 graph(%input, %weight, %bias):
596 %first_output = aten::linear(%input, %weight, %bias)
597 %second_output = aten::relu(%first_output)
598 return (%second_output) )");
599
600 // aten::linear + aten::relu_
601 const PatternInfo aten_linear_aten_relu_ = PatternInfo::parse_from_str(
602 R"(
603 graph(%input, %weight, %bias):
604 %first_output = aten::linear(%input, %weight, %bias)
605 %second_output = aten::relu_(%first_output)
606 return (%second_output) )");
607
608 const PatternInfo nn_conv1d_f_relu = PatternInfo::parse_from_str(
609 R"(
610 graph(%self, %input, %conv, %relu, %inplace):
611 %first_output = prim::CallMethod[name="forward"](%conv, %input)
612 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
613 return (%second_output) )",
614 {is_conv1d_module, is_functional_relu});
615
616 const PatternInfo nn_conv1d_nn_relu = PatternInfo::parse_from_str(
617 R"(
618 graph(%self, %input, %conv, %relu):
619 %first_output = prim::CallMethod[name="forward"](%conv, %input)
620 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
621 return (%second_output) )",
622 {is_conv1d_module, is_relu_module});
623
624 const PatternInfo nn_conv1d_aten_relu = PatternInfo::parse_from_str(
625 R"(
626 graph(%self, %input, %conv):
627 %first_output = prim::CallMethod[name="forward"](%conv, %input)
628 %second_output = aten::relu(%first_output)
629 return (%second_output) )",
630 {is_conv1d_module});
631
632 const PatternInfo nn_conv1d_aten_relu_ = PatternInfo::parse_from_str(
633 R"(
634 graph(%self, %input, %conv):
635 %first_output = prim::CallMethod[name="forward"](%conv, %input)
636 %second_output = aten::relu_(%first_output)
637 return (%second_output) )",
638 {is_conv1d_module});
639
640 const PatternInfo nn_conv2d_f_relu = PatternInfo::parse_from_str(
641 R"(
642 graph(%self, %input, %conv, %relu, %inplace):
643 %first_output = prim::CallMethod[name="forward"](%conv, %input)
644 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
645 return (%second_output) )",
646 {is_conv2d_module, is_functional_relu});
647
648 const PatternInfo nn_conv2d_nn_relu = PatternInfo::parse_from_str(
649 R"(
650 graph(%self, %input, %conv, %relu):
651 %first_output = prim::CallMethod[name="forward"](%conv, %input)
652 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
653 return (%second_output) )",
654 {is_conv2d_module, is_relu_module});
655
656 const PatternInfo nn_conv2d_aten_relu = PatternInfo::parse_from_str(
657 R"(
658 graph(%self, %input, %conv):
659 %first_output = prim::CallMethod[name="forward"](%conv, %input)
660 %second_output = aten::relu(%first_output)
661 return (%second_output) )",
662 {is_conv2d_module});
663
664 const PatternInfo nn_conv2d_aten_relu_ = PatternInfo::parse_from_str(
665 R"(
666 graph(%self, %input, %conv):
667 %first_output = prim::CallMethod[name="forward"](%conv, %input)
668 %second_output = aten::relu_(%first_output)
669 return (%second_output) )",
670 {is_conv2d_module});
671
672 const PatternInfo nn_conv3d_f_relu = PatternInfo::parse_from_str(
673 R"(
674 graph(%self, %input, %conv, %relu, %inplace):
675 %first_output = prim::CallMethod[name="forward"](%conv, %input)
676 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
677 return (%second_output) )",
678 {is_conv3d_module, is_functional_relu});
679
680 const PatternInfo nn_conv3d_nn_relu = PatternInfo::parse_from_str(
681 R"(
682 graph(%self, %input, %conv, %relu):
683 %first_output = prim::CallMethod[name="forward"](%conv, %input)
684 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
685 return (%second_output) )",
686 {is_conv3d_module, is_relu_module});
687
688 const PatternInfo nn_conv3d_aten_relu = PatternInfo::parse_from_str(
689 R"(
690 graph(%self, %conv, %input):
691 %first_output = prim::CallMethod[name="forward"](%conv, %input)
692 %second_output = aten::relu(%first_output)
693 return (%second_output) )",
694 {is_conv3d_module});
695
696 const PatternInfo nn_conv3d_aten_relu_ = PatternInfo::parse_from_str(
697 R"(
698 graph(%self, %input, %conv):
699 %first_output = prim::CallMethod[name="forward"](%conv, %input)
700 %second_output = aten::relu_(%first_output)
701 return (%second_output) )",
702 {is_conv3d_module});
703
704 const PatternInfo add_nn_relu = PatternInfo::parse_from_str(
705 R"(
706 graph(%self, %a, %b, %alpha, %relu):
707 %first_output = aten::add(%a, %b, %alpha)
708 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
709 return (%second_output) )",
710 {aten_add_alpha_is_one, is_relu_module});
711
712 const PatternInfo add_f_relu = PatternInfo::parse_from_str(
713 R"(
714 graph(%self, %a, %b, %alpha, %relu, %inplace):
715 %first_output = aten::add(%a, %b, %alpha)
716 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
717 return (%second_output) )",
718 {aten_add_alpha_is_one, is_functional_relu});
719
720 const PatternInfo inplace_add_nn_relu = PatternInfo::parse_from_str(
721 R"(
722 graph(%self, %a, %b, %alpha, %relu):
723 %first_output = aten::add_(%a, %b, %alpha)
724 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
725 return (%second_output) )",
726 {aten_add_alpha_is_one, is_relu_module});
727
728 const PatternInfo inplace_add_f_relu = PatternInfo::parse_from_str(
729 R"(
730 graph(%self, %a, %b, %alpha, %relu, %inplace):
731 %first_output = aten::add_(%a, %b, %alpha)
732 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
733 return (%second_output) )",
734 {aten_add_alpha_is_one, is_functional_relu});
735
736 const PatternInfo add_aten_relu = PatternInfo::parse_from_str(R"(
737 graph(%self, %a, %b, %alpha):
738 %first_output = aten::add(%a, %b, %alpha)
739 %second_output = aten::relu(%first_output)
740 return (%second_output) )");
741
742 const PatternInfo add_aten_relu_ = PatternInfo::parse_from_str(R"(
743 graph(%self, %a, %b, %alpha):
744 %first_output = aten::add(%a, %b, %alpha)
745 %second_output = aten::relu_(%first_output)
746 return (%second_output) )");
747
748 const PatternInfo inplace_add_aten_relu = PatternInfo::parse_from_str(R"(
749 graph(%self, %a, %b, %alpha):
750 %first_output = aten::add_(%a, %b, %alpha)
751 %second_output = aten::relu(%first_output)
752 return (%second_output) )");
753
754 const PatternInfo inplace_add_aten_relu_ = PatternInfo::parse_from_str(R"(
755 graph(%self, %a, %b, %alpha):
756 %first_output = aten::add_(%a, %b, %alpha)
757 %second_output = aten::relu_(%first_output)
758 return (%second_output) )");
759
760 const PatternInfo nn_bn2d_nn_relu = PatternInfo::parse_from_str(
761 R"(
762 graph(%self, %input, %batchnorm, %relu):
763 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
764 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
765 return (%second_output) )",
766 {is_batchnorm2d_module, is_relu_module});
767
768 const PatternInfo nn_bn2d_f_relu = PatternInfo::parse_from_str(
769 R"(
770 graph(%self, %input, %batchnorm, %relu, %inplace):
771 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
772 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
773 return (%second_output) )",
774 {is_batchnorm2d_module, is_functional_relu});
775
776 const PatternInfo nn_bn2d_aten_relu = PatternInfo::parse_from_str(
777 R"(
778 graph(%self, %input, %batchnorm):
779 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
780 %second_output = aten::relu(%first_output)
781 return (%second_output) )",
782 {is_batchnorm2d_module});
783
784 const PatternInfo nn_bn2d_aten_relu_ = PatternInfo::parse_from_str(
785 R"(
786 graph(%self, %input, %batchnorm):
787 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
788 %second_output = aten::relu_(%first_output)
789 return (%second_output) )",
790 {is_batchnorm2d_module});
791
792 const PatternInfo nn_bn3d_nn_relu = PatternInfo::parse_from_str(
793 R"(
794 graph(%self, %input, %batchnorm, %relu):
795 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
796 %second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
797 return (%second_output) )",
798 {is_batchnorm3d_module, is_relu_module});
799
800 const PatternInfo nn_bn3d_f_relu = PatternInfo::parse_from_str(
801 R"(
802 graph(%self, %input, %batchnorm, %relu, %inplace):
803 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
804 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
805 return (%second_output) )",
806 {is_batchnorm3d_module, is_functional_relu});
807
808 const PatternInfo nn_bn3d_aten_relu = PatternInfo::parse_from_str(
809 R"(
810 graph(%self, %input, %batchnorm):
811 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
812 %second_output = aten::relu(%first_output)
813 return (%second_output) )",
814 {is_batchnorm3d_module});
815
816 const PatternInfo nn_bn3d_aten_relu_ = PatternInfo::parse_from_str(
817 R"(
818 graph(%self, %input, %batchnorm):
819 %first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
820 %second_output = aten::relu_(%first_output)
821 return (%second_output) )",
822 {is_batchnorm3d_module});
823
824 const PatternInfo mul_nn_relu = PatternInfo::parse_from_str(
825 R"(
826 graph(%self, %a, %b, %relu):
827 %first_output = aten::mul(%a, %b)
828 %second_output = prim::CallMethod[name="forward"](%relu, %first_output)
829 return (%second_output) )",
830 {is_relu_module});
831
832 const PatternInfo mul_f_relu = PatternInfo::parse_from_str(
833 R"(
834 graph(%self, %a, %b, %relu, %inplace):
835 %first_output = aten::mul(%a, %b)
836 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
837 return (%second_output) )",
838 {is_functional_relu});
839
840 const PatternInfo inplace_mul_nn_relu = PatternInfo::parse_from_str(
841 R"(
842 graph(%self, %a, %b, %relu):
843 %first_output = aten::mul_(%a, %b)
844 %second_output = prim::CallMethod[name="forward"](%relu, %first_output)
845 return (%second_output) )",
846 {is_relu_module});
847
848 const PatternInfo inplace_mul_f_relu = PatternInfo::parse_from_str(
849 R"(
850 graph(%self, %a, %b, %relu, %inplace):
851 %first_output = aten::mul_(%a, %b)
852 %second_output = prim::CallFunction(%relu, %first_output, %inplace)
853 return (%second_output) )",
854 {is_functional_relu});
855
856 const PatternInfo mul_aten_relu = PatternInfo::parse_from_str(R"(
857 graph(%self, %a, %b):
858 %first_output = aten::mul(%a, %b)
859 %second_output = aten::relu(%first_output)
860 return (%second_output) )");
861
862 const PatternInfo mul_aten_relu_ = PatternInfo::parse_from_str(R"(
863 graph(%self, %a, %b):
864 %first_output = aten::mul(%a, %b)
865 %second_output = aten::relu_(%first_output)
866 return (%second_output) )");
867
868 const PatternInfo inplace_mul_aten_relu = PatternInfo::parse_from_str(R"(
869 graph(%self, %a, %b):
870 %first_output = aten::mul_(%a, %b)
871 %second_output = aten::relu(%first_output)
872 return (%second_output) )");
873
874 const PatternInfo inplace_mul_aten_relu_ = PatternInfo::parse_from_str(R"(
875 graph(%self, %a, %b):
876 %first_output = aten::mul_(%a, %b)
877 %second_output = aten::relu_(%first_output)
878 return (%second_output) )");
879
880 const std::vector<std::reference_wrapper<const PatternInfo>> delay_patterns =
881 {
882 nn_linear_f_relu, nn_linear_nn_relu,
883 nn_linear_aten_relu, nn_linear_aten_relu_,
884 aten_linear_f_relu, aten_linear_nn_relu,
885 aten_linear_aten_relu, aten_linear_aten_relu_,
886
887 nn_conv1d_f_relu, nn_conv1d_nn_relu,
888 nn_conv1d_aten_relu, nn_conv1d_aten_relu_,
889 nn_conv2d_f_relu, nn_conv2d_nn_relu,
890 nn_conv2d_aten_relu, nn_conv2d_aten_relu_,
891 nn_conv3d_f_relu, nn_conv3d_nn_relu,
892 nn_conv3d_aten_relu, nn_conv3d_aten_relu_,
893
894 add_nn_relu, add_f_relu,
895 inplace_add_nn_relu, inplace_add_f_relu,
896 add_aten_relu, add_aten_relu_,
897 inplace_add_aten_relu, inplace_add_aten_relu_,
898
899 nn_bn2d_nn_relu, nn_bn2d_f_relu,
900 nn_bn2d_aten_relu, nn_bn2d_aten_relu_,
901 nn_bn3d_nn_relu, nn_bn3d_f_relu,
902 nn_bn3d_aten_relu, nn_bn3d_aten_relu_,
903
904 mul_nn_relu, mul_f_relu,
905 inplace_mul_nn_relu, inplace_mul_f_relu,
906 mul_aten_relu, mul_aten_relu_,
907 inplace_mul_aten_relu, inplace_mul_aten_relu_,
908 };
909
910 bool insert_reset_observer_method_{false};
911 std::string reset_observer_method_name_;
912 };
913
getInvokedMethods(Module & module,const std::string & method_name)914 ModuleMethodVector InsertObserversHelper::getInvokedMethods(
915 Module& module,
916 const std::string& method_name) {
917 ModuleMethodVector invoked_methods;
918 Method method = module.get_method(method_name);
919 auto graph = method.graph();
920
921 std::stack<Block*> blocks_to_visit;
922 blocks_to_visit.push(graph->block());
923 while (!blocks_to_visit.empty()) {
924 Block* b = blocks_to_visit.top();
925 blocks_to_visit.pop();
926 for (Node* n : b->nodes()) {
927 // Skip observer nodes
928 if (observer_nodes_.count(n)) {
929 continue;
930 }
931 if (n->kind() == prim::CallMethod) {
932 auto m_opt = getInvokedModuleOpt(module, n, graph->inputs()[0]);
933 if (m_opt.has_value()) {
934 invoked_methods.emplace_back(*m_opt, n->s(attr::name));
935 }
936 }
937
938 for (Block* subblock : n->blocks()) {
939 blocks_to_visit.push(subblock);
940 }
941 }
942 }
943 return invoked_methods;
944 }
945
insertObserverFor(Value * v,Module & module,const Module & observer_module,NameModuleVector & observer_name_and_modules)946 void InsertObserversHelper::insertObserverFor(
947 Value* v,
948 Module& module,
949 const Module& observer_module,
950 NameModuleVector& observer_name_and_modules) {
951 if (observed_values_.count(v)) {
952 return;
953 }
954 GRAPH_DEBUG("Inserting observer for:", v->debugName());
955 Module observer = observer_module.deepcopy();
956 std::string observer_name = "_observer_" + std::to_string(uid_++);
957 while (module.hasattr(observer_name)) {
958 observer_name = "_observer_" + std::to_string(uid_++);
959 }
960 module.register_module(observer_name, observer);
961 observer_name_and_modules.emplace_back(observer_name, observer);
962
963 auto* g = v->owningGraph();
964 // Get handle of observer module
965 Node* observer_instance =
966 g->createGetAttr(g->inputs()[0], observer_name)->insertAfter(v->node());
967 observer_instance->output()->setDebugName(observer_name);
968
969 {
970 WithInsertPoint guard(observer_instance->next());
971 // Match arguments to types of observer's arguments
972 MatchedSchema forward_matched_schema = matchSchema(
973 observer.get_method("forward").function().getSchema(),
974 v->node()->sourceRange(),
975 *g,
976 {observer_instance->output(), v},
977 {});
978 // Insert call to observer's forward
979 Node* call = g->insertMethodCall("forward", forward_matched_schema)->node();
980 call->output()->copyMetadata(v);
981
982 // Replace v with the output of observer
983 v->replaceAllUsesWith(call->output());
984 // The above also replaced the input to `call`, so switch it back to
985 // the correct value
986 call->replaceInput(1, v);
987 observer_nodes_.emplace(call);
988 observed_values_.insert(call->output());
989 }
990 }
991
insertObserverResetMinMax(Module & module,const NameModuleVector & observer_name_and_modules)992 void InsertObserversHelper::insertObserverResetMinMax(
993 Module& module,
994 const NameModuleVector& observer_name_and_modules) {
995 if (observer_name_and_modules.empty()) {
996 return;
997 }
998 auto reset_min_max_opt = module.find_method(reset_observer_method_name_);
999 if (!reset_min_max_opt.has_value()) {
1000 std::shared_ptr<Graph> reset_observer_graph = std::make_shared<Graph>();
1001 Value* module_value = reset_observer_graph->addInput("self");
1002 Node* output_node = reset_observer_graph->createNone();
1003 reset_observer_graph->insertNode(output_node);
1004 reset_observer_graph->registerOutput(output_node->output());
1005 module_value->setType(module._ivalue()->type());
1006 const auto method_name = c10::QualifiedName(
1007 *(module.type()->name()), reset_observer_method_name_);
1008 auto reset_observer_fn =
1009 module._ivalue()->compilation_unit()->create_function(
1010 method_name, std::move(reset_observer_graph));
1011 auto self_arg = c10::Argument("self", module.type());
1012 auto output_arg = c10::Argument("none", output_node->output()->type());
1013 auto schema = c10::FunctionSchema(
1014 reset_observer_method_name_,
1015 "",
1016 {std::move(self_arg)},
1017 {std::move(output_arg)});
1018 reset_observer_fn->setSchema(std::move(schema));
1019 module.type()->addMethod(reset_observer_fn);
1020 }
1021 auto reset_min_max_graph =
1022 module.get_method(reset_observer_method_name_).graph();
1023 Value* self = reset_min_max_graph->inputs()[0];
1024
1025 for (const auto& pair : observer_name_and_modules) {
1026 const auto& observer_name = pair.first;
1027 const auto& observer = pair.second;
1028 Value* observer_value =
1029 reset_min_max_graph->insertGetAttr(self, observer_name);
1030 MatchedSchema reset_minmax_schema = matchSchema(
1031 observer.get_method("reset_min_max_vals").function().getSchema(),
1032 observer_value->node()->sourceRange(),
1033 *reset_min_max_graph,
1034 {observer_value},
1035 {});
1036 reset_min_max_graph->insertMethodCall(
1037 "reset_min_max_vals", reset_minmax_schema);
1038 }
1039 }
1040
delayObservingValuesInPattern(Graph & graph,const PatternInfo & pattern)1041 void InsertObserversHelper::delayObservingValuesInPattern(
1042 Graph& graph,
1043 const PatternInfo& pattern) {
1044 const Graph& pattern_graph = *pattern.pattern_graph;
1045 const std::unordered_map<std::string, Value*>& vmap = pattern.vmap;
1046
1047 const auto& matches = findPatternMatches(pattern_graph, graph);
1048 for (const auto& match : matches) {
1049 if (!std::all_of(
1050 pattern.filters.begin(),
1051 pattern.filters.end(),
1052 [&](const MatchFilter& f) { return f(match, vmap); })) {
1053 continue;
1054 }
1055 auto first_output = match.values_map.at(vmap.at("first_output"));
1056 auto second_output = match.values_map.at(vmap.at("second_output"));
1057 GRAPH_DEBUG(
1058 "Delay observation for value in function pattern:",
1059 first_output->debugName(),
1060 " to ",
1061 second_output->debugName());
1062 delay_observation_map_[first_output] = second_output;
1063 }
1064 }
1065
addValuesToDelayObservation(const Module & module,const std::string & method_name)1066 void InsertObserversHelper::addValuesToDelayObservation(
1067 const Module& module,
1068 const std::string& method_name) {
1069 Method method = module.get_method(method_name);
1070 auto graph = method.graph();
1071
1072 for (const auto& pattern : delay_patterns) {
1073 delayObservingValuesInPattern(*graph, pattern);
1074 }
1075 }
1076
fillPassThroughValueMap(const std::shared_ptr<Graph> & graph)1077 void InsertObserversHelper::fillPassThroughValueMap(
1078 const std::shared_ptr<Graph>& graph) {
1079 std::stack<Block*> blocks_to_visit;
1080 blocks_to_visit.push(graph->block());
1081 while (!blocks_to_visit.empty()) {
1082 Block* b = blocks_to_visit.top();
1083 blocks_to_visit.pop();
1084 for (Node* n : b->nodes()) {
1085 if (userDefinedCallFunction(n)) {
1086 auto g = getCallFunctionGraph(n);
1087 blocks_to_visit.push(g->block());
1088 }
1089 for (auto* output : n->outputs()) {
1090 for (auto* input : getPassThroughInputs(output)) {
1091 pass_through_value_map_[output].push_back(input);
1092 }
1093 }
1094 for (Block* subblock : n->blocks()) {
1095 blocks_to_visit.push(subblock);
1096 }
1097 }
1098 }
1099 }
1100
fillBoundaryValueMap(Module & module,const std::string & method_name)1101 void InsertObserversHelper::fillBoundaryValueMap(
1102 Module& module,
1103 const std::string& method_name) {
1104 for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1105 auto& invoked_module = std::get<0>(invoked_method);
1106 const auto& invoked_method_name = std::get<1>(invoked_method);
1107 fillBoundaryValueMap(invoked_module, invoked_method_name);
1108 }
1109
1110 auto graph = module.get_method(method_name).graph();
1111 std::stack<Block*> blocks_to_visit;
1112 blocks_to_visit.push(graph->block());
1113 auto* self = graph->inputs()[0];
1114 while (!blocks_to_visit.empty()) {
1115 Block* b = blocks_to_visit.top();
1116 blocks_to_visit.pop();
1117 for (Node* n : b->nodes()) {
1118 if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
1119 std::shared_ptr<Graph> g;
1120 // offset of input for the caller node, since the first
1121 // input of CallFunction is the function node and the graph
1122 // for CallFunction start with actual input
1123 size_t input_offset = 0;
1124 if (n->kind() == prim::CallMethod) {
1125 auto m_opt = getInvokedModuleOpt(module, n, self);
1126 if (!m_opt.has_value()) {
1127 continue;
1128 }
1129 auto m = *m_opt;
1130 g = m.get_method(n->s(attr::name)).graph();
1131 input_offset = 0;
1132 } else {
1133 g = getCallFunctionGraph(n);
1134 input_offset = 1;
1135 }
1136 // add mapping from callsite value to value in called graph
1137 for (auto i = 0U; i < g->outputs().size(); ++i) {
1138 auto* return_val = g->outputs()[i];
1139 GRAPH_DEBUG(
1140 "Boundary Map[return]:",
1141 n->output(i)->debugName(),
1142 " -> ",
1143 return_val->debugName());
1144 boundary_value_map_[n->output(i)].insert(return_val);
1145 }
1146 for (auto i = 0U; i < g->inputs().size(); ++i) {
1147 auto caller_input_index = i + input_offset;
1148 auto* caller_input = n->input(caller_input_index);
1149 auto* input_val = g->inputs()[i];
1150 GRAPH_DEBUG(
1151 "Boundary Map[input]:",
1152 caller_input->debugName(),
1153 " -> ",
1154 input_val->debugName());
1155 boundary_value_map_[caller_input].insert(input_val);
1156 }
1157 } else if (n->kind() == prim::If) {
1158 for (Block* subblock : n->blocks()) {
1159 blocks_to_visit.push(subblock);
1160 for (Value* v : n->outputs()) {
1161 Value* subblock_output = subblock->outputs()[v->offset()];
1162 GRAPH_DEBUG(
1163 "Boundary Map[if_output]:",
1164 v->debugName(),
1165 " -> ",
1166 subblock_output->debugName());
1167 boundary_value_map_[v].insert(subblock_output);
1168 }
1169 }
1170 } else {
1171 for (Block* subblock : n->blocks()) {
1172 blocks_to_visit.push(subblock);
1173 }
1174 }
1175 }
1176 }
1177 }
1178
preprocess(Module & module,const std::string & method_name)1179 void InsertObserversHelper::preprocess(
1180 Module& module,
1181 const std::string& method_name) {
1182 // run preprocess for child module before parent, since preprocess
1183 // mutates the graph and it might affect passes like fillBoundaryValueMap
1184 for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1185 auto& invoked_module = std::get<0>(invoked_method);
1186 const auto& invoked_method_name = std::get<1>(invoked_method);
1187 preprocess(invoked_module, invoked_method_name);
1188 }
1189
1190 Method method = module.get_method(method_name);
1191 auto graph = method.graph();
1192 // Inline fork-wait calls
1193 InlineForkWait(graph);
1194 // fuse decomposed linear into aten::linear
1195 FuseLinear(graph);
1196 replaceConvolutionWithAtenConv(graph);
1197 RemoveListMutation(graph);
1198 }
1199
analyze(Module & module,const std::string & method_name)1200 void InsertObserversHelper::analyze(
1201 Module& module,
1202 const std::string& method_name) {
1203 for (auto& invoked_method : getInvokedMethods(module, method_name)) {
1204 auto& invoked_module = std::get<0>(invoked_method);
1205 const auto& invoked_method_name = std::get<1>(invoked_method);
1206 analyze(invoked_module, invoked_method_name);
1207 }
1208
1209 // fill out various internal state which will be later used in
1210 // insertObservers to insert the correct observer
1211 addValuesToDelayObservation(module, method_name);
1212 fillValueObserverMap(module, method_name);
1213 Method method = module.get_method(method_name);
1214 auto graph = method.graph();
1215 fillPassThroughValueMap(graph);
1216 }
1217
valueNeedsToBeQuantized(Value * v,const QConfig & qconfig)1218 bool InsertObserversHelper::valueNeedsToBeQuantized(
1219 Value* v,
1220 const QConfig& qconfig) {
1221 if (isBiasOfConvOrLinear(v) ||
1222 !(v->type()->isSubtypeOf(*TensorType::get()) ||
1223 v->type()->isSubtypeOf(*ListType::ofTensors())) ||
1224 isEmbeddingBagNonInput(v)) {
1225 return false;
1226 }
1227 // For dynamic quantization we only insert observers at the input
1228 // of the quantizable function.
1229 if (quant_type_ == QuantType::STATIC) {
1230 // Check whether producer is quantizable
1231 if (!isWeightOnlyStaticQuantOp(v->node()) &&
1232 (nodeQuantizable(v->node()) || isPropagateQuantOp(v->node()))) {
1233 return true;
1234 }
1235 }
1236 if (quant_type_ == QuantType::DYNAMIC) {
1237 // Check the dtype of the observer module.
1238 Module observer_module = getObserverModuleFor(v, qconfig);
1239 auto scalar_type = observer_module.attr("dtype");
1240 // For inputs with Fp16 type that are not-weights we don't observer them for
1241 // dynamic quantization.
1242 if (scalar_type == at::ScalarType::Half && !isWeight(v)) {
1243 return false;
1244 }
1245 }
1246 // Check whether node input value is quantizable
1247 for (const auto& use : v->uses()) {
1248 if (useQuantizable(use, quant_type_)) {
1249 return true;
1250 }
1251 }
1252 return false;
1253 }
1254
removeActivationObservers()1255 void InsertObserversHelper::removeActivationObservers() {
1256 std::vector<std::unordered_map<Value*, Module>::iterator>
1257 values_to_be_removed;
1258 for (auto it = observer_for_value_.begin(); it != observer_for_value_.end();
1259 it++) {
1260 if (!isWeight(it->first)) {
1261 values_to_be_removed.push_back(it);
1262 }
1263 }
1264 for (auto it : values_to_be_removed) {
1265 observer_for_value_.erase(it);
1266 }
1267 }
1268
fillValueObserverMap(Module & module,const std::string & method_name)1269 void InsertObserversHelper::fillValueObserverMap(
1270 Module& module,
1271 const std::string& method_name) {
1272 Method method = module.get_method(method_name);
1273 auto graph = method.graph();
1274
1275 if (visited_graph_of_observer_map_.count(graph.get())) {
1276 return;
1277 }
1278 visited_graph_of_observer_map_.insert(graph.get());
1279
1280 std::stack<Block*> blocks_to_visit;
1281 auto qconfig_opt = module_qconfig_map_.at(module._ivalue());
1282 if (!qconfig_opt) {
1283 return;
1284 }
1285 auto qconfig = *qconfig_opt;
1286 for (auto* v : graph->inputs()) {
1287 if (valueNeedsToBeQuantized(v, qconfig)) {
1288 GRAPH_DEBUG("Recording observer for ", v->debugName());
1289 GRAPH_DUMP("In graph:", v->owningGraph());
1290 observer_for_value_[v] = getObserverModuleFor(v, qconfig);
1291 }
1292 }
1293
1294 blocks_to_visit.push(graph->block());
1295 while (!blocks_to_visit.empty()) {
1296 Block* b = blocks_to_visit.top();
1297 blocks_to_visit.pop();
1298 for (Node* n : b->nodes()) {
1299 for (Value* v : n->outputs()) {
1300 if (valueNeedsToBeQuantized(v, qconfig)) {
1301 GRAPH_DEBUG("Recording observer for ", v->debugName());
1302 GRAPH_DUMP("In graph:", v->owningGraph());
1303 observer_for_value_[v] = getObserverModuleFor(v, qconfig);
1304 }
1305 }
1306
1307 for (Block* subblock : n->blocks()) {
1308 blocks_to_visit.push(subblock);
1309 }
1310 }
1311 }
1312 }
1313
getObserverFor(Value * v)1314 std::optional<Module> InsertObserversHelper::getObserverFor(Value* v) {
1315 if (observer_for_value_.count(v)) {
1316 auto observer = observer_for_value_.at(v);
1317 GRAPH_DEBUG("Got observer module config for:", v->debugName());
1318 return observer;
1319 }
1320 std::optional<Module> result;
1321 if (boundary_value_map_.count(v)) {
1322 for (Value* next : boundary_value_map_.at(v)) {
1323 GRAPH_DEBUG(
1324 "Going through boundary map:",
1325 v->debugName(),
1326 " --> ",
1327 next->debugName());
1328 GRAPH_DUMP("From graph:", v->owningGraph());
1329 GRAPH_DUMP("To graph:", next->owningGraph());
1330 auto observer_opt = getObserverFor(next);
1331 if (observer_opt) {
1332 // Need to make sure all values are
1333 // configured with same observer
1334 if (result) {
1335 TORCH_CHECK(
1336 *observer_opt == *result,
1337 "Expecting all values in the graph only configured with one observer");
1338 } else {
1339 result = observer_opt;
1340 }
1341 }
1342 }
1343 }
1344 GRAPH_DEBUG(
1345 "Observer module config for ", v->debugName(), ":", result.has_value());
1346 return result;
1347 }
1348
1349 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObservers(Module & module,const std::string & method_name,bool is_entry_point,std::unordered_set<Value * > graph_observed_values)1350 InsertObserversHelper::insertObservers(
1351 Module& module,
1352 const std::string& method_name,
1353 bool is_entry_point,
1354 std::unordered_set<Value*> graph_observed_values) {
1355 auto graph = module.get_method(method_name).graph();
1356 return insertObserversFor(
1357 graph->block(), module, graph_observed_values, is_entry_point);
1358 }
1359
recordObserved(Value * v,const Module & observer_module,std::unordered_map<Value *,Module> & values_to_observe,std::unordered_set<Value * > & block_observed_values)1360 void InsertObserversHelper::recordObserved(
1361 Value* v,
1362 const Module& observer_module,
1363 std::unordered_map<Value*, Module>& values_to_observe,
1364 std::unordered_set<Value*>& block_observed_values) {
1365 Value* to_observe = v;
1366 if (delay_observation_map_.count(v)) {
1367 to_observe = delay_observation_map_.at(v);
1368 }
1369 values_to_observe[to_observe] = observer_module;
1370 block_observed_values.insert(to_observe);
1371 }
1372
1373 std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObserversFor(Block * block,script::Module & module,std::unordered_set<Value * > & block_observed_values,bool is_entry_point,bool is_user_defined_function)1374 InsertObserversHelper::insertObserversFor(
1375 Block* block,
1376 script::Module& module,
1377 std::unordered_set<Value*>& block_observed_values,
1378 bool is_entry_point,
1379 bool is_user_defined_function) {
1380 // input/output values, used to skip inserting observers
1381 // for input and output of the block and the owning graph,
1382 // we have to insert the observers at call site because
1383 // the graph itself can be shared
1384 std::unordered_set<Value*> inputs_outputs;
1385 // list of observer modules for input values
1386 std::vector<std::optional<Module>> block_input_observers;
1387 // list of observer modules for output values
1388 std::vector<std::optional<Module>> block_output_observers;
1389
1390 // if the current block is the block for entry point graph(the forward graph
1391 // of the top level module), we can insert observers in the block directly
1392 if (!is_entry_point) {
1393 auto* graph = block->owningGraph();
1394 // graph inputs/outputs
1395 for (auto list : {graph->inputs(), graph->outputs()}) {
1396 for (auto* v : list) {
1397 inputs_outputs.insert(v);
1398 }
1399 }
1400 // block outputs
1401 for (auto* v : block->outputs()) {
1402 inputs_outputs.insert(v);
1403 }
1404
1405 for (auto* v : block->inputs()) {
1406 block_input_observers.emplace_back(getObserverFor(v));
1407 }
1408
1409 for (auto* v : block->outputs()) {
1410 // we need explicitly skip the values that are already observed
1411 // this might happen in subblocks for `if` since
1412 // these subblock has access to all values before the `if` node
1413 if (!isObserved(v, block_observed_values)) {
1414 block_output_observers.emplace_back(getObserverFor(v));
1415 } else {
1416 block_output_observers.emplace_back(std::nullopt);
1417 }
1418 }
1419 }
1420
1421 // This means the block is been processed before, we just
1422 // need to attach observer modules and construct the information
1423 // needed by call site here
1424 bool visited = block_observer_map_.count(block);
1425 if (visited) {
1426 // instance clone of observer module and setAttr
1427 for (const auto& observer_attrs : block_observer_map_.at(block)) {
1428 const auto& name = std::get<0>(observer_attrs);
1429 const auto& observer = std::get<1>(observer_attrs);
1430 module._ivalue()->setAttr(name, observer.deepcopy()._ivalue());
1431 }
1432 }
1433 // NB: Why do we need to process the graph even if it's visited?
1434 // Reason is `block_observed_values` can
1435 // change depending on where the method is called, and
1436 // outputs that's been observed(third item of the returned result)
1437 // can change depending on that, so for each graph we'll need to go through
1438 // the whole process of inserting observers, the observers inserted in this
1439 // block won't change, but the information we return to the caller will change
1440 // based on `block_observed_values`
1441
1442 std::stack<Block*> blocks_to_visit;
1443 blocks_to_visit.push(block);
1444 auto* self = block->owningGraph()->inputs()[0];
1445 // We first construct a map from value to the module, then
1446 // insert observers for them later, this is to avoid interference
1447 // of the inserted observers with the analysis to decide where
1448 // to insert observers, also we only insert observers for
1449 // "intermediate values" that is not the input/output of the
1450 // graph
1451 std::unordered_map<Value*, Module> values_to_observe;
1452
1453 for (auto* v : block->inputs()) {
1454 if (!inputs_outputs.count(v) && !values_to_observe.count(v)) {
1455 if (auto observer_opt = getObserverFor(v)) {
1456 recordObserved(
1457 v, *observer_opt, values_to_observe, block_observed_values);
1458 }
1459 }
1460 }
1461 while (!blocks_to_visit.empty()) {
1462 Block* b = blocks_to_visit.top();
1463 blocks_to_visit.pop();
1464 for (Node* n : b->nodes()) {
1465 if (observer_nodes_.count(n)) {
1466 continue;
1467 }
1468 if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
1469 script::Module m;
1470 std::shared_ptr<Graph> g;
1471 size_t input_offset = 0;
1472 bool is_udf_for_subblock = is_user_defined_function;
1473 if (n->kind() == prim::CallMethod) {
1474 auto m_opt = getInvokedModuleOpt(module, n, self);
1475 if (!m_opt.has_value()) {
1476 continue;
1477 }
1478 m = *m_opt;
1479 g = m.get_method(n->s(attr::name)).graph();
1480 input_offset = 0;
1481 } else { // CallFunction
1482 m = module;
1483 g = getCallFunctionGraph(n);
1484 input_offset = 1;
1485 is_udf_for_subblock = true;
1486 }
1487
1488 std::unordered_set<Value*> callee_observed_inputs;
1489 for (auto i = 0U; i < g->inputs().size(); ++i) {
1490 auto* node_input = n->input(i + input_offset);
1491 if (isObserved(node_input, block_observed_values)) {
1492 callee_observed_inputs.insert(g->inputs()[i]);
1493 }
1494 }
1495 auto* subblock = g->block();
1496 auto info_from_callee = insertObserversFor(
1497 subblock, m, callee_observed_inputs, false, is_udf_for_subblock);
1498 auto input_observers = std::get<0>(info_from_callee);
1499 auto output_observers = std::get<1>(info_from_callee);
1500 auto callee_observed_outputs = std::get<2>(info_from_callee);
1501 for (auto idx : callee_observed_outputs) {
1502 block_observed_values.insert(n->outputs()[idx]);
1503 }
1504 for (auto i = 0U; i < g->inputs().size(); ++i) {
1505 auto* node_input = n->input(i + input_offset);
1506 if (input_observers[i] && !inputs_outputs.count(node_input) &&
1507 !isObserved(node_input, block_observed_values)) {
1508 recordObserved(
1509 node_input,
1510 *input_observers[i],
1511 values_to_observe,
1512 block_observed_values);
1513 }
1514 }
1515 for (auto i = 0U; i < n->outputs().size(); ++i) {
1516 if (output_observers[i] && !inputs_outputs.count(n->output(i)) &&
1517 !isObserved(n->output(i), block_observed_values)) {
1518 recordObserved(
1519 n->output(i),
1520 *output_observers[i],
1521 values_to_observe,
1522 block_observed_values);
1523 }
1524 }
1525 } else if (n->kind() == prim::If) {
1526 // a vector recoding whether each output is observed or not
1527 std::vector<bool> aggregated_output_observe_state;
1528 for (Block* subblock : n->blocks()) {
1529 if (alwaysRaisesException(subblock)) {
1530 continue;
1531 }
1532 // subblock has access to all the values in the scope of prim::If,
1533 // so subblock_observed_values == block_observed_values
1534 auto info_from_subblock =
1535 insertObserversFor(subblock, module, block_observed_values);
1536 // subblock for prim::If doesn't have inputs
1537 auto output_observers = std::get<1>(info_from_subblock);
1538 auto subblock_observed_outputs = std::get<2>(info_from_subblock);
1539
1540 // We'll insert output observer for each subblock, and in the end
1541 // we will check if output of subblocks are quantized consistently
1542 for (size_t i = 0; i < subblock->outputs().size(); ++i) {
1543 Value* output = subblock->outputs()[i];
1544 if (output_observers[i] && !inputs_outputs.count(output) &&
1545 !isObserved(output, block_observed_values)) {
1546 recordObserved(
1547 output,
1548 *output_observers[i],
1549 values_to_observe,
1550 block_observed_values);
1551 }
1552 }
1553 for (auto idx : subblock_observed_outputs) {
1554 block_observed_values.insert(subblock->outputs()[idx]);
1555 }
1556 std::vector<bool> subblock_output_observe_state;
1557 for (size_t i = 0; i < subblock->outputs().size(); ++i) {
1558 Value* output = subblock->outputs()[i];
1559 subblock_output_observe_state.push_back(
1560 isObserved(output, block_observed_values));
1561 }
1562 if (!aggregated_output_observe_state.empty()) {
1563 TORCH_CHECK(
1564 aggregated_output_observe_state ==
1565 subblock_output_observe_state,
1566 "branches for `if` should return values that are observed "
1567 "consistently, if node:",
1568 *n);
1569 } else {
1570 aggregated_output_observe_state = subblock_output_observe_state;
1571 }
1572 }
1573 // mark the output of if as observed
1574 for (size_t i = 0; i < n->outputs().size(); ++i) {
1575 if (aggregated_output_observe_state[i]) {
1576 block_observed_values.insert(n->output(i));
1577 }
1578 }
1579 } else if (n->kind() == prim::Loop) {
1580 TORCH_WARN_ONCE(
1581 "prim::Loop is not yet supported in quantization, "
1582 "please make sure nothing needs to be quantized in the "
1583 "loop");
1584 }
1585 for (Value* v : n->outputs()) {
1586 propagateObservedProperty(v, block_observed_values);
1587 if (!inputs_outputs.count(v) && !isObserved(v, block_observed_values)) {
1588 auto observer_opt = getObserverFor(v);
1589 // If the node is one of the propagate quant node, e.g.
1590 // aten::cat, we should observe its output only
1591 // if the input of the node is observed
1592 if (observer_opt &&
1593 shouldObserve(n, block_observed_values, quant_type_)) {
1594 recordObserved(
1595 v, *observer_opt, values_to_observe, block_observed_values);
1596 }
1597 }
1598 }
1599 }
1600 }
1601 std::vector<size_t> output_idxs;
1602 for (auto i = 0U; i < block->outputs().size(); ++i) {
1603 if (isObserved(block->outputs()[i], block_observed_values)) {
1604 output_idxs.push_back(i);
1605 }
1606 }
1607 if (!visited) {
1608 NameModuleVector observer_name_and_modules;
1609 for (const auto& item : values_to_observe) {
1610 auto* v = item.first;
1611 auto observer = item.second;
1612 TORCH_CHECK(
1613 !is_user_defined_function,
1614 "Inserting observers for user defined functions is not "
1615 "supported right now");
1616 insertObserverFor(v, module, observer, observer_name_and_modules);
1617 }
1618 if (insertResetObserverMethod()) {
1619 insertObserverResetMinMax(module, observer_name_and_modules);
1620 }
1621 block_observer_map_[block] = observer_name_and_modules;
1622 }
1623 return std::make_tuple(
1624 block_input_observers, block_output_observers, output_idxs);
1625 }
1626
propagateObservedProperty(Value * output,std::unordered_set<Value * > & block_observed_values)1627 void InsertObserversHelper::propagateObservedProperty(
1628 Value* output,
1629 std::unordered_set<Value*>& block_observed_values) {
1630 if (pass_through_value_map_.count(output)) {
1631 // since the vector is always non-empty, we will
1632 // not return the initial value
1633 bool all_observed = true;
1634 for (Value* v : pass_through_value_map_.at(output)) {
1635 all_observed &=
1636 observed_values_.count(v) || block_observed_values.count(v);
1637 }
1638 if (all_observed) {
1639 GRAPH_DEBUG("Pass through observed property in node:", *output->node());
1640 // This is to propagate observed property through
1641 // all ops that doesn't require observation
1642 block_observed_values.insert(output);
1643 }
1644 }
1645 }
1646
1647 } // namespace
1648
InsertObservers(Module & input_module,const std::string & method_name,const QConfigDict & qconfig_dict,bool inplace,QuantType quant_type)1649 Module InsertObservers(
1650 Module& input_module,
1651 const std::string& method_name,
1652 const QConfigDict& qconfig_dict,
1653 bool inplace,
1654 QuantType quant_type) {
1655 ModuleQConfigMap map_before_clone;
1656 fillQConfigMap(input_module, qconfig_dict, map_before_clone);
1657 ModuleCloneHelper mh;
1658 Module module = mh.clone(input_module, map_before_clone, inplace);
1659 SwapFunctionalLinear(module);
1660 ModuleQConfigMap module_qconfig_map;
1661 // Since the types are changed after clone, we need to fill
1662 // the qconfig map again
1663 fillQConfigMap(module, qconfig_dict, module_qconfig_map);
1664 GRAPH_DEBUG("Quant type:", quant_type);
1665 InsertObserversHelper helper(module_qconfig_map, quant_type);
1666 helper.preprocess(module, method_name);
1667 helper.fillBoundaryValueMap(module, method_name);
1668 // analyze needs to run after fillBoundaryValueMap
1669 // since we need to know the boundary value mapping to trace
1670 // through the calls
1671 helper.analyze(module, method_name);
1672 helper.insertObservers(module, method_name, /* is_entry_point */ true);
1673 return module;
1674 }
1675
InsertObserversForOnDevicePTQ(Module & input_module,const std::string & method_name,const QConfigDict & qconfig_dict,bool inplace,QuantType quant_type)1676 Module InsertObserversForOnDevicePTQ(
1677 Module& input_module,
1678 const std::string& method_name,
1679 const QConfigDict& qconfig_dict,
1680 bool inplace,
1681 QuantType quant_type) {
1682 ModuleQConfigMap map_before_clone;
1683 fillQConfigMap(input_module, qconfig_dict, map_before_clone);
1684 ModuleCloneHelper mh;
1685 Module cloned_module = mh.clone(input_module, map_before_clone, inplace);
1686 std::shared_ptr<Graph> g = cloned_module.get_method(method_name).graph();
1687 SwapFunctionalLinear(g);
1688 std::string observer_method_name = "observe_" + method_name;
1689 cloneMethod(cloned_module, method_name, observer_method_name);
1690 ModuleQConfigMap module_qconfig_map;
1691 // Since the types are changed after clone, we need to fill
1692 // the qconfig map again
1693 fillQConfigMap(cloned_module, qconfig_dict, module_qconfig_map);
1694 GRAPH_DEBUG("Quant type:", quant_type);
1695 InsertObserversHelper helper(module_qconfig_map, quant_type);
1696 // Removes list mutation part is not clear. Is it needed
1697 helper.preprocess(cloned_module, observer_method_name);
1698 // Since we expect the graph to be inlined this should not have any use
1699 // However, this function does handle if blocks
1700 // Although as far as I understood If blocks are not really handled
1701 // in JIT quantization. Should we just protect against this. That is if we
1702 // find observable value inside If block? Also side effect of inlining is that
1703 // you will have multiple getattrs for the same attribute and thus potentially
1704 // multiple observers observing the same value. This will also lead to
1705 // increased size of the packed param struct. I dont expect this to be a
1706 // common pattern but something to be aware fo Note that current quant
1707 // workflow does not prevent this anyway since during inset quant dequant
1708 // things are inlined anyway
1709 helper.fillBoundaryValueMap(cloned_module, observer_method_name);
1710 // analyze needs to run after fillBoundaryValueMap
1711 // since we need to know the boundary value mapping to trace
1712 // through the calls
1713 helper.analyze(cloned_module, observer_method_name);
1714 // Remove activation observer if quant_type is dynamic
1715 if (quant_type == QuantType::DYNAMIC) {
1716 helper.removeActivationObservers();
1717 }
1718 helper.setInsertResetObserverMethod(true, method_name);
1719 helper.insertObservers(
1720 cloned_module, observer_method_name, /* is_entry_point */ true);
1721 return cloned_module;
1722 }
1723 } // namespace jit
1724 } // namespace torch
1725