xref: /aosp_15_r20/external/emboss/compiler/front_end/write_inference.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1*99e0aae7SDavid Rees# Copyright 2019 Google LLC
2*99e0aae7SDavid Rees#
3*99e0aae7SDavid Rees# Licensed under the Apache License, Version 2.0 (the "License");
4*99e0aae7SDavid Rees# you may not use this file except in compliance with the License.
5*99e0aae7SDavid Rees# You may obtain a copy of the License at
6*99e0aae7SDavid Rees#
7*99e0aae7SDavid Rees#     https://www.apache.org/licenses/LICENSE-2.0
8*99e0aae7SDavid Rees#
9*99e0aae7SDavid Rees# Unless required by applicable law or agreed to in writing, software
10*99e0aae7SDavid Rees# distributed under the License is distributed on an "AS IS" BASIS,
11*99e0aae7SDavid Rees# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*99e0aae7SDavid Rees# See the License for the specific language governing permissions and
13*99e0aae7SDavid Rees# limitations under the License.
14*99e0aae7SDavid Rees
15*99e0aae7SDavid Rees"""Adds auto-generated virtual fields to the IR."""
16*99e0aae7SDavid Rees
17*99e0aae7SDavid Reesfrom compiler.front_end import attributes
18*99e0aae7SDavid Reesfrom compiler.front_end import expression_bounds
19*99e0aae7SDavid Reesfrom compiler.util import ir_data
20*99e0aae7SDavid Reesfrom compiler.util import ir_data_utils
21*99e0aae7SDavid Reesfrom compiler.util import ir_util
22*99e0aae7SDavid Reesfrom compiler.util import traverse_ir
23*99e0aae7SDavid Rees
24*99e0aae7SDavid Rees
25*99e0aae7SDavid Reesdef _find_field_reference_path(expression):
26*99e0aae7SDavid Rees  """Returns a path to a field reference, or None.
27*99e0aae7SDavid Rees
28*99e0aae7SDavid Rees  If the provided expression contains exactly one field_reference,
29*99e0aae7SDavid Rees  _find_field_reference_path will return a list of indexes, such that
30*99e0aae7SDavid Rees  recursively reading the index'th element of expression.function.args will find
31*99e0aae7SDavid Rees  the field_reference.  For example, for:
32*99e0aae7SDavid Rees
33*99e0aae7SDavid Rees      5 + (x * 2)
34*99e0aae7SDavid Rees
35*99e0aae7SDavid Rees  _find_field_reference_path will return [1, 0]: from the top-level `+`
36*99e0aae7SDavid Rees  expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*`
37*99e0aae7SDavid Rees  expression.
38*99e0aae7SDavid Rees
39*99e0aae7SDavid Rees  Arguments:
40*99e0aae7SDavid Rees    expression: an ir_data.Expression to walk
41*99e0aae7SDavid Rees
42*99e0aae7SDavid Rees  Returns:
43*99e0aae7SDavid Rees    A list of indexes to find a field_reference, or None.
44*99e0aae7SDavid Rees  """
45*99e0aae7SDavid Rees  found, indexes = _recursively_find_field_reference_path(expression)
46*99e0aae7SDavid Rees  if found == 1:
47*99e0aae7SDavid Rees    return indexes
48*99e0aae7SDavid Rees  else:
49*99e0aae7SDavid Rees    return None
50*99e0aae7SDavid Rees
51*99e0aae7SDavid Rees
52*99e0aae7SDavid Reesdef _recursively_find_field_reference_path(expression):
53*99e0aae7SDavid Rees  """Recursive implementation of _find_field_reference_path."""
54*99e0aae7SDavid Rees  if expression.WhichOneof("expression") == "field_reference":
55*99e0aae7SDavid Rees    return 1, []
56*99e0aae7SDavid Rees  elif expression.WhichOneof("expression") == "function":
57*99e0aae7SDavid Rees    field_count = 0
58*99e0aae7SDavid Rees    path = []
59*99e0aae7SDavid Rees    for index in range(len(expression.function.args)):
60*99e0aae7SDavid Rees      arg = expression.function.args[index]
61*99e0aae7SDavid Rees      arg_result = _recursively_find_field_reference_path(arg)
62*99e0aae7SDavid Rees      arg_field_count, arg_path = arg_result
63*99e0aae7SDavid Rees      if arg_field_count == 1 and field_count == 0:
64*99e0aae7SDavid Rees        path = [index] + arg_path
65*99e0aae7SDavid Rees      field_count += arg_field_count
66*99e0aae7SDavid Rees    if field_count == 1:
67*99e0aae7SDavid Rees      return field_count, path
68*99e0aae7SDavid Rees    else:
69*99e0aae7SDavid Rees      return field_count, []
70*99e0aae7SDavid Rees  else:
71*99e0aae7SDavid Rees    return 0, []
72*99e0aae7SDavid Rees
73*99e0aae7SDavid Rees
74*99e0aae7SDavid Reesdef _invert_expression(expression, ir):
75*99e0aae7SDavid Rees  """For the given expression, searches for an algebraic inverse expression.
76*99e0aae7SDavid Rees
77*99e0aae7SDavid Rees  That is, it takes the notional equation:
78*99e0aae7SDavid Rees
79*99e0aae7SDavid Rees      $logical_value = expression
80*99e0aae7SDavid Rees
81*99e0aae7SDavid Rees  and, if there is exactly one `field_reference` in `expression`, it will
82*99e0aae7SDavid Rees  attempt to solve the equation for that field.  For example, if the expression
83*99e0aae7SDavid Rees  is `x + 1`, it will iteratively transform:
84*99e0aae7SDavid Rees
85*99e0aae7SDavid Rees      $logical_value = x + 1
86*99e0aae7SDavid Rees      $logical_value - 1 = x + 1 - 1
87*99e0aae7SDavid Rees      $logical_value - 1 = x
88*99e0aae7SDavid Rees
89*99e0aae7SDavid Rees  and finally return `x` and `$logical_value - 1`.
90*99e0aae7SDavid Rees
91*99e0aae7SDavid Rees  The purpose of this transformation is to find an assignment statement that can
92*99e0aae7SDavid Rees  be used to write back through certain virtual fields.  E.g., given:
93*99e0aae7SDavid Rees
94*99e0aae7SDavid Rees      struct Foo:
95*99e0aae7SDavid Rees        0 [+1]  UInt  raw_value
96*99e0aae7SDavid Rees        let actual_value = raw_value + 100
97*99e0aae7SDavid Rees
98*99e0aae7SDavid Rees  it should be possible to write a value to the `actual_value` field, and have
99*99e0aae7SDavid Rees  it set `raw_value` to the appropriate value.
100*99e0aae7SDavid Rees
101*99e0aae7SDavid Rees  Arguments:
102*99e0aae7SDavid Rees    expression: an ir_data.Expression to be inverted.
103*99e0aae7SDavid Rees    ir: the full IR, for looking up symbols.
104*99e0aae7SDavid Rees
105*99e0aae7SDavid Rees  Returns:
106*99e0aae7SDavid Rees    (field_reference, inverse_expression) if expression can be inverted,
107*99e0aae7SDavid Rees    otherwise None.
108*99e0aae7SDavid Rees  """
109*99e0aae7SDavid Rees  reference_path = _find_field_reference_path(expression)
110*99e0aae7SDavid Rees  if reference_path is None:
111*99e0aae7SDavid Rees    return None
112*99e0aae7SDavid Rees  subexpression = expression
113*99e0aae7SDavid Rees  result = ir_data.Expression(
114*99e0aae7SDavid Rees      builtin_reference=ir_data.Reference(
115*99e0aae7SDavid Rees          canonical_name=ir_data.CanonicalName(
116*99e0aae7SDavid Rees              module_file="",
117*99e0aae7SDavid Rees              object_path=["$logical_value"]
118*99e0aae7SDavid Rees          ),
119*99e0aae7SDavid Rees          source_name=[ir_data.Word(
120*99e0aae7SDavid Rees              text="$logical_value",
121*99e0aae7SDavid Rees              source_location=ir_data.Location(is_synthetic=True)
122*99e0aae7SDavid Rees          )],
123*99e0aae7SDavid Rees          source_location=ir_data.Location(is_synthetic=True)
124*99e0aae7SDavid Rees      ),
125*99e0aae7SDavid Rees      type=expression.type,
126*99e0aae7SDavid Rees      source_location=ir_data.Location(is_synthetic=True)
127*99e0aae7SDavid Rees  )
128*99e0aae7SDavid Rees
129*99e0aae7SDavid Rees  # This loop essentially starts with:
130*99e0aae7SDavid Rees  #
131*99e0aae7SDavid Rees  #     f(g(x)) == $logical_value
132*99e0aae7SDavid Rees  #
133*99e0aae7SDavid Rees  # and ends with
134*99e0aae7SDavid Rees  #
135*99e0aae7SDavid Rees  #     x == g_inv(f_inv($logical_value))
136*99e0aae7SDavid Rees  #
137*99e0aae7SDavid Rees  # At each step, `subexpression` has one layer removed, and `result` has a
138*99e0aae7SDavid Rees  # corresponding inverse function applied.  So, for example, it might start
139*99e0aae7SDavid Rees  # with:
140*99e0aae7SDavid Rees  #
141*99e0aae7SDavid Rees  #     2 + ((3 - x) - 10)  ==  $logical_value
142*99e0aae7SDavid Rees  #
143*99e0aae7SDavid Rees  # On each iteration, `subexpression` and `result` will become:
144*99e0aae7SDavid Rees  #
145*99e0aae7SDavid Rees  #     (3 - x) - 10  ==  $logical_value - 2    [subtract 2 from both sides]
146*99e0aae7SDavid Rees  #     (3 - x)  ==  ($logical_value - 2) + 10  [add 10 to both sides]
147*99e0aae7SDavid Rees  #     x  ==  3 - (($logical_value - 2) + 10)  [subtract both sides from 3]
148*99e0aae7SDavid Rees  #
149*99e0aae7SDavid Rees  # This is an extremely limited algebraic solver, but it covers common-enough
150*99e0aae7SDavid Rees  # cases.
151*99e0aae7SDavid Rees  #
152*99e0aae7SDavid Rees  # Note that any equation that can be solved here becomes part of Emboss's
153*99e0aae7SDavid Rees  # contract, forever, so be conservative in expanding its solving capabilities!
154*99e0aae7SDavid Rees  for index in reference_path:
155*99e0aae7SDavid Rees    if subexpression.function.function == ir_data.FunctionMapping.ADDITION:
156*99e0aae7SDavid Rees      result = ir_data.Expression(
157*99e0aae7SDavid Rees          function=ir_data.Function(
158*99e0aae7SDavid Rees              function=ir_data.FunctionMapping.SUBTRACTION,
159*99e0aae7SDavid Rees              args=[
160*99e0aae7SDavid Rees                  result,
161*99e0aae7SDavid Rees                  subexpression.function.args[1 - index],
162*99e0aae7SDavid Rees              ]
163*99e0aae7SDavid Rees          ),
164*99e0aae7SDavid Rees          type=ir_data.ExpressionType(integer=ir_data.IntegerType())
165*99e0aae7SDavid Rees      )
166*99e0aae7SDavid Rees    elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION:
167*99e0aae7SDavid Rees      if index == 0:
168*99e0aae7SDavid Rees        result = ir_data.Expression(
169*99e0aae7SDavid Rees            function=ir_data.Function(
170*99e0aae7SDavid Rees                function=ir_data.FunctionMapping.ADDITION,
171*99e0aae7SDavid Rees                args=[
172*99e0aae7SDavid Rees                    result,
173*99e0aae7SDavid Rees                    subexpression.function.args[1],
174*99e0aae7SDavid Rees                ]
175*99e0aae7SDavid Rees            ),
176*99e0aae7SDavid Rees            type=ir_data.ExpressionType(integer=ir_data.IntegerType())
177*99e0aae7SDavid Rees        )
178*99e0aae7SDavid Rees      else:
179*99e0aae7SDavid Rees        result = ir_data.Expression(
180*99e0aae7SDavid Rees            function=ir_data.Function(
181*99e0aae7SDavid Rees                function=ir_data.FunctionMapping.SUBTRACTION,
182*99e0aae7SDavid Rees                args=[
183*99e0aae7SDavid Rees                    subexpression.function.args[0],
184*99e0aae7SDavid Rees                    result,
185*99e0aae7SDavid Rees                ]
186*99e0aae7SDavid Rees            ),
187*99e0aae7SDavid Rees            type=ir_data.ExpressionType(integer=ir_data.IntegerType())
188*99e0aae7SDavid Rees        )
189*99e0aae7SDavid Rees    else:
190*99e0aae7SDavid Rees      return None
191*99e0aae7SDavid Rees    subexpression = subexpression.function.args[index]
192*99e0aae7SDavid Rees  expression_bounds.compute_constraints_of_expression(result, ir)
193*99e0aae7SDavid Rees  return subexpression, result
194*99e0aae7SDavid Rees
195*99e0aae7SDavid Rees
196*99e0aae7SDavid Reesdef _add_write_method(field, ir):
197*99e0aae7SDavid Rees  """Adds an appropriate write_method to field, if applicable.
198*99e0aae7SDavid Rees
199*99e0aae7SDavid Rees  Currently, the "alias" write_method will be added for virtual fields of the
200*99e0aae7SDavid Rees  form `let v = some_field_reference` when `some_field_reference` is a physical
201*99e0aae7SDavid Rees  field or a writeable alias.  The "physical" write_method will be added for
202*99e0aae7SDavid Rees  physical fields.  The "transform" write_method will be added when the virtual
203*99e0aae7SDavid Rees  field's value is an easily-invertible function of a single writeable field.
204*99e0aae7SDavid Rees  All other fields will have the "read_only" write_method; i.e., they will not
205*99e0aae7SDavid Rees  be writeable.
206*99e0aae7SDavid Rees
207*99e0aae7SDavid Rees  Arguments:
208*99e0aae7SDavid Rees    field: an ir_data.Field to which to add a write_method.
209*99e0aae7SDavid Rees    ir: The IR in which to look up field_references.
210*99e0aae7SDavid Rees
211*99e0aae7SDavid Rees  Returns:
212*99e0aae7SDavid Rees    None
213*99e0aae7SDavid Rees  """
214*99e0aae7SDavid Rees  if field.HasField("write_method"):
215*99e0aae7SDavid Rees    # Do not recompute anything.
216*99e0aae7SDavid Rees    return
217*99e0aae7SDavid Rees
218*99e0aae7SDavid Rees  if not ir_util.field_is_virtual(field):
219*99e0aae7SDavid Rees    # If the field is not virtual, writes are physical.
220*99e0aae7SDavid Rees    ir_data_utils.builder(field).write_method.physical = True
221*99e0aae7SDavid Rees    return
222*99e0aae7SDavid Rees
223*99e0aae7SDavid Rees  field_checker = ir_data_utils.reader(field)
224*99e0aae7SDavid Rees  field_builder = ir_data_utils.builder(field)
225*99e0aae7SDavid Rees
226*99e0aae7SDavid Rees  # A virtual field cannot be a direct alias if it has an additional
227*99e0aae7SDavid Rees  # requirement.
228*99e0aae7SDavid Rees  requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
229*99e0aae7SDavid Rees  if (field_checker.read_transform.WhichOneof("expression") != "field_reference" or
230*99e0aae7SDavid Rees      requires_attr is not None):
231*99e0aae7SDavid Rees    inverse = _invert_expression(field.read_transform, ir)
232*99e0aae7SDavid Rees    if inverse:
233*99e0aae7SDavid Rees      field_reference, function_body = inverse
234*99e0aae7SDavid Rees      referenced_field = ir_util.find_object(
235*99e0aae7SDavid Rees          field_reference.field_reference.path[-1], ir)
236*99e0aae7SDavid Rees      if not isinstance(referenced_field, ir_data.Field):
237*99e0aae7SDavid Rees        reference_is_read_only = True
238*99e0aae7SDavid Rees      else:
239*99e0aae7SDavid Rees        _add_write_method(referenced_field, ir)
240*99e0aae7SDavid Rees        reference_is_read_only = referenced_field.write_method.read_only
241*99e0aae7SDavid Rees      if not reference_is_read_only:
242*99e0aae7SDavid Rees        field_builder.write_method.transform.destination.CopyFrom(
243*99e0aae7SDavid Rees            field_reference.field_reference)
244*99e0aae7SDavid Rees        field_builder.write_method.transform.function_body.CopyFrom(function_body)
245*99e0aae7SDavid Rees      else:
246*99e0aae7SDavid Rees        # If the virtual field's expression is invertible, but its target field
247*99e0aae7SDavid Rees        # is read-only, it is also read-only.
248*99e0aae7SDavid Rees        field_builder.write_method.read_only = True
249*99e0aae7SDavid Rees    else:
250*99e0aae7SDavid Rees      # If the virtual field's expression is not invertible, it is
251*99e0aae7SDavid Rees      # read-only.
252*99e0aae7SDavid Rees      field_builder.write_method.read_only = True
253*99e0aae7SDavid Rees    return
254*99e0aae7SDavid Rees
255*99e0aae7SDavid Rees  referenced_field = ir_util.find_object(
256*99e0aae7SDavid Rees      field.read_transform.field_reference.path[-1], ir)
257*99e0aae7SDavid Rees  if not isinstance(referenced_field, ir_data.Field):
258*99e0aae7SDavid Rees    # If the virtual field aliases a non-field (i.e., a parameter), it is
259*99e0aae7SDavid Rees    # read-only.
260*99e0aae7SDavid Rees    field_builder.write_method.read_only = True
261*99e0aae7SDavid Rees    return
262*99e0aae7SDavid Rees
263*99e0aae7SDavid Rees  _add_write_method(referenced_field, ir)
264*99e0aae7SDavid Rees  if referenced_field.write_method.read_only:
265*99e0aae7SDavid Rees    # If the virtual field directly aliases a read-only field, it is read-only.
266*99e0aae7SDavid Rees    field_builder.write_method.read_only = True
267*99e0aae7SDavid Rees    return
268*99e0aae7SDavid Rees
269*99e0aae7SDavid Rees  # Otherwise, it can be written as a direct alias.
270*99e0aae7SDavid Rees  field_builder.write_method.alias.CopyFrom(
271*99e0aae7SDavid Rees      field.read_transform.field_reference)
272*99e0aae7SDavid Rees
273*99e0aae7SDavid Rees
274*99e0aae7SDavid Reesdef set_write_methods(ir):
275*99e0aae7SDavid Rees  """Sets the write_method member of all ir_data.Fields in ir.
276*99e0aae7SDavid Rees
277*99e0aae7SDavid Rees  Arguments:
278*99e0aae7SDavid Rees      ir: The IR to which to add write_methods.
279*99e0aae7SDavid Rees
280*99e0aae7SDavid Rees  Returns:
281*99e0aae7SDavid Rees      A list of errors, or an empty list.
282*99e0aae7SDavid Rees  """
283*99e0aae7SDavid Rees  traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method)
284*99e0aae7SDavid Rees  return []
285