1 #include <torch/csrc/distributed/c10d/HashStore.hpp>
2
3 #include <unistd.h>
4 #include <cstdint>
5
6 #include <chrono>
7
8 #include <c10/util/Exception.h>
9
10 namespace c10d {
11
set(const std::string & key,const std::vector<uint8_t> & data)12 void HashStore::set(const std::string& key, const std::vector<uint8_t>& data) {
13 std::unique_lock<std::mutex> lock(m_);
14 map_[key] = data;
15 cv_.notify_all();
16 }
17
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)18 std::vector<uint8_t> HashStore::compareSet(
19 const std::string& key,
20 const std::vector<uint8_t>& expectedValue,
21 const std::vector<uint8_t>& desiredValue) {
22 std::unique_lock<std::mutex> lock(m_);
23 auto it = map_.find(key);
24 if ((it == map_.end() && expectedValue.empty()) ||
25 (it != map_.end() && it->second == expectedValue)) {
26 // if the key does not exist and currentValue arg is empty or
27 // the key does exist and current value is what is expected, then set it
28 map_[key] = desiredValue;
29 cv_.notify_all();
30 return desiredValue;
31 } else if (it == map_.end()) {
32 // if the key does not exist
33 return expectedValue;
34 }
35 // key exists but current value is not expected
36 return it->second;
37 }
38
get(const std::string & key)39 std::vector<uint8_t> HashStore::get(const std::string& key) {
40 std::unique_lock<std::mutex> lock(m_);
41 auto it = map_.find(key);
42 if (it != map_.end()) {
43 return it->second;
44 }
45 // Slow path: wait up to any timeout_.
46 auto pred = [&]() { return map_.find(key) != map_.end(); };
47 if (timeout_ == kNoTimeout) {
48 cv_.wait(lock, pred);
49 } else {
50 if (!cv_.wait_for(lock, timeout_, pred)) {
51 C10_THROW_ERROR(DistStoreError, "Wait timeout");
52 }
53 }
54 return map_[key];
55 }
56
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)57 void HashStore::wait(
58 const std::vector<std::string>& keys,
59 const std::chrono::milliseconds& timeout) {
60 const auto end = std::chrono::steady_clock::now() + timeout;
61 auto pred = [&]() {
62 auto done = true;
63 for (const auto& key : keys) {
64 if (map_.find(key) == map_.end()) {
65 done = false;
66 break;
67 }
68 }
69 return done;
70 };
71
72 std::unique_lock<std::mutex> lock(m_);
73 if (timeout == kNoTimeout) {
74 cv_.wait(lock, pred);
75 } else {
76 if (!cv_.wait_until(lock, end, pred)) {
77 C10_THROW_ERROR(DistStoreError, "Wait timeout");
78 }
79 }
80 }
81
add(const std::string & key,int64_t i)82 int64_t HashStore::add(const std::string& key, int64_t i) {
83 std::unique_lock<std::mutex> lock(m_);
84 const auto& value = map_[key];
85 int64_t ti = i;
86 if (!value.empty()) {
87 auto buf = reinterpret_cast<const char*>(value.data());
88 auto len = value.size();
89 ti += std::stoll(std::string(buf, len));
90 }
91
92 auto str = std::to_string(ti);
93 const uint8_t* strB = reinterpret_cast<const uint8_t*>(str.c_str());
94 map_[key] = std::vector<uint8_t>(strB, strB + str.size());
95 return ti;
96 }
97
getNumKeys()98 int64_t HashStore::getNumKeys() {
99 std::unique_lock<std::mutex> lock(m_);
100 return static_cast<int64_t>(map_.size());
101 }
102
deleteKey(const std::string & key)103 bool HashStore::deleteKey(const std::string& key) {
104 std::unique_lock<std::mutex> lock(m_);
105 auto numDeleted = map_.erase(key);
106 return (numDeleted == 1);
107 }
108
check(const std::vector<std::string> & keys)109 bool HashStore::check(const std::vector<std::string>& keys) {
110 std::unique_lock<std::mutex> lock(m_);
111 for (const auto& key : keys) {
112 if (map_.find(key) == map_.end()) {
113 return false;
114 }
115 }
116 return true;
117 }
118
append(const std::string & key,const std::vector<uint8_t> & value)119 void HashStore::append(
120 const std::string& key,
121 const std::vector<uint8_t>& value) {
122 std::unique_lock<std::mutex> lock(m_);
123 auto it = map_.find(key);
124 if (it == map_.end()) {
125 map_[key] = value;
126 } else {
127 it->second.insert(it->second.end(), value.begin(), value.end());
128 }
129 cv_.notify_all();
130 }
131
multiGet(const std::vector<std::string> & keys)132 std::vector<std::vector<uint8_t>> HashStore::multiGet(
133 const std::vector<std::string>& keys) {
134 std::unique_lock<std::mutex> lock(m_);
135 auto deadline = std::chrono::steady_clock::now() + timeout_;
136 std::vector<std::vector<uint8_t>> res;
137 res.reserve(keys.size());
138
139 for (auto& key : keys) {
140 auto it = map_.find(key);
141 if (it != map_.end()) {
142 res.emplace_back(it->second);
143 } else {
144 auto pred = [&]() { return map_.find(key) != map_.end(); };
145 if (timeout_ == kNoTimeout) {
146 cv_.wait(lock, pred);
147 } else {
148 if (!cv_.wait_until(lock, deadline, pred)) {
149 C10_THROW_ERROR(DistStoreError, "Wait timeout");
150 }
151 }
152 res.emplace_back(map_[key]);
153 }
154 }
155 return res;
156 }
157
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)158 void HashStore::multiSet(
159 const std::vector<std::string>& keys,
160 const std::vector<std::vector<uint8_t>>& values) {
161 std::unique_lock<std::mutex> lock(m_);
162
163 for (auto i : ::c10::irange(keys.size())) {
164 map_[keys[i]] = values[i];
165 }
166 cv_.notify_all();
167 }
168
hasExtendedApi() const169 bool HashStore::hasExtendedApi() const {
170 return true;
171 }
172
173 } // namespace c10d
174