1 #include <torch/csrc/jit/serialization/import_source.h> 2 3 #include <ATen/core/ivalue_inl.h> 4 #include <ATen/core/qualified_name.h> 5 #include <torch/csrc/jit/frontend/parser.h> 6 #include <torch/csrc/jit/frontend/resolver.h> 7 #include <torch/csrc/jit/frontend/script_type_parser.h> 8 #include <torch/custom_class.h> 9 10 #include <regex> 11 12 namespace torch::jit { 13 14 struct OpsValue : public SugaredValue { OpsValuetorch::jit::OpsValue15 OpsValue(size_t version) : version_(version) {} kindtorch::jit::OpsValue16 std::string kind() const override { 17 return "ops"; 18 } attrtorch::jit::OpsValue19 std::shared_ptr<SugaredValue> attr( 20 const SourceRange& loc, 21 GraphFunction& m, 22 const std::string& field) override { 23 return std::make_shared<BuiltinModule>(field, version_); 24 } 25 size_t version_; 26 }; 27 28 // Represents nested namespaces, like `foo.bar.Baz`. 29 // Right now these namespaces can only contain other namespaces or NamedTypes 30 struct TORCH_API ClassNamespaceValue : public SugaredValue { 31 /** 32 * @param name The fully qualified path, which can resolve either to a 33 * namespace or a NamedType 34 * @param si The source importer that searches for and loads 35 * classes/functions. 36 */ ClassNamespaceValuetorch::jit::ClassNamespaceValue37 explicit ClassNamespaceValue( 38 c10::QualifiedName name, 39 std::shared_ptr<SourceImporterImpl> si) 40 : basename_(std::move(name)), si_(std::move(si)) {} 41 42 std::shared_ptr<SugaredValue> attr( 43 const SourceRange& loc, 44 GraphFunction& m, 45 const std::string& name) override; kindtorch::jit::ClassNamespaceValue46 std::string kind() const override { 47 return "Class Namespace"; 48 } 49 50 private: 51 c10::QualifiedName basename_; 52 std::shared_ptr<SourceImporterImpl> si_; 53 }; 54 55 // This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries 56 // in the 'constants' vector. This table is will be stored in a container format 57 // and given to the import_method when restoring the code. 58 struct ConstantTableValue : public SugaredValue { ConstantTableValuetorch::jit::ConstantTableValue59 explicit ConstantTableValue(const std::vector<at::IValue>* constants) 60 : constants_(constants) {} kindtorch::jit::ConstantTableValue61 std::string kind() const override { 62 return "CONSTANTS"; 63 } 64 // select an attribute on it, e.g. `this.field` attrtorch::jit::ConstantTableValue65 std::shared_ptr<SugaredValue> attr( 66 const SourceRange& loc, 67 GraphFunction& m, 68 const std::string& field) override { 69 const char* field_s = field.c_str(); 70 char* end = nullptr; 71 int64_t offset = strtoll(field_s + 1, &end, 10); 72 if (field.size() < 2 || *end != 0) 73 throw(ErrorReport(loc) << "invalid constant specifier: " << field); 74 if (offset < 0 || size_t(offset) >= constants_->size()) { 75 throw( 76 ErrorReport(loc) << "constant index " << offset 77 << " is out of bounds (constant table has " 78 << constants_->size() << " entries)"); 79 } 80 auto ivalue = constants_->at(offset); 81 Value* value = nullptr; 82 83 // see [Constant Object Weak CompilationUnit Reference] 84 if (ivalue.isObject() && !ivalue.toObject()->is_weak_compilation_ref()) { 85 auto obj = ivalue.toObject(); 86 if (!non_holding_object_cache.count(obj)) { 87 non_holding_object_cache[obj] = obj->copy_to_weak_compilation_ref(); 88 } 89 value = m.graph()->insertConstant(non_holding_object_cache[obj], loc); 90 } else { 91 value = m.graph()->insertConstant(constants_->at(offset), loc); 92 } 93 94 // specializing tensor type on compilation messes up typing relations 95 value->setType(unshapedType(value->type())); 96 97 return std::make_shared<SimpleValue>(value); 98 } 99 100 private: 101 std::unordered_map< 102 c10::intrusive_ptr<at::ivalue::Object>, 103 c10::intrusive_ptr<at::ivalue::Object>> 104 non_holding_object_cache; 105 const std::vector<at::IValue>* constants_; 106 }; 107 SourceImporterImpl(std::shared_ptr<CompilationUnit> cu,const std::vector<at::IValue> * constant_table,SourceLoader source_loader,size_t version)108 SourceImporterImpl::SourceImporterImpl( 109 std::shared_ptr<CompilationUnit> cu, 110 const std::vector<at::IValue>* constant_table, 111 SourceLoader source_loader, 112 size_t version) 113 : cu_(std::move(cu)), 114 source_loader_(std::move(source_loader)), 115 version_(version) { 116 env_ = { 117 {"torch", std::make_shared<BuiltinModule>("aten", version)}, 118 {"ops", std::make_shared<OpsValue>(version)}, 119 // Constants present in the model. Used to resolve "CONSTANTS.n" to the 120 // actual value 121 {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)}, 122 {"fork", SpecialFormValue::create(prim::fork)}, 123 {"awaitable", SpecialFormValue::create(prim::awaitable)}, 124 {"annotate", SpecialFormValue::create(prim::annotate)}, 125 {"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)}, 126 {"uninitialized", SpecialFormValue::create(prim::Uninitialized)}, 127 }; 128 } 129 findNamedType(const QualifiedName & name)130 TypePtr SourceImporterImpl::findNamedType(const QualifiedName& name) { 131 if (auto custom_class = getCustomClass(name.qualifiedName())) { 132 return custom_class; 133 } 134 parseSourceIfNeeded(name.prefix()); 135 auto it = to_be_defined_.find(name); 136 if (it != to_be_defined_.end() && it->second->kind() == TK_CLASS_DEF) { 137 ClassDef cd(std::move(it->second)); 138 to_be_defined_.erase(it); 139 importNamedType(name.prefix(), cd); 140 } 141 return cu_->get_type(name); 142 } 143 findFunction(const QualifiedName & name)144 Function* SourceImporterImpl::findFunction(const QualifiedName& name) { 145 parseSourceIfNeeded(name.prefix()); 146 auto it = to_be_defined_.find(name); 147 if (it != to_be_defined_.end() && it->second->kind() == TK_DEF) { 148 Def d(it->second); 149 to_be_defined_.erase(it); 150 importFunction(name.prefix(), d); 151 } 152 return cu_->find_function(name); 153 } 154 parseSourceIfNeeded(const std::string & qualifier)155 void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) { 156 // qualifier may be blank, for instance checking if __torch__ is a class. 157 if (qualifier.empty() || loaded_sources_.count(qualifier)) { 158 return; 159 } 160 loaded_sources_.insert(qualifier); 161 std::shared_ptr<Source> src = source_loader_(qualifier); 162 163 // The importer, when looking for classes/functions doesn't know if 'foo' 164 // contains definitions or if it is a prefix of 'foo.bar', we only figure it 165 // out by testing if `foo.py` exists in the source loader. If it doesn't 166 // then there is nothing to load here 167 if (!src) { 168 return; 169 } 170 Parser p(src); 171 parsePossibleVersionNumber(p.lexer()); 172 173 auto& L = p.lexer(); 174 175 while (L.cur().kind != TK_EOF) { 176 parseImports(L); 177 auto tk = L.cur(); 178 auto kind = tk.kind; 179 switch (kind) { 180 case TK_CLASS_DEF: { 181 auto parsed_treeref = ClassDef(p.parseClass()); 182 to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] = 183 parsed_treeref; 184 } break; 185 case TK_DEF: { 186 auto parsed_treeref = Def(p.parseFunction(/*is_method=*/false)); 187 to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] = 188 parsed_treeref; 189 } break; 190 default: 191 throw( 192 ErrorReport(L.cur().range) 193 << "Unexpected token in code import: " << kindToString(kind)); 194 } 195 } 196 } 197 LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)198 void SourceImporterImpl::LEGACY_import_methods( 199 const Module& mod, 200 const std::shared_ptr<Source>& src) { 201 auto self = SimpleSelf(mod.type()); 202 c10::QualifiedName prefix = *mod.type()->name(); 203 Parser p(src); 204 205 parsePossibleVersionNumber(p.lexer()); 206 207 parseImports(p.lexer()); 208 209 std::vector<Def> definitions; 210 std::vector<ResolverPtr> resolvers; 211 while (p.lexer().cur().kind != TK_EOF) { 212 auto def = Def(p.parseFunction(/*is_method=*/true)); 213 definitions.emplace_back(def); 214 resolvers.emplace_back(shared_from_this()); 215 } 216 cu_->define( 217 prefix, 218 /*properties=*/{}, 219 /*propResolvers=*/{}, 220 definitions, 221 resolvers, 222 &self); 223 } 224 resolveValue(const std::string & name,GraphFunction & m,const SourceRange & loc)225 std::shared_ptr<SugaredValue> SourceImporterImpl::resolveValue( 226 const std::string& name, 227 GraphFunction& m, 228 const SourceRange& loc) { 229 auto it = env_.find(name); 230 if (it != env_.end()) { 231 return it->second; 232 } 233 auto graph = m.graph(); 234 if (name == "inf") { 235 return std::make_shared<SimpleValue>( 236 graph->insertConstant(std::numeric_limits<double>::infinity(), loc)); 237 } 238 if (name == "nan") { 239 return std::make_shared<SimpleValue>( 240 graph->insertConstant(std::numeric_limits<double>::quiet_NaN(), loc)); 241 } 242 if (name == "infj") { 243 return std::make_shared<SimpleValue>(graph->insertConstant( 244 c10::complex<double>(0, std::numeric_limits<double>::infinity()), loc)); 245 } 246 if (name == "nanj") { 247 return std::make_shared<SimpleValue>(graph->insertConstant( 248 c10::complex<double>(0, std::numeric_limits<double>::quiet_NaN()), 249 loc)); 250 } 251 if (name == "__torch__") { 252 return std::make_shared<ClassNamespaceValue>( 253 c10::QualifiedName(name), shared_from_this()); 254 } 255 return nullptr; 256 } 257 resolveType(const std::string & name,const SourceRange & loc)258 TypePtr SourceImporterImpl::resolveType( 259 const std::string& name, 260 const SourceRange& loc) { 261 return findNamedType(QualifiedName(name)); 262 } 263 importFunction(const std::string & qualifier,const Def & def)264 void SourceImporterImpl::importFunction( 265 const std::string& qualifier, 266 const Def& def) { 267 std::vector<Def> definitions{def}; 268 std::vector<ResolverPtr> resolvers{shared_from_this()}; 269 cu_->define( 270 qualifier, 271 /*properties=*/{}, 272 /*propResolvers=*/{}, 273 definitions, 274 resolvers, 275 nullptr); 276 } 277 importNamedType(const std::string & qualifier,const ClassDef & class_def)278 void SourceImporterImpl::importNamedType( 279 const std::string& qualifier, 280 const ClassDef& class_def) { 281 const auto qualified_name = 282 QualifiedName(QualifiedName(qualifier), class_def.name().name()); 283 if (!class_def.superclass().present()) { 284 return importClass(qualified_name, class_def, /*is_module=*/false); 285 } 286 const auto& superclass_name = Var(class_def.superclass().get()).name().name(); 287 if (superclass_name == "Module") { 288 importClass(qualified_name, class_def, /*is_module=*/true); 289 } else if (superclass_name == "NamedTuple") { 290 // NamedTuples have special rules (since they are TupleTypes and not 291 // ClassTypes) 292 return importNamedTuple(qualified_name, class_def); 293 } else if (superclass_name == "Interface") { 294 cu_->define_interface( 295 qualified_name, class_def, shared_from_this(), /*is_module=*/false); 296 } else if (superclass_name == "ModuleInterface") { 297 cu_->define_interface( 298 qualified_name, class_def, shared_from_this(), /*is_module=*/true); 299 } else if (superclass_name == "Enum") { 300 importEnum(qualified_name, class_def); 301 } else { 302 throw( 303 ErrorReport(class_def.range()) 304 << "Torchscript does not support class inheritance."); 305 } 306 } 307 308 std::optional<Assign> SourceImporterImpl:: attributeAssignmentSpecialHandlingHack(const QualifiedName & qualified_classname,const Assign & assign)309 attributeAssignmentSpecialHandlingHack( 310 const QualifiedName& qualified_classname, 311 const Assign& assign) { 312 struct AttrTypeReplacementDescr { 313 std::string attr_name; 314 std::string expected_type; 315 std::string replacement_type; 316 }; 317 318 // module demangled qualname -> ReplacementDescr 319 static std::unordered_map<std::string, AttrTypeReplacementDescr> replacements{ 320 {"__torch__.torch.ao.nn.quantized.modules.linear.LinearPackedParams", 321 {"_packed_params", 322 "Tensor", 323 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, 324 {"__torch__.torch.ao.nn.quantized.modules.linear.Linear", 325 {"_packed_params", 326 "Tensor", 327 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, 328 {"__torch__.torch.ao.nn.quantized.dynamic.modules.linear.Linear", 329 {"_packed_params", 330 "Tensor", 331 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, 332 {"__torch__.torch.ao.nn.quantized.modules.conv.Conv2d", 333 {"_packed_params", 334 "Tensor", 335 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, 336 {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d", 337 {"_packed_params", 338 "Tensor", 339 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, 340 {"__torch__.torch.ao.nn.quantized.modules.conv.Conv3d", 341 {"_packed_params", 342 "Tensor", 343 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}, 344 {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d", 345 {"_packed_params", 346 "Tensor", 347 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}, 348 // BC Stuff 349 {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams", 350 {"_packed_params", 351 "Tensor", 352 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, 353 {"__torch__.torch.nn.quantized.modules.linear.Linear", 354 {"_packed_params", 355 "Tensor", 356 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, 357 {"__torch__.torch.nn.quantized.modules.conv.Conv2d", 358 {"_packed_params", 359 "Tensor", 360 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, 361 {"__torch__.torch.nn.quantized.modules.conv.Conv3d", 362 {"_packed_params", 363 "Tensor", 364 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}, 365 {"__torch__.torch.nn.quantized.dynamic.modules.linear.Linear", 366 {"_packed_params", 367 "Tensor", 368 "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}}; 369 // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful 370 static std::regex mangle_re("\\.___torch_mangle_\\d+"); 371 auto demangled_classname = 372 std::regex_replace(qualified_classname.qualifiedName(), mangle_re, ""); 373 if (replacements.count(demangled_classname)) { 374 auto lhs = Var(assign.lhs()); 375 if (!assign.type().present() || assign.type().get().kind() != TK_VAR) { 376 return std::nullopt; 377 } 378 auto type = Var(assign.type().get()); 379 380 auto& attr_name = replacements.at(demangled_classname).attr_name; 381 auto& expected_type = replacements.at(demangled_classname).expected_type; 382 auto& replacement_type = 383 replacements.at(demangled_classname).replacement_type; 384 if (lhs.name().name() == attr_name && type.name().name() == expected_type) { 385 Parser p(std::make_shared<Source>(replacement_type)); 386 auto typename_expr = p.parseExp(); 387 auto maybe_typename = 388 Maybe<Expr>::create(typename_expr.range(), typename_expr); 389 return Assign::create( 390 assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename); 391 } 392 } 393 return std::nullopt; 394 } 395 importClass(const QualifiedName & qualified_classname,const ClassDef & class_def,bool is_module)396 void SourceImporterImpl::importClass( 397 const QualifiedName& qualified_classname, 398 const ClassDef& class_def, 399 bool is_module) { 400 // BC for TorchBind classes 401 // 402 // Previously we would serialize TorchBind classes as actual 403 // classes with methods that delegate to things in the 404 // torch.ops.* namespace. We've switched away from this and 405 // now just rely on those classes being present in the binary 406 // and emit code for them based on the ClassType in memory. 407 // 408 // TODO: remove this once we no longer have old TorchBind code 409 // in production models 410 { 411 static QualifiedName torch_classes_qualname("__torch__.torch.classes"); 412 if (torch_classes_qualname.isPrefixOf(qualified_classname)) { 413 return; 414 } 415 } 416 auto class_type = ClassType::create( 417 c10::QualifiedName(qualified_classname), cu_, is_module); 418 419 std::vector<Def> methods; 420 std::vector<ResolverPtr> method_resolvers; 421 std::map<std::string, Def> pre_hook_def_map; 422 std::map<std::string, Def> hook_def_map; 423 std::map<std::string, ResolverPtr> pre_hook_resolver_map; 424 std::map<std::string, ResolverPtr> hook_resolver_map; 425 std::vector<Assign> attributes; 426 std::vector<Assign> constants; 427 428 // Module-specific: which attrs are parameters? 429 std::unordered_set<std::string> parameter_names; 430 std::unordered_set<std::string> buffer_names; 431 std::unordered_set<std::string> pre_hook_names; 432 std::unordered_set<std::string> hook_names; 433 // used to keep track of original ordering of hooks and prehooks 434 // in case any are called more than once 435 std::vector<std::string> pre_hooks_order; 436 std::vector<std::string> hooks_order; 437 // Process statements, splitting things into attribute and method 438 // definitions. 439 for (const auto& statement : class_def.body()) { 440 switch (statement.kind()) { 441 case TK_ASSIGN: { 442 const auto assign = Assign(statement); 443 auto check_assign_values = [&assign](const std::string& name) { 444 TORCH_CHECK( 445 assign.rhs().present(), 446 "Malformed assignment statement: missing values to assign in ", 447 name); 448 }; 449 switch (assign.lhs().kind()) { 450 case TK_VAR: { 451 const auto name = Var(assign.lhs()).name().name(); 452 if (name == "__parameters__") { 453 // Populate the module parameter list. This is a field that 454 // looks like: 455 // __parameters__ = ["foo", "bar", "baz"] 456 // which tells us which attributes are module parameters. 457 TORCH_INTERNAL_ASSERT( 458 is_module, 459 "Assignments in class body only " 460 "supported on modules right now"); 461 check_assign_values(name); 462 const auto param_list = ListLiteral(assign.rhs().get()).inputs(); 463 for (const auto& param : param_list) { 464 parameter_names.insert(StringLiteral(param).text()); 465 } 466 } else if (name == "__annotations__") { 467 // This is to initialize the annotations dict, just ignore. 468 continue; 469 } else if (name == "__buffers__") { 470 TORCH_INTERNAL_ASSERT( 471 is_module, "Buffers only exist on modules at the moment"); 472 check_assign_values(name); 473 const auto buffer_list = ListLiteral(assign.rhs().get()).inputs(); 474 for (const auto& buffer : buffer_list) { 475 buffer_names.insert(StringLiteral(buffer).text()); 476 } 477 } else if (name == "__forward_pre_hooks__") { 478 TORCH_INTERNAL_ASSERT( 479 is_module, 480 "Forward pre hooks only exist on modules at the moment"); 481 check_assign_values(name); 482 const auto pre_hook_list = 483 ListLiteral(assign.rhs().get()).inputs(); 484 for (const auto& pre_hook : pre_hook_list) { 485 std::string pre_hook_name = StringLiteral(pre_hook).text(); 486 pre_hook_names.insert(pre_hook_name); 487 pre_hooks_order.emplace_back(pre_hook_name); 488 } 489 } else if (name == "__forward_hooks__") { 490 TORCH_INTERNAL_ASSERT( 491 is_module, 492 "Forward hooks only exist on modules at the moment"); 493 check_assign_values(name); 494 const auto hook_list = ListLiteral(assign.rhs().get()).inputs(); 495 for (const auto& hook : hook_list) { 496 std::string hook_name = StringLiteral(hook).text(); 497 hook_names.insert(hook_name); 498 hooks_order.emplace_back(hook_name); 499 } 500 } else { 501 if (auto fixed_up = attributeAssignmentSpecialHandlingHack( 502 qualified_classname, assign)) { 503 attributes.push_back(std::move(*fixed_up)); 504 } else if (assign.rhs().present()) { 505 // This is a constant assignment, of the form: 506 // foo : Final[int] = 3 507 constants.push_back(assign); 508 } else { 509 // This is a regular attribute assignment, of the form: 510 // foo : Tensor 511 attributes.push_back(assign); 512 } 513 } 514 } break; 515 case TK_SUBSCRIPT: { 516 // This is a special attribute assignment where the attribute 517 // is not a valid python, identifier. Looks like: 518 // __annotations__["0"] = Tensor 519 const auto lhs = Subscript(assign.lhs()); 520 TORCH_INTERNAL_ASSERT( 521 Var(lhs.value()).name().name() == "__annotations__"); 522 TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1); 523 attributes.push_back(assign); 524 } break; 525 default: { 526 TORCH_INTERNAL_ASSERT( 527 false, 528 "Unexpected statement kind in module metadata: ", 529 kindToString(statement.kind())); 530 } 531 } 532 } break; 533 case TK_DEF: { 534 Def def = Def(statement); 535 const auto def_name = def.name().name(); 536 if (pre_hook_names.find(def_name) != pre_hook_names.end()) { 537 pre_hook_def_map.emplace(def_name, def); 538 pre_hook_resolver_map.emplace(def_name, shared_from_this()); 539 } else if (hook_names.find(def_name) != hook_names.end()) { 540 hook_def_map.emplace(def_name, def); 541 hook_resolver_map.emplace(def_name, shared_from_this()); 542 } else { 543 methods.emplace_back(def); 544 method_resolvers.push_back(shared_from_this()); 545 } 546 } break; 547 default: { 548 TORCH_INTERNAL_ASSERT( 549 false, 550 "Unexpected statement kind in class body: ", 551 kindToString(statement.kind())); 552 } 553 } 554 } 555 556 // Populate class attributes 557 ScriptTypeParser type_parser(shared_from_this()); 558 for (const auto& assign : attributes) { 559 // NOLINTNEXTLINE(bugprone-switch-missing-default-case) 560 switch (assign.lhs().kind()) { 561 case TK_VAR: { 562 const auto name = Var(assign.lhs()).name().name(); 563 TORCH_INTERNAL_ASSERT(name != "__parameters__"); 564 const auto type = assign.type().present() 565 ? type_parser.parseTypeFromExpr(assign.type().get()) 566 : type_parser.parseTypeFromExpr(assign.rhs().get()); 567 const bool is_parameter = parameter_names.count(name); 568 const bool is_buffer = buffer_names.count(name); 569 class_type->addAttribute(name, type, is_parameter, is_buffer); 570 } break; 571 case TK_SUBSCRIPT: { 572 const auto name = 573 StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0]).text(); 574 const auto type = assign.type().present() 575 ? type_parser.parseTypeFromExpr(assign.type().get()) 576 : type_parser.parseTypeFromExpr(assign.rhs().get()); 577 const bool is_parameter = parameter_names.count(name); 578 const bool is_buffer = buffer_names.count(name); 579 class_type->addAttribute(name, type, is_parameter, is_buffer); 580 } 581 } 582 } 583 584 // Populate class constants 585 for (const auto& assign : constants) { 586 auto const_val = type_parser.parseClassConstant(assign); 587 const auto name = Var(assign.lhs()).name().name(); 588 class_type->addConstant(name, const_val); 589 } 590 591 // build pre hook and hook def/resolver pairs 592 // pairs are dedupped in ir_emitter.cpp's CompilationUnit::define_hooks() 593 // ordering here is call order for hooks 594 std::vector<Def> hooks; 595 std::vector<ResolverPtr> hook_resolvers; 596 for (const std::string& hook_name : hooks_order) { 597 hooks.emplace_back(hook_def_map.find(hook_name)->second); 598 hook_resolvers.push_back(hook_resolver_map.find(hook_name)->second); 599 } 600 std::vector<Def> pre_hooks; 601 std::vector<ResolverPtr> pre_hook_resolvers; 602 for (const std::string& pre_hook_name : pre_hooks_order) { 603 pre_hooks.emplace_back(pre_hook_def_map.find(pre_hook_name)->second); 604 pre_hook_resolvers.push_back( 605 pre_hook_resolver_map.find(pre_hook_name)->second); 606 } 607 608 cu_->register_type(class_type); 609 const auto self = SimpleSelf(class_type); 610 // TODO (this will include the version number later) 611 cu_->define( 612 qualified_classname, 613 /*properties=*/{}, 614 /*propResolvers=*/{}, 615 methods, 616 method_resolvers, 617 &self, 618 /*shouldMangle=*/false, 619 /*operator_set_version=*/version_); 620 cu_->define_hooks( 621 qualified_classname, 622 hooks, 623 hook_resolvers, 624 pre_hooks, 625 pre_hook_resolvers, 626 &self); 627 } 628 importEnum(const QualifiedName & qualified_name,const ClassDef & enum_def)629 void SourceImporterImpl::importEnum( 630 const QualifiedName& qualified_name, 631 const ClassDef& enum_def) { 632 std::vector<at::EnumNameValue> names_values; 633 634 TypePtr value_type = nullptr; 635 auto set_or_check_type = 636 [&value_type](const TypePtr& t, const SourceRange& loc) { 637 if (!value_type) { 638 value_type = t; 639 } else if (value_type != t) { 640 throw( 641 ErrorReport(loc) 642 << "Enum class with varying value types are not supported."); 643 } 644 }; 645 646 for (const auto& statement : enum_def.body()) { 647 if (statement.kind() != TK_ASSIGN) { 648 throw( 649 ErrorReport(statement.range()) 650 << "Unexpected statement in Enum class body: " 651 "only enum attribute definitions are currently supported."); 652 } 653 654 const auto assign = Assign(statement); 655 const auto name = Var(assign.lhs()).name().name(); 656 657 IValue ivalue; 658 auto rhs = assign.rhs().get(); 659 switch (rhs.kind()) { 660 case TK_STRINGLITERAL: 661 ivalue = IValue(StringLiteral(rhs).text()); 662 set_or_check_type(StringType::get(), statement.range()); 663 break; 664 case TK_CONST: { 665 auto numeric_const = Const(rhs); 666 if (numeric_const.isFloatingPoint()) { 667 ivalue = IValue(numeric_const.asFloatingPoint()); 668 set_or_check_type(FloatType::get(), statement.range()); 669 } else if (numeric_const.isIntegral()) { 670 ivalue = IValue(numeric_const.asIntegral()); 671 set_or_check_type(IntType::get(), statement.range()); 672 } 673 break; 674 } 675 default: 676 throw( 677 ErrorReport(rhs.range()) 678 << "Unsupported enum value type: " << rhs.kind() 679 << ". Only Integers, Floats and Strings are supported."); 680 } 681 682 names_values.emplace_back(name, ivalue); 683 } 684 685 if (!value_type) { 686 throw( 687 ErrorReport(enum_def.range()) 688 << "No enum values defined for " << qualified_name.qualifiedName()); 689 } 690 691 auto enum_type = EnumType::create( 692 qualified_name, std::move(value_type), std::move(names_values), cu_); 693 cu_->register_type(enum_type); 694 } 695 importNamedTuple(const QualifiedName & qualified_name,const ClassDef & named_tuple_def)696 void SourceImporterImpl::importNamedTuple( 697 const QualifiedName& qualified_name, 698 const ClassDef& named_tuple_def) { 699 ScriptTypeParser type_parser(shared_from_this()); 700 std::vector<std::string> field_names; 701 std::vector<TypePtr> field_types; 702 std::vector<IValue> field_defaults; 703 for (const auto& statement : named_tuple_def.body()) { 704 if (statement.kind() != TK_ASSIGN) { 705 throw( 706 ErrorReport(statement.range()) 707 << "Unexpected statement in NamedTuple body: " 708 "only attribute annotations are currently supported."); 709 } 710 const auto assign = Assign(statement); 711 TORCH_INTERNAL_ASSERT(assign.type().present()); 712 713 auto name = Var(Assign(statement).lhs()).name().name(); 714 std::optional<IValue> default_val; 715 if (assign.rhs().present()) { 716 std::vector<IValue> parsed = type_parser.evaluateDefaults( 717 assign.rhs().range(), {assign.rhs().get()}, {assign.type().get()}); 718 TORCH_INTERNAL_ASSERT(parsed.size() == 1); 719 default_val = parsed[0]; 720 } 721 722 auto type = type_parser.parseTypeFromExpr(assign.type().get()); 723 724 field_names.emplace_back(std::move(name)); 725 field_types.emplace_back(std::move(type)); 726 if (default_val) { 727 field_defaults.emplace_back(std::move(*default_val)); 728 } 729 } 730 731 auto tt = TupleType::createNamed( 732 qualified_name, field_names, field_types, field_defaults); 733 cu_->register_type(tt); 734 } 735 parsePossibleVersionNumber(Lexer & L)736 void SourceImporterImpl::parsePossibleVersionNumber(Lexer& L) { 737 // Older versions of serialization produced an op_version_set string 738 // per-file We now just use a single version which is handled by 739 // PyTorchStreamReader. We used to check if op_version_set was _newer_ for 740 // forward compatibility reasons but now that it doesn't exist there can't 741 // be a newer one, so we just discard this. 742 if (L.cur().kind == TK_IDENT && L.cur().text() == "op_version_set") { 743 auto range = L.cur().range; 744 L.next(); 745 L.expect('='); 746 L.expect(TK_NUMBER); 747 L.expect(TK_NEWLINE); 748 } 749 } 750 751 // older versions of serialization required import statements, 752 // and defined classes file-at-a-time in import order. 753 // The problem is that in Python 754 // it is possible to construct cyclic dependencies between files even 755 // when there are none between individual classes. New versions of loading 756 // just compile class-at-a-time, so we no longer need to follow the import 757 // order. Future serialization may stop producing the import code. parseImports(Lexer & L)758 void SourceImporterImpl::parseImports(Lexer& L) { 759 while (L.nextIf(TK_IMPORT)) { 760 std::ostringstream s; 761 while (L.cur().kind != TK_NEWLINE) { 762 s << L.cur().text(); 763 L.next(); 764 } 765 L.expect(TK_NEWLINE); 766 } 767 } 768 attr(const SourceRange & loc,GraphFunction & m,const std::string & name)769 std::shared_ptr<SugaredValue> ClassNamespaceValue::attr( 770 const SourceRange& loc, 771 GraphFunction& m, 772 const std::string& name) { 773 auto fullName = c10::QualifiedName(basename_, name); 774 // Could be a ClassType or NamedTuple constructor 775 if (auto serializable_type = si_->findNamedType(fullName)) { 776 if (auto classType = serializable_type->cast<ClassType>()) { 777 return std::make_shared<ClassValue>(classType); 778 } else if (auto tupleType = serializable_type->cast<TupleType>()) { 779 return std::make_shared<NamedTupleConstructor>(tupleType); 780 } else if (auto enumType = serializable_type->cast<EnumType>()) { 781 return std::make_shared<SugaredEnumClass>(enumType); 782 } 783 } 784 785 // Or it could be a free function 786 if (auto fn = si_->findFunction(fullName)) { 787 return std::make_shared<FunctionValue>(fn); 788 } 789 790 // If it's none of those things, assume it's another namespace 791 return std::make_shared<ClassNamespaceValue>(std::move(fullName), si_); 792 } 793 SourceImporter(std::shared_ptr<CompilationUnit> cu,const std::vector<IValue> * constant_table,SourceLoader loader,size_t version)794 SourceImporter::SourceImporter( 795 // The compilation unit that will own the imported source 796 std::shared_ptr<CompilationUnit> cu, 797 const std::vector<IValue>* constant_table, 798 SourceLoader loader, 799 size_t version) 800 : pImpl(std::make_shared<SourceImporterImpl>( 801 std::move(cu), 802 constant_table, 803 std::move(loader), 804 version)) {} 805 loadType(const QualifiedName & name) const806 TypePtr SourceImporter::loadType(const QualifiedName& name) const { 807 ScriptTypeParser type_parser(pImpl); 808 TypePtr t = type_parser.parseType(name.qualifiedName()); 809 return t; 810 } 811 LEGACY_import_methods(const Module & mod,const std::shared_ptr<Source> & src)812 void SourceImporter::LEGACY_import_methods( 813 const Module& mod, 814 const std::shared_ptr<Source>& src) { 815 pImpl->LEGACY_import_methods(mod, src); 816 } 817 SourceImporter::~SourceImporter() = default; 818 819 } // namespace torch::jit 820