1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 #include <netdutils/NetNativeTestBase.h>
20 #include <resolv_stats_test_utils.h>
21
22 #include "PrivateDnsConfiguration.h"
23 #include "resolv_cache.h"
24 #include "tests/dns_responder/dns_responder.h"
25 #include "tests/dns_responder/dns_tls_frontend.h"
26 #include "tests/resolv_test_utils.h"
27
28 namespace android::net {
29
30 using namespace std::chrono_literals;
31
32 class PrivateDnsConfigurationTest : public NetNativeTestBase {
33 public:
34 using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;
35
36 class WrappedPrivateDnsConfiguration : public PrivateDnsConfiguration {
37 public:
set(int32_t netId,uint32_t mark,const std::vector<std::string> & unencryptedServers,const std::vector<std::string> & encryptedServers)38 int set(int32_t netId, uint32_t mark, const std::vector<std::string>& unencryptedServers,
39 const std::vector<std::string>& encryptedServers) {
40 // TODO(b/240259333): Add test coverage for dohParamsParcel.
41 return PrivateDnsConfiguration::set(netId, mark, unencryptedServers, encryptedServers,
42 {} /* name */, {} /* caCert */,
43 std::nullopt /* dohParamsParcel */);
44 }
45 };
46
SetUpTestSuite()47 static void SetUpTestSuite() {
48 // stopServer() will be called in their destructor.
49 ASSERT_TRUE(tls1.startServer());
50 ASSERT_TRUE(tls2.startServer());
51 ASSERT_TRUE(backend.startServer());
52 ASSERT_TRUE(backend1ForUdpProbe.startServer());
53 ASSERT_TRUE(backend2ForUdpProbe.startServer());
54 }
55
SetUp()56 void SetUp() {
57 mPdc.setObserver(&mObserver);
58 mPdc.mBackoffBuilder.withInitialRetransmissionTime(std::chrono::seconds(1))
59 .withMaximumRetransmissionTime(std::chrono::seconds(1));
60
61 // The default and sole action when the observer is notified of onValidationStateUpdate.
62 // Don't override the action. In other words, don't use WillOnce() or WillRepeatedly()
63 // when mObserver.onValidationStateUpdate is expected to be called, like:
64 //
65 // EXPECT_CALL(mObserver, onValidationStateUpdate).WillOnce(Return());
66 //
67 // This is to ensure that tests can monitor how many validation threads are running. Tests
68 // must wait until every validation thread finishes.
69 ON_CALL(mObserver, onValidationStateUpdate)
70 .WillByDefault([&](const std::string& server, Validation validation, uint32_t) {
71 std::lock_guard guard(mObserver.lock);
72 if (validation == Validation::in_process) {
73 auto it = mObserver.serverStateMap.find(server);
74 if (it == mObserver.serverStateMap.end() ||
75 it->second != Validation::in_process) {
76 // Increment runningThreads only when receive the first in_process
77 // notification. The rest of the continuous in_process notifications
78 // are due to probe retry which runs on the same thread.
79 // TODO: consider adding onValidationThreadStart() and
80 // onValidationThreadEnd() callbacks.
81 mObserver.runningThreads++;
82 }
83 } else if (validation == Validation::success ||
84 validation == Validation::fail) {
85 mObserver.runningThreads--;
86 }
87 mObserver.serverStateMap[server] = validation;
88 });
89
90 // Create a NetConfig for stats.
91 EXPECT_EQ(0, resolv_create_cache_for_net(kNetId));
92 }
93
TearDown()94 void TearDown() {
95 // Reset the state for the next test.
96 resolv_delete_cache_for_net(kNetId);
97 mPdc.set(kNetId, kMark, {}, {});
98 }
99
100 protected:
101 class MockObserver : public PrivateDnsValidationObserver {
102 public:
103 MOCK_METHOD(void, onValidationStateUpdate,
104 (const std::string& serverIp, Validation validation, uint32_t netId),
105 (override));
106
getServerStateMap() const107 std::map<std::string, Validation> getServerStateMap() const {
108 std::lock_guard guard(lock);
109 return serverStateMap;
110 }
111
removeFromServerStateMap(const std::string & server)112 void removeFromServerStateMap(const std::string& server) {
113 std::lock_guard guard(lock);
114 if (const auto it = serverStateMap.find(server); it != serverStateMap.end())
115 serverStateMap.erase(it);
116 }
117
118 // The current number of validation threads running.
119 std::atomic<int> runningThreads = 0;
120
121 mutable std::mutex lock;
122 std::map<std::string, Validation> serverStateMap GUARDED_BY(lock);
123 };
124
expectPrivateDnsStatus(PrivateDnsMode mode)125 void expectPrivateDnsStatus(PrivateDnsMode mode) {
126 // Use PollForCondition because mObserver is notified asynchronously.
127 EXPECT_TRUE(PollForCondition([&]() { return checkPrivateDnsStatus(mode); }));
128 }
129
checkPrivateDnsStatus(PrivateDnsMode mode)130 bool checkPrivateDnsStatus(PrivateDnsMode mode) {
131 const PrivateDnsStatus status = mPdc.getStatus(kNetId);
132 if (status.mode != mode) return false;
133
134 std::map<std::string, Validation> serverStateMap;
135 for (const auto& [server, validation] : status.dotServersMap) {
136 serverStateMap[ToString(&server.ss)] = validation;
137 }
138 return (serverStateMap == mObserver.getServerStateMap());
139 }
140
hasPrivateDnsServer(const ServerIdentity & identity,unsigned netId)141 bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
142 return mPdc.getDotServer(identity, netId).ok();
143 }
144
145 static constexpr uint32_t kNetId = 30;
146 static constexpr uint32_t kMark = 30;
147 static constexpr char kBackend[] = "127.0.2.1";
148 static constexpr char kServer1[] = "127.0.2.2";
149 static constexpr char kServer2[] = "127.0.2.3";
150
151 MockObserver mObserver;
152 inline static WrappedPrivateDnsConfiguration mPdc;
153
154 // TODO: Because incorrect CAs result in validation failed in strict mode, have
155 // PrivateDnsConfiguration run mocked code rather than DnsTlsTransport::validate().
156 inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
157 inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
158 inline static test::DNSResponder backend{kBackend, "53"};
159 inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
160 inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
161 };
162
TEST_F(PrivateDnsConfigurationTest,ValidationSuccess)163 TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
164 testing::InSequence seq;
165 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
166 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
167
168 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
169 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
170
171 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
172 }
173
TEST_F(PrivateDnsConfigurationTest,ValidationFail_Opportunistic)174 TEST_F(PrivateDnsConfigurationTest, ValidationFail_Opportunistic) {
175 ASSERT_TRUE(backend.stopServer());
176
177 testing::InSequence seq;
178 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
179 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
180
181 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
182 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
183
184 // Strictly wait for all of the validation finish; otherwise, the test can crash somehow.
185 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
186 ASSERT_TRUE(backend.startServer());
187 }
188
TEST_F(PrivateDnsConfigurationTest,Revalidation_Opportunistic)189 TEST_F(PrivateDnsConfigurationTest, Revalidation_Opportunistic) {
190 const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
191
192 // Step 1: Set up and wait for validation complete.
193 testing::InSequence seq;
194 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
195 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
196
197 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
198 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
199 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
200
201 // Step 2: Simulate the DNS is temporarily broken, and then request a validation.
202 // Expect the validation to run as follows:
203 // 1. DnsResolver notifies of Validation::in_process when the validation is about to run.
204 // 2. The first probing fails. DnsResolver notifies of Validation::in_process.
205 // 3. One second later, the second probing begins and succeeds. DnsResolver notifies of
206 // Validation::success.
207 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId))
208 .Times(2);
209 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
210
211 std::thread t([] {
212 std::this_thread::sleep_for(1000ms);
213 backend.startServer();
214 });
215 backend.stopServer();
216 EXPECT_TRUE(mPdc.requestDotValidation(kNetId, ServerIdentity(server), kMark).ok());
217
218 t.join();
219 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
220 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
221 }
222
TEST_F(PrivateDnsConfigurationTest,ValidationBlock)223 TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
224 backend.setDeferredResp(true);
225
226 // onValidationStateUpdate() is called in sequence.
227 {
228 testing::InSequence seq;
229 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
230 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
231 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
232 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
233
234 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::in_process, kNetId));
235 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}), 0);
236 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 2; }));
237 mObserver.removeFromServerStateMap(kServer1);
238 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
239
240 // No duplicate validation as long as not in OFF mode; otherwise, an unexpected
241 // onValidationStateUpdate() will be caught.
242 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
243 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1, kServer2}), 0);
244 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer2}), 0);
245 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
246
247 // The status keeps unchanged if pass invalid arguments.
248 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}), -EINVAL);
249 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
250 }
251
252 // The update for |kServer1| will be Validation::fail because |kServer1| is not an expected
253 // server for the network.
254 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
255 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer2, Validation::success, kNetId));
256 backend.setDeferredResp(false);
257
258 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
259
260 // kServer1 is not a present server and thus should not be available from
261 // PrivateDnsConfiguration::getStatus().
262 mObserver.removeFromServerStateMap(kServer1);
263
264 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
265 }
266
TEST_F(PrivateDnsConfigurationTest,Validation_NetworkDestroyedOrOffMode)267 TEST_F(PrivateDnsConfigurationTest, Validation_NetworkDestroyedOrOffMode) {
268 for (const std::string_view config : {"OFF", "NETWORK_DESTROYED"}) {
269 SCOPED_TRACE(config);
270 backend.setDeferredResp(true);
271
272 testing::InSequence seq;
273 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
274 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
275 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 1; }));
276 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
277
278 if (config == "OFF") {
279 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}), 0);
280 } else if (config == "NETWORK_DESTROYED") {
281 mPdc.clear(kNetId);
282 }
283
284 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
285 backend.setDeferredResp(false);
286
287 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
288 mObserver.removeFromServerStateMap(kServer1);
289 expectPrivateDnsStatus(PrivateDnsMode::OFF);
290 }
291 }
292
TEST_F(PrivateDnsConfigurationTest,NoValidation)293 TEST_F(PrivateDnsConfigurationTest, NoValidation) {
294 // If onValidationStateUpdate() is called, the test will fail with uninteresting mock
295 // function calls in the end of the test.
296
297 const auto expectStatus = [&]() {
298 const PrivateDnsStatus status = mPdc.getStatus(kNetId);
299 EXPECT_EQ(status.mode, PrivateDnsMode::OFF);
300 EXPECT_THAT(status.dotServersMap, testing::IsEmpty());
301 };
302
303 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {"invalid_addr"}), -EINVAL);
304 expectStatus();
305
306 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {}), 0);
307 expectStatus();
308 }
309
TEST_F(PrivateDnsConfigurationTest,ServerIdentity_Comparison)310 TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
311 DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
312 server.name = "dns.example.com";
313
314 // Different socket address.
315 DnsTlsServer other = server;
316 EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
317 other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
318 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
319 other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
320 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
321
322 // Different provider hostname.
323 other = server;
324 EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
325 other.name = "other.example.com";
326 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
327 other.name = "";
328 EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
329 }
330
TEST_F(PrivateDnsConfigurationTest,RequestValidation)331 TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
332 const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
333 const ServerIdentity identity(server);
334
335 testing::InSequence seq;
336
337 for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
338 SCOPED_TRACE(config);
339
340 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
341 if (config == "SUCCESS") {
342 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
343 } else if (config == "IN_PROGRESS") {
344 backend.setDeferredResp(true);
345 } else {
346 // config = "FAIL"
347 ASSERT_TRUE(backend.stopServer());
348 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
349 }
350 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
351 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
352
353 // Wait until the validation state is transitioned.
354 const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
355 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));
356
357 if (config == "SUCCESS") {
358 EXPECT_CALL(mObserver,
359 onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
360 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
361 EXPECT_TRUE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
362 } else if (config == "IN_PROGRESS") {
363 EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
364 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
365 } else if (config == "FAIL") {
366 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
367 }
368
369 // Resending the same request or requesting nonexistent servers are denied.
370 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
371 EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark + 1).ok());
372 EXPECT_FALSE(mPdc.requestDotValidation(kNetId + 1, identity, kMark).ok());
373
374 // Reset the test state.
375 backend.setDeferredResp(false);
376 backend.startServer();
377
378 // Ensure the status of mObserver is synced.
379 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
380
381 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
382 mPdc.clear(kNetId);
383 }
384 }
385
TEST_F(PrivateDnsConfigurationTest,GetPrivateDns)386 TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
387 const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
388 const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
389
390 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
391 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
392
393 // Suppress the warning.
394 EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);
395
396 EXPECT_EQ(mPdc.set(kNetId, kMark, {}, {kServer1}), 0);
397 expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
398
399 EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
400 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
401 EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));
402
403 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
404 }
405
406 // Tests that getStatusForMetrics() returns the correct data.
TEST_F(PrivateDnsConfigurationTest,GetStatusForMetrics)407 TEST_F(PrivateDnsConfigurationTest, GetStatusForMetrics) {
408 tls2.stopServer();
409 const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
410 const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));
411
412 // Suppress the warning.
413 EXPECT_CALL(mObserver, onValidationStateUpdate).Times(4);
414
415 // Set 1 unencrypted server and 2 encrypted servers (one will pass DoT validation; the other
416 // will fail. Both of them don't support DoH).
417 EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer2}, {kServer1, kServer2}), 0);
418 ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
419
420 // Get the metric before call clear().
421 NetworkDnsServerSupportReported event = mPdc.getStatusForMetrics(kNetId);
422 NetworkDnsServerSupportReported expectedEvent;
423 // It's NT_UNKNOWN because this test didn't call resolv_set_nameservers() to set
424 // the network type.
425 expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
426 expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_OPPORTUNISTIC);
427 Server* server = expectedEvent.mutable_servers()->add_server();
428 server->set_protocol(PROTO_UDP); // kServer2
429 server->set_index(0);
430 server->set_validated(false);
431 server = expectedEvent.mutable_servers()->add_server();
432 server->set_protocol(PROTO_DOT); // kServer1
433 server->set_index(0);
434 server->set_validated(true);
435 server = expectedEvent.mutable_servers()->add_server();
436 server->set_protocol(PROTO_DOT); // kServer2
437 server->set_index(1);
438 server->set_validated(false);
439 server = expectedEvent.mutable_servers()->add_server();
440 server->set_protocol(PROTO_DOH); // kServer1
441 server->set_index(0);
442 server->set_validated(false);
443 server = expectedEvent.mutable_servers()->add_server();
444 server->set_protocol(PROTO_DOH); // kServer2
445 server->set_index(1);
446 server->set_validated(false);
447 EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
448
449 // Get the metric after call clear().
450 mPdc.clear(kNetId);
451 event = mPdc.getStatusForMetrics(kNetId);
452 expectedEvent.Clear();
453 expectedEvent.set_network_type(NetworkType::NT_UNKNOWN);
454 expectedEvent.set_private_dns_modes(PrivateDnsModes::PDM_UNKNOWN);
455 EXPECT_THAT(event, NetworkDnsServerSupportEq(expectedEvent));
456
457 tls2.startServer();
458 }
459
460 // TODO: add ValidationFail_Strict test.
461
462 } // namespace android::net
463