xref: /aosp_15_r20/external/pytorch/test/distributed/argparse_util_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #!/usr/bin/env python3
2 # Owner(s): ["oncall: distributed"]
3 
4 # Copyright (c) Facebook, Inc. and its affiliates.
5 # All rights reserved.
6 #
7 # This source code is licensed under the BSD-style license found in the
8 # LICENSE file in the root directory of this source tree.
9 import os
10 import unittest
11 from argparse import ArgumentParser
12 
13 from torch.distributed.argparse_util import check_env, env
14 
15 
16 class ArgParseUtilTest(unittest.TestCase):
17     def setUp(self):
18         # remove any lingering environment variables
19         for e in os.environ.keys():
20             if e.startswith("PET_"):
21                 del os.environ[e]
22 
23     def test_env_string_arg_no_env(self):
24         parser = ArgumentParser()
25         parser.add_argument("-f", "--foo", action=env, default="bar")
26 
27         self.assertEqual("bar", parser.parse_args([]).foo)
28         self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo)
29         self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo)
30 
31     def test_env_string_arg_env(self):
32         os.environ["PET_FOO"] = "env_baz"
33         parser = ArgumentParser()
34         parser.add_argument("-f", "--foo", action=env, default="bar")
35 
36         self.assertEqual("env_baz", parser.parse_args([]).foo)
37         self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo)
38         self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo)
39 
40     def test_env_int_arg_no_env(self):
41         parser = ArgumentParser()
42         parser.add_argument("-f", "--foo", action=env, default=1, type=int)
43 
44         self.assertEqual(1, parser.parse_args([]).foo)
45         self.assertEqual(2, parser.parse_args(["-f", "2"]).foo)
46         self.assertEqual(2, parser.parse_args(["--foo", "2"]).foo)
47 
48     def test_env_int_arg_env(self):
49         os.environ["PET_FOO"] = "3"
50         parser = ArgumentParser()
51         parser.add_argument("-f", "--foo", action=env, default=1, type=int)
52 
53         self.assertEqual(3, parser.parse_args([]).foo)
54         self.assertEqual(2, parser.parse_args(["-f", "2"]).foo)
55         self.assertEqual(2, parser.parse_args(["--foo", "2"]).foo)
56 
57     def test_env_no_default_no_env(self):
58         parser = ArgumentParser()
59         parser.add_argument("-f", "--foo", action=env)
60 
61         self.assertIsNone(parser.parse_args([]).foo)
62         self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo)
63         self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo)
64 
65     def test_env_no_default_env(self):
66         os.environ["PET_FOO"] = "env_baz"
67         parser = ArgumentParser()
68         parser.add_argument("-f", "--foo", action=env)
69 
70         self.assertEqual("env_baz", parser.parse_args([]).foo)
71         self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo)
72         self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo)
73 
74     def test_env_required_no_env(self):
75         parser = ArgumentParser()
76         parser.add_argument("-f", "--foo", action=env, required=True)
77 
78         self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo)
79         self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo)
80 
81     def test_env_required_env(self):
82         os.environ["PET_FOO"] = "env_baz"
83         parser = ArgumentParser()
84         parser.add_argument("-f", "--foo", action=env, default="bar", required=True)
85 
86         self.assertEqual("env_baz", parser.parse_args([]).foo)
87         self.assertEqual("baz", parser.parse_args(["-f", "baz"]).foo)
88         self.assertEqual("baz", parser.parse_args(["--foo", "baz"]).foo)
89 
90     def test_check_env_no_env(self):
91         parser = ArgumentParser()
92         parser.add_argument("-v", "--verbose", action=check_env)
93 
94         self.assertFalse(parser.parse_args([]).verbose)
95         self.assertTrue(parser.parse_args(["-v"]).verbose)
96         self.assertTrue(parser.parse_args(["--verbose"]).verbose)
97 
98     def test_check_env_default_no_env(self):
99         parser = ArgumentParser()
100         parser.add_argument("-v", "--verbose", action=check_env, default=True)
101 
102         self.assertTrue(parser.parse_args([]).verbose)
103         self.assertTrue(parser.parse_args(["-v"]).verbose)
104         self.assertTrue(parser.parse_args(["--verbose"]).verbose)
105 
106     def test_check_env_env_zero(self):
107         os.environ["PET_VERBOSE"] = "0"
108         parser = ArgumentParser()
109         parser.add_argument("-v", "--verbose", action=check_env)
110 
111         self.assertFalse(parser.parse_args([]).verbose)
112         self.assertTrue(parser.parse_args(["--verbose"]).verbose)
113 
114     def test_check_env_env_one(self):
115         os.environ["PET_VERBOSE"] = "1"
116         parser = ArgumentParser()
117         parser.add_argument("-v", "--verbose", action=check_env)
118 
119         self.assertTrue(parser.parse_args([]).verbose)
120         self.assertTrue(parser.parse_args(["--verbose"]).verbose)
121 
122     def test_check_env_default_env_zero(self):
123         os.environ["PET_VERBOSE"] = "0"
124         parser = ArgumentParser()
125         parser.add_argument("-v", "--verbose", action=check_env, default=True)
126 
127         self.assertFalse(parser.parse_args([]).verbose)
128         self.assertTrue(parser.parse_args(["--verbose"]).verbose)
129 
130     def test_check_env_default_env_one(self):
131         os.environ["PET_VERBOSE"] = "1"
132         parser = ArgumentParser()
133         parser.add_argument("-v", "--verbose", action=check_env, default=True)
134 
135         self.assertTrue(parser.parse_args([]).verbose)
136         self.assertTrue(parser.parse_args(["--verbose"]).verbose)
137