xref: /aosp_15_r20/external/emboss/compiler/util/ir_data_utils.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1*99e0aae7SDavid Rees# Copyright 2024 Google LLC
2*99e0aae7SDavid Rees#
3*99e0aae7SDavid Rees# Licensed under the Apache License, Version 2.0 (the "License");
4*99e0aae7SDavid Rees# you may not use this file except in compliance with the License.
5*99e0aae7SDavid Rees# You may obtain a copy of the License at
6*99e0aae7SDavid Rees#
7*99e0aae7SDavid Rees#     https://www.apache.org/licenses/LICENSE-2.0
8*99e0aae7SDavid Rees#
9*99e0aae7SDavid Rees# Unless required by applicable law or agreed to in writing, software
10*99e0aae7SDavid Rees# distributed under the License is distributed on an "AS IS" BASIS,
11*99e0aae7SDavid Rees# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*99e0aae7SDavid Rees# See the License for the specific language governing permissions and
13*99e0aae7SDavid Rees# limitations under the License.
14*99e0aae7SDavid Rees
15*99e0aae7SDavid Rees"""Provides a helpers for working with IR data elements.
16*99e0aae7SDavid Rees
17*99e0aae7SDavid ReesHistorical note: At one point protocol buffers were used for IR data. The
18*99e0aae7SDavid Reescodebase still expects the IR data classes to behave similarly, particularly
19*99e0aae7SDavid Reeswith respect to "autovivification" where accessing an undefined field will
20*99e0aae7SDavid Reescreate it temporarily and add it if assigned to. Though, perhaps not fully
21*99e0aae7SDavid Reesfollowing the Pythonic ethos, we provide this behavior via the `builder` and
22*99e0aae7SDavid Rees`reader` helpers to remain compatible with the rest of the codebase.
23*99e0aae7SDavid Rees
24*99e0aae7SDavid Reesbuilder
25*99e0aae7SDavid Rees-------
26*99e0aae7SDavid ReesInstead of:
27*99e0aae7SDavid Rees```
28*99e0aae7SDavid Reesdef set_function_name_end(function: Function):
29*99e0aae7SDavid Rees  if not function.function_name:
30*99e0aae7SDavid Rees    function.function_name = Word()
31*99e0aae7SDavid Rees  if not function.function_name.source_location:
32*99e0aae7SDavid Rees    function.function_name.source_location = Location()
33*99e0aae7SDavid Rees  word.source_location.end = Position(line=1,column=2)
34*99e0aae7SDavid Rees```
35*99e0aae7SDavid Rees
36*99e0aae7SDavid ReesWe can do:
37*99e0aae7SDavid Rees```
38*99e0aae7SDavid Reesdef set_function_name_end(function: Function):
39*99e0aae7SDavid Rees  builder(function).function_name.source_location.end = Position(line=1,
40*99e0aae7SDavid Rees  column=2)
41*99e0aae7SDavid Rees```
42*99e0aae7SDavid Rees
43*99e0aae7SDavid Reesreader
44*99e0aae7SDavid Rees------
45*99e0aae7SDavid ReesInstead of:
46*99e0aae7SDavid Rees```
47*99e0aae7SDavid Reesdef is_leaf_synthetic(data):
48*99e0aae7SDavid Rees  if data:
49*99e0aae7SDavid Rees    if data.attribute:
50*99e0aae7SDavid Rees      if data.attribute.value:
51*99e0aae7SDavid Rees        if data.attribute.value.is_synthetic is not None:
52*99e0aae7SDavid Rees          return data.attribute.value.is_synthetic
53*99e0aae7SDavid Rees  return False
54*99e0aae7SDavid Rees```
55*99e0aae7SDavid ReesWe can do:
56*99e0aae7SDavid Rees```
57*99e0aae7SDavid Reesdef is_leaf_synthetic(data):
58*99e0aae7SDavid Rees  return reader(data).attribute.value.is_synthetic
59*99e0aae7SDavid Rees```
60*99e0aae7SDavid Rees
61*99e0aae7SDavid ReesIrDataSerializer
62*99e0aae7SDavid Rees----------------
63*99e0aae7SDavid ReesProvides methods for serializing and deserializing an IR data object.
64*99e0aae7SDavid Rees"""
65*99e0aae7SDavid Reesimport enum
66*99e0aae7SDavid Reesimport json
67*99e0aae7SDavid Reesfrom typing import (
68*99e0aae7SDavid Rees    Any,
69*99e0aae7SDavid Rees    Callable,
70*99e0aae7SDavid Rees    Generic,
71*99e0aae7SDavid Rees    MutableMapping,
72*99e0aae7SDavid Rees    MutableSequence,
73*99e0aae7SDavid Rees    Optional,
74*99e0aae7SDavid Rees    Tuple,
75*99e0aae7SDavid Rees    TypeVar,
76*99e0aae7SDavid Rees    Union,
77*99e0aae7SDavid Rees    cast,
78*99e0aae7SDavid Rees)
79*99e0aae7SDavid Rees
80*99e0aae7SDavid Reesfrom compiler.util import ir_data
81*99e0aae7SDavid Reesfrom compiler.util import ir_data_fields
82*99e0aae7SDavid Rees
83*99e0aae7SDavid Rees
84*99e0aae7SDavid ReesMessageT = TypeVar("MessageT", bound=ir_data.Message)
85*99e0aae7SDavid Rees
86*99e0aae7SDavid Rees
87*99e0aae7SDavid Reesdef field_specs(ir: Union[MessageT, type[MessageT]]):
88*99e0aae7SDavid Rees  """Retrieves the field specs for the IR data class"""
89*99e0aae7SDavid Rees  data_type = ir if isinstance(ir, type) else type(ir)
90*99e0aae7SDavid Rees  return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs
91*99e0aae7SDavid Rees
92*99e0aae7SDavid Rees
93*99e0aae7SDavid Reesclass IrDataSerializer:
94*99e0aae7SDavid Rees  """Provides methods for serializing IR data objects"""
95*99e0aae7SDavid Rees
96*99e0aae7SDavid Rees  def __init__(self, ir: MessageT):
97*99e0aae7SDavid Rees    assert ir is not None
98*99e0aae7SDavid Rees    self.ir = ir
99*99e0aae7SDavid Rees
100*99e0aae7SDavid Rees  def _to_dict(
101*99e0aae7SDavid Rees      self,
102*99e0aae7SDavid Rees      ir: MessageT,
103*99e0aae7SDavid Rees      field_func: Callable[
104*99e0aae7SDavid Rees          [MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]
105*99e0aae7SDavid Rees      ],
106*99e0aae7SDavid Rees  ) -> MutableMapping[str, Any]:
107*99e0aae7SDavid Rees    assert ir is not None
108*99e0aae7SDavid Rees    values: MutableMapping[str, Any] = {}
109*99e0aae7SDavid Rees    for spec, value in field_func(ir):
110*99e0aae7SDavid Rees      if value is not None and spec.is_dataclass:
111*99e0aae7SDavid Rees        if spec.is_sequence:
112*99e0aae7SDavid Rees          value = [self._to_dict(v, field_func) for v in value]
113*99e0aae7SDavid Rees        else:
114*99e0aae7SDavid Rees          value = self._to_dict(value, field_func)
115*99e0aae7SDavid Rees      values[spec.name] = value
116*99e0aae7SDavid Rees    return values
117*99e0aae7SDavid Rees
118*99e0aae7SDavid Rees  def to_dict(self, exclude_none: bool = False):
119*99e0aae7SDavid Rees    """Converts the IR data class to a dictionary."""
120*99e0aae7SDavid Rees
121*99e0aae7SDavid Rees    def non_empty(ir):
122*99e0aae7SDavid Rees      return fields_and_values(
123*99e0aae7SDavid Rees          ir, lambda v: v is not None and (not isinstance(v, list) or len(v))
124*99e0aae7SDavid Rees      )
125*99e0aae7SDavid Rees
126*99e0aae7SDavid Rees    def all_fields(ir):
127*99e0aae7SDavid Rees      return fields_and_values(ir)
128*99e0aae7SDavid Rees
129*99e0aae7SDavid Rees    # It's tempting to use `dataclasses.asdict` here, but that does a deep
130*99e0aae7SDavid Rees    # copy which is overkill for the current usage; mainly as an intermediary
131*99e0aae7SDavid Rees    # for `to_json` and `repr`.
132*99e0aae7SDavid Rees    return self._to_dict(self.ir, non_empty if exclude_none else all_fields)
133*99e0aae7SDavid Rees
134*99e0aae7SDavid Rees  def to_json(self, *args, **kwargs):
135*99e0aae7SDavid Rees    """Converts the IR data class to a JSON string"""
136*99e0aae7SDavid Rees    return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs)
137*99e0aae7SDavid Rees
138*99e0aae7SDavid Rees  @staticmethod
139*99e0aae7SDavid Rees  def from_json(data_cls, data):
140*99e0aae7SDavid Rees    """Constructs an IR data class from the given JSON string"""
141*99e0aae7SDavid Rees    as_dict = json.loads(data)
142*99e0aae7SDavid Rees    return IrDataSerializer.from_dict(data_cls, as_dict)
143*99e0aae7SDavid Rees
144*99e0aae7SDavid Rees  def copy_from_dict(self, data):
145*99e0aae7SDavid Rees    """Deserializes the data and overwrites the IR data class with it"""
146*99e0aae7SDavid Rees    cls = type(self.ir)
147*99e0aae7SDavid Rees    data_copy = IrDataSerializer.from_dict(cls, data)
148*99e0aae7SDavid Rees    for k in field_specs(cls):
149*99e0aae7SDavid Rees      setattr(self.ir, k, getattr(data_copy, k))
150*99e0aae7SDavid Rees
151*99e0aae7SDavid Rees  @staticmethod
152*99e0aae7SDavid Rees  def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum:
153*99e0aae7SDavid Rees    if isinstance(val, str):
154*99e0aae7SDavid Rees      return getattr(enum_cls, val)
155*99e0aae7SDavid Rees    return enum_cls(val)
156*99e0aae7SDavid Rees
157*99e0aae7SDavid Rees  @staticmethod
158*99e0aae7SDavid Rees  def _enum_type_hook(enum_cls: type[enum.Enum]):
159*99e0aae7SDavid Rees    return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val)
160*99e0aae7SDavid Rees
161*99e0aae7SDavid Rees  @staticmethod
162*99e0aae7SDavid Rees  def _from_dict(data_cls: type[MessageT], data):
163*99e0aae7SDavid Rees    class_fields: MutableMapping[str, Any] = {}
164*99e0aae7SDavid Rees    for name, spec in ir_data_fields.field_specs(data_cls).items():
165*99e0aae7SDavid Rees      if (value := data.get(name)) is not None:
166*99e0aae7SDavid Rees        if spec.is_dataclass:
167*99e0aae7SDavid Rees          if spec.is_sequence:
168*99e0aae7SDavid Rees            class_fields[name] = [
169*99e0aae7SDavid Rees                IrDataSerializer._from_dict(spec.data_type, v) for v in value
170*99e0aae7SDavid Rees            ]
171*99e0aae7SDavid Rees          else:
172*99e0aae7SDavid Rees            class_fields[name] = IrDataSerializer._from_dict(
173*99e0aae7SDavid Rees                spec.data_type, value
174*99e0aae7SDavid Rees            )
175*99e0aae7SDavid Rees        else:
176*99e0aae7SDavid Rees          if spec.data_type in (
177*99e0aae7SDavid Rees              ir_data.FunctionMapping,
178*99e0aae7SDavid Rees              ir_data.AddressableUnit,
179*99e0aae7SDavid Rees          ):
180*99e0aae7SDavid Rees            class_fields[name] = IrDataSerializer._enum_type_converter(
181*99e0aae7SDavid Rees                spec.data_type, value
182*99e0aae7SDavid Rees            )
183*99e0aae7SDavid Rees          else:
184*99e0aae7SDavid Rees            if spec.is_sequence:
185*99e0aae7SDavid Rees              class_fields[name] = value
186*99e0aae7SDavid Rees            else:
187*99e0aae7SDavid Rees              class_fields[name] = spec.data_type(value)
188*99e0aae7SDavid Rees    return data_cls(**class_fields)
189*99e0aae7SDavid Rees
190*99e0aae7SDavid Rees  @staticmethod
191*99e0aae7SDavid Rees  def from_dict(data_cls: type[MessageT], data):
192*99e0aae7SDavid Rees    """Creates a new IR data instance from a serialized dict"""
193*99e0aae7SDavid Rees    return IrDataSerializer._from_dict(data_cls, data)
194*99e0aae7SDavid Rees
195*99e0aae7SDavid Rees
196*99e0aae7SDavid Reesclass _IrDataSequenceBuilder(MutableSequence[MessageT]):
197*99e0aae7SDavid Rees  """Wrapper for a list of IR elements
198*99e0aae7SDavid Rees
199*99e0aae7SDavid Rees  Simply wraps the returned values during indexed access and iteration with
200*99e0aae7SDavid Rees  IrDataBuilders.
201*99e0aae7SDavid Rees  """
202*99e0aae7SDavid Rees
203*99e0aae7SDavid Rees  def __init__(self, target: MutableSequence[MessageT]):
204*99e0aae7SDavid Rees    self._target = target
205*99e0aae7SDavid Rees
206*99e0aae7SDavid Rees  def __delitem__(self, key):
207*99e0aae7SDavid Rees    del self._target[key]
208*99e0aae7SDavid Rees
209*99e0aae7SDavid Rees  def __getitem__(self, key):
210*99e0aae7SDavid Rees    return _IrDataBuilder(self._target.__getitem__(key))
211*99e0aae7SDavid Rees
212*99e0aae7SDavid Rees  def __setitem__(self, key, value):
213*99e0aae7SDavid Rees    self._target[key] = value
214*99e0aae7SDavid Rees
215*99e0aae7SDavid Rees  def __iter__(self):
216*99e0aae7SDavid Rees    itr = iter(self._target)
217*99e0aae7SDavid Rees    for i in itr:
218*99e0aae7SDavid Rees      yield _IrDataBuilder(i)
219*99e0aae7SDavid Rees
220*99e0aae7SDavid Rees  def __repr__(self):
221*99e0aae7SDavid Rees    return repr(self._target)
222*99e0aae7SDavid Rees
223*99e0aae7SDavid Rees  def __len__(self):
224*99e0aae7SDavid Rees    return len(self._target)
225*99e0aae7SDavid Rees
226*99e0aae7SDavid Rees  def __eq__(self, other):
227*99e0aae7SDavid Rees    return self._target == other
228*99e0aae7SDavid Rees
229*99e0aae7SDavid Rees  def __ne__(self, other):
230*99e0aae7SDavid Rees    return self._target != other
231*99e0aae7SDavid Rees
232*99e0aae7SDavid Rees  def insert(self, index, value):
233*99e0aae7SDavid Rees    self._target.insert(index, value)
234*99e0aae7SDavid Rees
235*99e0aae7SDavid Rees  def extend(self, values):
236*99e0aae7SDavid Rees    self._target.extend(values)
237*99e0aae7SDavid Rees
238*99e0aae7SDavid Rees
239*99e0aae7SDavid Reesclass _IrDataBuilder(Generic[MessageT]):
240*99e0aae7SDavid Rees  """Wrapper for an IR element"""
241*99e0aae7SDavid Rees
242*99e0aae7SDavid Rees  def __init__(self, ir: MessageT) -> None:
243*99e0aae7SDavid Rees    assert ir is not None
244*99e0aae7SDavid Rees    self.ir: MessageT = ir
245*99e0aae7SDavid Rees
246*99e0aae7SDavid Rees  def __setattr__(self, __name: str, __value: Any) -> None:
247*99e0aae7SDavid Rees    if __name == "ir":
248*99e0aae7SDavid Rees      # This our proxy object
249*99e0aae7SDavid Rees      object.__setattr__(self, __name, __value)
250*99e0aae7SDavid Rees    else:
251*99e0aae7SDavid Rees      # Passthrough to the proxy object
252*99e0aae7SDavid Rees      ir: MessageT = object.__getattribute__(self, "ir")
253*99e0aae7SDavid Rees      setattr(ir, __name, __value)
254*99e0aae7SDavid Rees
255*99e0aae7SDavid Rees  def __getattribute__(self, name: str) -> Any:
256*99e0aae7SDavid Rees    """Hook for `getattr` that handles adding missing fields.
257*99e0aae7SDavid Rees
258*99e0aae7SDavid Rees    If the field is missing inserts it, and then returns either the raw value
259*99e0aae7SDavid Rees    for basic types
260*99e0aae7SDavid Rees    or a new IrBuilder wrapping the field to handle the next field access in a
261*99e0aae7SDavid Rees    longer chain.
262*99e0aae7SDavid Rees    """
263*99e0aae7SDavid Rees
264*99e0aae7SDavid Rees    # Check if getting one of the builder attributes
265*99e0aae7SDavid Rees    if name in ("CopyFrom", "ir"):
266*99e0aae7SDavid Rees      return object.__getattribute__(self, name)
267*99e0aae7SDavid Rees
268*99e0aae7SDavid Rees    # Get our target object by bypassing our getattr hook
269*99e0aae7SDavid Rees    ir: MessageT = object.__getattribute__(self, "ir")
270*99e0aae7SDavid Rees    if ir is None:
271*99e0aae7SDavid Rees      return object.__getattribute__(self, name)
272*99e0aae7SDavid Rees
273*99e0aae7SDavid Rees    if name in ("HasField", "WhichOneof"):
274*99e0aae7SDavid Rees      return getattr(ir, name)
275*99e0aae7SDavid Rees
276*99e0aae7SDavid Rees    field_spec = field_specs(ir).get(name)
277*99e0aae7SDavid Rees    if field_spec is None:
278*99e0aae7SDavid Rees      raise AttributeError(
279*99e0aae7SDavid Rees          f"No field {name} on {type(ir).__module__}.{type(ir).__name__}."
280*99e0aae7SDavid Rees      )
281*99e0aae7SDavid Rees
282*99e0aae7SDavid Rees    obj = getattr(ir, name, None)
283*99e0aae7SDavid Rees    if obj is None:
284*99e0aae7SDavid Rees      # Create a default and store it
285*99e0aae7SDavid Rees      obj = ir_data_fields.build_default(field_spec)
286*99e0aae7SDavid Rees      setattr(ir, name, obj)
287*99e0aae7SDavid Rees
288*99e0aae7SDavid Rees    if field_spec.is_dataclass:
289*99e0aae7SDavid Rees      obj = (
290*99e0aae7SDavid Rees          _IrDataSequenceBuilder(obj)
291*99e0aae7SDavid Rees          if field_spec.is_sequence
292*99e0aae7SDavid Rees          else _IrDataBuilder(obj)
293*99e0aae7SDavid Rees      )
294*99e0aae7SDavid Rees
295*99e0aae7SDavid Rees    return obj
296*99e0aae7SDavid Rees
297*99e0aae7SDavid Rees  def CopyFrom(self, template: MessageT):  # pylint:disable=invalid-name
298*99e0aae7SDavid Rees    """Updates the fields of this class with values set in the template"""
299*99e0aae7SDavid Rees    update(cast(type[MessageT], self), template)
300*99e0aae7SDavid Rees
301*99e0aae7SDavid Rees
302*99e0aae7SDavid Reesdef builder(target: MessageT) -> MessageT:
303*99e0aae7SDavid Rees  """Create a wrapper around the target to help build an IR Data structure"""
304*99e0aae7SDavid Rees  # Check if the target is already a builder.
305*99e0aae7SDavid Rees  if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
306*99e0aae7SDavid Rees    return target
307*99e0aae7SDavid Rees
308*99e0aae7SDavid Rees  # Builders are only valid for IR data classes.
309*99e0aae7SDavid Rees  if not hasattr(type(target), "IR_DATACLASS"):
310*99e0aae7SDavid Rees    raise TypeError(f"Builder target {type(target)} is not an ir_data.message")
311*99e0aae7SDavid Rees
312*99e0aae7SDavid Rees  # Create a builder and cast it to the target type to expose type hinting for
313*99e0aae7SDavid Rees  # the wrapped type.
314*99e0aae7SDavid Rees  return cast(MessageT, _IrDataBuilder(target))
315*99e0aae7SDavid Rees
316*99e0aae7SDavid Rees
317*99e0aae7SDavid Reesdef _field_checker_from_spec(spec: ir_data_fields.FieldSpec):
318*99e0aae7SDavid Rees  """Helper that builds an FieldChecker that pretends to be an IR class"""
319*99e0aae7SDavid Rees  if spec.is_sequence:
320*99e0aae7SDavid Rees    return []
321*99e0aae7SDavid Rees  if spec.is_dataclass:
322*99e0aae7SDavid Rees    return _ReadOnlyFieldChecker(spec)
323*99e0aae7SDavid Rees  return ir_data_fields.build_default(spec)
324*99e0aae7SDavid Rees
325*99e0aae7SDavid Rees
326*99e0aae7SDavid Reesdef _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type:
327*99e0aae7SDavid Rees  if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
328*99e0aae7SDavid Rees    return ir_or_spec.data_type
329*99e0aae7SDavid Rees  return type(ir_or_spec)
330*99e0aae7SDavid Rees
331*99e0aae7SDavid Rees
332*99e0aae7SDavid Reesclass _ReadOnlyFieldChecker:
333*99e0aae7SDavid Rees  """Class used the chain calls to fields that aren't set"""
334*99e0aae7SDavid Rees
335*99e0aae7SDavid Rees  def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None:
336*99e0aae7SDavid Rees    self.ir_or_spec = ir_or_spec
337*99e0aae7SDavid Rees
338*99e0aae7SDavid Rees  def __setattr__(self, name: str, value: Any) -> None:
339*99e0aae7SDavid Rees    if name == "ir_or_spec":
340*99e0aae7SDavid Rees      return object.__setattr__(self, name, value)
341*99e0aae7SDavid Rees
342*99e0aae7SDavid Rees    raise AttributeError(f"Cannot set {name} on read-only wrapper")
343*99e0aae7SDavid Rees
344*99e0aae7SDavid Rees  def __getattribute__(self, name: str) -> Any:  # pylint:disable=too-many-return-statements
345*99e0aae7SDavid Rees    ir_or_spec = object.__getattribute__(self, "ir_or_spec")
346*99e0aae7SDavid Rees    if name == "ir_or_spec":
347*99e0aae7SDavid Rees      return ir_or_spec
348*99e0aae7SDavid Rees
349*99e0aae7SDavid Rees    field_type = _field_type(ir_or_spec)
350*99e0aae7SDavid Rees    spec = field_specs(field_type).get(name)
351*99e0aae7SDavid Rees    if not spec:
352*99e0aae7SDavid Rees      if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
353*99e0aae7SDavid Rees        if name == "HasField":
354*99e0aae7SDavid Rees          return lambda x: False
355*99e0aae7SDavid Rees        if name == "WhichOneof":
356*99e0aae7SDavid Rees          return lambda x: None
357*99e0aae7SDavid Rees      return object.__getattribute__(ir_or_spec, name)
358*99e0aae7SDavid Rees
359*99e0aae7SDavid Rees    if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
360*99e0aae7SDavid Rees      # Just pretending
361*99e0aae7SDavid Rees      return _field_checker_from_spec(spec)
362*99e0aae7SDavid Rees
363*99e0aae7SDavid Rees    value = getattr(ir_or_spec, name)
364*99e0aae7SDavid Rees    if value is None:
365*99e0aae7SDavid Rees      return _field_checker_from_spec(spec)
366*99e0aae7SDavid Rees
367*99e0aae7SDavid Rees    if spec.is_dataclass:
368*99e0aae7SDavid Rees      if spec.is_sequence:
369*99e0aae7SDavid Rees        return [_ReadOnlyFieldChecker(i) for i in value]
370*99e0aae7SDavid Rees      return _ReadOnlyFieldChecker(value)
371*99e0aae7SDavid Rees
372*99e0aae7SDavid Rees    return value
373*99e0aae7SDavid Rees
374*99e0aae7SDavid Rees  def __eq__(self, other):
375*99e0aae7SDavid Rees    if isinstance(other, _ReadOnlyFieldChecker):
376*99e0aae7SDavid Rees      other = other.ir_or_spec
377*99e0aae7SDavid Rees    return self.ir_or_spec == other
378*99e0aae7SDavid Rees
379*99e0aae7SDavid Rees  def __ne__(self, other):
380*99e0aae7SDavid Rees    return not self == other
381*99e0aae7SDavid Rees
382*99e0aae7SDavid Rees
383*99e0aae7SDavid Reesdef reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT:
384*99e0aae7SDavid Rees  """Builds a read-only wrapper that can be used to check chains of possibly
385*99e0aae7SDavid Rees  unset fields.
386*99e0aae7SDavid Rees
387*99e0aae7SDavid Rees  This wrapper explicitly does not alter the wrapped object and is only
388*99e0aae7SDavid Rees  intended for reading contents.
389*99e0aae7SDavid Rees
390*99e0aae7SDavid Rees  For example, a `reader` lets you do:
391*99e0aae7SDavid Rees  ```
392*99e0aae7SDavid Rees  def get_function_name_end_column(function: ir_data.Function):
393*99e0aae7SDavid Rees    return reader(function).function_name.source_location.end.column
394*99e0aae7SDavid Rees  ```
395*99e0aae7SDavid Rees
396*99e0aae7SDavid Rees  Instead of:
397*99e0aae7SDavid Rees  ```
398*99e0aae7SDavid Rees  def get_function_name_end_column(function: ir_data.Function):
399*99e0aae7SDavid Rees    if function.function_name:
400*99e0aae7SDavid Rees      if function.function_name.source_location:
401*99e0aae7SDavid Rees        if function.function_name.source_location.end:
402*99e0aae7SDavid Rees          return function.function_name.source_location.end.column
403*99e0aae7SDavid Rees    return 0
404*99e0aae7SDavid Rees  ```
405*99e0aae7SDavid Rees  """
406*99e0aae7SDavid Rees  # Create a read-only wrapper if it's not already one.
407*99e0aae7SDavid Rees  if not isinstance(obj, _ReadOnlyFieldChecker):
408*99e0aae7SDavid Rees    obj = _ReadOnlyFieldChecker(obj)
409*99e0aae7SDavid Rees
410*99e0aae7SDavid Rees  # Cast it back to the original type.
411*99e0aae7SDavid Rees  return cast(MessageT, obj)
412*99e0aae7SDavid Rees
413*99e0aae7SDavid Rees
414*99e0aae7SDavid Reesdef _extract_ir(
415*99e0aae7SDavid Rees    ir_or_wrapper: Union[MessageT, _ReadOnlyFieldChecker, _IrDataBuilder, None],
416*99e0aae7SDavid Rees) -> Optional[ir_data_fields.IrDataclassInstance]:
417*99e0aae7SDavid Rees  if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker):
418*99e0aae7SDavid Rees    ir_or_spec = ir_or_wrapper.ir_or_spec
419*99e0aae7SDavid Rees    if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
420*99e0aae7SDavid Rees      # This is a placeholder entry, no fields are set.
421*99e0aae7SDavid Rees      return None
422*99e0aae7SDavid Rees    ir_or_wrapper = ir_or_spec
423*99e0aae7SDavid Rees  elif isinstance(ir_or_wrapper, _IrDataBuilder):
424*99e0aae7SDavid Rees    ir_or_wrapper = ir_or_wrapper.ir
425*99e0aae7SDavid Rees  return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper)
426*99e0aae7SDavid Rees
427*99e0aae7SDavid Rees
428*99e0aae7SDavid Reesdef fields_and_values(
429*99e0aae7SDavid Rees    ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker],
430*99e0aae7SDavid Rees    value_filt: Optional[Callable[[Any], bool]] = None,
431*99e0aae7SDavid Rees) -> list[Tuple[ir_data_fields.FieldSpec, Any]]:
432*99e0aae7SDavid Rees  """Retrieves the fields and their values for a given IR data class.
433*99e0aae7SDavid Rees
434*99e0aae7SDavid Rees  Args:
435*99e0aae7SDavid Rees    ir: The IR data class or a read-only wrapper of an IR data class.
436*99e0aae7SDavid Rees    value_filt: Optional filter used to exclude values.
437*99e0aae7SDavid Rees  """
438*99e0aae7SDavid Rees  if (ir := _extract_ir(ir_wrapper)) is None:
439*99e0aae7SDavid Rees    return []
440*99e0aae7SDavid Rees
441*99e0aae7SDavid Rees  return ir_data_fields.fields_and_values(ir, value_filt)
442*99e0aae7SDavid Rees
443*99e0aae7SDavid Rees
444*99e0aae7SDavid Reesdef get_set_fields(ir: MessageT):
445*99e0aae7SDavid Rees  """Retrieves the field spec and value of fields that are set in the given IR data class.
446*99e0aae7SDavid Rees
447*99e0aae7SDavid Rees  A value is considered "set" if it is not None.
448*99e0aae7SDavid Rees  """
449*99e0aae7SDavid Rees  return fields_and_values(ir, lambda v: v is not None)
450*99e0aae7SDavid Rees
451*99e0aae7SDavid Rees
452*99e0aae7SDavid Reesdef copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]:
453*99e0aae7SDavid Rees  """Creates a copy of the given IR data class"""
454*99e0aae7SDavid Rees  if (ir := _extract_ir(ir_wrapper)) is None:
455*99e0aae7SDavid Rees    return None
456*99e0aae7SDavid Rees  ir_copy = ir_data_fields.copy(ir)
457*99e0aae7SDavid Rees  return cast(MessageT, ir_copy)
458*99e0aae7SDavid Rees
459*99e0aae7SDavid Rees
460*99e0aae7SDavid Reesdef update(ir: MessageT, template: MessageT):
461*99e0aae7SDavid Rees  """Updates `ir`s fields with all set fields in the template."""
462*99e0aae7SDavid Rees  if not (template_ir := _extract_ir(template)):
463*99e0aae7SDavid Rees    return
464*99e0aae7SDavid Rees
465*99e0aae7SDavid Rees  ir_data_fields.update(
466*99e0aae7SDavid Rees      cast(ir_data_fields.IrDataclassInstance, ir), template_ir
467*99e0aae7SDavid Rees  )
468