xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/import_source.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <torch/csrc/jit/serialization/import_source.h>
2  
3  #include <ATen/core/ivalue_inl.h>
4  #include <ATen/core/qualified_name.h>
5  #include <torch/csrc/jit/frontend/parser.h>
6  #include <torch/csrc/jit/frontend/resolver.h>
7  #include <torch/csrc/jit/frontend/script_type_parser.h>
8  #include <torch/custom_class.h>
9  
10  #include <regex>
11  
12  namespace torch::jit {
13  
14  struct OpsValue : public SugaredValue {
OpsValuetorch::jit::OpsValue15    OpsValue(size_t version) : version_(version) {}
kindtorch::jit::OpsValue16    std::string kind() const override {
17      return "ops";
18    }
attrtorch::jit::OpsValue19    std::shared_ptr<SugaredValue> attr(
20        const SourceRange& loc,
21        GraphFunction& m,
22        const std::string& field) override {
23      return std::make_shared<BuiltinModule>(field, version_);
24    }
25    size_t version_;
26  };
27  
28  // Represents nested namespaces, like `foo.bar.Baz`.
29  // Right now these namespaces can only contain other namespaces or NamedTypes
30  struct TORCH_API ClassNamespaceValue : public SugaredValue {
31    /**
32     * @param  name  The fully qualified path, which can resolve either to a
33     *               namespace or a NamedType
34     * @param  si    The source importer that searches for and loads
35     * classes/functions.
36     */
ClassNamespaceValuetorch::jit::ClassNamespaceValue37    explicit ClassNamespaceValue(
38        c10::QualifiedName name,
39        std::shared_ptr<SourceImporterImpl> si)
40        : basename_(std::move(name)), si_(std::move(si)) {}
41  
42    std::shared_ptr<SugaredValue> attr(
43        const SourceRange& loc,
44        GraphFunction& m,
45        const std::string& name) override;
kindtorch::jit::ClassNamespaceValue46    std::string kind() const override {
47      return "Class Namespace";
48    }
49  
50   private:
51    c10::QualifiedName basename_;
52    std::shared_ptr<SourceImporterImpl> si_;
53  };
54  
55  // This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries
56  // in the 'constants' vector. This table is will be stored in a container format
57  // and given to the import_method when restoring the code.
58  struct ConstantTableValue : public SugaredValue {
ConstantTableValuetorch::jit::ConstantTableValue59    explicit ConstantTableValue(const std::vector<at::IValue>* constants)
60        : constants_(constants) {}
kindtorch::jit::ConstantTableValue61    std::string kind() const override {
62      return "CONSTANTS";
63    }
64    // select an attribute on it, e.g. `this.field`
attrtorch::jit::ConstantTableValue65    std::shared_ptr<SugaredValue> attr(
66        const SourceRange& loc,
67        GraphFunction& m,
68        const std::string& field) override {
69      const char* field_s = field.c_str();
70      char* end = nullptr;
71      int64_t offset = strtoll(field_s + 1, &end, 10);
72      if (field.size() < 2 || *end != 0)
73        throw(ErrorReport(loc) << "invalid constant specifier: " << field);
74      if (offset < 0 || size_t(offset) >= constants_->size()) {
75        throw(
76            ErrorReport(loc) << "constant index " << offset
77                             << " is out of bounds (constant table has "
78                             << constants_->size() << " entries)");
79      }
80      auto ivalue = constants_->at(offset);
81      Value* value = nullptr;
82  
83      // see [Constant Object Weak CompilationUnit Reference]
84      if (ivalue.isObject() && !ivalue.toObject()->is_weak_compilation_ref()) {
85        auto obj = ivalue.toObject();
86        if (!non_holding_object_cache.count(obj)) {
87          non_holding_object_cache[obj] = obj->copy_to_weak_compilation_ref();
88        }
89        value = m.graph()->insertConstant(non_holding_object_cache[obj], loc);
90      } else {
91        value = m.graph()->insertConstant(constants_->at(offset), loc);
92      }
93  
94      // specializing tensor type on compilation messes up typing relations
95      value->setType(unshapedType(value->type()));
96  
97      return std::make_shared<SimpleValue>(value);
98    }
99  
100   private:
101    std::unordered_map<
102        c10::intrusive_ptr<at::ivalue::Object>,
103        c10::intrusive_ptr<at::ivalue::Object>>
104        non_holding_object_cache;
105    const std::vector<at::IValue>* constants_;
106  };
107  
SourceImporterImpl(std::shared_ptr<CompilationUnit> cu,const std::vector<at::IValue> * constant_table,SourceLoader source_loader,size_t version)108  SourceImporterImpl::SourceImporterImpl(
109      std::shared_ptr<CompilationUnit> cu,
110      const std::vector<at::IValue>* constant_table,
111      SourceLoader source_loader,
112      size_t version)
113      : cu_(std::move(cu)),
114        source_loader_(std::move(source_loader)),
115        version_(version) {
116    env_ = {
117        {"torch", std::make_shared<BuiltinModule>("aten", version)},
118        {"ops", std::make_shared<OpsValue>(version)},
119        // Constants present in the model. Used to resolve "CONSTANTS.n" to the
120        // actual value
121        {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
122        {"fork", SpecialFormValue::create(prim::fork)},
123        {"awaitable", SpecialFormValue::create(prim::awaitable)},
124        {"annotate", SpecialFormValue::create(prim::annotate)},
125        {"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)},
126        {"uninitialized", SpecialFormValue::create(prim::Uninitialized)},
127    };
128  }
129  
findNamedType(const QualifiedName & name)130  TypePtr SourceImporterImpl::findNamedType(const QualifiedName& name) {
131    if (auto custom_class = getCustomClass(name.qualifiedName())) {
132      return custom_class;
133    }
134    parseSourceIfNeeded(name.prefix());
135    auto it = to_be_defined_.find(name);
136    if (it != to_be_defined_.end() && it->second->kind() == TK_CLASS_DEF) {
137      ClassDef cd(std::move(it->second));
138      to_be_defined_.erase(it);
139      importNamedType(name.prefix(), cd);
140    }
141    return cu_->get_type(name);
142  }
143  
findFunction(const QualifiedName & name)144  Function* SourceImporterImpl::findFunction(const QualifiedName& name) {
145    parseSourceIfNeeded(name.prefix());
146    auto it = to_be_defined_.find(name);
147    if (it != to_be_defined_.end() && it->second->kind() == TK_DEF) {
148      Def d(it->second);
149      to_be_defined_.erase(it);
150      importFunction(name.prefix(), d);
151    }
152    return cu_->find_function(name);
153  }
154  
parseSourceIfNeeded(const std::string & qualifier)155  void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) {
156    // qualifier may be blank, for instance checking if __torch__ is a class.
157    if (qualifier.empty() || loaded_sources_.count(qualifier)) {
158      return;
159    }
160    loaded_sources_.insert(qualifier);
161    std::shared_ptr<Source> src = source_loader_(qualifier);
162  
163    // The importer, when looking for classes/functions doesn't know if 'foo'
164    // contains definitions or if it is a prefix of 'foo.bar', we only figure it
165    // out by testing if `foo.py` exists in the source loader. If it doesn't
166    // then there is nothing to load here
167    if (!src) {
168      return;
169    }
170    Parser p(src);
171    parsePossibleVersionNumber(p.lexer());
172  
173    auto& L = p.lexer();
174  
175    while (L.cur().kind != TK_EOF) {
176      parseImports(L);
177      auto tk = L.cur();
178      auto kind = tk.kind;
179      switch (kind) {
180        case TK_CLASS_DEF: {
181          auto parsed_treeref = ClassDef(p.parseClass());
182          to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] =
183              parsed_treeref;
184        } break;
185        case TK_DEF: {
186          auto parsed_treeref = Def(p.parseFunction(/*is_method=*/false));
187          to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] =
188              parsed_treeref;
189        } break;
190        default:
191          throw(
192              ErrorReport(L.cur().range)
193              << "Unexpected token in code import: " << kindToString(kind));
194      }
195    }
196  }
197  
LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)198  void SourceImporterImpl::LEGACY_import_methods(
199      const Module& mod,
200      const std::shared_ptr<Source>& src) {
201    auto self = SimpleSelf(mod.type());
202    c10::QualifiedName prefix = *mod.type()->name();
203    Parser p(src);
204  
205    parsePossibleVersionNumber(p.lexer());
206  
207    parseImports(p.lexer());
208  
209    std::vector<Def> definitions;
210    std::vector<ResolverPtr> resolvers;
211    while (p.lexer().cur().kind != TK_EOF) {
212      auto def = Def(p.parseFunction(/*is_method=*/true));
213      definitions.emplace_back(def);
214      resolvers.emplace_back(shared_from_this());
215    }
216    cu_->define(
217        prefix,
218        /*properties=*/{},
219        /*propResolvers=*/{},
220        definitions,
221        resolvers,
222        &self);
223  }
224  
resolveValue(const std::string & name,GraphFunction & m,const SourceRange & loc)225  std::shared_ptr<SugaredValue> SourceImporterImpl::resolveValue(
226      const std::string& name,
227      GraphFunction& m,
228      const SourceRange& loc) {
229    auto it = env_.find(name);
230    if (it != env_.end()) {
231      return it->second;
232    }
233    auto graph = m.graph();
234    if (name == "inf") {
235      return std::make_shared<SimpleValue>(
236          graph->insertConstant(std::numeric_limits<double>::infinity(), loc));
237    }
238    if (name == "nan") {
239      return std::make_shared<SimpleValue>(
240          graph->insertConstant(std::numeric_limits<double>::quiet_NaN(), loc));
241    }
242    if (name == "infj") {
243      return std::make_shared<SimpleValue>(graph->insertConstant(
244          c10::complex<double>(0, std::numeric_limits<double>::infinity()), loc));
245    }
246    if (name == "nanj") {
247      return std::make_shared<SimpleValue>(graph->insertConstant(
248          c10::complex<double>(0, std::numeric_limits<double>::quiet_NaN()),
249          loc));
250    }
251    if (name == "__torch__") {
252      return std::make_shared<ClassNamespaceValue>(
253          c10::QualifiedName(name), shared_from_this());
254    }
255    return nullptr;
256  }
257  
resolveType(const std::string & name,const SourceRange & loc)258  TypePtr SourceImporterImpl::resolveType(
259      const std::string& name,
260      const SourceRange& loc) {
261    return findNamedType(QualifiedName(name));
262  }
263  
importFunction(const std::string & qualifier,const Def & def)264  void SourceImporterImpl::importFunction(
265      const std::string& qualifier,
266      const Def& def) {
267    std::vector<Def> definitions{def};
268    std::vector<ResolverPtr> resolvers{shared_from_this()};
269    cu_->define(
270        qualifier,
271        /*properties=*/{},
272        /*propResolvers=*/{},
273        definitions,
274        resolvers,
275        nullptr);
276  }
277  
importNamedType(const std::string & qualifier,const ClassDef & class_def)278  void SourceImporterImpl::importNamedType(
279      const std::string& qualifier,
280      const ClassDef& class_def) {
281    const auto qualified_name =
282        QualifiedName(QualifiedName(qualifier), class_def.name().name());
283    if (!class_def.superclass().present()) {
284      return importClass(qualified_name, class_def, /*is_module=*/false);
285    }
286    const auto& superclass_name = Var(class_def.superclass().get()).name().name();
287    if (superclass_name == "Module") {
288      importClass(qualified_name, class_def, /*is_module=*/true);
289    } else if (superclass_name == "NamedTuple") {
290      // NamedTuples have special rules (since they are TupleTypes and not
291      // ClassTypes)
292      return importNamedTuple(qualified_name, class_def);
293    } else if (superclass_name == "Interface") {
294      cu_->define_interface(
295          qualified_name, class_def, shared_from_this(), /*is_module=*/false);
296    } else if (superclass_name == "ModuleInterface") {
297      cu_->define_interface(
298          qualified_name, class_def, shared_from_this(), /*is_module=*/true);
299    } else if (superclass_name == "Enum") {
300      importEnum(qualified_name, class_def);
301    } else {
302      throw(
303          ErrorReport(class_def.range())
304          << "Torchscript does not support class inheritance.");
305    }
306  }
307  
308  std::optional<Assign> SourceImporterImpl::
attributeAssignmentSpecialHandlingHack(const QualifiedName & qualified_classname,const Assign & assign)309      attributeAssignmentSpecialHandlingHack(
310          const QualifiedName& qualified_classname,
311          const Assign& assign) {
312    struct AttrTypeReplacementDescr {
313      std::string attr_name;
314      std::string expected_type;
315      std::string replacement_type;
316    };
317  
318    // module demangled qualname -> ReplacementDescr
319    static std::unordered_map<std::string, AttrTypeReplacementDescr> replacements{
320        {"__torch__.torch.ao.nn.quantized.modules.linear.LinearPackedParams",
321         {"_packed_params",
322          "Tensor",
323          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
324        {"__torch__.torch.ao.nn.quantized.modules.linear.Linear",
325         {"_packed_params",
326          "Tensor",
327          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
328        {"__torch__.torch.ao.nn.quantized.dynamic.modules.linear.Linear",
329         {"_packed_params",
330          "Tensor",
331          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
332        {"__torch__.torch.ao.nn.quantized.modules.conv.Conv2d",
333         {"_packed_params",
334          "Tensor",
335          "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
336        {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d",
337         {"_packed_params",
338          "Tensor",
339          "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
340        {"__torch__.torch.ao.nn.quantized.modules.conv.Conv3d",
341         {"_packed_params",
342          "Tensor",
343          "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
344        {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d",
345         {"_packed_params",
346          "Tensor",
347          "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
348        // BC Stuff
349        {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams",
350         {"_packed_params",
351          "Tensor",
352          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
353        {"__torch__.torch.nn.quantized.modules.linear.Linear",
354         {"_packed_params",
355          "Tensor",
356          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}},
357        {"__torch__.torch.nn.quantized.modules.conv.Conv2d",
358         {"_packed_params",
359          "Tensor",
360          "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}},
361        {"__torch__.torch.nn.quantized.modules.conv.Conv3d",
362         {"_packed_params",
363          "Tensor",
364          "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}},
365        {"__torch__.torch.nn.quantized.dynamic.modules.linear.Linear",
366         {"_packed_params",
367          "Tensor",
368          "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}};
369    // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
370    static std::regex mangle_re("\\.___torch_mangle_\\d+");
371    auto demangled_classname =
372        std::regex_replace(qualified_classname.qualifiedName(), mangle_re, "");
373    if (replacements.count(demangled_classname)) {
374      auto lhs = Var(assign.lhs());
375      if (!assign.type().present() || assign.type().get().kind() != TK_VAR) {
376        return std::nullopt;
377      }
378      auto type = Var(assign.type().get());
379  
380      auto& attr_name = replacements.at(demangled_classname).attr_name;
381      auto& expected_type = replacements.at(demangled_classname).expected_type;
382      auto& replacement_type =
383          replacements.at(demangled_classname).replacement_type;
384      if (lhs.name().name() == attr_name && type.name().name() == expected_type) {
385        Parser p(std::make_shared<Source>(replacement_type));
386        auto typename_expr = p.parseExp();
387        auto maybe_typename =
388            Maybe<Expr>::create(typename_expr.range(), typename_expr);
389        return Assign::create(
390            assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename);
391      }
392    }
393    return std::nullopt;
394  }
395  
importClass(const QualifiedName & qualified_classname,const ClassDef & class_def,bool is_module)396  void SourceImporterImpl::importClass(
397      const QualifiedName& qualified_classname,
398      const ClassDef& class_def,
399      bool is_module) {
400    // BC for TorchBind classes
401    //
402    // Previously we would serialize TorchBind classes as actual
403    // classes with methods that delegate to things in the
404    // torch.ops.* namespace. We've switched away from this and
405    // now just rely on those classes being present in the binary
406    // and emit code for them based on the ClassType in memory.
407    //
408    // TODO: remove this once we no longer have old TorchBind code
409    // in production models
410    {
411      static QualifiedName torch_classes_qualname("__torch__.torch.classes");
412      if (torch_classes_qualname.isPrefixOf(qualified_classname)) {
413        return;
414      }
415    }
416    auto class_type = ClassType::create(
417        c10::QualifiedName(qualified_classname), cu_, is_module);
418  
419    std::vector<Def> methods;
420    std::vector<ResolverPtr> method_resolvers;
421    std::map<std::string, Def> pre_hook_def_map;
422    std::map<std::string, Def> hook_def_map;
423    std::map<std::string, ResolverPtr> pre_hook_resolver_map;
424    std::map<std::string, ResolverPtr> hook_resolver_map;
425    std::vector<Assign> attributes;
426    std::vector<Assign> constants;
427  
428    // Module-specific: which attrs are parameters?
429    std::unordered_set<std::string> parameter_names;
430    std::unordered_set<std::string> buffer_names;
431    std::unordered_set<std::string> pre_hook_names;
432    std::unordered_set<std::string> hook_names;
433    // used to keep track of original ordering of hooks and prehooks
434    // in case any are called more than once
435    std::vector<std::string> pre_hooks_order;
436    std::vector<std::string> hooks_order;
437    // Process statements, splitting things into attribute and method
438    // definitions.
439    for (const auto& statement : class_def.body()) {
440      switch (statement.kind()) {
441        case TK_ASSIGN: {
442          const auto assign = Assign(statement);
443          auto check_assign_values = [&assign](const std::string& name) {
444            TORCH_CHECK(
445                assign.rhs().present(),
446                "Malformed assignment statement: missing values to assign in ",
447                name);
448          };
449          switch (assign.lhs().kind()) {
450            case TK_VAR: {
451              const auto name = Var(assign.lhs()).name().name();
452              if (name == "__parameters__") {
453                // Populate the module parameter list. This is a field that
454                // looks like:
455                //   __parameters__ = ["foo", "bar", "baz"]
456                // which tells us which attributes are module parameters.
457                TORCH_INTERNAL_ASSERT(
458                    is_module,
459                    "Assignments in class body only "
460                    "supported on modules right now");
461                check_assign_values(name);
462                const auto param_list = ListLiteral(assign.rhs().get()).inputs();
463                for (const auto& param : param_list) {
464                  parameter_names.insert(StringLiteral(param).text());
465                }
466              } else if (name == "__annotations__") {
467                // This is to initialize the annotations dict, just ignore.
468                continue;
469              } else if (name == "__buffers__") {
470                TORCH_INTERNAL_ASSERT(
471                    is_module, "Buffers only exist on modules at the moment");
472                check_assign_values(name);
473                const auto buffer_list = ListLiteral(assign.rhs().get()).inputs();
474                for (const auto& buffer : buffer_list) {
475                  buffer_names.insert(StringLiteral(buffer).text());
476                }
477              } else if (name == "__forward_pre_hooks__") {
478                TORCH_INTERNAL_ASSERT(
479                    is_module,
480                    "Forward pre hooks only exist on modules at the moment");
481                check_assign_values(name);
482                const auto pre_hook_list =
483                    ListLiteral(assign.rhs().get()).inputs();
484                for (const auto& pre_hook : pre_hook_list) {
485                  std::string pre_hook_name = StringLiteral(pre_hook).text();
486                  pre_hook_names.insert(pre_hook_name);
487                  pre_hooks_order.emplace_back(pre_hook_name);
488                }
489              } else if (name == "__forward_hooks__") {
490                TORCH_INTERNAL_ASSERT(
491                    is_module,
492                    "Forward hooks only exist on modules at the moment");
493                check_assign_values(name);
494                const auto hook_list = ListLiteral(assign.rhs().get()).inputs();
495                for (const auto& hook : hook_list) {
496                  std::string hook_name = StringLiteral(hook).text();
497                  hook_names.insert(hook_name);
498                  hooks_order.emplace_back(hook_name);
499                }
500              } else {
501                if (auto fixed_up = attributeAssignmentSpecialHandlingHack(
502                        qualified_classname, assign)) {
503                  attributes.push_back(std::move(*fixed_up));
504                } else if (assign.rhs().present()) {
505                  // This is a constant assignment, of the form:
506                  // foo : Final[int] = 3
507                  constants.push_back(assign);
508                } else {
509                  // This is a regular attribute assignment, of the form:
510                  // foo : Tensor
511                  attributes.push_back(assign);
512                }
513              }
514            } break;
515            case TK_SUBSCRIPT: {
516              // This is a special attribute assignment where the attribute
517              // is not a valid python, identifier. Looks like:
518              //    __annotations__["0"] = Tensor
519              const auto lhs = Subscript(assign.lhs());
520              TORCH_INTERNAL_ASSERT(
521                  Var(lhs.value()).name().name() == "__annotations__");
522              TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1);
523              attributes.push_back(assign);
524            } break;
525            default: {
526              TORCH_INTERNAL_ASSERT(
527                  false,
528                  "Unexpected statement kind in module metadata: ",
529                  kindToString(statement.kind()));
530            }
531          }
532        } break;
533        case TK_DEF: {
534          Def def = Def(statement);
535          const auto def_name = def.name().name();
536          if (pre_hook_names.find(def_name) != pre_hook_names.end()) {
537            pre_hook_def_map.emplace(def_name, def);
538            pre_hook_resolver_map.emplace(def_name, shared_from_this());
539          } else if (hook_names.find(def_name) != hook_names.end()) {
540            hook_def_map.emplace(def_name, def);
541            hook_resolver_map.emplace(def_name, shared_from_this());
542          } else {
543            methods.emplace_back(def);
544            method_resolvers.push_back(shared_from_this());
545          }
546        } break;
547        default: {
548          TORCH_INTERNAL_ASSERT(
549              false,
550              "Unexpected statement kind in class body: ",
551              kindToString(statement.kind()));
552        }
553      }
554    }
555  
556    // Populate class attributes
557    ScriptTypeParser type_parser(shared_from_this());
558    for (const auto& assign : attributes) {
559      // NOLINTNEXTLINE(bugprone-switch-missing-default-case)
560      switch (assign.lhs().kind()) {
561        case TK_VAR: {
562          const auto name = Var(assign.lhs()).name().name();
563          TORCH_INTERNAL_ASSERT(name != "__parameters__");
564          const auto type = assign.type().present()
565              ? type_parser.parseTypeFromExpr(assign.type().get())
566              : type_parser.parseTypeFromExpr(assign.rhs().get());
567          const bool is_parameter = parameter_names.count(name);
568          const bool is_buffer = buffer_names.count(name);
569          class_type->addAttribute(name, type, is_parameter, is_buffer);
570        } break;
571        case TK_SUBSCRIPT: {
572          const auto name =
573              StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0]).text();
574          const auto type = assign.type().present()
575              ? type_parser.parseTypeFromExpr(assign.type().get())
576              : type_parser.parseTypeFromExpr(assign.rhs().get());
577          const bool is_parameter = parameter_names.count(name);
578          const bool is_buffer = buffer_names.count(name);
579          class_type->addAttribute(name, type, is_parameter, is_buffer);
580        }
581      }
582    }
583  
584    // Populate class constants
585    for (const auto& assign : constants) {
586      auto const_val = type_parser.parseClassConstant(assign);
587      const auto name = Var(assign.lhs()).name().name();
588      class_type->addConstant(name, const_val);
589    }
590  
591    // build pre hook and hook def/resolver pairs
592    // pairs are dedupped in ir_emitter.cpp's CompilationUnit::define_hooks()
593    // ordering here is call order for hooks
594    std::vector<Def> hooks;
595    std::vector<ResolverPtr> hook_resolvers;
596    for (const std::string& hook_name : hooks_order) {
597      hooks.emplace_back(hook_def_map.find(hook_name)->second);
598      hook_resolvers.push_back(hook_resolver_map.find(hook_name)->second);
599    }
600    std::vector<Def> pre_hooks;
601    std::vector<ResolverPtr> pre_hook_resolvers;
602    for (const std::string& pre_hook_name : pre_hooks_order) {
603      pre_hooks.emplace_back(pre_hook_def_map.find(pre_hook_name)->second);
604      pre_hook_resolvers.push_back(
605          pre_hook_resolver_map.find(pre_hook_name)->second);
606    }
607  
608    cu_->register_type(class_type);
609    const auto self = SimpleSelf(class_type);
610    // TODO (this will include the version number later)
611    cu_->define(
612        qualified_classname,
613        /*properties=*/{},
614        /*propResolvers=*/{},
615        methods,
616        method_resolvers,
617        &self,
618        /*shouldMangle=*/false,
619        /*operator_set_version=*/version_);
620    cu_->define_hooks(
621        qualified_classname,
622        hooks,
623        hook_resolvers,
624        pre_hooks,
625        pre_hook_resolvers,
626        &self);
627  }
628  
importEnum(const QualifiedName & qualified_name,const ClassDef & enum_def)629  void SourceImporterImpl::importEnum(
630      const QualifiedName& qualified_name,
631      const ClassDef& enum_def) {
632    std::vector<at::EnumNameValue> names_values;
633  
634    TypePtr value_type = nullptr;
635    auto set_or_check_type =
636        [&value_type](const TypePtr& t, const SourceRange& loc) {
637          if (!value_type) {
638            value_type = t;
639          } else if (value_type != t) {
640            throw(
641                ErrorReport(loc)
642                << "Enum class with varying value types are not supported.");
643          }
644        };
645  
646    for (const auto& statement : enum_def.body()) {
647      if (statement.kind() != TK_ASSIGN) {
648        throw(
649            ErrorReport(statement.range())
650            << "Unexpected statement in Enum class body: "
651               "only enum attribute definitions are currently supported.");
652      }
653  
654      const auto assign = Assign(statement);
655      const auto name = Var(assign.lhs()).name().name();
656  
657      IValue ivalue;
658      auto rhs = assign.rhs().get();
659      switch (rhs.kind()) {
660        case TK_STRINGLITERAL:
661          ivalue = IValue(StringLiteral(rhs).text());
662          set_or_check_type(StringType::get(), statement.range());
663          break;
664        case TK_CONST: {
665          auto numeric_const = Const(rhs);
666          if (numeric_const.isFloatingPoint()) {
667            ivalue = IValue(numeric_const.asFloatingPoint());
668            set_or_check_type(FloatType::get(), statement.range());
669          } else if (numeric_const.isIntegral()) {
670            ivalue = IValue(numeric_const.asIntegral());
671            set_or_check_type(IntType::get(), statement.range());
672          }
673          break;
674        }
675        default:
676          throw(
677              ErrorReport(rhs.range())
678              << "Unsupported enum value type: " << rhs.kind()
679              << ". Only Integers, Floats and Strings are supported.");
680      }
681  
682      names_values.emplace_back(name, ivalue);
683    }
684  
685    if (!value_type) {
686      throw(
687          ErrorReport(enum_def.range())
688          << "No enum values defined for " << qualified_name.qualifiedName());
689    }
690  
691    auto enum_type = EnumType::create(
692        qualified_name, std::move(value_type), std::move(names_values), cu_);
693    cu_->register_type(enum_type);
694  }
695  
importNamedTuple(const QualifiedName & qualified_name,const ClassDef & named_tuple_def)696  void SourceImporterImpl::importNamedTuple(
697      const QualifiedName& qualified_name,
698      const ClassDef& named_tuple_def) {
699    ScriptTypeParser type_parser(shared_from_this());
700    std::vector<std::string> field_names;
701    std::vector<TypePtr> field_types;
702    std::vector<IValue> field_defaults;
703    for (const auto& statement : named_tuple_def.body()) {
704      if (statement.kind() != TK_ASSIGN) {
705        throw(
706            ErrorReport(statement.range())
707            << "Unexpected statement in NamedTuple body: "
708               "only attribute annotations are currently supported.");
709      }
710      const auto assign = Assign(statement);
711      TORCH_INTERNAL_ASSERT(assign.type().present());
712  
713      auto name = Var(Assign(statement).lhs()).name().name();
714      std::optional<IValue> default_val;
715      if (assign.rhs().present()) {
716        std::vector<IValue> parsed = type_parser.evaluateDefaults(
717            assign.rhs().range(), {assign.rhs().get()}, {assign.type().get()});
718        TORCH_INTERNAL_ASSERT(parsed.size() == 1);
719        default_val = parsed[0];
720      }
721  
722      auto type = type_parser.parseTypeFromExpr(assign.type().get());
723  
724      field_names.emplace_back(std::move(name));
725      field_types.emplace_back(std::move(type));
726      if (default_val) {
727        field_defaults.emplace_back(std::move(*default_val));
728      }
729    }
730  
731    auto tt = TupleType::createNamed(
732        qualified_name, field_names, field_types, field_defaults);
733    cu_->register_type(tt);
734  }
735  
parsePossibleVersionNumber(Lexer & L)736  void SourceImporterImpl::parsePossibleVersionNumber(Lexer& L) {
737    // Older versions of serialization produced an op_version_set string
738    // per-file We now just use a single version which is handled by
739    // PyTorchStreamReader. We used to check if op_version_set was _newer_ for
740    // forward compatibility reasons but now that it doesn't exist there can't
741    // be a newer one, so we just discard this.
742    if (L.cur().kind == TK_IDENT && L.cur().text() == "op_version_set") {
743      auto range = L.cur().range;
744      L.next();
745      L.expect('=');
746      L.expect(TK_NUMBER);
747      L.expect(TK_NEWLINE);
748    }
749  }
750  
751  // older versions of serialization required import statements,
752  // and defined classes file-at-a-time in import order.
753  // The problem is that in Python
754  // it is possible to construct cyclic dependencies between files even
755  // when there are none between individual classes. New versions of loading
756  // just compile class-at-a-time, so we no longer need to follow the import
757  // order. Future serialization may stop producing the import code.
parseImports(Lexer & L)758  void SourceImporterImpl::parseImports(Lexer& L) {
759    while (L.nextIf(TK_IMPORT)) {
760      std::ostringstream s;
761      while (L.cur().kind != TK_NEWLINE) {
762        s << L.cur().text();
763        L.next();
764      }
765      L.expect(TK_NEWLINE);
766    }
767  }
768  
attr(const SourceRange & loc,GraphFunction & m,const std::string & name)769  std::shared_ptr<SugaredValue> ClassNamespaceValue::attr(
770      const SourceRange& loc,
771      GraphFunction& m,
772      const std::string& name) {
773    auto fullName = c10::QualifiedName(basename_, name);
774    // Could be a ClassType or NamedTuple constructor
775    if (auto serializable_type = si_->findNamedType(fullName)) {
776      if (auto classType = serializable_type->cast<ClassType>()) {
777        return std::make_shared<ClassValue>(classType);
778      } else if (auto tupleType = serializable_type->cast<TupleType>()) {
779        return std::make_shared<NamedTupleConstructor>(tupleType);
780      } else if (auto enumType = serializable_type->cast<EnumType>()) {
781        return std::make_shared<SugaredEnumClass>(enumType);
782      }
783    }
784  
785    // Or it could be a free function
786    if (auto fn = si_->findFunction(fullName)) {
787      return std::make_shared<FunctionValue>(fn);
788    }
789  
790    // If it's none of those things, assume it's another namespace
791    return std::make_shared<ClassNamespaceValue>(std::move(fullName), si_);
792  }
793  
SourceImporter(std::shared_ptr<CompilationUnit> cu,const std::vector<IValue> * constant_table,SourceLoader loader,size_t version)794  SourceImporter::SourceImporter(
795      // The compilation unit that will own the imported source
796      std::shared_ptr<CompilationUnit> cu,
797      const std::vector<IValue>* constant_table,
798      SourceLoader loader,
799      size_t version)
800      : pImpl(std::make_shared<SourceImporterImpl>(
801            std::move(cu),
802            constant_table,
803            std::move(loader),
804            version)) {}
805  
loadType(const QualifiedName & name) const806  TypePtr SourceImporter::loadType(const QualifiedName& name) const {
807    ScriptTypeParser type_parser(pImpl);
808    TypePtr t = type_parser.parseType(name.qualifiedName());
809    return t;
810  }
811  
LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)812  void SourceImporter::LEGACY_import_methods(
813      const Module& mod,
814      const std::shared_ptr<Source>& src) {
815    pImpl->LEGACY_import_methods(mod, src);
816  }
817  SourceImporter::~SourceImporter() = default;
818  
819  } // namespace torch::jit
820