1 // Copyright 2021 gRPC authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GRPC_SRC_CORE_LIB_AVL_AVL_H 16 #define GRPC_SRC_CORE_LIB_AVL_AVL_H 17 18 #include <grpc/support/port_platform.h> 19 20 #include <stdlib.h> 21 22 #include <algorithm> // IWYU pragma: keep 23 #include <memory> 24 #include <utility> 25 26 #include "src/core/lib/gpr/useful.h" 27 28 namespace grpc_core { 29 30 template <class K, class V = void> 31 class AVL { 32 public: AVL()33 AVL() {} 34 Add(K key,V value)35 AVL Add(K key, V value) const { 36 return AVL(AddKey(root_, std::move(key), std::move(value))); 37 } 38 template <typename SomethingLikeK> Remove(const SomethingLikeK & key)39 AVL Remove(const SomethingLikeK& key) const { 40 return AVL(RemoveKey(root_, key)); 41 } 42 template <typename SomethingLikeK> Lookup(const SomethingLikeK & key)43 const V* Lookup(const SomethingLikeK& key) const { 44 NodePtr n = Get(root_, key); 45 return n ? &n->kv.second : nullptr; 46 } 47 LookupBelow(const K & key)48 const std::pair<K, V>* LookupBelow(const K& key) const { 49 NodePtr n = GetBelow(root_, *key); 50 return n ? &n->kv : nullptr; 51 } 52 Empty()53 bool Empty() const { return root_ == nullptr; } 54 55 template <class F> ForEach(F && f)56 void ForEach(F&& f) const { 57 ForEachImpl(root_.get(), std::forward<F>(f)); 58 } 59 SameIdentity(const AVL & avl)60 bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; } 61 QsortCompare(const AVL & left,const AVL & right)62 friend int QsortCompare(const AVL& left, const AVL& right) { 63 if (left.root_.get() == right.root_.get()) return 0; 64 Iterator a(left.root_); 65 Iterator b(right.root_); 66 for (;;) { 67 Node* p = a.current(); 68 Node* q = b.current(); 69 if (p != q) { 70 if (p == nullptr) return -1; 71 if (q == nullptr) return 1; 72 const int kv = QsortCompare(p->kv, q->kv); 73 if (kv != 0) return kv; 74 } else if (p == nullptr) { 75 return 0; 76 } 77 a.MoveNext(); 78 b.MoveNext(); 79 } 80 } 81 82 bool operator==(const AVL& other) const { 83 return QsortCompare(*this, other) == 0; 84 } 85 86 bool operator<(const AVL& other) const { 87 return QsortCompare(*this, other) < 0; 88 } 89 Height()90 size_t Height() const { 91 if (root_ == nullptr) return 0; 92 return root_->height; 93 } 94 95 private: 96 struct Node; 97 98 typedef std::shared_ptr<Node> NodePtr; 99 struct Node : public std::enable_shared_from_this<Node> { NodeNode100 Node(K k, V v, NodePtr l, NodePtr r, long h) 101 : kv(std::move(k), std::move(v)), 102 left(std::move(l)), 103 right(std::move(r)), 104 height(h) {} 105 const std::pair<K, V> kv; 106 const NodePtr left; 107 const NodePtr right; 108 const long height; 109 }; 110 NodePtr root_; 111 112 class IteratorStack { 113 public: Push(Node * n)114 void Push(Node* n) { 115 nodes_[depth_] = n; 116 ++depth_; 117 } 118 Pop()119 Node* Pop() { 120 --depth_; 121 return nodes_[depth_]; 122 } 123 Back()124 Node* Back() const { return nodes_[depth_ - 1]; } 125 Empty()126 bool Empty() const { return depth_ == 0; } 127 128 private: 129 size_t depth_{0}; 130 // 32 is the maximum depth we can accept, and corresponds to ~4billion nodes 131 // - which ought to suffice our use cases. 132 Node* nodes_[32]; 133 }; 134 135 class Iterator { 136 public: Iterator(const NodePtr & root)137 explicit Iterator(const NodePtr& root) { 138 auto* n = root.get(); 139 while (n != nullptr) { 140 stack_.Push(n); 141 n = n->left.get(); 142 } 143 } current()144 Node* current() const { return stack_.Empty() ? nullptr : stack_.Back(); } MoveNext()145 void MoveNext() { 146 auto* n = stack_.Pop(); 147 if (n->right != nullptr) { 148 n = n->right.get(); 149 while (n != nullptr) { 150 stack_.Push(n); 151 n = n->left.get(); 152 } 153 } 154 } 155 156 private: 157 IteratorStack stack_; 158 }; 159 AVL(NodePtr root)160 explicit AVL(NodePtr root) : root_(std::move(root)) {} 161 162 template <class F> ForEachImpl(const Node * n,F && f)163 static void ForEachImpl(const Node* n, F&& f) { 164 if (n == nullptr) return; 165 ForEachImpl(n->left.get(), std::forward<F>(f)); 166 f(const_cast<const K&>(n->kv.first), const_cast<const V&>(n->kv.second)); 167 ForEachImpl(n->right.get(), std::forward<F>(f)); 168 } 169 Height(const NodePtr & n)170 static long Height(const NodePtr& n) { return n ? n->height : 0; } 171 MakeNode(K key,V value,const NodePtr & left,const NodePtr & right)172 static NodePtr MakeNode(K key, V value, const NodePtr& left, 173 const NodePtr& right) { 174 return std::make_shared<Node>(std::move(key), std::move(value), left, right, 175 1 + std::max(Height(left), Height(right))); 176 } 177 178 template <typename SomethingLikeK> Get(const NodePtr & node,const SomethingLikeK & key)179 static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) { 180 if (node == nullptr) { 181 return nullptr; 182 } 183 184 if (node->kv.first > key) { 185 return Get(node->left, key); 186 } else if (node->kv.first < key) { 187 return Get(node->right, key); 188 } else { 189 return node; 190 } 191 } 192 GetBelow(const NodePtr & node,const K & key)193 static NodePtr GetBelow(const NodePtr& node, const K& key) { 194 if (!node) return nullptr; 195 if (node->kv.first > key) { 196 return GetBelow(node->left, key); 197 } else if (node->kv.first < key) { 198 NodePtr n = GetBelow(node->right, key); 199 if (n == nullptr) n = node; 200 return n; 201 } else { 202 return node; 203 } 204 } 205 RotateLeft(K key,V value,const NodePtr & left,const NodePtr & right)206 static NodePtr RotateLeft(K key, V value, const NodePtr& left, 207 const NodePtr& right) { 208 return MakeNode( 209 right->kv.first, right->kv.second, 210 MakeNode(std::move(key), std::move(value), left, right->left), 211 right->right); 212 } 213 RotateRight(K key,V value,const NodePtr & left,const NodePtr & right)214 static NodePtr RotateRight(K key, V value, const NodePtr& left, 215 const NodePtr& right) { 216 return MakeNode( 217 left->kv.first, left->kv.second, left->left, 218 MakeNode(std::move(key), std::move(value), left->right, right)); 219 } 220 RotateLeftRight(K key,V value,const NodePtr & left,const NodePtr & right)221 static NodePtr RotateLeftRight(K key, V value, const NodePtr& left, 222 const NodePtr& right) { 223 // rotate_right(..., rotate_left(left), right) 224 return MakeNode( 225 left->right->kv.first, left->right->kv.second, 226 MakeNode(left->kv.first, left->kv.second, left->left, 227 left->right->left), 228 MakeNode(std::move(key), std::move(value), left->right->right, right)); 229 } 230 RotateRightLeft(K key,V value,const NodePtr & left,const NodePtr & right)231 static NodePtr RotateRightLeft(K key, V value, const NodePtr& left, 232 const NodePtr& right) { 233 // rotate_left(..., left, rotate_right(right)) 234 return MakeNode( 235 right->left->kv.first, right->left->kv.second, 236 MakeNode(std::move(key), std::move(value), left, right->left->left), 237 MakeNode(right->kv.first, right->kv.second, right->left->right, 238 right->right)); 239 } 240 Rebalance(K key,V value,const NodePtr & left,const NodePtr & right)241 static NodePtr Rebalance(K key, V value, const NodePtr& left, 242 const NodePtr& right) { 243 switch (Height(left) - Height(right)) { 244 case 2: 245 if (Height(left->left) - Height(left->right) == -1) { 246 return RotateLeftRight(std::move(key), std::move(value), left, right); 247 } else { 248 return RotateRight(std::move(key), std::move(value), left, right); 249 } 250 case -2: 251 if (Height(right->left) - Height(right->right) == 1) { 252 return RotateRightLeft(std::move(key), std::move(value), left, right); 253 } else { 254 return RotateLeft(std::move(key), std::move(value), left, right); 255 } 256 default: 257 return MakeNode(key, value, left, right); 258 } 259 } 260 AddKey(const NodePtr & node,K key,V value)261 static NodePtr AddKey(const NodePtr& node, K key, V value) { 262 if (!node) { 263 return MakeNode(std::move(key), std::move(value), nullptr, nullptr); 264 } 265 if (node->kv.first < key) { 266 return Rebalance(node->kv.first, node->kv.second, node->left, 267 AddKey(node->right, std::move(key), std::move(value))); 268 } 269 if (key < node->kv.first) { 270 return Rebalance(node->kv.first, node->kv.second, 271 AddKey(node->left, std::move(key), std::move(value)), 272 node->right); 273 } 274 return MakeNode(std::move(key), std::move(value), node->left, node->right); 275 } 276 InOrderHead(NodePtr node)277 static NodePtr InOrderHead(NodePtr node) { 278 while (node->left != nullptr) { 279 node = node->left; 280 } 281 return node; 282 } 283 InOrderTail(NodePtr node)284 static NodePtr InOrderTail(NodePtr node) { 285 while (node->right != nullptr) { 286 node = node->right; 287 } 288 return node; 289 } 290 291 template <typename SomethingLikeK> RemoveKey(const NodePtr & node,const SomethingLikeK & key)292 static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) { 293 if (node == nullptr) { 294 return nullptr; 295 } 296 if (key < node->kv.first) { 297 return Rebalance(node->kv.first, node->kv.second, 298 RemoveKey(node->left, key), node->right); 299 } else if (node->kv.first < key) { 300 return Rebalance(node->kv.first, node->kv.second, node->left, 301 RemoveKey(node->right, key)); 302 } else { 303 if (node->left == nullptr) { 304 return node->right; 305 } else if (node->right == nullptr) { 306 return node->left; 307 } else if (node->left->height < node->right->height) { 308 NodePtr h = InOrderHead(node->right); 309 return Rebalance(h->kv.first, h->kv.second, node->left, 310 RemoveKey(node->right, h->kv.first)); 311 } else { 312 NodePtr h = InOrderTail(node->left); 313 return Rebalance(h->kv.first, h->kv.second, 314 RemoveKey(node->left, h->kv.first), node->right); 315 } 316 } 317 abort(); 318 } 319 }; 320 321 template <class K> 322 class AVL<K, void> { 323 public: AVL()324 AVL() {} 325 Add(K key)326 AVL Add(K key) const { return AVL(AddKey(root_, std::move(key))); } Remove(const K & key)327 AVL Remove(const K& key) const { return AVL(RemoveKey(root_, key)); } Lookup(const K & key)328 bool Lookup(const K& key) const { return Get(root_, key) != nullptr; } Empty()329 bool Empty() const { return root_ == nullptr; } 330 331 template <class F> ForEach(F && f)332 void ForEach(F&& f) const { 333 ForEachImpl(root_.get(), std::forward<F>(f)); 334 } 335 SameIdentity(AVL avl)336 bool SameIdentity(AVL avl) const { return root_ == avl.root_; } 337 338 private: 339 struct Node; 340 341 typedef std::shared_ptr<Node> NodePtr; 342 struct Node : public std::enable_shared_from_this<Node> { NodeNode343 Node(K k, NodePtr l, NodePtr r, long h) 344 : key(std::move(k)), 345 left(std::move(l)), 346 right(std::move(r)), 347 height(h) {} 348 const K key; 349 const NodePtr left; 350 const NodePtr right; 351 const long height; 352 }; 353 NodePtr root_; 354 AVL(NodePtr root)355 explicit AVL(NodePtr root) : root_(std::move(root)) {} 356 357 template <class F> ForEachImpl(const Node * n,F && f)358 static void ForEachImpl(const Node* n, F&& f) { 359 if (n == nullptr) return; 360 ForEachImpl(n->left.get(), std::forward<F>(f)); 361 f(const_cast<const K&>(n->key)); 362 ForEachImpl(n->right.get(), std::forward<F>(f)); 363 } 364 Height(const NodePtr & n)365 static long Height(const NodePtr& n) { return n ? n->height : 0; } 366 MakeNode(K key,const NodePtr & left,const NodePtr & right)367 static NodePtr MakeNode(K key, const NodePtr& left, const NodePtr& right) { 368 return std::make_shared<Node>(std::move(key), left, right, 369 1 + std::max(Height(left), Height(right))); 370 } 371 Get(const NodePtr & node,const K & key)372 static NodePtr Get(const NodePtr& node, const K& key) { 373 if (node == nullptr) { 374 return nullptr; 375 } 376 377 if (node->key > key) { 378 return Get(node->left, key); 379 } else if (node->key < key) { 380 return Get(node->right, key); 381 } else { 382 return node; 383 } 384 } 385 RotateLeft(K key,const NodePtr & left,const NodePtr & right)386 static NodePtr RotateLeft(K key, const NodePtr& left, const NodePtr& right) { 387 return MakeNode(right->key, MakeNode(std::move(key), left, right->left), 388 right->right); 389 } 390 RotateRight(K key,const NodePtr & left,const NodePtr & right)391 static NodePtr RotateRight(K key, const NodePtr& left, const NodePtr& right) { 392 return MakeNode(left->key, left->left, 393 MakeNode(std::move(key), left->right, right)); 394 } 395 RotateLeftRight(K key,const NodePtr & left,const NodePtr & right)396 static NodePtr RotateLeftRight(K key, const NodePtr& left, 397 const NodePtr& right) { 398 // rotate_right(..., rotate_left(left), right) 399 return MakeNode(left->right->key, 400 MakeNode(left->key, left->left, left->right->left), 401 MakeNode(std::move(key), left->right->right, right)); 402 } 403 RotateRightLeft(K key,const NodePtr & left,const NodePtr & right)404 static NodePtr RotateRightLeft(K key, const NodePtr& left, 405 const NodePtr& right) { 406 // rotate_left(..., left, rotate_right(right)) 407 return MakeNode(right->left->key, 408 MakeNode(std::move(key), left, right->left->left), 409 MakeNode(right->key, right->left->right, right->right)); 410 } 411 Rebalance(K key,const NodePtr & left,const NodePtr & right)412 static NodePtr Rebalance(K key, const NodePtr& left, const NodePtr& right) { 413 switch (Height(left) - Height(right)) { 414 case 2: 415 if (Height(left->left) - Height(left->right) == -1) { 416 return RotateLeftRight(std::move(key), left, right); 417 } else { 418 return RotateRight(std::move(key), left, right); 419 } 420 case -2: 421 if (Height(right->left) - Height(right->right) == 1) { 422 return RotateRightLeft(std::move(key), left, right); 423 } else { 424 return RotateLeft(std::move(key), left, right); 425 } 426 default: 427 return MakeNode(key, left, right); 428 } 429 } 430 AddKey(const NodePtr & node,K key)431 static NodePtr AddKey(const NodePtr& node, K key) { 432 if (!node) { 433 return MakeNode(std::move(key), nullptr, nullptr); 434 } 435 if (node->key < key) { 436 return Rebalance(node->key, node->left, 437 AddKey(node->right, std::move(key))); 438 } 439 if (key < node->key) { 440 return Rebalance(node->key, AddKey(node->left, std::move(key)), 441 node->right); 442 } 443 return MakeNode(std::move(key), node->left, node->right); 444 } 445 InOrderHead(NodePtr node)446 static NodePtr InOrderHead(NodePtr node) { 447 while (node->left != nullptr) { 448 node = node->left; 449 } 450 return node; 451 } 452 InOrderTail(NodePtr node)453 static NodePtr InOrderTail(NodePtr node) { 454 while (node->right != nullptr) { 455 node = node->right; 456 } 457 return node; 458 } 459 RemoveKey(const NodePtr & node,const K & key)460 static NodePtr RemoveKey(const NodePtr& node, const K& key) { 461 if (node == nullptr) { 462 return nullptr; 463 } 464 if (key < node->key) { 465 return Rebalance(node->key, RemoveKey(node->left, key), node->right); 466 } else if (node->key < key) { 467 return Rebalance(node->key, node->left, RemoveKey(node->right, key)); 468 } else { 469 if (node->left == nullptr) { 470 return node->right; 471 } else if (node->right == nullptr) { 472 return node->left; 473 } else if (node->left->height < node->right->height) { 474 NodePtr h = InOrderHead(node->right); 475 return Rebalance(h->key, node->left, RemoveKey(node->right, h->key)); 476 } else { 477 NodePtr h = InOrderTail(node->left); 478 return Rebalance(h->key, RemoveKey(node->left, h->key), node->right); 479 } 480 } 481 abort(); 482 } 483 }; 484 485 } // namespace grpc_core 486 487 #endif // GRPC_SRC_CORE_LIB_AVL_AVL_H 488