xref: /aosp_15_r20/external/pytorch/torch/onnx/_flags.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Internal feature flags for torch.onnx.
2
3NOTE: These flags are experimental only. Any flag here can be removed at any
4time without notice.
5"""
6
7import logging
8import os
9
10
11logger = logging.getLogger(__name__)
12
13
14def _load_boolean_flag(
15    name: str,
16    *,
17    this_will: str,
18    deprecated: bool = False,
19    default: bool = False,
20) -> bool:
21    """Load a boolean flag from environment variable.
22
23    Args:
24        name: The name of the environment variable.
25        this_will: A string that describes what this flag will do.
26        deprecated: Whether this flag is deprecated.
27        default: The default value if envvar not defined.
28    """
29    undefined = os.getenv(name) is None
30    state = os.getenv(name) == "1"
31    if state:
32        if deprecated:
33            logger.error(
34                "Experimental flag %s is deprecated. Please remove it from your environment.",
35                name,
36            )
37        else:
38            logger.warning(
39                "Experimental flag %s is enabled. This will %s.", name, this_will
40            )
41    if undefined:
42        state = default
43    return state
44
45
46USE_EXPERIMENTAL_LOGIC: bool = _load_boolean_flag(
47    "TORCH_ONNX_USE_EXPERIMENTAL_LOGIC",
48    this_will="use ExportedProgram and the new torch.onnx export logic",
49)
50