# Owner(s): ["oncall: package/deploy"] import importlib from io import BytesIO from sys import version_info from textwrap import dedent from unittest import skipIf import torch.nn from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter from torch.package.package_exporter import PackagingError from torch.testing._internal.common_utils import IS_WINDOWS, run_tests try: from .common import PackageTestCase except ImportError: # Support the case where we run this file directly. from common import PackageTestCase class TestDependencyAPI(PackageTestCase): """Dependency management API tests. - mock() - extern() - deny() """ def test_extern(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.extern(["package_a.subpackage", "module_a"]) he.save_source_string("foo", "import package_a.subpackage; import module_a") buffer.seek(0) hi = PackageImporter(buffer) import module_a import package_a.subpackage module_a_im = hi.import_module("module_a") hi.import_module("package_a.subpackage") package_a_im = hi.import_module("package_a") self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage) def test_extern_glob(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.extern(["package_a.*", "module_*"]) he.save_module("package_a") he.save_source_string( "test_module", dedent( """\ import package_a.subpackage import module_a """ ), ) buffer.seek(0) hi = PackageImporter(buffer) import module_a import package_a.subpackage module_a_im = hi.import_module("module_a") hi.import_module("package_a.subpackage") package_a_im = hi.import_module("package_a") self.assertIs(module_a, module_a_im) self.assertIsNot(package_a, package_a_im) self.assertIs(package_a.subpackage, package_a_im.subpackage) def test_extern_glob_allow_empty(self): """ Test that an error is thrown when a extern glob is specified with allow_empty=True and no matching module is required during packaging. """ import package_a.subpackage # noqa: F401 buffer = BytesIO() with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"): with PackageExporter(buffer) as exporter: exporter.extern(include=["package_b.*"], allow_empty=False) exporter.save_module("package_a.subpackage") def test_deny(self): """ Test marking packages as "deny" during export. """ buffer = BytesIO() with self.assertRaisesRegex(PackagingError, "denied"): with PackageExporter(buffer) as exporter: exporter.deny(["package_a.subpackage", "module_a"]) exporter.save_source_string("foo", "import package_a.subpackage") def test_deny_glob(self): """ Test marking packages as "deny" using globs instead of package names. """ buffer = BytesIO() with self.assertRaises(PackagingError): with PackageExporter(buffer) as exporter: exporter.deny(["package_a.*", "module_*"]) exporter.save_source_string( "test_module", dedent( """\ import package_a.subpackage import module_a """ ), ) @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") def test_mock(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.mock(["package_a.subpackage", "module_a"]) # Import something that dependso n package_a.subpackage he.save_source_string("foo", "import package_a.subpackage") buffer.seek(0) hi = PackageImporter(buffer) import package_a.subpackage _ = package_a.subpackage import module_a _ = module_a m = hi.import_module("package_a.subpackage") r = m.result with self.assertRaisesRegex(NotImplementedError, "was mocked out"): r() @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") def test_mock_glob(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.mock(["package_a.*", "module*"]) he.save_module("package_a") he.save_source_string( "test_module", dedent( """\ import package_a.subpackage import module_a """ ), ) buffer.seek(0) hi = PackageImporter(buffer) import package_a.subpackage _ = package_a.subpackage import module_a _ = module_a m = hi.import_module("package_a.subpackage") r = m.result with self.assertRaisesRegex(NotImplementedError, "was mocked out"): r() def test_mock_glob_allow_empty(self): """ Test that an error is thrown when a mock glob is specified with allow_empty=True and no matching module is required during packaging. """ import package_a.subpackage # noqa: F401 buffer = BytesIO() with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"): with PackageExporter(buffer) as exporter: exporter.mock(include=["package_b.*"], allow_empty=False) exporter.save_module("package_a.subpackage") @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") def test_pickle_mocked(self): import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) buffer = BytesIO() with self.assertRaises(PackagingError): with PackageExporter(buffer) as he: he.mock(include="package_a.subpackage") he.intern("**") he.save_pickle("obj", "obj.pkl", obj2) @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") def test_pickle_mocked_all(self): import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) buffer = BytesIO() with PackageExporter(buffer) as he: he.intern(include="package_a.**") he.mock("**") he.save_pickle("obj", "obj.pkl", obj2) def test_allow_empty_with_error(self): """If an error occurs during packaging, it should not be shadowed by the allow_empty error.""" buffer = BytesIO() with self.assertRaises(ModuleNotFoundError): with PackageExporter(buffer) as pe: # Even though we did not extern a module that matches this # pattern, we want to show the save_module error, not the allow_empty error. pe.extern("foo", allow_empty=False) pe.save_module("aodoifjodisfj") # will error # we never get here, so technically the allow_empty check # should raise an error. However, the error above is more # informative to what's actually going wrong with packaging. pe.save_source_string("bar", "import foo\n") def test_implicit_intern(self): """The save_module APIs should implicitly intern the module being saved.""" import package_a # noqa: F401 buffer = BytesIO() with PackageExporter(buffer) as he: he.save_module("package_a") def test_intern_error(self): """Failure to handle all dependencies should lead to an error.""" import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) buffer = BytesIO() with self.assertRaises(PackagingError) as e: with PackageExporter(buffer) as he: he.save_pickle("obj", "obj.pkl", obj2) self.assertEqual( str(e.exception), dedent( """ * Module did not match against any action pattern. Extern, mock, or intern it. package_a package_a.subpackage Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! """ ), ) # Interning all dependencies should work with PackageExporter(buffer) as he: he.intern(["package_a", "package_a.subpackage"]) he.save_pickle("obj", "obj.pkl", obj2) @skipIf(IS_WINDOWS, "extension modules have a different file extension on windows") def test_broken_dependency(self): """A unpackageable dependency should raise a PackagingError.""" def create_module(name): spec = importlib.machinery.ModuleSpec(name, self, is_package=False) # type: ignore[arg-type] module = importlib.util.module_from_spec(spec) ns = module.__dict__ ns["__spec__"] = spec ns["__loader__"] = self ns["__file__"] = f"{name}.so" ns["__cached__"] = None return module class BrokenImporter(Importer): def __init__(self) -> None: self.modules = { "foo": create_module("foo"), "bar": create_module("bar"), } def import_module(self, module_name): return self.modules[module_name] buffer = BytesIO() with self.assertRaises(PackagingError) as e: with PackageExporter(buffer, importer=BrokenImporter()) as exporter: exporter.intern(["foo", "bar"]) exporter.save_source_string("my_module", "import foo; import bar") self.assertEqual( str(e.exception), dedent( """ * Module is a C extension module. torch.package supports Python modules only. foo bar Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! """ ), ) def test_invalid_import(self): """An incorrectly-formed import should raise a PackagingError.""" buffer = BytesIO() with self.assertRaises(PackagingError) as e: with PackageExporter(buffer) as exporter: # This import will fail to load. exporter.save_source_string("foo", "from ........ import lol") self.assertEqual( str(e.exception), dedent( """ * Dependency resolution failed. foo Context: attempted relative import beyond top-level package Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! """ ), ) @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") def test_repackage_mocked_module(self): """Re-packaging a package that contains a mocked module should work correctly.""" buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.mock("package_a") exporter.save_source_string("foo", "import package_a") buffer.seek(0) importer = PackageImporter(buffer) foo = importer.import_module("foo") # "package_a" should be mocked out. with self.assertRaises(NotImplementedError): foo.package_a.get_something() # Re-package the model, but intern the previously-mocked module and mock # everything else. buffer2 = BytesIO() with PackageExporter(buffer2, importer=importer) as exporter: exporter.intern("package_a") exporter.mock("**") exporter.save_source_string("foo", "import package_a") buffer2.seek(0) importer2 = PackageImporter(buffer2) foo2 = importer2.import_module("foo") # "package_a" should still be mocked out. with self.assertRaises(NotImplementedError): foo2.package_a.get_something() def test_externing_c_extension(self): """Externing c extensions modules should allow us to still access them especially those found in torch._C.""" buffer = BytesIO() # The C extension module in question is F.gelu which comes from torch._C._nn model = torch.nn.TransformerEncoderLayer( d_model=64, nhead=2, dim_feedforward=64, dropout=1.0, batch_first=True, activation="gelu", norm_first=True, ) with PackageExporter(buffer) as e: e.extern("torch.**") e.intern("**") e.save_pickle("model", "model.pkl", model) buffer.seek(0) imp = PackageImporter(buffer) imp.load_pickle("model", "model.pkl") if __name__ == "__main__": run_tests()