#!/usr/bin/env python3 # Owner(s): ["oncall: mobile"] import os import io import functools import tempfile import urllib import unittest import torch import torch.backends.xnnpack import torch.utils.model_dump import torch.utils.mobile_optimizer from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK from torch.testing._internal.common_quantized import supported_qengines class SimpleModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.layer1 = torch.nn.Linear(16, 64) self.relu1 = torch.nn.ReLU() self.layer2 = torch.nn.Linear(64, 8) self.relu2 = torch.nn.ReLU() def forward(self, features): act = features act = self.layer1(act) act = self.relu1(act) act = self.layer2(act) act = self.relu2(act) return act class QuantModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.quant = torch.ao.quantization.QuantStub() self.dequant = torch.ao.quantization.DeQuantStub() self.core = SimpleModel() def forward(self, x): x = self.quant(x) x = self.core(x) x = self.dequant(x) return x class ModelWithLists(torch.nn.Module): def __init__(self) -> None: super().__init__() self.rt = [torch.zeros(1)] self.ot = [torch.zeros(1), None] def forward(self, arg): arg = arg + self.rt[0] o = self.ot[0] if o is not None: arg = arg + o return arg def webdriver_test(testfunc): @functools.wraps(testfunc) def wrapper(self, *args, **kwds): self.needs_resources() if os.environ.get("RUN_WEBDRIVER") != "1": self.skipTest("Webdriver not requested") from selenium import webdriver for driver in [ "Firefox", "Chrome", ]: with self.subTest(driver=driver): wd = getattr(webdriver, driver)() testfunc(self, wd, *args, **kwds) wd.close() return wrapper class TestModelDump(TestCase): def needs_resources(self): pass def test_inline_skeleton(self): self.needs_resources() skel = torch.utils.model_dump.get_inline_skeleton() assert "unpkg.org" not in skel assert "src=" not in skel def do_dump_model(self, model, extra_files=None): # Just check that we're able to run successfully. buf = io.BytesIO() torch.jit.save(model, buf, _extra_files=extra_files) info = torch.utils.model_dump.get_model_info(buf) assert info is not None def open_html_model(self, wd, model, extra_files=None): buf = io.BytesIO() torch.jit.save(model, buf, _extra_files=extra_files) page = torch.utils.model_dump.get_info_and_burn_skeleton(buf) wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page)) def open_section_and_get_body(self, wd, name): container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']") caret = container.find_element_by_class_name("caret") if container.get_attribute("data-shown") != "true": caret.click() content = container.find_element_by_tag_name("div") return content def test_scripted_model(self): model = torch.jit.script(SimpleModel()) self.do_dump_model(model) def test_traced_model(self): model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16)) self.do_dump_model(model) def test_main(self): self.needs_resources() if IS_WINDOWS: # I was getting tempfile errors in CI. Just skip it. self.skipTest("Disabled on Windows.") with tempfile.NamedTemporaryFile() as tf: torch.jit.save(torch.jit.script(SimpleModel()), tf) # Actually write contents to disk so we can read it below tf.flush() stdout = io.StringIO() torch.utils.model_dump.main( [ None, "--style=json", tf.name, ], stdout=stdout) self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel') stdout = io.StringIO() torch.utils.model_dump.main( [ None, "--style=html", tf.name, ], stdout=stdout) self.assertRegex( stdout.getvalue().replace("\n", " "), r'\A None: super().__init__() self.w1 = torch.zeros(1, 2) self.w2 = torch.ones(2, 2) def forward(self, arg): return arg * self.w2 + self.w1 check_memory( torch.jit.freeze( torch.jit.script(ComposedModule()).eval(), preserved_attrs=["w1"]), 4 * (2 + 4)) if __name__ == '__main__': run_tests()