xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/HashStore.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/HashStore.hpp>
2 
3 #include <unistd.h>
4 #include <cstdint>
5 
6 #include <chrono>
7 
8 #include <c10/util/Exception.h>
9 
10 namespace c10d {
11 
set(const std::string & key,const std::vector<uint8_t> & data)12 void HashStore::set(const std::string& key, const std::vector<uint8_t>& data) {
13   std::unique_lock<std::mutex> lock(m_);
14   map_[key] = data;
15   cv_.notify_all();
16 }
17 
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)18 std::vector<uint8_t> HashStore::compareSet(
19     const std::string& key,
20     const std::vector<uint8_t>& expectedValue,
21     const std::vector<uint8_t>& desiredValue) {
22   std::unique_lock<std::mutex> lock(m_);
23   auto it = map_.find(key);
24   if ((it == map_.end() && expectedValue.empty()) ||
25       (it != map_.end() && it->second == expectedValue)) {
26     // if the key does not exist and currentValue arg is empty or
27     // the key does exist and current value is what is expected, then set it
28     map_[key] = desiredValue;
29     cv_.notify_all();
30     return desiredValue;
31   } else if (it == map_.end()) {
32     // if the key does not exist
33     return expectedValue;
34   }
35   // key exists but current value is not expected
36   return it->second;
37 }
38 
get(const std::string & key)39 std::vector<uint8_t> HashStore::get(const std::string& key) {
40   std::unique_lock<std::mutex> lock(m_);
41   auto it = map_.find(key);
42   if (it != map_.end()) {
43     return it->second;
44   }
45   // Slow path: wait up to any timeout_.
46   auto pred = [&]() { return map_.find(key) != map_.end(); };
47   if (timeout_ == kNoTimeout) {
48     cv_.wait(lock, pred);
49   } else {
50     if (!cv_.wait_for(lock, timeout_, pred)) {
51       C10_THROW_ERROR(DistStoreError, "Wait timeout");
52     }
53   }
54   return map_[key];
55 }
56 
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)57 void HashStore::wait(
58     const std::vector<std::string>& keys,
59     const std::chrono::milliseconds& timeout) {
60   const auto end = std::chrono::steady_clock::now() + timeout;
61   auto pred = [&]() {
62     auto done = true;
63     for (const auto& key : keys) {
64       if (map_.find(key) == map_.end()) {
65         done = false;
66         break;
67       }
68     }
69     return done;
70   };
71 
72   std::unique_lock<std::mutex> lock(m_);
73   if (timeout == kNoTimeout) {
74     cv_.wait(lock, pred);
75   } else {
76     if (!cv_.wait_until(lock, end, pred)) {
77       C10_THROW_ERROR(DistStoreError, "Wait timeout");
78     }
79   }
80 }
81 
add(const std::string & key,int64_t i)82 int64_t HashStore::add(const std::string& key, int64_t i) {
83   std::unique_lock<std::mutex> lock(m_);
84   const auto& value = map_[key];
85   int64_t ti = i;
86   if (!value.empty()) {
87     auto buf = reinterpret_cast<const char*>(value.data());
88     auto len = value.size();
89     ti += std::stoll(std::string(buf, len));
90   }
91 
92   auto str = std::to_string(ti);
93   const uint8_t* strB = reinterpret_cast<const uint8_t*>(str.c_str());
94   map_[key] = std::vector<uint8_t>(strB, strB + str.size());
95   return ti;
96 }
97 
getNumKeys()98 int64_t HashStore::getNumKeys() {
99   std::unique_lock<std::mutex> lock(m_);
100   return static_cast<int64_t>(map_.size());
101 }
102 
deleteKey(const std::string & key)103 bool HashStore::deleteKey(const std::string& key) {
104   std::unique_lock<std::mutex> lock(m_);
105   auto numDeleted = map_.erase(key);
106   return (numDeleted == 1);
107 }
108 
check(const std::vector<std::string> & keys)109 bool HashStore::check(const std::vector<std::string>& keys) {
110   std::unique_lock<std::mutex> lock(m_);
111   for (const auto& key : keys) {
112     if (map_.find(key) == map_.end()) {
113       return false;
114     }
115   }
116   return true;
117 }
118 
append(const std::string & key,const std::vector<uint8_t> & value)119 void HashStore::append(
120     const std::string& key,
121     const std::vector<uint8_t>& value) {
122   std::unique_lock<std::mutex> lock(m_);
123   auto it = map_.find(key);
124   if (it == map_.end()) {
125     map_[key] = value;
126   } else {
127     it->second.insert(it->second.end(), value.begin(), value.end());
128   }
129   cv_.notify_all();
130 }
131 
multiGet(const std::vector<std::string> & keys)132 std::vector<std::vector<uint8_t>> HashStore::multiGet(
133     const std::vector<std::string>& keys) {
134   std::unique_lock<std::mutex> lock(m_);
135   auto deadline = std::chrono::steady_clock::now() + timeout_;
136   std::vector<std::vector<uint8_t>> res;
137   res.reserve(keys.size());
138 
139   for (auto& key : keys) {
140     auto it = map_.find(key);
141     if (it != map_.end()) {
142       res.emplace_back(it->second);
143     } else {
144       auto pred = [&]() { return map_.find(key) != map_.end(); };
145       if (timeout_ == kNoTimeout) {
146         cv_.wait(lock, pred);
147       } else {
148         if (!cv_.wait_until(lock, deadline, pred)) {
149           C10_THROW_ERROR(DistStoreError, "Wait timeout");
150         }
151       }
152       res.emplace_back(map_[key]);
153     }
154   }
155   return res;
156 }
157 
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)158 void HashStore::multiSet(
159     const std::vector<std::string>& keys,
160     const std::vector<std::vector<uint8_t>>& values) {
161   std::unique_lock<std::mutex> lock(m_);
162 
163   for (auto i : ::c10::irange(keys.size())) {
164     map_[keys[i]] = values[i];
165   }
166   cv_.notify_all();
167 }
168 
hasExtendedApi() const169 bool HashStore::hasExtendedApi() const {
170   return true;
171 }
172 
173 } // namespace c10d
174