xref: /aosp_15_r20/external/libprotobuf-mutator/src/weighted_reservoir_sampler.h (revision fd525a9c096e28cf6f8d8719388df0568a611e7b)
1*fd525a9cSAndroid Build Coastguard Worker // Copyright 2016 Google Inc. All rights reserved.
2*fd525a9cSAndroid Build Coastguard Worker //
3*fd525a9cSAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*fd525a9cSAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*fd525a9cSAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*fd525a9cSAndroid Build Coastguard Worker //
7*fd525a9cSAndroid Build Coastguard Worker //     http://www.apache.org/licenses/LICENSE-2.0
8*fd525a9cSAndroid Build Coastguard Worker //
9*fd525a9cSAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*fd525a9cSAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*fd525a9cSAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*fd525a9cSAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*fd525a9cSAndroid Build Coastguard Worker // limitations under the License.
14*fd525a9cSAndroid Build Coastguard Worker 
15*fd525a9cSAndroid Build Coastguard Worker #ifndef SRC_WEIGHTED_RESERVOIR_SAMPLER_H_
16*fd525a9cSAndroid Build Coastguard Worker #define SRC_WEIGHTED_RESERVOIR_SAMPLER_H_
17*fd525a9cSAndroid Build Coastguard Worker 
18*fd525a9cSAndroid Build Coastguard Worker #include <cassert>
19*fd525a9cSAndroid Build Coastguard Worker #include <random>
20*fd525a9cSAndroid Build Coastguard Worker 
21*fd525a9cSAndroid Build Coastguard Worker namespace protobuf_mutator {
22*fd525a9cSAndroid Build Coastguard Worker 
23*fd525a9cSAndroid Build Coastguard Worker // Algorithm pick one item from the sequence of weighted items.
24*fd525a9cSAndroid Build Coastguard Worker // https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
25*fd525a9cSAndroid Build Coastguard Worker //
26*fd525a9cSAndroid Build Coastguard Worker // Example:
27*fd525a9cSAndroid Build Coastguard Worker //   WeightedReservoirSampler<int> sampler;
28*fd525a9cSAndroid Build Coastguard Worker //   for(int i = 0; i < size; ++i)
29*fd525a9cSAndroid Build Coastguard Worker //     sampler.Pick(weight[i], i);
30*fd525a9cSAndroid Build Coastguard Worker //   return sampler.GetSelected();
31*fd525a9cSAndroid Build Coastguard Worker template <class T, class RandomEngine = std::default_random_engine>
32*fd525a9cSAndroid Build Coastguard Worker class WeightedReservoirSampler {
33*fd525a9cSAndroid Build Coastguard Worker  public:
WeightedReservoirSampler(RandomEngine * random)34*fd525a9cSAndroid Build Coastguard Worker   explicit WeightedReservoirSampler(RandomEngine* random) : random_(random) {}
35*fd525a9cSAndroid Build Coastguard Worker 
Try(uint64_t weight,const T & item)36*fd525a9cSAndroid Build Coastguard Worker   void Try(uint64_t weight, const T& item) {
37*fd525a9cSAndroid Build Coastguard Worker     if (Pick(weight)) selected_ = item;
38*fd525a9cSAndroid Build Coastguard Worker   }
39*fd525a9cSAndroid Build Coastguard Worker 
selected()40*fd525a9cSAndroid Build Coastguard Worker   const T& selected() const { return selected_; }
41*fd525a9cSAndroid Build Coastguard Worker 
IsEmpty()42*fd525a9cSAndroid Build Coastguard Worker   bool IsEmpty() const { return total_weight_ == 0; }
43*fd525a9cSAndroid Build Coastguard Worker 
44*fd525a9cSAndroid Build Coastguard Worker  private:
Pick(uint64_t weight)45*fd525a9cSAndroid Build Coastguard Worker   bool Pick(uint64_t weight) {
46*fd525a9cSAndroid Build Coastguard Worker     if (weight == 0) return false;
47*fd525a9cSAndroid Build Coastguard Worker     total_weight_ += weight;
48*fd525a9cSAndroid Build Coastguard Worker     return weight == total_weight_ || std::uniform_int_distribution<uint64_t>(
49*fd525a9cSAndroid Build Coastguard Worker                                           1, total_weight_)(*random_) <= weight;
50*fd525a9cSAndroid Build Coastguard Worker   }
51*fd525a9cSAndroid Build Coastguard Worker 
52*fd525a9cSAndroid Build Coastguard Worker   T selected_ = {};
53*fd525a9cSAndroid Build Coastguard Worker   uint64_t total_weight_ = 0;
54*fd525a9cSAndroid Build Coastguard Worker   RandomEngine* random_;
55*fd525a9cSAndroid Build Coastguard Worker };
56*fd525a9cSAndroid Build Coastguard Worker 
57*fd525a9cSAndroid Build Coastguard Worker }  // namespace protobuf_mutator
58*fd525a9cSAndroid Build Coastguard Worker 
59*fd525a9cSAndroid Build Coastguard Worker #endif  // SRC_WEIGHTED_RESERVOIR_SAMPLER_H_
60