1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "keymint_benchmark"
18
19 #include <iostream>
20
21 #include <base/command_line.h>
22 #include <benchmark/benchmark.h>
23
24 #include <aidl/Vintf.h>
25 #include <aidl/android/hardware/security/keymint/ErrorCode.h>
26 #include <aidl/android/hardware/security/keymint/IKeyMintDevice.h>
27 #include <android/binder_manager.h>
28 #include <binder/IServiceManager.h>
29
30 #include <keymint_support/authorization_set.h>
31 #include <keymint_support/openssl_utils.h>
32 #include <openssl/curve25519.h>
33 #include <openssl/x509.h>
34
35 #define SMALL_MESSAGE_SIZE 64
36 #define MEDIUM_MESSAGE_SIZE 1024
37 #define LARGE_MESSAGE_SIZE 131072
38
39 namespace aidl::android::hardware::security::keymint::test {
40
41 ::std::ostream& operator<<(::std::ostream& os, const keymint::AuthorizationSet& set);
42
43 using ::android::sp;
44 using Status = ::ndk::ScopedAStatus;
45 using ::std::optional;
46 using ::std::shared_ptr;
47 using ::std::string;
48 using ::std::vector;
49
50 class KeyMintBenchmarkTest {
51 public:
KeyMintBenchmarkTest()52 KeyMintBenchmarkTest() {
53 message_cache_.push_back(string(SMALL_MESSAGE_SIZE, 'x'));
54 message_cache_.push_back(string(MEDIUM_MESSAGE_SIZE, 'x'));
55 message_cache_.push_back(string(LARGE_MESSAGE_SIZE, 'x'));
56 }
57
newInstance(const char * instanceName)58 static KeyMintBenchmarkTest* newInstance(const char* instanceName) {
59 if (AServiceManager_isDeclared(instanceName)) {
60 ::ndk::SpAIBinder binder(AServiceManager_waitForService(instanceName));
61 KeyMintBenchmarkTest* test = new KeyMintBenchmarkTest();
62 test->InitializeKeyMint(IKeyMintDevice::fromBinder(binder));
63 return test;
64 } else {
65 return nullptr;
66 }
67 }
68
getError()69 int getError() { return static_cast<int>(error_); }
70
GenerateMessage(int size)71 const string GenerateMessage(int size) {
72 for (const string& message : message_cache_) {
73 if (message.size() == size) {
74 return message;
75 }
76 }
77 string message = string(size, 'x');
78 message_cache_.push_back(message);
79 return message;
80 }
81
getBlockMode(string transform)82 optional<BlockMode> getBlockMode(string transform) {
83 if (transform.find("/ECB") != string::npos) {
84 return BlockMode::ECB;
85 } else if (transform.find("/CBC") != string::npos) {
86 return BlockMode::CBC;
87 } else if (transform.find("/CTR") != string::npos) {
88 return BlockMode::CTR;
89 } else if (transform.find("/GCM") != string::npos) {
90 return BlockMode::GCM;
91 }
92 return {};
93 }
94
getPadding(string transform,bool sign)95 PaddingMode getPadding(string transform, bool sign) {
96 if (transform.find("/PKCS7") != string::npos) {
97 return PaddingMode::PKCS7;
98 } else if (transform.find("/PSS") != string::npos) {
99 return PaddingMode::RSA_PSS;
100 } else if (transform.find("/OAEP") != string::npos) {
101 return PaddingMode::RSA_OAEP;
102 } else if (transform.find("/PKCS1") != string::npos) {
103 return sign ? PaddingMode::RSA_PKCS1_1_5_SIGN : PaddingMode::RSA_PKCS1_1_5_ENCRYPT;
104 } else if (sign && transform.find("RSA") != string::npos) {
105 // RSA defaults to PKCS1 for sign
106 return PaddingMode::RSA_PKCS1_1_5_SIGN;
107 }
108 return PaddingMode::NONE;
109 }
110
getAlgorithm(string transform)111 optional<Algorithm> getAlgorithm(string transform) {
112 if (transform.find("AES") != string::npos) {
113 return Algorithm::AES;
114 } else if (transform.find("Hmac") != string::npos) {
115 return Algorithm::HMAC;
116 } else if (transform.find("DESede") != string::npos) {
117 return Algorithm::TRIPLE_DES;
118 } else if (transform.find("RSA") != string::npos) {
119 return Algorithm::RSA;
120 } else if (transform.find("EC") != string::npos) {
121 return Algorithm::EC;
122 }
123 std::cerr << "Can't find algorithm for " << transform << std::endl;
124 return {};
125 }
126
getAlgorithmString(string transform)127 string getAlgorithmString(string transform) {
128 if (transform.find("AES") != string::npos) {
129 return "AES";
130 } else if (transform.find("Hmac") != string::npos) {
131 return "HMAC";
132 } else if (transform.find("DESede") != string::npos) {
133 return "TRIPLE_DES";
134 } else if (transform.find("RSA") != string::npos) {
135 return "RSA";
136 } else if (transform.find("EC") != string::npos) {
137 return "EC";
138 }
139 std::cerr << "Can't find algorithm for " << transform << std::endl;
140 return "";
141 }
142
getDigest(string transform)143 Digest getDigest(string transform) {
144 if (transform.find("MD5") != string::npos) {
145 return Digest::MD5;
146 } else if (transform.find("SHA1") != string::npos ||
147 transform.find("SHA-1") != string::npos) {
148 return Digest::SHA1;
149 } else if (transform.find("SHA224") != string::npos) {
150 return Digest::SHA_2_224;
151 } else if (transform.find("SHA256") != string::npos) {
152 return Digest::SHA_2_256;
153 } else if (transform.find("SHA384") != string::npos) {
154 return Digest::SHA_2_384;
155 } else if (transform.find("SHA512") != string::npos) {
156 return Digest::SHA_2_512;
157 } else if (transform.find("RSA") != string::npos &&
158 transform.find("OAEP") != string::npos) {
159 if (securityLevel_ == SecurityLevel::STRONGBOX) {
160 return Digest::SHA_2_256;
161 } else {
162 return Digest::SHA1;
163 }
164 } else if (transform.find("Hmac") != string::npos) {
165 return Digest::SHA_2_256;
166 }
167 return Digest::NONE;
168 }
169
getDigestString(string transform)170 string getDigestString(string transform) {
171 if (transform.find("MD5") != string::npos) {
172 return "MD5";
173 } else if (transform.find("SHA1") != string::npos ||
174 transform.find("SHA-1") != string::npos) {
175 return "SHA1";
176 } else if (transform.find("SHA224") != string::npos) {
177 return "SHA_2_224";
178 } else if (transform.find("SHA256") != string::npos) {
179 return "SHA_2_256";
180 } else if (transform.find("SHA384") != string::npos) {
181 return "SHA_2_384";
182 } else if (transform.find("SHA512") != string::npos) {
183 return "SHA_2_512";
184 } else if (transform.find("RSA") != string::npos &&
185 transform.find("OAEP") != string::npos) {
186 if (securityLevel_ == SecurityLevel::STRONGBOX) {
187 return "SHA_2_256";
188 } else {
189 return "SHA1";
190 }
191 } else if (transform.find("Hmac") != string::npos) {
192 return "SHA_2_256";
193 }
194 return "";
195 }
196
getCurveFromLength(int keySize)197 optional<EcCurve> getCurveFromLength(int keySize) {
198 switch (keySize) {
199 case 224:
200 return EcCurve::P_224;
201 case 256:
202 return EcCurve::P_256;
203 case 384:
204 return EcCurve::P_384;
205 case 521:
206 return EcCurve::P_521;
207 default:
208 return std::nullopt;
209 }
210 }
211
GenerateKey(string transform,int keySize,bool sign=false)212 bool GenerateKey(string transform, int keySize, bool sign = false) {
213 if (transform == key_transform_) {
214 return true;
215 } else if (key_transform_ != "") {
216 // Deleting old key first
217 key_transform_ = "";
218 if (DeleteKey() != ErrorCode::OK) {
219 return false;
220 }
221 }
222 std::optional<Algorithm> algorithm = getAlgorithm(transform);
223 if (!algorithm) {
224 std::cerr << "Error: invalid algorithm " << transform << std::endl;
225 return false;
226 }
227 key_transform_ = transform;
228 AuthorizationSetBuilder authSet = AuthorizationSetBuilder()
229 .Authorization(TAG_NO_AUTH_REQUIRED)
230 .Authorization(TAG_PURPOSE, KeyPurpose::ENCRYPT)
231 .Authorization(TAG_PURPOSE, KeyPurpose::DECRYPT)
232 .Authorization(TAG_PURPOSE, KeyPurpose::SIGN)
233 .Authorization(TAG_PURPOSE, KeyPurpose::VERIFY)
234 .Authorization(TAG_KEY_SIZE, keySize)
235 .Authorization(TAG_ALGORITHM, algorithm.value())
236 .Digest(getDigest(transform))
237 .Padding(getPadding(transform, sign));
238 std::optional<BlockMode> blockMode = getBlockMode(transform);
239 if (blockMode) {
240 authSet.BlockMode(blockMode.value());
241 if (blockMode == BlockMode::GCM) {
242 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
243 }
244 }
245 if (algorithm == Algorithm::HMAC) {
246 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
247 }
248 if (algorithm == Algorithm::RSA) {
249 authSet.Authorization(TAG_RSA_PUBLIC_EXPONENT, 65537U);
250 authSet.SetDefaultValidity();
251 }
252 if (algorithm == Algorithm::EC) {
253 authSet.SetDefaultValidity();
254 std::optional<EcCurve> curve = getCurveFromLength(keySize);
255 if (!curve) {
256 std::cerr << "Error: invalid EC-Curve from size " << keySize << std::endl;
257 return false;
258 }
259 authSet.Authorization(TAG_EC_CURVE, curve.value());
260 }
261 error_ = GenerateKey(authSet);
262 return error_ == ErrorCode::OK;
263 }
264
getOperationParams(string transform,bool sign=false)265 AuthorizationSet getOperationParams(string transform, bool sign = false) {
266 AuthorizationSetBuilder builder = AuthorizationSetBuilder()
267 .Padding(getPadding(transform, sign))
268 .Digest(getDigest(transform));
269 std::optional<BlockMode> blockMode = getBlockMode(transform);
270 if (sign && (transform.find("Hmac") != string::npos)) {
271 builder.Authorization(TAG_MAC_LENGTH, 128);
272 }
273 if (blockMode) {
274 builder.BlockMode(*blockMode);
275 if (blockMode == BlockMode::GCM) {
276 builder.Authorization(TAG_MAC_LENGTH, 128);
277 }
278 }
279 return std::move(builder);
280 }
281
Process(const string & message,const string & signature="")282 optional<string> Process(const string& message, const string& signature = "") {
283 ErrorCode result;
284
285 string output;
286 result = Finish(message, signature, &output);
287 if (result != ErrorCode::OK) {
288 error_ = result;
289 return {};
290 }
291 return output;
292 }
293
DeleteKey()294 ErrorCode DeleteKey() {
295 Status result = keymint_->deleteKey(key_blob_);
296 key_blob_ = vector<uint8_t>();
297 key_transform_ = "";
298 return GetReturnErrorCode(result);
299 }
300
Begin(KeyPurpose purpose,const AuthorizationSet & in_params,AuthorizationSet * out_params)301 ErrorCode Begin(KeyPurpose purpose, const AuthorizationSet& in_params,
302 AuthorizationSet* out_params) {
303 Status result;
304 BeginResult out;
305 result = keymint_->begin(purpose, key_blob_, in_params.vector_data(), std::nullopt, &out);
306 if (result.isOk()) {
307 *out_params = out.params;
308 op_ = out.operation;
309 }
310 return GetReturnErrorCode(result);
311 }
312
313 /* Copied the function LocalRsaEncryptMessage from
314 * hardware/interfaces/security/keymint/aidl/vts/functional/KeyMintAidlTestBase.cpp in VTS.
315 * Replaced asserts with the condition check and return false in case of failure condition.
316 * Require return value to skip the benchmark test case from further execution in case
317 * LocalRsaEncryptMessage fails.
318 */
LocalRsaEncryptMessage(const string & message,const AuthorizationSet & params)319 optional<string> LocalRsaEncryptMessage(const string& message, const AuthorizationSet& params) {
320 // Retrieve the public key from the leaf certificate.
321 if (cert_chain_.empty()) {
322 std::cerr << "Local RSA encrypt Error: invalid cert_chain_" << std::endl;
323 return "Failure";
324 }
325 X509_Ptr key_cert(parse_cert_blob(cert_chain_[0].encodedCertificate));
326 EVP_PKEY_Ptr pub_key(X509_get_pubkey(key_cert.get()));
327 RSA_Ptr rsa(EVP_PKEY_get1_RSA(const_cast<EVP_PKEY*>(pub_key.get())));
328
329 // Retrieve relevant tags.
330 Digest digest = Digest::NONE;
331 Digest mgf_digest = Digest::SHA1;
332 PaddingMode padding = PaddingMode::NONE;
333
334 auto digest_tag = params.GetTagValue(TAG_DIGEST);
335 if (digest_tag.has_value()) digest = digest_tag.value();
336 auto pad_tag = params.GetTagValue(TAG_PADDING);
337 if (pad_tag.has_value()) padding = pad_tag.value();
338 auto mgf_tag = params.GetTagValue(TAG_RSA_OAEP_MGF_DIGEST);
339 if (mgf_tag.has_value()) mgf_digest = mgf_tag.value();
340
341 const EVP_MD* md = openssl_digest(digest);
342 const EVP_MD* mgf_md = openssl_digest(mgf_digest);
343
344 // Set up encryption context.
345 EVP_PKEY_CTX_Ptr ctx(EVP_PKEY_CTX_new(pub_key.get(), /* engine= */ nullptr));
346 if (EVP_PKEY_encrypt_init(ctx.get()) <= 0) {
347 std::cerr << "Local RSA encrypt Error: Encryption init failed" << std::endl;
348 return "Failure";
349 }
350
351 int rc = -1;
352 switch (padding) {
353 case PaddingMode::NONE:
354 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_NO_PADDING);
355 break;
356 case PaddingMode::RSA_PKCS1_1_5_ENCRYPT:
357 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_PADDING);
358 break;
359 case PaddingMode::RSA_OAEP:
360 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING);
361 break;
362 default:
363 break;
364 }
365 if (rc <= 0) {
366 std::cerr << "Local RSA encrypt Error: Set padding failed" << std::endl;
367 return "Failure";
368 }
369 if (padding == PaddingMode::RSA_OAEP) {
370 if (!EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), md)) {
371 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
372 << std::endl;
373 return "Failure";
374 }
375 if (!EVP_PKEY_CTX_set_rsa_mgf1_md(ctx.get(), mgf_md)) {
376 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
377 << std::endl;
378 return "Failure";
379 }
380 }
381
382 // Determine output size.
383 size_t outlen;
384 if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen,
385 reinterpret_cast<const uint8_t*>(message.data()),
386 message.size()) <= 0) {
387 std::cerr << "Local RSA encrypt Error: Determine output size failed: "
388 << ERR_peek_last_error() << std::endl;
389 return "Failure";
390 }
391
392 // Left-zero-pad the input if necessary.
393 const uint8_t* to_encrypt = reinterpret_cast<const uint8_t*>(message.data());
394 size_t to_encrypt_len = message.size();
395
396 std::unique_ptr<string> zero_padded_message;
397 if (padding == PaddingMode::NONE && to_encrypt_len < outlen) {
398 zero_padded_message.reset(new string(outlen, '\0'));
399 memcpy(zero_padded_message->data() + (outlen - to_encrypt_len), message.data(),
400 message.size());
401 to_encrypt = reinterpret_cast<const uint8_t*>(zero_padded_message->data());
402 to_encrypt_len = outlen;
403 }
404
405 // Do the encryption.
406 string output(outlen, '\0');
407 if (EVP_PKEY_encrypt(ctx.get(), reinterpret_cast<uint8_t*>(output.data()), &outlen,
408 to_encrypt, to_encrypt_len) <= 0) {
409 std::cerr << "Local RSA encrypt Error: Encryption failed: " << ERR_peek_last_error()
410 << std::endl;
411 return "Failure";
412 }
413 return output;
414 }
415
416 SecurityLevel securityLevel_;
417 string name_;
418
419 private:
GenerateKey(const AuthorizationSet & key_desc,const optional<AttestationKey> & attest_key=std::nullopt)420 ErrorCode GenerateKey(const AuthorizationSet& key_desc,
421 const optional<AttestationKey>& attest_key = std::nullopt) {
422 key_blob_.clear();
423 cert_chain_.clear();
424 KeyCreationResult creationResult;
425 Status result = keymint_->generateKey(key_desc.vector_data(), attest_key, &creationResult);
426 if (result.isOk()) {
427 key_blob_ = std::move(creationResult.keyBlob);
428 cert_chain_ = std::move(creationResult.certificateChain);
429 creationResult.keyCharacteristics.clear();
430 }
431 return GetReturnErrorCode(result);
432 }
433
InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint)434 void InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint) {
435 if (!keyMint) {
436 std::cerr << "Trying initialize nullptr in InitializeKeyMint" << std::endl;
437 return;
438 }
439 keymint_ = std::move(keyMint);
440 KeyMintHardwareInfo info;
441 Status result = keymint_->getHardwareInfo(&info);
442 if (!result.isOk()) {
443 std::cerr << "InitializeKeyMint: getHardwareInfo failed with "
444 << result.getServiceSpecificError() << std::endl;
445 }
446 securityLevel_ = info.securityLevel;
447 name_.assign(info.keyMintName.begin(), info.keyMintName.end());
448 }
449
Finish(const string & input,const string & signature,string * output)450 ErrorCode Finish(const string& input, const string& signature, string* output) {
451 if (!op_) {
452 std::cerr << "Finish: Operation is nullptr" << std::endl;
453 return ErrorCode::UNEXPECTED_NULL_POINTER;
454 }
455
456 vector<uint8_t> oPut;
457 Status result =
458 op_->finish(vector<uint8_t>(input.begin(), input.end()),
459 vector<uint8_t>(signature.begin(), signature.end()), {} /* authToken */,
460 {} /* timestampToken */, {} /* confirmationToken */, &oPut);
461
462 if (result.isOk()) output->append(oPut.begin(), oPut.end());
463
464 op_.reset();
465 return GetReturnErrorCode(result);
466 }
467
Update(const string & input,string * output)468 ErrorCode Update(const string& input, string* output) {
469 Status result;
470 if (!op_) {
471 std::cerr << "Update: Operation is nullptr" << std::endl;
472 return ErrorCode::UNEXPECTED_NULL_POINTER;
473 }
474
475 std::vector<uint8_t> o_put;
476 result = op_->update(vector<uint8_t>(input.begin(), input.end()), {} /* authToken */,
477 {} /* timestampToken */, &o_put);
478
479 if (result.isOk() && output) *output = {o_put.begin(), o_put.end()};
480 return GetReturnErrorCode(result);
481 }
482
GetReturnErrorCode(const Status & result)483 ErrorCode GetReturnErrorCode(const Status& result) {
484 error_ = static_cast<ErrorCode>(result.getServiceSpecificError());
485 if (result.isOk()) return ErrorCode::OK;
486
487 if (result.getExceptionCode() == EX_SERVICE_SPECIFIC) {
488 return static_cast<ErrorCode>(result.getServiceSpecificError());
489 }
490
491 return ErrorCode::UNKNOWN_ERROR;
492 }
493
parse_cert_blob(const vector<uint8_t> & blob)494 X509_Ptr parse_cert_blob(const vector<uint8_t>& blob) {
495 const uint8_t* p = blob.data();
496 return X509_Ptr(d2i_X509(nullptr /* allocate new */, &p, blob.size()));
497 }
498
499 std::shared_ptr<IKeyMintOperation> op_;
500 vector<Certificate> cert_chain_;
501 vector<uint8_t> key_blob_;
502 vector<KeyCharacteristics> key_characteristics_;
503 std::shared_ptr<IKeyMintDevice> keymint_;
504 std::vector<string> message_cache_;
505 std::string key_transform_;
506 ErrorCode error_;
507 };
508
509 KeyMintBenchmarkTest* keymintTest;
510
settings(benchmark::internal::Benchmark * benchmark)511 static void settings(benchmark::internal::Benchmark* benchmark) {
512 benchmark->Unit(benchmark::kMillisecond);
513 }
514
addDefaultLabel(benchmark::State & state)515 static void addDefaultLabel(benchmark::State& state) {
516 std::string secLevel;
517 switch (keymintTest->securityLevel_) {
518 case SecurityLevel::STRONGBOX:
519 secLevel = "STRONGBOX";
520 break;
521 case SecurityLevel::SOFTWARE:
522 secLevel = "SOFTWARE";
523 break;
524 case SecurityLevel::TRUSTED_ENVIRONMENT:
525 secLevel = "TEE";
526 break;
527 case SecurityLevel::KEYSTORE:
528 secLevel = "KEYSTORE";
529 break;
530 }
531 state.SetLabel("hardware_name:" + keymintTest->name_ + " sec_level:" + secLevel);
532 }
533
534 // clang-format off
535 #define BENCHMARK_KM(func, transform, keySize) \
536 BENCHMARK_CAPTURE(func, transform/keySize, #transform "/" #keySize, keySize)->Apply(settings);
537 #define BENCHMARK_KM_MSG(func, transform, keySize, msgSize) \
538 BENCHMARK_CAPTURE(func, transform/keySize/msgSize, #transform "/" #keySize "/" #msgSize, \
539 keySize, msgSize) \
540 ->Apply(settings);
541
542 #define BENCHMARK_KM_ALL_MSGS(func, transform, keySize) \
543 BENCHMARK_KM_MSG(func, transform, keySize, SMALL_MESSAGE_SIZE) \
544 BENCHMARK_KM_MSG(func, transform, keySize, MEDIUM_MESSAGE_SIZE) \
545 BENCHMARK_KM_MSG(func, transform, keySize, LARGE_MESSAGE_SIZE)
546
547 #define BENCHMARK_KM_CIPHER(transform, keySize, msgSize) \
548 BENCHMARK_KM_MSG(encrypt, transform, keySize, msgSize) \
549 BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
550
551 // Skip public key operations as they are not supported in KeyMint.
552 #define BENCHMARK_KM_ASYM_CIPHER(transform, keySize, msgSize) \
553 BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
554
555 #define BENCHMARK_KM_CIPHER_ALL_MSGS(transform, keySize) \
556 BENCHMARK_KM_ALL_MSGS(encrypt, transform, keySize) \
557 BENCHMARK_KM_ALL_MSGS(decrypt, transform, keySize)
558
559 #define BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, keySize) \
560 BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
561 BENCHMARK_KM_ALL_MSGS(verify, transform, keySize)
562
563 // Skip public key operations as they are not supported in KeyMint.
564 #define BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, keySize) \
565 BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
566 // clang-format on
567
568 /*
569 * ============= KeyGen TESTS ==================
570 */
571
isValidSBKeySize(string transform,int keySize)572 static bool isValidSBKeySize(string transform, int keySize) {
573 std::optional<Algorithm> algorithm = keymintTest->getAlgorithm(transform);
574 switch (algorithm.value()) {
575 case Algorithm::AES:
576 return (keySize == 128 || keySize == 256);
577 case Algorithm::HMAC:
578 return (keySize % 8 == 0 && keySize >= 64 && keySize <= 512);
579 case Algorithm::TRIPLE_DES:
580 return (keySize == 168);
581 case Algorithm::RSA:
582 return (keySize == 2048);
583 case Algorithm::EC:
584 return (keySize == 256);
585 }
586 return false;
587 }
588
keygen(benchmark::State & state,string transform,int keySize)589 static void keygen(benchmark::State& state, string transform, int keySize) {
590 // Skip the test for unsupported key size in StrongBox
591 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
592 !isValidSBKeySize(transform, keySize)) {
593 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
594 " is not supported in StrongBox for algorithm: " +
595 keymintTest->getAlgorithmString(transform))
596 .c_str());
597 return;
598 }
599 addDefaultLabel(state);
600 for (auto _ : state) {
601 if (!keymintTest->GenerateKey(transform, keySize)) {
602 state.SkipWithError(
603 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
604 }
605 state.PauseTiming();
606
607 keymintTest->DeleteKey();
608 state.ResumeTiming();
609 }
610 }
611
612 BENCHMARK_KM(keygen, AES, 128);
613 BENCHMARK_KM(keygen, AES, 256);
614
615 BENCHMARK_KM(keygen, RSA, 2048);
616 BENCHMARK_KM(keygen, RSA, 3072);
617 BENCHMARK_KM(keygen, RSA, 4096);
618
619 BENCHMARK_KM(keygen, EC, 224);
620 BENCHMARK_KM(keygen, EC, 256);
621 BENCHMARK_KM(keygen, EC, 384);
622 BENCHMARK_KM(keygen, EC, 521);
623
624 BENCHMARK_KM(keygen, DESede, 168);
625
626 BENCHMARK_KM(keygen, Hmac, 64);
627 BENCHMARK_KM(keygen, Hmac, 128);
628 BENCHMARK_KM(keygen, Hmac, 256);
629 BENCHMARK_KM(keygen, Hmac, 512);
630
631 /*
632 * ============= SIGNATURE TESTS ==================
633 */
sign(benchmark::State & state,string transform,int keySize,int msgSize)634 static void sign(benchmark::State& state, string transform, int keySize, int msgSize) {
635 // Skip the test for unsupported key size or unsupported digest in StrongBox
636 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
637 if (!isValidSBKeySize(transform, keySize)) {
638 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
639 " is not supported in StrongBox for algorithm: " +
640 keymintTest->getAlgorithmString(transform))
641 .c_str());
642 return;
643 }
644 if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
645 state.SkipWithError(
646 ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
647 " is not supported in StrongBox")
648 .c_str());
649 return;
650 }
651 }
652 addDefaultLabel(state);
653 if (!keymintTest->GenerateKey(transform, keySize, true)) {
654 state.SkipWithError(
655 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
656 return;
657 }
658
659 auto in_params = keymintTest->getOperationParams(transform, true);
660 AuthorizationSet out_params;
661 string message = keymintTest->GenerateMessage(msgSize);
662
663 for (auto _ : state) {
664 state.PauseTiming();
665 ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
666 if (error != ErrorCode::OK) {
667 state.SkipWithError(
668 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
669 return;
670 }
671 state.ResumeTiming();
672 out_params.Clear();
673 if (!keymintTest->Process(message)) {
674 state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
675 break;
676 }
677 }
678 }
679
verify(benchmark::State & state,string transform,int keySize,int msgSize)680 static void verify(benchmark::State& state, string transform, int keySize, int msgSize) {
681 // Skip the test for unsupported key size or unsupported digest in StrongBox
682 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
683 if (!isValidSBKeySize(transform, keySize)) {
684 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
685 " is not supported in StrongBox for algorithm: " +
686 keymintTest->getAlgorithmString(transform))
687 .c_str());
688 return;
689 }
690 if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
691 state.SkipWithError(
692 ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
693 " is not supported in StrongBox")
694 .c_str());
695 return;
696 }
697 }
698 addDefaultLabel(state);
699 if (!keymintTest->GenerateKey(transform, keySize, true)) {
700 state.SkipWithError(
701 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
702 return;
703 }
704 AuthorizationSet out_params;
705 auto in_params = keymintTest->getOperationParams(transform, true);
706 string message = keymintTest->GenerateMessage(msgSize);
707 ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
708 if (error != ErrorCode::OK) {
709 state.SkipWithError(
710 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
711 return;
712 }
713 std::optional<string> signature = keymintTest->Process(message);
714 if (!signature) {
715 state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
716 return;
717 }
718 out_params.Clear();
719 if (transform.find("Hmac") != string::npos) {
720 in_params = keymintTest->getOperationParams(transform, false);
721 }
722 for (auto _ : state) {
723 state.PauseTiming();
724 error = keymintTest->Begin(KeyPurpose::VERIFY, in_params, &out_params);
725 if (error != ErrorCode::OK) {
726 state.SkipWithError(
727 ("Verify begin error, " + std::to_string(keymintTest->getError())).c_str());
728 return;
729 }
730 state.ResumeTiming();
731 if (!keymintTest->Process(message, *signature)) {
732 state.SkipWithError(
733 ("Verify error, " + std::to_string(keymintTest->getError())).c_str());
734 break;
735 }
736 }
737 }
738
739 // clang-format off
740 #define BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(transform) \
741 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 64) \
742 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 128) \
743 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256) \
744 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 512)
745
746 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA1)
747 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
748 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA224)
749 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
750 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA384)
751 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA512)
752
753 #define BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(transform) \
754 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 224) \
755 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 256) \
756 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 384) \
757 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 521)
758
759 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(NONEwithECDSA);
760 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA1withECDSA);
761 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA224withECDSA);
762 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA256withECDSA);
763 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA384withECDSA);
764 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA512withECDSA);
765
766 #define BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(transform) \
767 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 2048) \
768 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 3072) \
769 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 4096)
770
771 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA);
772 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA);
773 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA);
774 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA256withRSA);
775 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA);
776 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA);
777
778 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA/PSS);
779 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA/PSS);
780 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA/PSS);
781 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA/PSS);
782 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA/PSS);
783
784 // clang-format on
785
786 /*
787 * ============= CIPHER TESTS ==================
788 */
789
encrypt(benchmark::State & state,string transform,int keySize,int msgSize)790 static void encrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
791 // Skip the test for unsupported key size in StrongBox
792 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
793 (!isValidSBKeySize(transform, keySize))) {
794 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
795 " is not supported in StrongBox for algorithm: " +
796 keymintTest->getAlgorithmString(transform))
797 .c_str());
798 return;
799 }
800 addDefaultLabel(state);
801 if (!keymintTest->GenerateKey(transform, keySize)) {
802 state.SkipWithError(
803 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
804 return;
805 }
806 auto in_params = keymintTest->getOperationParams(transform);
807 AuthorizationSet out_params;
808 string message = keymintTest->GenerateMessage(msgSize);
809
810 for (auto _ : state) {
811 state.PauseTiming();
812 auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
813 if (error != ErrorCode::OK) {
814 state.SkipWithError(
815 ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
816 return;
817 }
818 out_params.Clear();
819 state.ResumeTiming();
820 if (!keymintTest->Process(message)) {
821 state.SkipWithError(
822 ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
823 break;
824 }
825 }
826 }
827
decrypt(benchmark::State & state,string transform,int keySize,int msgSize)828 static void decrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
829 // Skip the test for unsupported key size in StrongBox
830 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
831 (!isValidSBKeySize(transform, keySize))) {
832 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
833 " is not supported in StrongBox for algorithm: " +
834 keymintTest->getAlgorithmString(transform))
835 .c_str());
836 return;
837 }
838 addDefaultLabel(state);
839 if (!keymintTest->GenerateKey(transform, keySize)) {
840 state.SkipWithError(
841 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
842 return;
843 }
844 AuthorizationSet out_params;
845 AuthorizationSet in_params = keymintTest->getOperationParams(transform);
846 string message = keymintTest->GenerateMessage(msgSize);
847 optional<string> encryptedMessage;
848
849 if (keymintTest->getAlgorithm(transform).value() == Algorithm::RSA) {
850 // Public key operation not supported, doing local Encryption
851 encryptedMessage = keymintTest->LocalRsaEncryptMessage(message, in_params);
852 if ((keySize / 8) != (*encryptedMessage).size()) {
853 state.SkipWithError("Local Encryption falied");
854 return;
855 }
856 } else {
857 auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
858 if (error != ErrorCode::OK) {
859 state.SkipWithError(
860 ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
861 return;
862 }
863 encryptedMessage = keymintTest->Process(message);
864 if (!encryptedMessage) {
865 state.SkipWithError(
866 ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
867 return;
868 }
869 in_params.push_back(out_params);
870 out_params.Clear();
871 }
872 for (auto _ : state) {
873 state.PauseTiming();
874 auto error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
875 if (error != ErrorCode::OK) {
876 state.SkipWithError(
877 ("Decryption begin error, " + std::to_string(keymintTest->getError())).c_str());
878 return;
879 }
880 state.ResumeTiming();
881 if (!keymintTest->Process(*encryptedMessage)) {
882 state.SkipWithError(
883 ("Decryption error, " + std::to_string(keymintTest->getError())).c_str());
884 break;
885 }
886 }
887 }
888
889 // clang-format off
890 // AES
891 #define BENCHMARK_KM_CIPHER_ALL_AES_KEYS(transform) \
892 BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 128) \
893 BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 256)
894
895 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/NoPadding);
896 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/PKCS7Padding);
897 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CTR/NoPadding);
898 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/NoPadding);
899 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/PKCS7Padding);
900 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/GCM/NoPadding);
901
902 // Triple DES
903 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/NoPadding, 168);
904 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/PKCS7Padding, 168);
905 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/NoPadding, 168);
906 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/PKCS7Padding, 168);
907
908 #define BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(transform, msgSize) \
909 BENCHMARK_KM_ASYM_CIPHER(transform, 2048, msgSize) \
910 BENCHMARK_KM_ASYM_CIPHER(transform, 3072, msgSize) \
911 BENCHMARK_KM_ASYM_CIPHER(transform, 4096, msgSize)
912
913 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/NoPadding, SMALL_MESSAGE_SIZE);
914 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/PKCS1Padding, SMALL_MESSAGE_SIZE);
915 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/OAEPPadding, SMALL_MESSAGE_SIZE);
916
917 // clang-format on
918 } // namespace aidl::android::hardware::security::keymint::test
919
main(int argc,char ** argv)920 int main(int argc, char** argv) {
921 ::benchmark::Initialize(&argc, argv);
922 base::CommandLine::Init(argc, argv);
923 base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
924 auto service_name = command_line->GetSwitchValueASCII("service_name");
925 if (service_name.empty()) {
926 service_name =
927 std::string(
928 aidl::android::hardware::security::keymint::IKeyMintDevice::descriptor) +
929 "/default";
930 }
931 std::cerr << service_name << std::endl;
932 aidl::android::hardware::security::keymint::test::keymintTest =
933 aidl::android::hardware::security::keymint::test::KeyMintBenchmarkTest::newInstance(
934 service_name.c_str());
935 if (!aidl::android::hardware::security::keymint::test::keymintTest) {
936 return 1;
937 }
938 ::benchmark::RunSpecifiedBenchmarks();
939 }
940