xref: /aosp_15_r20/external/protobuf/python/google/protobuf/internal/containers.py (revision 1b3f573f81763fcece89efc2b6a5209149e44ab8)
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains container classes to represent different protocol buffer types.
32
33This file defines container classes which represent categories of protocol
34buffer field types which need extra maintenance. Currently these categories
35are:
36
37-   Repeated scalar fields - These are all repeated fields which aren't
38    composite (e.g. they are of simple types like int32, string, etc).
39-   Repeated composite fields - Repeated fields which are composite. This
40    includes groups and nested messages.
41"""
42
43import collections.abc
44import copy
45import pickle
46from typing import (
47    Any,
48    Iterable,
49    Iterator,
50    List,
51    MutableMapping,
52    MutableSequence,
53    NoReturn,
54    Optional,
55    Sequence,
56    TypeVar,
57    Union,
58    overload,
59)
60
61
62_T = TypeVar('_T')
63_K = TypeVar('_K')
64_V = TypeVar('_V')
65
66
67class BaseContainer(Sequence[_T]):
68  """Base container class."""
69
70  # Minimizes memory usage and disallows assignment to other attributes.
71  __slots__ = ['_message_listener', '_values']
72
73  def __init__(self, message_listener: Any) -> None:
74    """
75    Args:
76      message_listener: A MessageListener implementation.
77        The RepeatedScalarFieldContainer will call this object's
78        Modified() method when it is modified.
79    """
80    self._message_listener = message_listener
81    self._values = []
82
83  @overload
84  def __getitem__(self, key: int) -> _T:
85    ...
86
87  @overload
88  def __getitem__(self, key: slice) -> List[_T]:
89    ...
90
91  def __getitem__(self, key):
92    """Retrieves item by the specified key."""
93    return self._values[key]
94
95  def __len__(self) -> int:
96    """Returns the number of elements in the container."""
97    return len(self._values)
98
99  def __ne__(self, other: Any) -> bool:
100    """Checks if another instance isn't equal to this one."""
101    # The concrete classes should define __eq__.
102    return not self == other
103
104  __hash__ = None
105
106  def __repr__(self) -> str:
107    return repr(self._values)
108
109  def sort(self, *args, **kwargs) -> None:
110    # Continue to support the old sort_function keyword argument.
111    # This is expected to be a rare occurrence, so use LBYL to avoid
112    # the overhead of actually catching KeyError.
113    if 'sort_function' in kwargs:
114      kwargs['cmp'] = kwargs.pop('sort_function')
115    self._values.sort(*args, **kwargs)
116
117  def reverse(self) -> None:
118    self._values.reverse()
119
120
121# TODO(slebedev): Remove this. BaseContainer does *not* conform to
122# MutableSequence, only its subclasses do.
123collections.abc.MutableSequence.register(BaseContainer)
124
125
126class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
127  """Simple, type-checked, list-like container for holding repeated scalars."""
128
129  # Disallows assignment to other attributes.
130  __slots__ = ['_type_checker']
131
132  def __init__(
133      self,
134      message_listener: Any,
135      type_checker: Any,
136  ) -> None:
137    """Args:
138
139      message_listener: A MessageListener implementation. The
140      RepeatedScalarFieldContainer will call this object's Modified() method
141      when it is modified.
142      type_checker: A type_checkers.ValueChecker instance to run on elements
143      inserted into this container.
144    """
145    super().__init__(message_listener)
146    self._type_checker = type_checker
147
148  def append(self, value: _T) -> None:
149    """Appends an item to the list. Similar to list.append()."""
150    self._values.append(self._type_checker.CheckValue(value))
151    if not self._message_listener.dirty:
152      self._message_listener.Modified()
153
154  def insert(self, key: int, value: _T) -> None:
155    """Inserts the item at the specified position. Similar to list.insert()."""
156    self._values.insert(key, self._type_checker.CheckValue(value))
157    if not self._message_listener.dirty:
158      self._message_listener.Modified()
159
160  def extend(self, elem_seq: Iterable[_T]) -> None:
161    """Extends by appending the given iterable. Similar to list.extend()."""
162    if elem_seq is None:
163      return
164    try:
165      elem_seq_iter = iter(elem_seq)
166    except TypeError:
167      if not elem_seq:
168        # silently ignore falsy inputs :-/.
169        # TODO(ptucker): Deprecate this behavior. b/18413862
170        return
171      raise
172
173    new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
174    if new_values:
175      self._values.extend(new_values)
176    self._message_listener.Modified()
177
178  def MergeFrom(
179      self,
180      other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
181  ) -> None:
182    """Appends the contents of another repeated field of the same type to this
183    one. We do not check the types of the individual fields.
184    """
185    self._values.extend(other)
186    self._message_listener.Modified()
187
188  def remove(self, elem: _T):
189    """Removes an item from the list. Similar to list.remove()."""
190    self._values.remove(elem)
191    self._message_listener.Modified()
192
193  def pop(self, key: Optional[int] = -1) -> _T:
194    """Removes and returns an item at a given index. Similar to list.pop()."""
195    value = self._values[key]
196    self.__delitem__(key)
197    return value
198
199  @overload
200  def __setitem__(self, key: int, value: _T) -> None:
201    ...
202
203  @overload
204  def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
205    ...
206
207  def __setitem__(self, key, value) -> None:
208    """Sets the item on the specified position."""
209    if isinstance(key, slice):
210      if key.step is not None:
211        raise ValueError('Extended slices not supported')
212      self._values[key] = map(self._type_checker.CheckValue, value)
213      self._message_listener.Modified()
214    else:
215      self._values[key] = self._type_checker.CheckValue(value)
216      self._message_listener.Modified()
217
218  def __delitem__(self, key: Union[int, slice]) -> None:
219    """Deletes the item at the specified position."""
220    del self._values[key]
221    self._message_listener.Modified()
222
223  def __eq__(self, other: Any) -> bool:
224    """Compares the current instance with another one."""
225    if self is other:
226      return True
227    # Special case for the same type which should be common and fast.
228    if isinstance(other, self.__class__):
229      return other._values == self._values
230    # We are presumably comparing against some other sequence type.
231    return other == self._values
232
233  def __deepcopy__(
234      self,
235      unused_memo: Any = None,
236  ) -> 'RepeatedScalarFieldContainer[_T]':
237    clone = RepeatedScalarFieldContainer(
238        copy.deepcopy(self._message_listener), self._type_checker)
239    clone.MergeFrom(self)
240    return clone
241
242  def __reduce__(self, **kwargs) -> NoReturn:
243    raise pickle.PickleError(
244        "Can't pickle repeated scalar fields, convert to list first")
245
246
247# TODO(slebedev): Constrain T to be a subtype of Message.
248class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
249  """Simple, list-like container for holding repeated composite fields."""
250
251  # Disallows assignment to other attributes.
252  __slots__ = ['_message_descriptor']
253
254  def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
255    """
256    Note that we pass in a descriptor instead of the generated directly,
257    since at the time we construct a _RepeatedCompositeFieldContainer we
258    haven't yet necessarily initialized the type that will be contained in the
259    container.
260
261    Args:
262      message_listener: A MessageListener implementation.
263        The RepeatedCompositeFieldContainer will call this object's
264        Modified() method when it is modified.
265      message_descriptor: A Descriptor instance describing the protocol type
266        that should be present in this container.  We'll use the
267        _concrete_class field of this descriptor when the client calls add().
268    """
269    super().__init__(message_listener)
270    self._message_descriptor = message_descriptor
271
272  def add(self, **kwargs: Any) -> _T:
273    """Adds a new element at the end of the list and returns it. Keyword
274    arguments may be used to initialize the element.
275    """
276    new_element = self._message_descriptor._concrete_class(**kwargs)
277    new_element._SetListener(self._message_listener)
278    self._values.append(new_element)
279    if not self._message_listener.dirty:
280      self._message_listener.Modified()
281    return new_element
282
283  def append(self, value: _T) -> None:
284    """Appends one element by copying the message."""
285    new_element = self._message_descriptor._concrete_class()
286    new_element._SetListener(self._message_listener)
287    new_element.CopyFrom(value)
288    self._values.append(new_element)
289    if not self._message_listener.dirty:
290      self._message_listener.Modified()
291
292  def insert(self, key: int, value: _T) -> None:
293    """Inserts the item at the specified position by copying."""
294    new_element = self._message_descriptor._concrete_class()
295    new_element._SetListener(self._message_listener)
296    new_element.CopyFrom(value)
297    self._values.insert(key, new_element)
298    if not self._message_listener.dirty:
299      self._message_listener.Modified()
300
301  def extend(self, elem_seq: Iterable[_T]) -> None:
302    """Extends by appending the given sequence of elements of the same type
303
304    as this one, copying each individual message.
305    """
306    message_class = self._message_descriptor._concrete_class
307    listener = self._message_listener
308    values = self._values
309    for message in elem_seq:
310      new_element = message_class()
311      new_element._SetListener(listener)
312      new_element.MergeFrom(message)
313      values.append(new_element)
314    listener.Modified()
315
316  def MergeFrom(
317      self,
318      other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
319  ) -> None:
320    """Appends the contents of another repeated field of the same type to this
321    one, copying each individual message.
322    """
323    self.extend(other)
324
325  def remove(self, elem: _T) -> None:
326    """Removes an item from the list. Similar to list.remove()."""
327    self._values.remove(elem)
328    self._message_listener.Modified()
329
330  def pop(self, key: Optional[int] = -1) -> _T:
331    """Removes and returns an item at a given index. Similar to list.pop()."""
332    value = self._values[key]
333    self.__delitem__(key)
334    return value
335
336  @overload
337  def __setitem__(self, key: int, value: _T) -> None:
338    ...
339
340  @overload
341  def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
342    ...
343
344  def __setitem__(self, key, value):
345    # This method is implemented to make RepeatedCompositeFieldContainer
346    # structurally compatible with typing.MutableSequence. It is
347    # otherwise unsupported and will always raise an error.
348    raise TypeError(
349        f'{self.__class__.__name__} object does not support item assignment')
350
351  def __delitem__(self, key: Union[int, slice]) -> None:
352    """Deletes the item at the specified position."""
353    del self._values[key]
354    self._message_listener.Modified()
355
356  def __eq__(self, other: Any) -> bool:
357    """Compares the current instance with another one."""
358    if self is other:
359      return True
360    if not isinstance(other, self.__class__):
361      raise TypeError('Can only compare repeated composite fields against '
362                      'other repeated composite fields.')
363    return self._values == other._values
364
365
366class ScalarMap(MutableMapping[_K, _V]):
367  """Simple, type-checked, dict-like container for holding repeated scalars."""
368
369  # Disallows assignment to other attributes.
370  __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
371               '_entry_descriptor']
372
373  def __init__(
374      self,
375      message_listener: Any,
376      key_checker: Any,
377      value_checker: Any,
378      entry_descriptor: Any,
379  ) -> None:
380    """
381    Args:
382      message_listener: A MessageListener implementation.
383        The ScalarMap will call this object's Modified() method when it
384        is modified.
385      key_checker: A type_checkers.ValueChecker instance to run on keys
386        inserted into this container.
387      value_checker: A type_checkers.ValueChecker instance to run on values
388        inserted into this container.
389      entry_descriptor: The MessageDescriptor of a map entry: key and value.
390    """
391    self._message_listener = message_listener
392    self._key_checker = key_checker
393    self._value_checker = value_checker
394    self._entry_descriptor = entry_descriptor
395    self._values = {}
396
397  def __getitem__(self, key: _K) -> _V:
398    try:
399      return self._values[key]
400    except KeyError:
401      key = self._key_checker.CheckValue(key)
402      val = self._value_checker.DefaultValue()
403      self._values[key] = val
404      return val
405
406  def __contains__(self, item: _K) -> bool:
407    # We check the key's type to match the strong-typing flavor of the API.
408    # Also this makes it easier to match the behavior of the C++ implementation.
409    self._key_checker.CheckValue(item)
410    return item in self._values
411
412  @overload
413  def get(self, key: _K) -> Optional[_V]:
414    ...
415
416  @overload
417  def get(self, key: _K, default: _T) -> Union[_V, _T]:
418    ...
419
420  # We need to override this explicitly, because our defaultdict-like behavior
421  # will make the default implementation (from our base class) always insert
422  # the key.
423  def get(self, key, default=None):
424    if key in self:
425      return self[key]
426    else:
427      return default
428
429  def __setitem__(self, key: _K, value: _V) -> _T:
430    checked_key = self._key_checker.CheckValue(key)
431    checked_value = self._value_checker.CheckValue(value)
432    self._values[checked_key] = checked_value
433    self._message_listener.Modified()
434
435  def __delitem__(self, key: _K) -> None:
436    del self._values[key]
437    self._message_listener.Modified()
438
439  def __len__(self) -> int:
440    return len(self._values)
441
442  def __iter__(self) -> Iterator[_K]:
443    return iter(self._values)
444
445  def __repr__(self) -> str:
446    return repr(self._values)
447
448  def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
449    self._values.update(other._values)
450    self._message_listener.Modified()
451
452  def InvalidateIterators(self) -> None:
453    # It appears that the only way to reliably invalidate iterators to
454    # self._values is to ensure that its size changes.
455    original = self._values
456    self._values = original.copy()
457    original[None] = None
458
459  # This is defined in the abstract base, but we can do it much more cheaply.
460  def clear(self) -> None:
461    self._values.clear()
462    self._message_listener.Modified()
463
464  def GetEntryClass(self) -> Any:
465    return self._entry_descriptor._concrete_class
466
467
468class MessageMap(MutableMapping[_K, _V]):
469  """Simple, type-checked, dict-like container for with submessage values."""
470
471  # Disallows assignment to other attributes.
472  __slots__ = ['_key_checker', '_values', '_message_listener',
473               '_message_descriptor', '_entry_descriptor']
474
475  def __init__(
476      self,
477      message_listener: Any,
478      message_descriptor: Any,
479      key_checker: Any,
480      entry_descriptor: Any,
481  ) -> None:
482    """
483    Args:
484      message_listener: A MessageListener implementation.
485        The ScalarMap will call this object's Modified() method when it
486        is modified.
487      key_checker: A type_checkers.ValueChecker instance to run on keys
488        inserted into this container.
489      value_checker: A type_checkers.ValueChecker instance to run on values
490        inserted into this container.
491      entry_descriptor: The MessageDescriptor of a map entry: key and value.
492    """
493    self._message_listener = message_listener
494    self._message_descriptor = message_descriptor
495    self._key_checker = key_checker
496    self._entry_descriptor = entry_descriptor
497    self._values = {}
498
499  def __getitem__(self, key: _K) -> _V:
500    key = self._key_checker.CheckValue(key)
501    try:
502      return self._values[key]
503    except KeyError:
504      new_element = self._message_descriptor._concrete_class()
505      new_element._SetListener(self._message_listener)
506      self._values[key] = new_element
507      self._message_listener.Modified()
508      return new_element
509
510  def get_or_create(self, key: _K) -> _V:
511    """get_or_create() is an alias for getitem (ie. map[key]).
512
513    Args:
514      key: The key to get or create in the map.
515
516    This is useful in cases where you want to be explicit that the call is
517    mutating the map.  This can avoid lint errors for statements like this
518    that otherwise would appear to be pointless statements:
519
520      msg.my_map[key]
521    """
522    return self[key]
523
524  @overload
525  def get(self, key: _K) -> Optional[_V]:
526    ...
527
528  @overload
529  def get(self, key: _K, default: _T) -> Union[_V, _T]:
530    ...
531
532  # We need to override this explicitly, because our defaultdict-like behavior
533  # will make the default implementation (from our base class) always insert
534  # the key.
535  def get(self, key, default=None):
536    if key in self:
537      return self[key]
538    else:
539      return default
540
541  def __contains__(self, item: _K) -> bool:
542    item = self._key_checker.CheckValue(item)
543    return item in self._values
544
545  def __setitem__(self, key: _K, value: _V) -> NoReturn:
546    raise ValueError('May not set values directly, call my_map[key].foo = 5')
547
548  def __delitem__(self, key: _K) -> None:
549    key = self._key_checker.CheckValue(key)
550    del self._values[key]
551    self._message_listener.Modified()
552
553  def __len__(self) -> int:
554    return len(self._values)
555
556  def __iter__(self) -> Iterator[_K]:
557    return iter(self._values)
558
559  def __repr__(self) -> str:
560    return repr(self._values)
561
562  def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
563    # pylint: disable=protected-access
564    for key in other._values:
565      # According to documentation: "When parsing from the wire or when merging,
566      # if there are duplicate map keys the last key seen is used".
567      if key in self:
568        del self[key]
569      self[key].CopyFrom(other[key])
570    # self._message_listener.Modified() not required here, because
571    # mutations to submessages already propagate.
572
573  def InvalidateIterators(self) -> None:
574    # It appears that the only way to reliably invalidate iterators to
575    # self._values is to ensure that its size changes.
576    original = self._values
577    self._values = original.copy()
578    original[None] = None
579
580  # This is defined in the abstract base, but we can do it much more cheaply.
581  def clear(self) -> None:
582    self._values.clear()
583    self._message_listener.Modified()
584
585  def GetEntryClass(self) -> Any:
586    return self._entry_descriptor._concrete_class
587
588
589class _UnknownField:
590  """A parsed unknown field."""
591
592  # Disallows assignment to other attributes.
593  __slots__ = ['_field_number', '_wire_type', '_data']
594
595  def __init__(self, field_number, wire_type, data):
596    self._field_number = field_number
597    self._wire_type = wire_type
598    self._data = data
599    return
600
601  def __lt__(self, other):
602    # pylint: disable=protected-access
603    return self._field_number < other._field_number
604
605  def __eq__(self, other):
606    if self is other:
607      return True
608    # pylint: disable=protected-access
609    return (self._field_number == other._field_number and
610            self._wire_type == other._wire_type and
611            self._data == other._data)
612
613
614class UnknownFieldRef:  # pylint: disable=missing-class-docstring
615
616  def __init__(self, parent, index):
617    self._parent = parent
618    self._index = index
619
620  def _check_valid(self):
621    if not self._parent:
622      raise ValueError('UnknownField does not exist. '
623                       'The parent message might be cleared.')
624    if self._index >= len(self._parent):
625      raise ValueError('UnknownField does not exist. '
626                       'The parent message might be cleared.')
627
628  @property
629  def field_number(self):
630    self._check_valid()
631    # pylint: disable=protected-access
632    return self._parent._internal_get(self._index)._field_number
633
634  @property
635  def wire_type(self):
636    self._check_valid()
637    # pylint: disable=protected-access
638    return self._parent._internal_get(self._index)._wire_type
639
640  @property
641  def data(self):
642    self._check_valid()
643    # pylint: disable=protected-access
644    return self._parent._internal_get(self._index)._data
645
646
647class UnknownFieldSet:
648  """UnknownField container"""
649
650  # Disallows assignment to other attributes.
651  __slots__ = ['_values']
652
653  def __init__(self):
654    self._values = []
655
656  def __getitem__(self, index):
657    if self._values is None:
658      raise ValueError('UnknownFields does not exist. '
659                       'The parent message might be cleared.')
660    size = len(self._values)
661    if index < 0:
662      index += size
663    if index < 0 or index >= size:
664      raise IndexError('index %d out of range'.index)
665
666    return UnknownFieldRef(self, index)
667
668  def _internal_get(self, index):
669    return self._values[index]
670
671  def __len__(self):
672    if self._values is None:
673      raise ValueError('UnknownFields does not exist. '
674                       'The parent message might be cleared.')
675    return len(self._values)
676
677  def _add(self, field_number, wire_type, data):
678    unknown_field = _UnknownField(field_number, wire_type, data)
679    self._values.append(unknown_field)
680    return unknown_field
681
682  def __iter__(self):
683    for i in range(len(self)):
684      yield UnknownFieldRef(self, i)
685
686  def _extend(self, other):
687    if other is None:
688      return
689    # pylint: disable=protected-access
690    self._values.extend(other._values)
691
692  def __eq__(self, other):
693    if self is other:
694      return True
695    # Sort unknown fields because their order shouldn't
696    # affect equality test.
697    values = list(self._values)
698    if other is None:
699      return not values
700    values.sort()
701    # pylint: disable=protected-access
702    other_values = sorted(other._values)
703    return values == other_values
704
705  def _clear(self):
706    for value in self._values:
707      # pylint: disable=protected-access
708      if isinstance(value._data, UnknownFieldSet):
709        value._data._clear()  # pylint: disable=protected-access
710    self._values = None
711