1# SPDX-License-Identifier: MIT
2# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
3# Licensed to PSF under a Contributor Agreement.
4
5"""Utilities for tests that are in the "burntsushi" format."""
6
7import datetime
8from typing import Any
9
10# Aliases for converting TOML compliance format [1] to BurntSushi format [2]
11# [1] https://github.com/toml-lang/compliance/blob/db7c3211fda30ff9ddb10292f4aeda7e2e10abc4/docs/json-encoding.md  # noqa: E501
12# [2] https://github.com/BurntSushi/toml-test/blob/4634fdf3a6ecd6aaea5f4cdcd98b2733c2694993/README.md  # noqa: E501
13_aliases = {
14    "boolean": "bool",
15    "offset datetime": "datetime",
16    "local datetime": "datetime-local",
17    "local date": "date-local",
18    "local time": "time-local",
19}
20
21
22def convert(obj):  # noqa: C901
23    if isinstance(obj, str):
24        return {"type": "string", "value": obj}
25    elif isinstance(obj, bool):
26        return {"type": "bool", "value": str(obj).lower()}
27    elif isinstance(obj, int):
28        return {"type": "integer", "value": str(obj)}
29    elif isinstance(obj, float):
30        return {"type": "float", "value": _normalize_float_str(str(obj))}
31    elif isinstance(obj, datetime.datetime):
32        val = _normalize_datetime_str(obj.isoformat())
33        if obj.tzinfo:
34            return {"type": "datetime", "value": val}
35        return {"type": "datetime-local", "value": val}
36    elif isinstance(obj, datetime.time):
37        return {
38            "type": "time-local",
39            "value": _normalize_localtime_str(str(obj)),
40        }
41    elif isinstance(obj, datetime.date):
42        return {
43            "type": "date-local",
44            "value": str(obj),
45        }
46    elif isinstance(obj, list):
47        return [convert(i) for i in obj]
48    elif isinstance(obj, dict):
49        return {k: convert(v) for k, v in obj.items()}
50    raise Exception("unsupported type")
51
52
53def normalize(obj: Any) -> Any:
54    """Normalize test objects.
55
56    This normalizes primitive values (e.g. floats), and also converts from
57    TOML compliance format [1] to BurntSushi format [2].
58
59    [1] https://github.com/toml-lang/compliance/blob/db7c3211fda30ff9ddb10292f4aeda7e2e10abc4/docs/json-encoding.md  # noqa: E501
60    [2] https://github.com/BurntSushi/toml-test/blob/4634fdf3a6ecd6aaea5f4cdcd98b2733c2694993/README.md  # noqa: E501
61    """
62    if isinstance(obj, list):
63        return [normalize(item) for item in obj]
64    if isinstance(obj, dict):
65        if "type" in obj and "value" in obj:
66            type_ = obj["type"]
67            norm_type = _aliases.get(type_, type_)
68            value = obj["value"]
69            if norm_type == "float":
70                norm_value = _normalize_float_str(value)
71            elif norm_type in {"datetime", "datetime-local"}:
72                norm_value = _normalize_datetime_str(value)
73            elif norm_type == "time-local":
74                norm_value = _normalize_localtime_str(value)
75            else:
76                norm_value = value
77
78            if norm_type == "array":
79                return [normalize(item) for item in value]
80            return {"type": norm_type, "value": norm_value}
81        return {k: normalize(v) for k, v in obj.items()}
82    raise AssertionError("Burntsushi fixtures should be dicts/lists only")
83
84
85def _normalize_datetime_str(dt_str: str) -> str:
86    if dt_str[-1].lower() == "z":
87        dt_str = dt_str[:-1] + "+00:00"
88
89    date = dt_str[:10]
90    rest = dt_str[11:]
91
92    if "+" in rest:
93        sign = "+"
94    elif "-" in rest:
95        sign = "-"
96    else:
97        sign = ""
98
99    if sign:
100        time, _, offset = rest.partition(sign)
101    else:
102        time = rest
103        offset = ""
104
105    time = time.rstrip("0") if "." in time else time
106    return date + "T" + time + sign + offset
107
108
109def _normalize_localtime_str(lt_str: str) -> str:
110    return lt_str.rstrip("0") if "." in lt_str else lt_str
111
112
113def _normalize_float_str(float_str: str) -> str:
114    as_float = float(float_str)
115
116    # Normalize "-0.0" and "+0.0"
117    if as_float == 0:
118        return "0"
119
120    return str(as_float)
121