1 #!/usr/bin/env python3 2 # Owner(s): ["oncall: distributed"] 3 4 # Copyright (c) Facebook, Inc. and its affiliates. 5 # All rights reserved. 6 # 7 # This source code is licensed under the BSD-style license found in the 8 # LICENSE file in the root directory of this source tree. 9 import os 10 import unittest 11 from argparse import ArgumentParser 12 13 from torch.distributed.argparse_util import check_env, env 14 15 16 class ArgParseUtilTest(unittest.TestCase): 17 def setUp(self): 18 # remove any lingering environment variables 19 for e in os.environ.keys(): 20 if e.startswith("PET_"): 21 del os.environ[e] 22 23 def test_env_string_arg_no_env(self): 24 parser = ArgumentParser() 25 parser.add_argument("-f", "--foo", action=env, default="bar") 26 27 self.assertEqual("bar", parser.parse_args([]).foo) 28 self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo) 29 self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo) 30 31 def test_env_string_arg_env(self): 32 os.environ["PET_FOO"] = "env_baz" 33 parser = ArgumentParser() 34 parser.add_argument("-f", "--foo", action=env, default="bar") 35 36 self.assertEqual("env_baz", parser.parse_args([]).foo) 37 self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo) 38 self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo) 39 40 def test_env_int_arg_no_env(self): 41 parser = ArgumentParser() 42 parser.add_argument("-f", "--foo", action=env, default=1, type=int) 43 44 self.assertEqual(1, parser.parse_args([]).foo) 45 self.assertEqual(2, parser.parse_args(["-f", "2"]).foo) 46 self.assertEqual(2, parser.parse_args(["--foo", "2"]).foo) 47 48 def test_env_int_arg_env(self): 49 os.environ["PET_FOO"] = "3" 50 parser = ArgumentParser() 51 parser.add_argument("-f", "--foo", action=env, default=1, type=int) 52 53 self.assertEqual(3, parser.parse_args([]).foo) 54 self.assertEqual(2, parser.parse_args(["-f", "2"]).foo) 55 self.assertEqual(2, parser.parse_args(["--foo", "2"]).foo) 56 57 def test_env_no_default_no_env(self): 58 parser = ArgumentParser() 59 parser.add_argument("-f", "--foo", action=env) 60 61 self.assertIsNone(parser.parse_args([]).foo) 62 self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo) 63 self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo) 64 65 def test_env_no_default_env(self): 66 os.environ["PET_FOO"] = "env_baz" 67 parser = ArgumentParser() 68 parser.add_argument("-f", "--foo", action=env) 69 70 self.assertEqual("env_baz", parser.parse_args([]).foo) 71 self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo) 72 self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo) 73 74 def test_env_required_no_env(self): 75 parser = ArgumentParser() 76 parser.add_argument("-f", "--foo", action=env, required=True) 77 78 self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo) 79 self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo) 80 81 def test_env_required_env(self): 82 os.environ["PET_FOO"] = "env_baz" 83 parser = ArgumentParser() 84 parser.add_argument("-f", "--foo", action=env, default="bar", required=True) 85 86 self.assertEqual("env_baz", parser.parse_args([]).foo) 87 self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo) 88 self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo) 89 90 def test_check_env_no_env(self): 91 parser = ArgumentParser() 92 parser.add_argument("-v", "--verbose", action=check_env) 93 94 self.assertFalse(parser.parse_args([]).verbose) 95 self.assertTrue(parser.parse_args(["-v"]).verbose) 96 self.assertTrue(parser.parse_args(["--verbose"]).verbose) 97 98 def test_check_env_default_no_env(self): 99 parser = ArgumentParser() 100 parser.add_argument("-v", "--verbose", action=check_env, default=True) 101 102 self.assertTrue(parser.parse_args([]).verbose) 103 self.assertTrue(parser.parse_args(["-v"]).verbose) 104 self.assertTrue(parser.parse_args(["--verbose"]).verbose) 105 106 def test_check_env_env_zero(self): 107 os.environ["PET_VERBOSE"] = "0" 108 parser = ArgumentParser() 109 parser.add_argument("-v", "--verbose", action=check_env) 110 111 self.assertFalse(parser.parse_args([]).verbose) 112 self.assertTrue(parser.parse_args(["--verbose"]).verbose) 113 114 def test_check_env_env_one(self): 115 os.environ["PET_VERBOSE"] = "1" 116 parser = ArgumentParser() 117 parser.add_argument("-v", "--verbose", action=check_env) 118 119 self.assertTrue(parser.parse_args([]).verbose) 120 self.assertTrue(parser.parse_args(["--verbose"]).verbose) 121 122 def test_check_env_default_env_zero(self): 123 os.environ["PET_VERBOSE"] = "0" 124 parser = ArgumentParser() 125 parser.add_argument("-v", "--verbose", action=check_env, default=True) 126 127 self.assertFalse(parser.parse_args([]).verbose) 128 self.assertTrue(parser.parse_args(["--verbose"]).verbose) 129 130 def test_check_env_default_env_one(self): 131 os.environ["PET_VERBOSE"] = "1" 132 parser = ArgumentParser() 133 parser.add_argument("-v", "--verbose", action=check_env, default=True) 134 135 self.assertTrue(parser.parse_args([]).verbose) 136 self.assertTrue(parser.parse_args(["--verbose"]).verbose) 137