1 // Copyright 2021 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_AVL_AVL_H
16 #define GRPC_SRC_CORE_LIB_AVL_AVL_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include <stdlib.h>
21 
22 #include <algorithm>  // IWYU pragma: keep
23 #include <memory>
24 #include <utility>
25 
26 #include "src/core/lib/gpr/useful.h"
27 
28 namespace grpc_core {
29 
30 template <class K, class V = void>
31 class AVL {
32  public:
AVL()33   AVL() {}
34 
Add(K key,V value)35   AVL Add(K key, V value) const {
36     return AVL(AddKey(root_, std::move(key), std::move(value)));
37   }
38   template <typename SomethingLikeK>
Remove(const SomethingLikeK & key)39   AVL Remove(const SomethingLikeK& key) const {
40     return AVL(RemoveKey(root_, key));
41   }
42   template <typename SomethingLikeK>
Lookup(const SomethingLikeK & key)43   const V* Lookup(const SomethingLikeK& key) const {
44     NodePtr n = Get(root_, key);
45     return n ? &n->kv.second : nullptr;
46   }
47 
LookupBelow(const K & key)48   const std::pair<K, V>* LookupBelow(const K& key) const {
49     NodePtr n = GetBelow(root_, *key);
50     return n ? &n->kv : nullptr;
51   }
52 
Empty()53   bool Empty() const { return root_ == nullptr; }
54 
55   template <class F>
ForEach(F && f)56   void ForEach(F&& f) const {
57     ForEachImpl(root_.get(), std::forward<F>(f));
58   }
59 
SameIdentity(const AVL & avl)60   bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; }
61 
QsortCompare(const AVL & left,const AVL & right)62   friend int QsortCompare(const AVL& left, const AVL& right) {
63     if (left.root_.get() == right.root_.get()) return 0;
64     Iterator a(left.root_);
65     Iterator b(right.root_);
66     for (;;) {
67       Node* p = a.current();
68       Node* q = b.current();
69       if (p != q) {
70         if (p == nullptr) return -1;
71         if (q == nullptr) return 1;
72         const int kv = QsortCompare(p->kv, q->kv);
73         if (kv != 0) return kv;
74       } else if (p == nullptr) {
75         return 0;
76       }
77       a.MoveNext();
78       b.MoveNext();
79     }
80   }
81 
82   bool operator==(const AVL& other) const {
83     return QsortCompare(*this, other) == 0;
84   }
85 
86   bool operator<(const AVL& other) const {
87     return QsortCompare(*this, other) < 0;
88   }
89 
Height()90   size_t Height() const {
91     if (root_ == nullptr) return 0;
92     return root_->height;
93   }
94 
95  private:
96   struct Node;
97 
98   typedef std::shared_ptr<Node> NodePtr;
99   struct Node : public std::enable_shared_from_this<Node> {
NodeNode100     Node(K k, V v, NodePtr l, NodePtr r, long h)
101         : kv(std::move(k), std::move(v)),
102           left(std::move(l)),
103           right(std::move(r)),
104           height(h) {}
105     const std::pair<K, V> kv;
106     const NodePtr left;
107     const NodePtr right;
108     const long height;
109   };
110   NodePtr root_;
111 
112   class IteratorStack {
113    public:
Push(Node * n)114     void Push(Node* n) {
115       nodes_[depth_] = n;
116       ++depth_;
117     }
118 
Pop()119     Node* Pop() {
120       --depth_;
121       return nodes_[depth_];
122     }
123 
Back()124     Node* Back() const { return nodes_[depth_ - 1]; }
125 
Empty()126     bool Empty() const { return depth_ == 0; }
127 
128    private:
129     size_t depth_{0};
130     // 32 is the maximum depth we can accept, and corresponds to ~4billion nodes
131     // - which ought to suffice our use cases.
132     Node* nodes_[32];
133   };
134 
135   class Iterator {
136    public:
Iterator(const NodePtr & root)137     explicit Iterator(const NodePtr& root) {
138       auto* n = root.get();
139       while (n != nullptr) {
140         stack_.Push(n);
141         n = n->left.get();
142       }
143     }
current()144     Node* current() const { return stack_.Empty() ? nullptr : stack_.Back(); }
MoveNext()145     void MoveNext() {
146       auto* n = stack_.Pop();
147       if (n->right != nullptr) {
148         n = n->right.get();
149         while (n != nullptr) {
150           stack_.Push(n);
151           n = n->left.get();
152         }
153       }
154     }
155 
156    private:
157     IteratorStack stack_;
158   };
159 
AVL(NodePtr root)160   explicit AVL(NodePtr root) : root_(std::move(root)) {}
161 
162   template <class F>
ForEachImpl(const Node * n,F && f)163   static void ForEachImpl(const Node* n, F&& f) {
164     if (n == nullptr) return;
165     ForEachImpl(n->left.get(), std::forward<F>(f));
166     f(const_cast<const K&>(n->kv.first), const_cast<const V&>(n->kv.second));
167     ForEachImpl(n->right.get(), std::forward<F>(f));
168   }
169 
Height(const NodePtr & n)170   static long Height(const NodePtr& n) { return n ? n->height : 0; }
171 
MakeNode(K key,V value,const NodePtr & left,const NodePtr & right)172   static NodePtr MakeNode(K key, V value, const NodePtr& left,
173                           const NodePtr& right) {
174     return std::make_shared<Node>(std::move(key), std::move(value), left, right,
175                                   1 + std::max(Height(left), Height(right)));
176   }
177 
178   template <typename SomethingLikeK>
Get(const NodePtr & node,const SomethingLikeK & key)179   static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) {
180     if (node == nullptr) {
181       return nullptr;
182     }
183 
184     if (node->kv.first > key) {
185       return Get(node->left, key);
186     } else if (node->kv.first < key) {
187       return Get(node->right, key);
188     } else {
189       return node;
190     }
191   }
192 
GetBelow(const NodePtr & node,const K & key)193   static NodePtr GetBelow(const NodePtr& node, const K& key) {
194     if (!node) return nullptr;
195     if (node->kv.first > key) {
196       return GetBelow(node->left, key);
197     } else if (node->kv.first < key) {
198       NodePtr n = GetBelow(node->right, key);
199       if (n == nullptr) n = node;
200       return n;
201     } else {
202       return node;
203     }
204   }
205 
RotateLeft(K key,V value,const NodePtr & left,const NodePtr & right)206   static NodePtr RotateLeft(K key, V value, const NodePtr& left,
207                             const NodePtr& right) {
208     return MakeNode(
209         right->kv.first, right->kv.second,
210         MakeNode(std::move(key), std::move(value), left, right->left),
211         right->right);
212   }
213 
RotateRight(K key,V value,const NodePtr & left,const NodePtr & right)214   static NodePtr RotateRight(K key, V value, const NodePtr& left,
215                              const NodePtr& right) {
216     return MakeNode(
217         left->kv.first, left->kv.second, left->left,
218         MakeNode(std::move(key), std::move(value), left->right, right));
219   }
220 
RotateLeftRight(K key,V value,const NodePtr & left,const NodePtr & right)221   static NodePtr RotateLeftRight(K key, V value, const NodePtr& left,
222                                  const NodePtr& right) {
223     // rotate_right(..., rotate_left(left), right)
224     return MakeNode(
225         left->right->kv.first, left->right->kv.second,
226         MakeNode(left->kv.first, left->kv.second, left->left,
227                  left->right->left),
228         MakeNode(std::move(key), std::move(value), left->right->right, right));
229   }
230 
RotateRightLeft(K key,V value,const NodePtr & left,const NodePtr & right)231   static NodePtr RotateRightLeft(K key, V value, const NodePtr& left,
232                                  const NodePtr& right) {
233     // rotate_left(..., left, rotate_right(right))
234     return MakeNode(
235         right->left->kv.first, right->left->kv.second,
236         MakeNode(std::move(key), std::move(value), left, right->left->left),
237         MakeNode(right->kv.first, right->kv.second, right->left->right,
238                  right->right));
239   }
240 
Rebalance(K key,V value,const NodePtr & left,const NodePtr & right)241   static NodePtr Rebalance(K key, V value, const NodePtr& left,
242                            const NodePtr& right) {
243     switch (Height(left) - Height(right)) {
244       case 2:
245         if (Height(left->left) - Height(left->right) == -1) {
246           return RotateLeftRight(std::move(key), std::move(value), left, right);
247         } else {
248           return RotateRight(std::move(key), std::move(value), left, right);
249         }
250       case -2:
251         if (Height(right->left) - Height(right->right) == 1) {
252           return RotateRightLeft(std::move(key), std::move(value), left, right);
253         } else {
254           return RotateLeft(std::move(key), std::move(value), left, right);
255         }
256       default:
257         return MakeNode(key, value, left, right);
258     }
259   }
260 
AddKey(const NodePtr & node,K key,V value)261   static NodePtr AddKey(const NodePtr& node, K key, V value) {
262     if (!node) {
263       return MakeNode(std::move(key), std::move(value), nullptr, nullptr);
264     }
265     if (node->kv.first < key) {
266       return Rebalance(node->kv.first, node->kv.second, node->left,
267                        AddKey(node->right, std::move(key), std::move(value)));
268     }
269     if (key < node->kv.first) {
270       return Rebalance(node->kv.first, node->kv.second,
271                        AddKey(node->left, std::move(key), std::move(value)),
272                        node->right);
273     }
274     return MakeNode(std::move(key), std::move(value), node->left, node->right);
275   }
276 
InOrderHead(NodePtr node)277   static NodePtr InOrderHead(NodePtr node) {
278     while (node->left != nullptr) {
279       node = node->left;
280     }
281     return node;
282   }
283 
InOrderTail(NodePtr node)284   static NodePtr InOrderTail(NodePtr node) {
285     while (node->right != nullptr) {
286       node = node->right;
287     }
288     return node;
289   }
290 
291   template <typename SomethingLikeK>
RemoveKey(const NodePtr & node,const SomethingLikeK & key)292   static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) {
293     if (node == nullptr) {
294       return nullptr;
295     }
296     if (key < node->kv.first) {
297       return Rebalance(node->kv.first, node->kv.second,
298                        RemoveKey(node->left, key), node->right);
299     } else if (node->kv.first < key) {
300       return Rebalance(node->kv.first, node->kv.second, node->left,
301                        RemoveKey(node->right, key));
302     } else {
303       if (node->left == nullptr) {
304         return node->right;
305       } else if (node->right == nullptr) {
306         return node->left;
307       } else if (node->left->height < node->right->height) {
308         NodePtr h = InOrderHead(node->right);
309         return Rebalance(h->kv.first, h->kv.second, node->left,
310                          RemoveKey(node->right, h->kv.first));
311       } else {
312         NodePtr h = InOrderTail(node->left);
313         return Rebalance(h->kv.first, h->kv.second,
314                          RemoveKey(node->left, h->kv.first), node->right);
315       }
316     }
317     abort();
318   }
319 };
320 
321 template <class K>
322 class AVL<K, void> {
323  public:
AVL()324   AVL() {}
325 
Add(K key)326   AVL Add(K key) const { return AVL(AddKey(root_, std::move(key))); }
Remove(const K & key)327   AVL Remove(const K& key) const { return AVL(RemoveKey(root_, key)); }
Lookup(const K & key)328   bool Lookup(const K& key) const { return Get(root_, key) != nullptr; }
Empty()329   bool Empty() const { return root_ == nullptr; }
330 
331   template <class F>
ForEach(F && f)332   void ForEach(F&& f) const {
333     ForEachImpl(root_.get(), std::forward<F>(f));
334   }
335 
SameIdentity(AVL avl)336   bool SameIdentity(AVL avl) const { return root_ == avl.root_; }
337 
338  private:
339   struct Node;
340 
341   typedef std::shared_ptr<Node> NodePtr;
342   struct Node : public std::enable_shared_from_this<Node> {
NodeNode343     Node(K k, NodePtr l, NodePtr r, long h)
344         : key(std::move(k)),
345           left(std::move(l)),
346           right(std::move(r)),
347           height(h) {}
348     const K key;
349     const NodePtr left;
350     const NodePtr right;
351     const long height;
352   };
353   NodePtr root_;
354 
AVL(NodePtr root)355   explicit AVL(NodePtr root) : root_(std::move(root)) {}
356 
357   template <class F>
ForEachImpl(const Node * n,F && f)358   static void ForEachImpl(const Node* n, F&& f) {
359     if (n == nullptr) return;
360     ForEachImpl(n->left.get(), std::forward<F>(f));
361     f(const_cast<const K&>(n->key));
362     ForEachImpl(n->right.get(), std::forward<F>(f));
363   }
364 
Height(const NodePtr & n)365   static long Height(const NodePtr& n) { return n ? n->height : 0; }
366 
MakeNode(K key,const NodePtr & left,const NodePtr & right)367   static NodePtr MakeNode(K key, const NodePtr& left, const NodePtr& right) {
368     return std::make_shared<Node>(std::move(key), left, right,
369                                   1 + std::max(Height(left), Height(right)));
370   }
371 
Get(const NodePtr & node,const K & key)372   static NodePtr Get(const NodePtr& node, const K& key) {
373     if (node == nullptr) {
374       return nullptr;
375     }
376 
377     if (node->key > key) {
378       return Get(node->left, key);
379     } else if (node->key < key) {
380       return Get(node->right, key);
381     } else {
382       return node;
383     }
384   }
385 
RotateLeft(K key,const NodePtr & left,const NodePtr & right)386   static NodePtr RotateLeft(K key, const NodePtr& left, const NodePtr& right) {
387     return MakeNode(right->key, MakeNode(std::move(key), left, right->left),
388                     right->right);
389   }
390 
RotateRight(K key,const NodePtr & left,const NodePtr & right)391   static NodePtr RotateRight(K key, const NodePtr& left, const NodePtr& right) {
392     return MakeNode(left->key, left->left,
393                     MakeNode(std::move(key), left->right, right));
394   }
395 
RotateLeftRight(K key,const NodePtr & left,const NodePtr & right)396   static NodePtr RotateLeftRight(K key, const NodePtr& left,
397                                  const NodePtr& right) {
398     // rotate_right(..., rotate_left(left), right)
399     return MakeNode(left->right->key,
400                     MakeNode(left->key, left->left, left->right->left),
401                     MakeNode(std::move(key), left->right->right, right));
402   }
403 
RotateRightLeft(K key,const NodePtr & left,const NodePtr & right)404   static NodePtr RotateRightLeft(K key, const NodePtr& left,
405                                  const NodePtr& right) {
406     // rotate_left(..., left, rotate_right(right))
407     return MakeNode(right->left->key,
408                     MakeNode(std::move(key), left, right->left->left),
409                     MakeNode(right->key, right->left->right, right->right));
410   }
411 
Rebalance(K key,const NodePtr & left,const NodePtr & right)412   static NodePtr Rebalance(K key, const NodePtr& left, const NodePtr& right) {
413     switch (Height(left) - Height(right)) {
414       case 2:
415         if (Height(left->left) - Height(left->right) == -1) {
416           return RotateLeftRight(std::move(key), left, right);
417         } else {
418           return RotateRight(std::move(key), left, right);
419         }
420       case -2:
421         if (Height(right->left) - Height(right->right) == 1) {
422           return RotateRightLeft(std::move(key), left, right);
423         } else {
424           return RotateLeft(std::move(key), left, right);
425         }
426       default:
427         return MakeNode(key, left, right);
428     }
429   }
430 
AddKey(const NodePtr & node,K key)431   static NodePtr AddKey(const NodePtr& node, K key) {
432     if (!node) {
433       return MakeNode(std::move(key), nullptr, nullptr);
434     }
435     if (node->key < key) {
436       return Rebalance(node->key, node->left,
437                        AddKey(node->right, std::move(key)));
438     }
439     if (key < node->key) {
440       return Rebalance(node->key, AddKey(node->left, std::move(key)),
441                        node->right);
442     }
443     return MakeNode(std::move(key), node->left, node->right);
444   }
445 
InOrderHead(NodePtr node)446   static NodePtr InOrderHead(NodePtr node) {
447     while (node->left != nullptr) {
448       node = node->left;
449     }
450     return node;
451   }
452 
InOrderTail(NodePtr node)453   static NodePtr InOrderTail(NodePtr node) {
454     while (node->right != nullptr) {
455       node = node->right;
456     }
457     return node;
458   }
459 
RemoveKey(const NodePtr & node,const K & key)460   static NodePtr RemoveKey(const NodePtr& node, const K& key) {
461     if (node == nullptr) {
462       return nullptr;
463     }
464     if (key < node->key) {
465       return Rebalance(node->key, RemoveKey(node->left, key), node->right);
466     } else if (node->key < key) {
467       return Rebalance(node->key, node->left, RemoveKey(node->right, key));
468     } else {
469       if (node->left == nullptr) {
470         return node->right;
471       } else if (node->right == nullptr) {
472         return node->left;
473       } else if (node->left->height < node->right->height) {
474         NodePtr h = InOrderHead(node->right);
475         return Rebalance(h->key, node->left, RemoveKey(node->right, h->key));
476       } else {
477         NodePtr h = InOrderTail(node->left);
478         return Rebalance(h->key, RemoveKey(node->left, h->key), node->right);
479       }
480     }
481     abort();
482   }
483 };
484 
485 }  // namespace grpc_core
486 
487 #endif  // GRPC_SRC_CORE_LIB_AVL_AVL_H
488