1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Contains container classes to represent different protocol buffer types. 32 33This file defines container classes which represent categories of protocol 34buffer field types which need extra maintenance. Currently these categories 35are: 36 37- Repeated scalar fields - These are all repeated fields which aren't 38 composite (e.g. they are of simple types like int32, string, etc). 39- Repeated composite fields - Repeated fields which are composite. This 40 includes groups and nested messages. 41""" 42 43import collections.abc 44import copy 45import pickle 46from typing import ( 47 Any, 48 Iterable, 49 Iterator, 50 List, 51 MutableMapping, 52 MutableSequence, 53 NoReturn, 54 Optional, 55 Sequence, 56 TypeVar, 57 Union, 58 overload, 59) 60 61 62_T = TypeVar('_T') 63_K = TypeVar('_K') 64_V = TypeVar('_V') 65 66 67class BaseContainer(Sequence[_T]): 68 """Base container class.""" 69 70 # Minimizes memory usage and disallows assignment to other attributes. 71 __slots__ = ['_message_listener', '_values'] 72 73 def __init__(self, message_listener: Any) -> None: 74 """ 75 Args: 76 message_listener: A MessageListener implementation. 77 The RepeatedScalarFieldContainer will call this object's 78 Modified() method when it is modified. 79 """ 80 self._message_listener = message_listener 81 self._values = [] 82 83 @overload 84 def __getitem__(self, key: int) -> _T: 85 ... 86 87 @overload 88 def __getitem__(self, key: slice) -> List[_T]: 89 ... 90 91 def __getitem__(self, key): 92 """Retrieves item by the specified key.""" 93 return self._values[key] 94 95 def __len__(self) -> int: 96 """Returns the number of elements in the container.""" 97 return len(self._values) 98 99 def __ne__(self, other: Any) -> bool: 100 """Checks if another instance isn't equal to this one.""" 101 # The concrete classes should define __eq__. 102 return not self == other 103 104 __hash__ = None 105 106 def __repr__(self) -> str: 107 return repr(self._values) 108 109 def sort(self, *args, **kwargs) -> None: 110 # Continue to support the old sort_function keyword argument. 111 # This is expected to be a rare occurrence, so use LBYL to avoid 112 # the overhead of actually catching KeyError. 113 if 'sort_function' in kwargs: 114 kwargs['cmp'] = kwargs.pop('sort_function') 115 self._values.sort(*args, **kwargs) 116 117 def reverse(self) -> None: 118 self._values.reverse() 119 120 121# TODO(slebedev): Remove this. BaseContainer does *not* conform to 122# MutableSequence, only its subclasses do. 123collections.abc.MutableSequence.register(BaseContainer) 124 125 126class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]): 127 """Simple, type-checked, list-like container for holding repeated scalars.""" 128 129 # Disallows assignment to other attributes. 130 __slots__ = ['_type_checker'] 131 132 def __init__( 133 self, 134 message_listener: Any, 135 type_checker: Any, 136 ) -> None: 137 """Args: 138 139 message_listener: A MessageListener implementation. The 140 RepeatedScalarFieldContainer will call this object's Modified() method 141 when it is modified. 142 type_checker: A type_checkers.ValueChecker instance to run on elements 143 inserted into this container. 144 """ 145 super().__init__(message_listener) 146 self._type_checker = type_checker 147 148 def append(self, value: _T) -> None: 149 """Appends an item to the list. Similar to list.append().""" 150 self._values.append(self._type_checker.CheckValue(value)) 151 if not self._message_listener.dirty: 152 self._message_listener.Modified() 153 154 def insert(self, key: int, value: _T) -> None: 155 """Inserts the item at the specified position. Similar to list.insert().""" 156 self._values.insert(key, self._type_checker.CheckValue(value)) 157 if not self._message_listener.dirty: 158 self._message_listener.Modified() 159 160 def extend(self, elem_seq: Iterable[_T]) -> None: 161 """Extends by appending the given iterable. Similar to list.extend().""" 162 if elem_seq is None: 163 return 164 try: 165 elem_seq_iter = iter(elem_seq) 166 except TypeError: 167 if not elem_seq: 168 # silently ignore falsy inputs :-/. 169 # TODO(ptucker): Deprecate this behavior. b/18413862 170 return 171 raise 172 173 new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter] 174 if new_values: 175 self._values.extend(new_values) 176 self._message_listener.Modified() 177 178 def MergeFrom( 179 self, 180 other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]], 181 ) -> None: 182 """Appends the contents of another repeated field of the same type to this 183 one. We do not check the types of the individual fields. 184 """ 185 self._values.extend(other) 186 self._message_listener.Modified() 187 188 def remove(self, elem: _T): 189 """Removes an item from the list. Similar to list.remove().""" 190 self._values.remove(elem) 191 self._message_listener.Modified() 192 193 def pop(self, key: Optional[int] = -1) -> _T: 194 """Removes and returns an item at a given index. Similar to list.pop().""" 195 value = self._values[key] 196 self.__delitem__(key) 197 return value 198 199 @overload 200 def __setitem__(self, key: int, value: _T) -> None: 201 ... 202 203 @overload 204 def __setitem__(self, key: slice, value: Iterable[_T]) -> None: 205 ... 206 207 def __setitem__(self, key, value) -> None: 208 """Sets the item on the specified position.""" 209 if isinstance(key, slice): 210 if key.step is not None: 211 raise ValueError('Extended slices not supported') 212 self._values[key] = map(self._type_checker.CheckValue, value) 213 self._message_listener.Modified() 214 else: 215 self._values[key] = self._type_checker.CheckValue(value) 216 self._message_listener.Modified() 217 218 def __delitem__(self, key: Union[int, slice]) -> None: 219 """Deletes the item at the specified position.""" 220 del self._values[key] 221 self._message_listener.Modified() 222 223 def __eq__(self, other: Any) -> bool: 224 """Compares the current instance with another one.""" 225 if self is other: 226 return True 227 # Special case for the same type which should be common and fast. 228 if isinstance(other, self.__class__): 229 return other._values == self._values 230 # We are presumably comparing against some other sequence type. 231 return other == self._values 232 233 def __deepcopy__( 234 self, 235 unused_memo: Any = None, 236 ) -> 'RepeatedScalarFieldContainer[_T]': 237 clone = RepeatedScalarFieldContainer( 238 copy.deepcopy(self._message_listener), self._type_checker) 239 clone.MergeFrom(self) 240 return clone 241 242 def __reduce__(self, **kwargs) -> NoReturn: 243 raise pickle.PickleError( 244 "Can't pickle repeated scalar fields, convert to list first") 245 246 247# TODO(slebedev): Constrain T to be a subtype of Message. 248class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]): 249 """Simple, list-like container for holding repeated composite fields.""" 250 251 # Disallows assignment to other attributes. 252 __slots__ = ['_message_descriptor'] 253 254 def __init__(self, message_listener: Any, message_descriptor: Any) -> None: 255 """ 256 Note that we pass in a descriptor instead of the generated directly, 257 since at the time we construct a _RepeatedCompositeFieldContainer we 258 haven't yet necessarily initialized the type that will be contained in the 259 container. 260 261 Args: 262 message_listener: A MessageListener implementation. 263 The RepeatedCompositeFieldContainer will call this object's 264 Modified() method when it is modified. 265 message_descriptor: A Descriptor instance describing the protocol type 266 that should be present in this container. We'll use the 267 _concrete_class field of this descriptor when the client calls add(). 268 """ 269 super().__init__(message_listener) 270 self._message_descriptor = message_descriptor 271 272 def add(self, **kwargs: Any) -> _T: 273 """Adds a new element at the end of the list and returns it. Keyword 274 arguments may be used to initialize the element. 275 """ 276 new_element = self._message_descriptor._concrete_class(**kwargs) 277 new_element._SetListener(self._message_listener) 278 self._values.append(new_element) 279 if not self._message_listener.dirty: 280 self._message_listener.Modified() 281 return new_element 282 283 def append(self, value: _T) -> None: 284 """Appends one element by copying the message.""" 285 new_element = self._message_descriptor._concrete_class() 286 new_element._SetListener(self._message_listener) 287 new_element.CopyFrom(value) 288 self._values.append(new_element) 289 if not self._message_listener.dirty: 290 self._message_listener.Modified() 291 292 def insert(self, key: int, value: _T) -> None: 293 """Inserts the item at the specified position by copying.""" 294 new_element = self._message_descriptor._concrete_class() 295 new_element._SetListener(self._message_listener) 296 new_element.CopyFrom(value) 297 self._values.insert(key, new_element) 298 if not self._message_listener.dirty: 299 self._message_listener.Modified() 300 301 def extend(self, elem_seq: Iterable[_T]) -> None: 302 """Extends by appending the given sequence of elements of the same type 303 304 as this one, copying each individual message. 305 """ 306 message_class = self._message_descriptor._concrete_class 307 listener = self._message_listener 308 values = self._values 309 for message in elem_seq: 310 new_element = message_class() 311 new_element._SetListener(listener) 312 new_element.MergeFrom(message) 313 values.append(new_element) 314 listener.Modified() 315 316 def MergeFrom( 317 self, 318 other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]], 319 ) -> None: 320 """Appends the contents of another repeated field of the same type to this 321 one, copying each individual message. 322 """ 323 self.extend(other) 324 325 def remove(self, elem: _T) -> None: 326 """Removes an item from the list. Similar to list.remove().""" 327 self._values.remove(elem) 328 self._message_listener.Modified() 329 330 def pop(self, key: Optional[int] = -1) -> _T: 331 """Removes and returns an item at a given index. Similar to list.pop().""" 332 value = self._values[key] 333 self.__delitem__(key) 334 return value 335 336 @overload 337 def __setitem__(self, key: int, value: _T) -> None: 338 ... 339 340 @overload 341 def __setitem__(self, key: slice, value: Iterable[_T]) -> None: 342 ... 343 344 def __setitem__(self, key, value): 345 # This method is implemented to make RepeatedCompositeFieldContainer 346 # structurally compatible with typing.MutableSequence. It is 347 # otherwise unsupported and will always raise an error. 348 raise TypeError( 349 f'{self.__class__.__name__} object does not support item assignment') 350 351 def __delitem__(self, key: Union[int, slice]) -> None: 352 """Deletes the item at the specified position.""" 353 del self._values[key] 354 self._message_listener.Modified() 355 356 def __eq__(self, other: Any) -> bool: 357 """Compares the current instance with another one.""" 358 if self is other: 359 return True 360 if not isinstance(other, self.__class__): 361 raise TypeError('Can only compare repeated composite fields against ' 362 'other repeated composite fields.') 363 return self._values == other._values 364 365 366class ScalarMap(MutableMapping[_K, _V]): 367 """Simple, type-checked, dict-like container for holding repeated scalars.""" 368 369 # Disallows assignment to other attributes. 370 __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener', 371 '_entry_descriptor'] 372 373 def __init__( 374 self, 375 message_listener: Any, 376 key_checker: Any, 377 value_checker: Any, 378 entry_descriptor: Any, 379 ) -> None: 380 """ 381 Args: 382 message_listener: A MessageListener implementation. 383 The ScalarMap will call this object's Modified() method when it 384 is modified. 385 key_checker: A type_checkers.ValueChecker instance to run on keys 386 inserted into this container. 387 value_checker: A type_checkers.ValueChecker instance to run on values 388 inserted into this container. 389 entry_descriptor: The MessageDescriptor of a map entry: key and value. 390 """ 391 self._message_listener = message_listener 392 self._key_checker = key_checker 393 self._value_checker = value_checker 394 self._entry_descriptor = entry_descriptor 395 self._values = {} 396 397 def __getitem__(self, key: _K) -> _V: 398 try: 399 return self._values[key] 400 except KeyError: 401 key = self._key_checker.CheckValue(key) 402 val = self._value_checker.DefaultValue() 403 self._values[key] = val 404 return val 405 406 def __contains__(self, item: _K) -> bool: 407 # We check the key's type to match the strong-typing flavor of the API. 408 # Also this makes it easier to match the behavior of the C++ implementation. 409 self._key_checker.CheckValue(item) 410 return item in self._values 411 412 @overload 413 def get(self, key: _K) -> Optional[_V]: 414 ... 415 416 @overload 417 def get(self, key: _K, default: _T) -> Union[_V, _T]: 418 ... 419 420 # We need to override this explicitly, because our defaultdict-like behavior 421 # will make the default implementation (from our base class) always insert 422 # the key. 423 def get(self, key, default=None): 424 if key in self: 425 return self[key] 426 else: 427 return default 428 429 def __setitem__(self, key: _K, value: _V) -> _T: 430 checked_key = self._key_checker.CheckValue(key) 431 checked_value = self._value_checker.CheckValue(value) 432 self._values[checked_key] = checked_value 433 self._message_listener.Modified() 434 435 def __delitem__(self, key: _K) -> None: 436 del self._values[key] 437 self._message_listener.Modified() 438 439 def __len__(self) -> int: 440 return len(self._values) 441 442 def __iter__(self) -> Iterator[_K]: 443 return iter(self._values) 444 445 def __repr__(self) -> str: 446 return repr(self._values) 447 448 def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None: 449 self._values.update(other._values) 450 self._message_listener.Modified() 451 452 def InvalidateIterators(self) -> None: 453 # It appears that the only way to reliably invalidate iterators to 454 # self._values is to ensure that its size changes. 455 original = self._values 456 self._values = original.copy() 457 original[None] = None 458 459 # This is defined in the abstract base, but we can do it much more cheaply. 460 def clear(self) -> None: 461 self._values.clear() 462 self._message_listener.Modified() 463 464 def GetEntryClass(self) -> Any: 465 return self._entry_descriptor._concrete_class 466 467 468class MessageMap(MutableMapping[_K, _V]): 469 """Simple, type-checked, dict-like container for with submessage values.""" 470 471 # Disallows assignment to other attributes. 472 __slots__ = ['_key_checker', '_values', '_message_listener', 473 '_message_descriptor', '_entry_descriptor'] 474 475 def __init__( 476 self, 477 message_listener: Any, 478 message_descriptor: Any, 479 key_checker: Any, 480 entry_descriptor: Any, 481 ) -> None: 482 """ 483 Args: 484 message_listener: A MessageListener implementation. 485 The ScalarMap will call this object's Modified() method when it 486 is modified. 487 key_checker: A type_checkers.ValueChecker instance to run on keys 488 inserted into this container. 489 value_checker: A type_checkers.ValueChecker instance to run on values 490 inserted into this container. 491 entry_descriptor: The MessageDescriptor of a map entry: key and value. 492 """ 493 self._message_listener = message_listener 494 self._message_descriptor = message_descriptor 495 self._key_checker = key_checker 496 self._entry_descriptor = entry_descriptor 497 self._values = {} 498 499 def __getitem__(self, key: _K) -> _V: 500 key = self._key_checker.CheckValue(key) 501 try: 502 return self._values[key] 503 except KeyError: 504 new_element = self._message_descriptor._concrete_class() 505 new_element._SetListener(self._message_listener) 506 self._values[key] = new_element 507 self._message_listener.Modified() 508 return new_element 509 510 def get_or_create(self, key: _K) -> _V: 511 """get_or_create() is an alias for getitem (ie. map[key]). 512 513 Args: 514 key: The key to get or create in the map. 515 516 This is useful in cases where you want to be explicit that the call is 517 mutating the map. This can avoid lint errors for statements like this 518 that otherwise would appear to be pointless statements: 519 520 msg.my_map[key] 521 """ 522 return self[key] 523 524 @overload 525 def get(self, key: _K) -> Optional[_V]: 526 ... 527 528 @overload 529 def get(self, key: _K, default: _T) -> Union[_V, _T]: 530 ... 531 532 # We need to override this explicitly, because our defaultdict-like behavior 533 # will make the default implementation (from our base class) always insert 534 # the key. 535 def get(self, key, default=None): 536 if key in self: 537 return self[key] 538 else: 539 return default 540 541 def __contains__(self, item: _K) -> bool: 542 item = self._key_checker.CheckValue(item) 543 return item in self._values 544 545 def __setitem__(self, key: _K, value: _V) -> NoReturn: 546 raise ValueError('May not set values directly, call my_map[key].foo = 5') 547 548 def __delitem__(self, key: _K) -> None: 549 key = self._key_checker.CheckValue(key) 550 del self._values[key] 551 self._message_listener.Modified() 552 553 def __len__(self) -> int: 554 return len(self._values) 555 556 def __iter__(self) -> Iterator[_K]: 557 return iter(self._values) 558 559 def __repr__(self) -> str: 560 return repr(self._values) 561 562 def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None: 563 # pylint: disable=protected-access 564 for key in other._values: 565 # According to documentation: "When parsing from the wire or when merging, 566 # if there are duplicate map keys the last key seen is used". 567 if key in self: 568 del self[key] 569 self[key].CopyFrom(other[key]) 570 # self._message_listener.Modified() not required here, because 571 # mutations to submessages already propagate. 572 573 def InvalidateIterators(self) -> None: 574 # It appears that the only way to reliably invalidate iterators to 575 # self._values is to ensure that its size changes. 576 original = self._values 577 self._values = original.copy() 578 original[None] = None 579 580 # This is defined in the abstract base, but we can do it much more cheaply. 581 def clear(self) -> None: 582 self._values.clear() 583 self._message_listener.Modified() 584 585 def GetEntryClass(self) -> Any: 586 return self._entry_descriptor._concrete_class 587 588 589class _UnknownField: 590 """A parsed unknown field.""" 591 592 # Disallows assignment to other attributes. 593 __slots__ = ['_field_number', '_wire_type', '_data'] 594 595 def __init__(self, field_number, wire_type, data): 596 self._field_number = field_number 597 self._wire_type = wire_type 598 self._data = data 599 return 600 601 def __lt__(self, other): 602 # pylint: disable=protected-access 603 return self._field_number < other._field_number 604 605 def __eq__(self, other): 606 if self is other: 607 return True 608 # pylint: disable=protected-access 609 return (self._field_number == other._field_number and 610 self._wire_type == other._wire_type and 611 self._data == other._data) 612 613 614class UnknownFieldRef: # pylint: disable=missing-class-docstring 615 616 def __init__(self, parent, index): 617 self._parent = parent 618 self._index = index 619 620 def _check_valid(self): 621 if not self._parent: 622 raise ValueError('UnknownField does not exist. ' 623 'The parent message might be cleared.') 624 if self._index >= len(self._parent): 625 raise ValueError('UnknownField does not exist. ' 626 'The parent message might be cleared.') 627 628 @property 629 def field_number(self): 630 self._check_valid() 631 # pylint: disable=protected-access 632 return self._parent._internal_get(self._index)._field_number 633 634 @property 635 def wire_type(self): 636 self._check_valid() 637 # pylint: disable=protected-access 638 return self._parent._internal_get(self._index)._wire_type 639 640 @property 641 def data(self): 642 self._check_valid() 643 # pylint: disable=protected-access 644 return self._parent._internal_get(self._index)._data 645 646 647class UnknownFieldSet: 648 """UnknownField container""" 649 650 # Disallows assignment to other attributes. 651 __slots__ = ['_values'] 652 653 def __init__(self): 654 self._values = [] 655 656 def __getitem__(self, index): 657 if self._values is None: 658 raise ValueError('UnknownFields does not exist. ' 659 'The parent message might be cleared.') 660 size = len(self._values) 661 if index < 0: 662 index += size 663 if index < 0 or index >= size: 664 raise IndexError('index %d out of range'.index) 665 666 return UnknownFieldRef(self, index) 667 668 def _internal_get(self, index): 669 return self._values[index] 670 671 def __len__(self): 672 if self._values is None: 673 raise ValueError('UnknownFields does not exist. ' 674 'The parent message might be cleared.') 675 return len(self._values) 676 677 def _add(self, field_number, wire_type, data): 678 unknown_field = _UnknownField(field_number, wire_type, data) 679 self._values.append(unknown_field) 680 return unknown_field 681 682 def __iter__(self): 683 for i in range(len(self)): 684 yield UnknownFieldRef(self, i) 685 686 def _extend(self, other): 687 if other is None: 688 return 689 # pylint: disable=protected-access 690 self._values.extend(other._values) 691 692 def __eq__(self, other): 693 if self is other: 694 return True 695 # Sort unknown fields because their order shouldn't 696 # affect equality test. 697 values = list(self._values) 698 if other is None: 699 return not values 700 values.sort() 701 # pylint: disable=protected-access 702 other_values = sorted(other._values) 703 return values == other_values 704 705 def _clear(self): 706 for value in self._values: 707 # pylint: disable=protected-access 708 if isinstance(value._data, UnknownFieldSet): 709 value._data._clear() # pylint: disable=protected-access 710 self._values = None 711