# Owner(s): ["module: onnx"] """Unit tests on `torch.onnx.symbolic_helper`.""" import torch from torch.onnx import symbolic_helper from torch.onnx._globals import GLOBALS from torch.testing._internal import common_utils class TestHelperFunctions(common_utils.TestCase): def setUp(self): super().setUp() self._initial_training_mode = GLOBALS.training_mode def tearDown(self): GLOBALS.training_mode = self._initial_training_mode @common_utils.parametrize( "op_train_mode,export_mode", [ common_utils.subtest( [1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve" ), common_utils.subtest( [0, torch.onnx.TrainingMode.EVAL], name="modes_match_op_train_mode_0_export_mode_eval", ), common_utils.subtest( [1, torch.onnx.TrainingMode.TRAINING], name="modes_match_op_train_mode_1_export_mode_training", ), ], ) def test_check_training_mode_does_not_warn_when( self, op_train_mode: int, export_mode: torch.onnx.TrainingMode ): GLOBALS.training_mode = export_mode self.assertNotWarn( lambda: symbolic_helper.check_training_mode(op_train_mode, "testop") ) @common_utils.parametrize( "op_train_mode,export_mode", [ common_utils.subtest( [0, torch.onnx.TrainingMode.TRAINING], name="modes_do_not_match_op_train_mode_0_export_mode_training", ), common_utils.subtest( [1, torch.onnx.TrainingMode.EVAL], name="modes_do_not_match_op_train_mode_1_export_mode_eval", ), ], ) def test_check_training_mode_warns_when( self, op_train_mode: int, export_mode: torch.onnx.TrainingMode, ): with self.assertWarnsRegex( UserWarning, f"ONNX export mode is set to {export_mode}" ): GLOBALS.training_mode = export_mode symbolic_helper.check_training_mode(op_train_mode, "testop") common_utils.instantiate_parametrized_tests(TestHelperFunctions) if __name__ == "__main__": common_utils.run_tests()