xref: /aosp_15_r20/external/pytorch/test/distributed/launcher/bin/test_script.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import argparse
11import os
12from pathlib import Path
13
14
15def parse_args():
16    parser = argparse.ArgumentParser(description="test script")
17
18    parser.add_argument(
19        "--fail",
20        default=False,
21        action="store_true",
22        help="forces the script to throw a RuntimeError",
23    )
24
25    # file is used for assertions
26    parser.add_argument(
27        "--touch-file-dir",
28        "--touch_file_dir",
29        type=str,
30        help="dir to touch a file with global rank as the filename",
31    )
32    return parser.parse_args()
33
34
35def main():
36    args = parse_args()
37    env_vars = [
38        "LOCAL_RANK",
39        "RANK",
40        "GROUP_RANK",
41        "ROLE_RANK",
42        "ROLE_NAME",
43        "LOCAL_WORLD_SIZE",
44        "WORLD_SIZE",
45        "ROLE_WORLD_SIZE",
46        "MASTER_ADDR",
47        "MASTER_PORT",
48        "TORCHELASTIC_RESTART_COUNT",
49        "TORCHELASTIC_MAX_RESTARTS",
50        "TORCHELASTIC_RUN_ID",
51        "OMP_NUM_THREADS",
52        "TEST_SENTINEL_PARENT",
53        "TORCHELASTIC_ERROR_FILE",
54    ]
55
56    print("Distributed env vars set by agent:")
57    for env_var in env_vars:
58        value = os.environ[env_var]
59        print(f"{env_var} = {value}")
60
61    if args.fail:
62        raise RuntimeError("raising exception since --fail flag was set")
63    else:
64        file = os.path.join(args.touch_file_dir, os.environ["RANK"])
65        Path(file).touch()
66        print(f"Success, created {file}")
67
68
69if __name__ == "__main__":
70    main()
71