1  // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2  // Copyright by contributors to this project.
3  // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4  
5  use alloc::vec;
6  use alloc::vec::Vec;
7  use core::{
8      fmt::{self, Debug},
9      ops::Deref,
10  };
11  use mls_rs_core::crypto::CipherSuiteProvider;
12  use zeroize::Zeroizing;
13  
14  #[cfg(feature = "psk")]
15  use mls_rs_codec::MlsEncode;
16  
17  #[cfg(feature = "psk")]
18  use mls_rs_core::{error::IntoAnyError, psk::PreSharedKey};
19  
20  #[cfg(feature = "psk")]
21  use crate::{
22      client::MlsError,
23      group::key_schedule::kdf_expand_with_label,
24      psk::{PSKLabel, PreSharedKeyID},
25  };
26  
27  #[cfg(feature = "psk")]
28  #[derive(Clone)]
29  pub(crate) struct PskSecretInput {
30      pub id: PreSharedKeyID,
31      pub psk: PreSharedKey,
32  }
33  
34  #[derive(PartialEq, Eq, Clone)]
35  pub(crate) struct PskSecret(Zeroizing<Vec<u8>>);
36  
37  impl Debug for PskSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result38      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39          mls_rs_core::debug::pretty_bytes(&self.0)
40              .named("PskSecret")
41              .fmt(f)
42      }
43  }
44  
45  #[cfg(test)]
46  impl From<Vec<u8>> for PskSecret {
from(value: Vec<u8>) -> Self47      fn from(value: Vec<u8>) -> Self {
48          PskSecret(Zeroizing::new(value))
49      }
50  }
51  
52  impl Deref for PskSecret {
53      type Target = [u8];
54  
deref(&self) -> &Self::Target55      fn deref(&self) -> &Self::Target {
56          &self.0
57      }
58  }
59  
60  impl PskSecret {
new<P: CipherSuiteProvider>(provider: &P) -> PskSecret61      pub(crate) fn new<P: CipherSuiteProvider>(provider: &P) -> PskSecret {
62          PskSecret(Zeroizing::new(vec![0u8; provider.kdf_extract_size()]))
63      }
64  
65      #[cfg(feature = "psk")]
66      #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
calculate<P: CipherSuiteProvider>( input: &[PskSecretInput], cipher_suite_provider: &P, ) -> Result<PskSecret, MlsError>67      pub(crate) async fn calculate<P: CipherSuiteProvider>(
68          input: &[PskSecretInput],
69          cipher_suite_provider: &P,
70      ) -> Result<PskSecret, MlsError> {
71          let len = u16::try_from(input.len()).map_err(|_| MlsError::TooManyPskIds)?;
72          let mut psk_secret = PskSecret::new(cipher_suite_provider);
73  
74          for (index, psk_secret_input) in input.iter().enumerate() {
75              let index = index as u16;
76  
77              let label = PSKLabel {
78                  id: &psk_secret_input.id,
79                  index,
80                  count: len,
81              };
82  
83              let psk_extracted = cipher_suite_provider
84                  .kdf_extract(
85                      &vec![0; cipher_suite_provider.kdf_extract_size()],
86                      &psk_secret_input.psk,
87                  )
88                  .await
89                  .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
90  
91              let psk_input = kdf_expand_with_label(
92                  cipher_suite_provider,
93                  &psk_extracted,
94                  b"derived psk",
95                  &label.mls_encode_to_vec()?,
96                  None,
97              )
98              .await?;
99  
100              psk_secret = cipher_suite_provider
101                  .kdf_extract(&psk_input, &psk_secret)
102                  .await
103                  .map(PskSecret)
104                  .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
105          }
106  
107          Ok(psk_secret)
108      }
109  }
110  
111  #[cfg(feature = "psk")]
112  #[cfg(test)]
113  mod tests {
114      use alloc::vec::Vec;
115      #[cfg(not(mls_build_async))]
116      use core::iter;
117      use serde::{Deserialize, Serialize};
118  
119      use crate::{
120          crypto::test_utils::try_test_cipher_suite_provider,
121          psk::ExternalPskId,
122          psk::{JustPreSharedKeyID, PreSharedKeyID, PskNonce},
123          CipherSuiteProvider,
124      };
125  
126      #[cfg(not(mls_build_async))]
127      use crate::{
128          crypto::test_utils::test_cipher_suite_provider, psk::test_utils::make_external_psk_id,
129          CipherSuite,
130      };
131  
132      use super::{PskSecret, PskSecretInput};
133  
134      #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
135      struct PskInfo {
136          #[serde(with = "hex::serde")]
137          id: Vec<u8>,
138          #[serde(with = "hex::serde")]
139          psk: Vec<u8>,
140          #[serde(with = "hex::serde")]
141          nonce: Vec<u8>,
142      }
143  
144      impl From<PskInfo> for PskSecretInput {
from(info: PskInfo) -> Self145          fn from(info: PskInfo) -> Self {
146              let id = PreSharedKeyID {
147                  key_id: JustPreSharedKeyID::External(ExternalPskId::new(info.id)),
148                  psk_nonce: PskNonce(info.nonce),
149              };
150  
151              PskSecretInput {
152                  id,
153                  psk: info.psk.into(),
154              }
155          }
156      }
157  
158      #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
159      struct TestScenario {
160          cipher_suite: u16,
161          psks: Vec<PskInfo>,
162          #[serde(with = "hex::serde")]
163          psk_secret: Vec<u8>,
164      }
165  
166      impl TestScenario {
167          #[cfg_attr(coverage_nightly, coverage(off))]
168          #[cfg(not(mls_build_async))]
make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo>169          fn make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo> {
170              iter::repeat_with(
171                  #[cfg_attr(coverage_nightly, coverage(off))]
172                  || PskInfo {
173                      id: make_external_psk_id(cs).to_vec(),
174                      psk: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
175                      nonce: crate::psk::test_utils::make_nonce(cs.cipher_suite()).0,
176                  },
177              )
178              .take(n)
179              .collect::<Vec<_>>()
180          }
181  
182          #[cfg(not(mls_build_async))]
183          #[cfg_attr(coverage_nightly, coverage(off))]
generate() -> Vec<TestScenario>184          fn generate() -> Vec<TestScenario> {
185              CipherSuite::all()
186                  .flat_map(
187                      #[cfg_attr(coverage_nightly, coverage(off))]
188                      |cs| (1..=10).map(move |n| (cs, n)),
189                  )
190                  .map(
191                      #[cfg_attr(coverage_nightly, coverage(off))]
192                      |(cs, n)| {
193                          let provider = test_cipher_suite_provider(cs);
194                          let psks = Self::make_psk_list(&provider, n);
195                          let psk_secret = Self::compute_psk_secret(&provider, psks.clone());
196                          TestScenario {
197                              cipher_suite: cs.into(),
198                              psks: psks.to_vec(),
199                              psk_secret: psk_secret.to_vec(),
200                          }
201                      },
202                  )
203                  .collect()
204          }
205  
206          #[cfg(mls_build_async)]
generate() -> Vec<TestScenario>207          fn generate() -> Vec<TestScenario> {
208              panic!("Tests cannot be generated in async mode");
209          }
210  
211          #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
compute_psk_secret<P: CipherSuiteProvider>( provider: &P, psks: Vec<PskInfo>, ) -> PskSecret212          async fn compute_psk_secret<P: CipherSuiteProvider>(
213              provider: &P,
214              psks: Vec<PskInfo>,
215          ) -> PskSecret {
216              let input = psks
217                  .into_iter()
218                  .map(PskSecretInput::from)
219                  .collect::<Vec<_>>();
220  
221              PskSecret::calculate(&input, provider).await.unwrap()
222          }
223      }
224  
225      #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
expected_psk_secret_is_produced()226      async fn expected_psk_secret_is_produced() {
227          let scenarios: Vec<TestScenario> =
228              load_test_case_json!(psk_secret, TestScenario::generate());
229  
230          for scenario in scenarios {
231              if let Some(provider) = try_test_cipher_suite_provider(scenario.cipher_suite) {
232                  let computed =
233                      TestScenario::compute_psk_secret(&provider, scenario.psks.clone()).await;
234  
235                  assert_eq!(scenario.psk_secret, computed.to_vec());
236              }
237          }
238      }
239  }
240