xref: /aosp_15_r20/hardware/interfaces/security/keymint/aidl/vts/performance/KeyMintBenchmark.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "keymint_benchmark"
18 
19 #include <iostream>
20 
21 #include <base/command_line.h>
22 #include <benchmark/benchmark.h>
23 
24 #include <aidl/Vintf.h>
25 #include <aidl/android/hardware/security/keymint/ErrorCode.h>
26 #include <aidl/android/hardware/security/keymint/IKeyMintDevice.h>
27 #include <android/binder_manager.h>
28 #include <binder/IServiceManager.h>
29 
30 #include <keymint_support/authorization_set.h>
31 #include <keymint_support/openssl_utils.h>
32 #include <openssl/curve25519.h>
33 #include <openssl/x509.h>
34 
35 #define SMALL_MESSAGE_SIZE 64
36 #define MEDIUM_MESSAGE_SIZE 1024
37 #define LARGE_MESSAGE_SIZE 131072
38 
39 namespace aidl::android::hardware::security::keymint::test {
40 
41 ::std::ostream& operator<<(::std::ostream& os, const keymint::AuthorizationSet& set);
42 
43 using ::android::sp;
44 using Status = ::ndk::ScopedAStatus;
45 using ::std::optional;
46 using ::std::shared_ptr;
47 using ::std::string;
48 using ::std::vector;
49 
50 class KeyMintBenchmarkTest {
51   public:
KeyMintBenchmarkTest()52     KeyMintBenchmarkTest() {
53         message_cache_.push_back(string(SMALL_MESSAGE_SIZE, 'x'));
54         message_cache_.push_back(string(MEDIUM_MESSAGE_SIZE, 'x'));
55         message_cache_.push_back(string(LARGE_MESSAGE_SIZE, 'x'));
56     }
57 
newInstance(const char * instanceName)58     static KeyMintBenchmarkTest* newInstance(const char* instanceName) {
59         if (AServiceManager_isDeclared(instanceName)) {
60             ::ndk::SpAIBinder binder(AServiceManager_waitForService(instanceName));
61             KeyMintBenchmarkTest* test = new KeyMintBenchmarkTest();
62             test->InitializeKeyMint(IKeyMintDevice::fromBinder(binder));
63             return test;
64         } else {
65             return nullptr;
66         }
67     }
68 
getError()69     int getError() { return static_cast<int>(error_); }
70 
GenerateMessage(int size)71     const string GenerateMessage(int size) {
72         for (const string& message : message_cache_) {
73             if (message.size() == size) {
74                 return message;
75             }
76         }
77         string message = string(size, 'x');
78         message_cache_.push_back(message);
79         return message;
80     }
81 
getBlockMode(string transform)82     optional<BlockMode> getBlockMode(string transform) {
83         if (transform.find("/ECB") != string::npos) {
84             return BlockMode::ECB;
85         } else if (transform.find("/CBC") != string::npos) {
86             return BlockMode::CBC;
87         } else if (transform.find("/CTR") != string::npos) {
88             return BlockMode::CTR;
89         } else if (transform.find("/GCM") != string::npos) {
90             return BlockMode::GCM;
91         }
92         return {};
93     }
94 
getPadding(string transform,bool sign)95     PaddingMode getPadding(string transform, bool sign) {
96         if (transform.find("/PKCS7") != string::npos) {
97             return PaddingMode::PKCS7;
98         } else if (transform.find("/PSS") != string::npos) {
99             return PaddingMode::RSA_PSS;
100         } else if (transform.find("/OAEP") != string::npos) {
101             return PaddingMode::RSA_OAEP;
102         } else if (transform.find("/PKCS1") != string::npos) {
103             return sign ? PaddingMode::RSA_PKCS1_1_5_SIGN : PaddingMode::RSA_PKCS1_1_5_ENCRYPT;
104         } else if (sign && transform.find("RSA") != string::npos) {
105             // RSA defaults to PKCS1 for sign
106             return PaddingMode::RSA_PKCS1_1_5_SIGN;
107         }
108         return PaddingMode::NONE;
109     }
110 
getAlgorithm(string transform)111     optional<Algorithm> getAlgorithm(string transform) {
112         if (transform.find("AES") != string::npos) {
113             return Algorithm::AES;
114         } else if (transform.find("Hmac") != string::npos) {
115             return Algorithm::HMAC;
116         } else if (transform.find("DESede") != string::npos) {
117             return Algorithm::TRIPLE_DES;
118         } else if (transform.find("RSA") != string::npos) {
119             return Algorithm::RSA;
120         } else if (transform.find("EC") != string::npos) {
121             return Algorithm::EC;
122         }
123         std::cerr << "Can't find algorithm for " << transform << std::endl;
124         return {};
125     }
126 
getAlgorithmString(string transform)127     string getAlgorithmString(string transform) {
128         if (transform.find("AES") != string::npos) {
129             return "AES";
130         } else if (transform.find("Hmac") != string::npos) {
131             return "HMAC";
132         } else if (transform.find("DESede") != string::npos) {
133             return "TRIPLE_DES";
134         } else if (transform.find("RSA") != string::npos) {
135             return "RSA";
136         } else if (transform.find("EC") != string::npos) {
137             return "EC";
138         }
139         std::cerr << "Can't find algorithm for " << transform << std::endl;
140         return "";
141     }
142 
getDigest(string transform)143     Digest getDigest(string transform) {
144         if (transform.find("MD5") != string::npos) {
145             return Digest::MD5;
146         } else if (transform.find("SHA1") != string::npos ||
147                    transform.find("SHA-1") != string::npos) {
148             return Digest::SHA1;
149         } else if (transform.find("SHA224") != string::npos) {
150             return Digest::SHA_2_224;
151         } else if (transform.find("SHA256") != string::npos) {
152             return Digest::SHA_2_256;
153         } else if (transform.find("SHA384") != string::npos) {
154             return Digest::SHA_2_384;
155         } else if (transform.find("SHA512") != string::npos) {
156             return Digest::SHA_2_512;
157         } else if (transform.find("RSA") != string::npos &&
158                    transform.find("OAEP") != string::npos) {
159             if (securityLevel_ == SecurityLevel::STRONGBOX) {
160                 return Digest::SHA_2_256;
161             } else {
162                 return Digest::SHA1;
163             }
164         } else if (transform.find("Hmac") != string::npos) {
165             return Digest::SHA_2_256;
166         }
167         return Digest::NONE;
168     }
169 
getDigestString(string transform)170     string getDigestString(string transform) {
171         if (transform.find("MD5") != string::npos) {
172             return "MD5";
173         } else if (transform.find("SHA1") != string::npos ||
174                    transform.find("SHA-1") != string::npos) {
175             return "SHA1";
176         } else if (transform.find("SHA224") != string::npos) {
177             return "SHA_2_224";
178         } else if (transform.find("SHA256") != string::npos) {
179             return "SHA_2_256";
180         } else if (transform.find("SHA384") != string::npos) {
181             return "SHA_2_384";
182         } else if (transform.find("SHA512") != string::npos) {
183             return "SHA_2_512";
184         } else if (transform.find("RSA") != string::npos &&
185                    transform.find("OAEP") != string::npos) {
186             if (securityLevel_ == SecurityLevel::STRONGBOX) {
187                 return "SHA_2_256";
188             } else {
189                 return "SHA1";
190             }
191         } else if (transform.find("Hmac") != string::npos) {
192             return "SHA_2_256";
193         }
194         return "";
195     }
196 
getCurveFromLength(int keySize)197     optional<EcCurve> getCurveFromLength(int keySize) {
198         switch (keySize) {
199             case 224:
200                 return EcCurve::P_224;
201             case 256:
202                 return EcCurve::P_256;
203             case 384:
204                 return EcCurve::P_384;
205             case 521:
206                 return EcCurve::P_521;
207             default:
208                 return std::nullopt;
209         }
210     }
211 
GenerateKey(string transform,int keySize,bool sign=false)212     bool GenerateKey(string transform, int keySize, bool sign = false) {
213         if (transform == key_transform_) {
214             return true;
215         } else if (key_transform_ != "") {
216             // Deleting old key first
217             key_transform_ = "";
218             if (DeleteKey() != ErrorCode::OK) {
219                 return false;
220             }
221         }
222         std::optional<Algorithm> algorithm = getAlgorithm(transform);
223         if (!algorithm) {
224             std::cerr << "Error: invalid algorithm " << transform << std::endl;
225             return false;
226         }
227         key_transform_ = transform;
228         AuthorizationSetBuilder authSet = AuthorizationSetBuilder()
229                                                   .Authorization(TAG_NO_AUTH_REQUIRED)
230                                                   .Authorization(TAG_PURPOSE, KeyPurpose::ENCRYPT)
231                                                   .Authorization(TAG_PURPOSE, KeyPurpose::DECRYPT)
232                                                   .Authorization(TAG_PURPOSE, KeyPurpose::SIGN)
233                                                   .Authorization(TAG_PURPOSE, KeyPurpose::VERIFY)
234                                                   .Authorization(TAG_KEY_SIZE, keySize)
235                                                   .Authorization(TAG_ALGORITHM, algorithm.value())
236                                                   .Digest(getDigest(transform))
237                                                   .Padding(getPadding(transform, sign));
238         std::optional<BlockMode> blockMode = getBlockMode(transform);
239         if (blockMode) {
240             authSet.BlockMode(blockMode.value());
241             if (blockMode == BlockMode::GCM) {
242                 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
243             }
244         }
245         if (algorithm == Algorithm::HMAC) {
246             authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
247         }
248         if (algorithm == Algorithm::RSA) {
249             authSet.Authorization(TAG_RSA_PUBLIC_EXPONENT, 65537U);
250             authSet.SetDefaultValidity();
251         }
252         if (algorithm == Algorithm::EC) {
253             authSet.SetDefaultValidity();
254             std::optional<EcCurve> curve = getCurveFromLength(keySize);
255             if (!curve) {
256                 std::cerr << "Error: invalid EC-Curve from size " << keySize << std::endl;
257                 return false;
258             }
259             authSet.Authorization(TAG_EC_CURVE, curve.value());
260         }
261         error_ = GenerateKey(authSet);
262         return error_ == ErrorCode::OK;
263     }
264 
getOperationParams(string transform,bool sign=false)265     AuthorizationSet getOperationParams(string transform, bool sign = false) {
266         AuthorizationSetBuilder builder = AuthorizationSetBuilder()
267                                                   .Padding(getPadding(transform, sign))
268                                                   .Digest(getDigest(transform));
269         std::optional<BlockMode> blockMode = getBlockMode(transform);
270         if (sign && (transform.find("Hmac") != string::npos)) {
271             builder.Authorization(TAG_MAC_LENGTH, 128);
272         }
273         if (blockMode) {
274             builder.BlockMode(*blockMode);
275             if (blockMode == BlockMode::GCM) {
276                 builder.Authorization(TAG_MAC_LENGTH, 128);
277             }
278         }
279         return std::move(builder);
280     }
281 
Process(const string & message,const string & signature="")282     optional<string> Process(const string& message, const string& signature = "") {
283         ErrorCode result;
284 
285         string output;
286         result = Finish(message, signature, &output);
287         if (result != ErrorCode::OK) {
288             error_ = result;
289             return {};
290         }
291         return output;
292     }
293 
DeleteKey()294     ErrorCode DeleteKey() {
295         Status result = keymint_->deleteKey(key_blob_);
296         key_blob_ = vector<uint8_t>();
297         key_transform_ = "";
298         return GetReturnErrorCode(result);
299     }
300 
Begin(KeyPurpose purpose,const AuthorizationSet & in_params,AuthorizationSet * out_params)301     ErrorCode Begin(KeyPurpose purpose, const AuthorizationSet& in_params,
302                     AuthorizationSet* out_params) {
303         Status result;
304         BeginResult out;
305         result = keymint_->begin(purpose, key_blob_, in_params.vector_data(), std::nullopt, &out);
306         if (result.isOk()) {
307             *out_params = out.params;
308             op_ = out.operation;
309         }
310         return GetReturnErrorCode(result);
311     }
312 
313     /* Copied the function LocalRsaEncryptMessage from
314      * hardware/interfaces/security/keymint/aidl/vts/functional/KeyMintAidlTestBase.cpp in VTS.
315      * Replaced asserts with the condition check and return false in case of failure condition.
316      * Require return value to skip the benchmark test case from further execution in case
317      * LocalRsaEncryptMessage fails.
318      */
LocalRsaEncryptMessage(const string & message,const AuthorizationSet & params)319     optional<string> LocalRsaEncryptMessage(const string& message, const AuthorizationSet& params) {
320         // Retrieve the public key from the leaf certificate.
321         if (cert_chain_.empty()) {
322             std::cerr << "Local RSA encrypt Error: invalid cert_chain_" << std::endl;
323             return "Failure";
324         }
325         X509_Ptr key_cert(parse_cert_blob(cert_chain_[0].encodedCertificate));
326         EVP_PKEY_Ptr pub_key(X509_get_pubkey(key_cert.get()));
327         RSA_Ptr rsa(EVP_PKEY_get1_RSA(const_cast<EVP_PKEY*>(pub_key.get())));
328 
329         // Retrieve relevant tags.
330         Digest digest = Digest::NONE;
331         Digest mgf_digest = Digest::SHA1;
332         PaddingMode padding = PaddingMode::NONE;
333 
334         auto digest_tag = params.GetTagValue(TAG_DIGEST);
335         if (digest_tag.has_value()) digest = digest_tag.value();
336         auto pad_tag = params.GetTagValue(TAG_PADDING);
337         if (pad_tag.has_value()) padding = pad_tag.value();
338         auto mgf_tag = params.GetTagValue(TAG_RSA_OAEP_MGF_DIGEST);
339         if (mgf_tag.has_value()) mgf_digest = mgf_tag.value();
340 
341         const EVP_MD* md = openssl_digest(digest);
342         const EVP_MD* mgf_md = openssl_digest(mgf_digest);
343 
344         // Set up encryption context.
345         EVP_PKEY_CTX_Ptr ctx(EVP_PKEY_CTX_new(pub_key.get(), /* engine= */ nullptr));
346         if (EVP_PKEY_encrypt_init(ctx.get()) <= 0) {
347             std::cerr << "Local RSA encrypt Error: Encryption init failed" << std::endl;
348             return "Failure";
349         }
350 
351         int rc = -1;
352         switch (padding) {
353             case PaddingMode::NONE:
354                 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_NO_PADDING);
355                 break;
356             case PaddingMode::RSA_PKCS1_1_5_ENCRYPT:
357                 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_PADDING);
358                 break;
359             case PaddingMode::RSA_OAEP:
360                 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING);
361                 break;
362             default:
363                 break;
364         }
365         if (rc <= 0) {
366             std::cerr << "Local RSA encrypt Error: Set padding failed" << std::endl;
367             return "Failure";
368         }
369         if (padding == PaddingMode::RSA_OAEP) {
370             if (!EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), md)) {
371                 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
372                           << std::endl;
373                 return "Failure";
374             }
375             if (!EVP_PKEY_CTX_set_rsa_mgf1_md(ctx.get(), mgf_md)) {
376                 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
377                           << std::endl;
378                 return "Failure";
379             }
380         }
381 
382         // Determine output size.
383         size_t outlen;
384         if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen,
385                              reinterpret_cast<const uint8_t*>(message.data()),
386                              message.size()) <= 0) {
387             std::cerr << "Local RSA encrypt Error: Determine output size failed: "
388                       << ERR_peek_last_error() << std::endl;
389             return "Failure";
390         }
391 
392         // Left-zero-pad the input if necessary.
393         const uint8_t* to_encrypt = reinterpret_cast<const uint8_t*>(message.data());
394         size_t to_encrypt_len = message.size();
395 
396         std::unique_ptr<string> zero_padded_message;
397         if (padding == PaddingMode::NONE && to_encrypt_len < outlen) {
398             zero_padded_message.reset(new string(outlen, '\0'));
399             memcpy(zero_padded_message->data() + (outlen - to_encrypt_len), message.data(),
400                    message.size());
401             to_encrypt = reinterpret_cast<const uint8_t*>(zero_padded_message->data());
402             to_encrypt_len = outlen;
403         }
404 
405         // Do the encryption.
406         string output(outlen, '\0');
407         if (EVP_PKEY_encrypt(ctx.get(), reinterpret_cast<uint8_t*>(output.data()), &outlen,
408                              to_encrypt, to_encrypt_len) <= 0) {
409             std::cerr << "Local RSA encrypt Error: Encryption failed: " << ERR_peek_last_error()
410                       << std::endl;
411             return "Failure";
412         }
413         return output;
414     }
415 
416     SecurityLevel securityLevel_;
417     string name_;
418 
419   private:
GenerateKey(const AuthorizationSet & key_desc,const optional<AttestationKey> & attest_key=std::nullopt)420     ErrorCode GenerateKey(const AuthorizationSet& key_desc,
421                           const optional<AttestationKey>& attest_key = std::nullopt) {
422         key_blob_.clear();
423         cert_chain_.clear();
424         KeyCreationResult creationResult;
425         Status result = keymint_->generateKey(key_desc.vector_data(), attest_key, &creationResult);
426         if (result.isOk()) {
427             key_blob_ = std::move(creationResult.keyBlob);
428             cert_chain_ = std::move(creationResult.certificateChain);
429             creationResult.keyCharacteristics.clear();
430         }
431         return GetReturnErrorCode(result);
432     }
433 
InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint)434     void InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint) {
435         if (!keyMint) {
436             std::cerr << "Trying initialize nullptr in InitializeKeyMint" << std::endl;
437             return;
438         }
439         keymint_ = std::move(keyMint);
440         KeyMintHardwareInfo info;
441         Status result = keymint_->getHardwareInfo(&info);
442         if (!result.isOk()) {
443             std::cerr << "InitializeKeyMint: getHardwareInfo failed with "
444                       << result.getServiceSpecificError() << std::endl;
445         }
446         securityLevel_ = info.securityLevel;
447         name_.assign(info.keyMintName.begin(), info.keyMintName.end());
448     }
449 
Finish(const string & input,const string & signature,string * output)450     ErrorCode Finish(const string& input, const string& signature, string* output) {
451         if (!op_) {
452             std::cerr << "Finish: Operation is nullptr" << std::endl;
453             return ErrorCode::UNEXPECTED_NULL_POINTER;
454         }
455 
456         vector<uint8_t> oPut;
457         Status result =
458                 op_->finish(vector<uint8_t>(input.begin(), input.end()),
459                             vector<uint8_t>(signature.begin(), signature.end()), {} /* authToken */,
460                             {} /* timestampToken */, {} /* confirmationToken */, &oPut);
461 
462         if (result.isOk()) output->append(oPut.begin(), oPut.end());
463 
464         op_.reset();
465         return GetReturnErrorCode(result);
466     }
467 
Update(const string & input,string * output)468     ErrorCode Update(const string& input, string* output) {
469         Status result;
470         if (!op_) {
471             std::cerr << "Update: Operation is nullptr" << std::endl;
472             return ErrorCode::UNEXPECTED_NULL_POINTER;
473         }
474 
475         std::vector<uint8_t> o_put;
476         result = op_->update(vector<uint8_t>(input.begin(), input.end()), {} /* authToken */,
477                              {} /* timestampToken */, &o_put);
478 
479         if (result.isOk() && output) *output = {o_put.begin(), o_put.end()};
480         return GetReturnErrorCode(result);
481     }
482 
GetReturnErrorCode(const Status & result)483     ErrorCode GetReturnErrorCode(const Status& result) {
484         error_ = static_cast<ErrorCode>(result.getServiceSpecificError());
485         if (result.isOk()) return ErrorCode::OK;
486 
487         if (result.getExceptionCode() == EX_SERVICE_SPECIFIC) {
488             return static_cast<ErrorCode>(result.getServiceSpecificError());
489         }
490 
491         return ErrorCode::UNKNOWN_ERROR;
492     }
493 
parse_cert_blob(const vector<uint8_t> & blob)494     X509_Ptr parse_cert_blob(const vector<uint8_t>& blob) {
495         const uint8_t* p = blob.data();
496         return X509_Ptr(d2i_X509(nullptr /* allocate new */, &p, blob.size()));
497     }
498 
499     std::shared_ptr<IKeyMintOperation> op_;
500     vector<Certificate> cert_chain_;
501     vector<uint8_t> key_blob_;
502     vector<KeyCharacteristics> key_characteristics_;
503     std::shared_ptr<IKeyMintDevice> keymint_;
504     std::vector<string> message_cache_;
505     std::string key_transform_;
506     ErrorCode error_;
507 };
508 
509 KeyMintBenchmarkTest* keymintTest;
510 
settings(benchmark::internal::Benchmark * benchmark)511 static void settings(benchmark::internal::Benchmark* benchmark) {
512     benchmark->Unit(benchmark::kMillisecond);
513 }
514 
addDefaultLabel(benchmark::State & state)515 static void addDefaultLabel(benchmark::State& state) {
516     std::string secLevel;
517     switch (keymintTest->securityLevel_) {
518         case SecurityLevel::STRONGBOX:
519             secLevel = "STRONGBOX";
520             break;
521         case SecurityLevel::SOFTWARE:
522             secLevel = "SOFTWARE";
523             break;
524         case SecurityLevel::TRUSTED_ENVIRONMENT:
525             secLevel = "TEE";
526             break;
527         case SecurityLevel::KEYSTORE:
528             secLevel = "KEYSTORE";
529             break;
530     }
531     state.SetLabel("hardware_name:" + keymintTest->name_ + " sec_level:" + secLevel);
532 }
533 
534 // clang-format off
535 #define BENCHMARK_KM(func, transform, keySize) \
536     BENCHMARK_CAPTURE(func, transform/keySize, #transform "/" #keySize, keySize)->Apply(settings);
537 #define BENCHMARK_KM_MSG(func, transform, keySize, msgSize)                                      \
538     BENCHMARK_CAPTURE(func, transform/keySize/msgSize, #transform "/" #keySize "/" #msgSize, \
539                       keySize, msgSize)                                                          \
540             ->Apply(settings);
541 
542 #define BENCHMARK_KM_ALL_MSGS(func, transform, keySize)             \
543     BENCHMARK_KM_MSG(func, transform, keySize, SMALL_MESSAGE_SIZE)  \
544     BENCHMARK_KM_MSG(func, transform, keySize, MEDIUM_MESSAGE_SIZE) \
545     BENCHMARK_KM_MSG(func, transform, keySize, LARGE_MESSAGE_SIZE)
546 
547 #define BENCHMARK_KM_CIPHER(transform, keySize, msgSize)   \
548     BENCHMARK_KM_MSG(encrypt, transform, keySize, msgSize) \
549     BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
550 
551 // Skip public key operations as they are not supported in KeyMint.
552 #define BENCHMARK_KM_ASYM_CIPHER(transform, keySize, msgSize)   \
553     BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
554 
555 #define BENCHMARK_KM_CIPHER_ALL_MSGS(transform, keySize) \
556     BENCHMARK_KM_ALL_MSGS(encrypt, transform, keySize)   \
557     BENCHMARK_KM_ALL_MSGS(decrypt, transform, keySize)
558 
559 #define BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, keySize) \
560     BENCHMARK_KM_ALL_MSGS(sign, transform, keySize)         \
561     BENCHMARK_KM_ALL_MSGS(verify, transform, keySize)
562 
563 // Skip public key operations as they are not supported in KeyMint.
564 #define BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, keySize) \
565     BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
566     // clang-format on
567 
568 /*
569  * ============= KeyGen TESTS ==================
570  */
571 
isValidSBKeySize(string transform,int keySize)572 static bool isValidSBKeySize(string transform, int keySize) {
573     std::optional<Algorithm> algorithm = keymintTest->getAlgorithm(transform);
574     switch (algorithm.value()) {
575         case Algorithm::AES:
576             return (keySize == 128 || keySize == 256);
577         case Algorithm::HMAC:
578             return (keySize % 8 == 0 && keySize >= 64 && keySize <= 512);
579         case Algorithm::TRIPLE_DES:
580             return (keySize == 168);
581         case Algorithm::RSA:
582             return (keySize == 2048);
583         case Algorithm::EC:
584             return (keySize == 256);
585     }
586     return false;
587 }
588 
keygen(benchmark::State & state,string transform,int keySize)589 static void keygen(benchmark::State& state, string transform, int keySize) {
590     // Skip the test for unsupported key size in StrongBox
591     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
592         !isValidSBKeySize(transform, keySize)) {
593         state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
594                              " is not supported in StrongBox for algorithm: " +
595                              keymintTest->getAlgorithmString(transform))
596                                     .c_str());
597         return;
598     }
599     addDefaultLabel(state);
600     for (auto _ : state) {
601         if (!keymintTest->GenerateKey(transform, keySize)) {
602             state.SkipWithError(
603                     ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
604         }
605         state.PauseTiming();
606 
607         keymintTest->DeleteKey();
608         state.ResumeTiming();
609     }
610 }
611 
612 BENCHMARK_KM(keygen, AES, 128);
613 BENCHMARK_KM(keygen, AES, 256);
614 
615 BENCHMARK_KM(keygen, RSA, 2048);
616 BENCHMARK_KM(keygen, RSA, 3072);
617 BENCHMARK_KM(keygen, RSA, 4096);
618 
619 BENCHMARK_KM(keygen, EC, 224);
620 BENCHMARK_KM(keygen, EC, 256);
621 BENCHMARK_KM(keygen, EC, 384);
622 BENCHMARK_KM(keygen, EC, 521);
623 
624 BENCHMARK_KM(keygen, DESede, 168);
625 
626 BENCHMARK_KM(keygen, Hmac, 64);
627 BENCHMARK_KM(keygen, Hmac, 128);
628 BENCHMARK_KM(keygen, Hmac, 256);
629 BENCHMARK_KM(keygen, Hmac, 512);
630 
631 /*
632  * ============= SIGNATURE TESTS ==================
633  */
sign(benchmark::State & state,string transform,int keySize,int msgSize)634 static void sign(benchmark::State& state, string transform, int keySize, int msgSize) {
635     // Skip the test for unsupported key size or unsupported digest in StrongBox
636     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
637         if (!isValidSBKeySize(transform, keySize)) {
638             state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
639                                  " is not supported in StrongBox for algorithm: " +
640                                  keymintTest->getAlgorithmString(transform))
641                                         .c_str());
642             return;
643         }
644         if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
645             state.SkipWithError(
646                     ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
647                      " is not supported in StrongBox")
648                             .c_str());
649             return;
650         }
651     }
652     addDefaultLabel(state);
653     if (!keymintTest->GenerateKey(transform, keySize, true)) {
654         state.SkipWithError(
655                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
656         return;
657     }
658 
659     auto in_params = keymintTest->getOperationParams(transform, true);
660     AuthorizationSet out_params;
661     string message = keymintTest->GenerateMessage(msgSize);
662 
663     for (auto _ : state) {
664         state.PauseTiming();
665         ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
666         if (error != ErrorCode::OK) {
667             state.SkipWithError(
668                     ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
669             return;
670         }
671         state.ResumeTiming();
672         out_params.Clear();
673         if (!keymintTest->Process(message)) {
674             state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
675             break;
676         }
677     }
678 }
679 
verify(benchmark::State & state,string transform,int keySize,int msgSize)680 static void verify(benchmark::State& state, string transform, int keySize, int msgSize) {
681     // Skip the test for unsupported key size or unsupported digest in StrongBox
682     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
683         if (!isValidSBKeySize(transform, keySize)) {
684             state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
685                                  " is not supported in StrongBox for algorithm: " +
686                                  keymintTest->getAlgorithmString(transform))
687                                         .c_str());
688             return;
689         }
690         if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
691             state.SkipWithError(
692                     ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
693                      " is not supported in StrongBox")
694                             .c_str());
695             return;
696         }
697     }
698     addDefaultLabel(state);
699     if (!keymintTest->GenerateKey(transform, keySize, true)) {
700         state.SkipWithError(
701                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
702         return;
703     }
704     AuthorizationSet out_params;
705     auto in_params = keymintTest->getOperationParams(transform, true);
706     string message = keymintTest->GenerateMessage(msgSize);
707     ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
708     if (error != ErrorCode::OK) {
709         state.SkipWithError(
710                 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
711         return;
712     }
713     std::optional<string> signature = keymintTest->Process(message);
714     if (!signature) {
715         state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
716         return;
717     }
718     out_params.Clear();
719     if (transform.find("Hmac") != string::npos) {
720         in_params = keymintTest->getOperationParams(transform, false);
721     }
722     for (auto _ : state) {
723         state.PauseTiming();
724         error = keymintTest->Begin(KeyPurpose::VERIFY, in_params, &out_params);
725         if (error != ErrorCode::OK) {
726             state.SkipWithError(
727                     ("Verify begin error, " + std::to_string(keymintTest->getError())).c_str());
728             return;
729         }
730         state.ResumeTiming();
731         if (!keymintTest->Process(message, *signature)) {
732             state.SkipWithError(
733                     ("Verify error, " + std::to_string(keymintTest->getError())).c_str());
734             break;
735         }
736     }
737 }
738 
739 // clang-format off
740 #define BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(transform) \
741     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 64)      \
742     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 128)     \
743     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256)     \
744     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 512)
745 
746 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA1)
747 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
748 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA224)
749 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
750 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA384)
751 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA512)
752 
753 #define BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(transform) \
754     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 224)      \
755     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 256)      \
756     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 384)      \
757     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 521)
758 
759 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(NONEwithECDSA);
760 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA1withECDSA);
761 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA224withECDSA);
762 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA256withECDSA);
763 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA384withECDSA);
764 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA512withECDSA);
765 
766 #define BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(transform) \
767     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 2048)   \
768     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 3072)   \
769     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 4096)
770 
771 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA);
772 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA);
773 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA);
774 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA256withRSA);
775 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA);
776 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA);
777 
778 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA/PSS);
779 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA/PSS);
780 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA/PSS);
781 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA/PSS);
782 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA/PSS);
783 
784 // clang-format on
785 
786 /*
787  * ============= CIPHER TESTS ==================
788  */
789 
encrypt(benchmark::State & state,string transform,int keySize,int msgSize)790 static void encrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
791     // Skip the test for unsupported key size in StrongBox
792     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
793         (!isValidSBKeySize(transform, keySize))) {
794         state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
795                              " is not supported in StrongBox for algorithm: " +
796                              keymintTest->getAlgorithmString(transform))
797                                     .c_str());
798         return;
799     }
800     addDefaultLabel(state);
801     if (!keymintTest->GenerateKey(transform, keySize)) {
802         state.SkipWithError(
803                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
804         return;
805     }
806     auto in_params = keymintTest->getOperationParams(transform);
807     AuthorizationSet out_params;
808     string message = keymintTest->GenerateMessage(msgSize);
809 
810     for (auto _ : state) {
811         state.PauseTiming();
812         auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
813         if (error != ErrorCode::OK) {
814             state.SkipWithError(
815                     ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
816             return;
817         }
818         out_params.Clear();
819         state.ResumeTiming();
820         if (!keymintTest->Process(message)) {
821             state.SkipWithError(
822                     ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
823             break;
824         }
825     }
826 }
827 
decrypt(benchmark::State & state,string transform,int keySize,int msgSize)828 static void decrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
829     // Skip the test for unsupported key size in StrongBox
830     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
831         (!isValidSBKeySize(transform, keySize))) {
832         state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
833                              " is not supported in StrongBox for algorithm: " +
834                              keymintTest->getAlgorithmString(transform))
835                                     .c_str());
836         return;
837     }
838     addDefaultLabel(state);
839     if (!keymintTest->GenerateKey(transform, keySize)) {
840         state.SkipWithError(
841                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
842         return;
843     }
844     AuthorizationSet out_params;
845     AuthorizationSet in_params = keymintTest->getOperationParams(transform);
846     string message = keymintTest->GenerateMessage(msgSize);
847     optional<string> encryptedMessage;
848 
849     if (keymintTest->getAlgorithm(transform).value() == Algorithm::RSA) {
850         // Public key operation not supported, doing local Encryption
851         encryptedMessage = keymintTest->LocalRsaEncryptMessage(message, in_params);
852         if ((keySize / 8) != (*encryptedMessage).size()) {
853             state.SkipWithError("Local Encryption falied");
854             return;
855         }
856     } else {
857         auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
858         if (error != ErrorCode::OK) {
859             state.SkipWithError(
860                     ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
861             return;
862         }
863         encryptedMessage = keymintTest->Process(message);
864         if (!encryptedMessage) {
865             state.SkipWithError(
866                     ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
867             return;
868         }
869         in_params.push_back(out_params);
870         out_params.Clear();
871     }
872     for (auto _ : state) {
873         state.PauseTiming();
874         auto error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
875         if (error != ErrorCode::OK) {
876             state.SkipWithError(
877                     ("Decryption begin error, " + std::to_string(keymintTest->getError())).c_str());
878             return;
879         }
880         state.ResumeTiming();
881         if (!keymintTest->Process(*encryptedMessage)) {
882             state.SkipWithError(
883                     ("Decryption error, " + std::to_string(keymintTest->getError())).c_str());
884             break;
885         }
886     }
887 }
888 
889 // clang-format off
890 // AES
891 #define BENCHMARK_KM_CIPHER_ALL_AES_KEYS(transform) \
892     BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 128)    \
893     BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 256)
894 
895 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/NoPadding);
896 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/PKCS7Padding);
897 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CTR/NoPadding);
898 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/NoPadding);
899 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/PKCS7Padding);
900 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/GCM/NoPadding);
901 
902 // Triple DES
903 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/NoPadding, 168);
904 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/PKCS7Padding, 168);
905 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/NoPadding, 168);
906 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/PKCS7Padding, 168);
907 
908 #define BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(transform, msgSize) \
909     BENCHMARK_KM_ASYM_CIPHER(transform, 2048, msgSize)            \
910     BENCHMARK_KM_ASYM_CIPHER(transform, 3072, msgSize)            \
911     BENCHMARK_KM_ASYM_CIPHER(transform, 4096, msgSize)
912 
913 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/NoPadding, SMALL_MESSAGE_SIZE);
914 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/PKCS1Padding, SMALL_MESSAGE_SIZE);
915 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/OAEPPadding, SMALL_MESSAGE_SIZE);
916 
917 // clang-format on
918 }  // namespace aidl::android::hardware::security::keymint::test
919 
main(int argc,char ** argv)920 int main(int argc, char** argv) {
921     ::benchmark::Initialize(&argc, argv);
922     base::CommandLine::Init(argc, argv);
923     base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
924     auto service_name = command_line->GetSwitchValueASCII("service_name");
925     if (service_name.empty()) {
926         service_name =
927                 std::string(
928                         aidl::android::hardware::security::keymint::IKeyMintDevice::descriptor) +
929                 "/default";
930     }
931     std::cerr << service_name << std::endl;
932     aidl::android::hardware::security::keymint::test::keymintTest =
933             aidl::android::hardware::security::keymint::test::KeyMintBenchmarkTest::newInstance(
934                     service_name.c_str());
935     if (!aidl::android::hardware::security::keymint::test::keymintTest) {
936         return 1;
937     }
938     ::benchmark::RunSpecifiedBenchmarks();
939 }
940