1  // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2  // Copyright by contributors to this project.
3  // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4  
5  use alloc::vec::Vec;
6  use core::fmt::{self, Debug};
7  use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
8  use mls_rs_core::extension::{ExtensionType, MlsCodecExtension};
9  
10  use mls_rs_core::{group::ProposalType, identity::CredentialType};
11  
12  #[cfg(feature = "by_ref_proposal")]
13  use mls_rs_core::{
14      extension::ExtensionList,
15      identity::{IdentityProvider, SigningIdentity},
16      time::MlsTime,
17  };
18  
19  use crate::group::ExportedTree;
20  
21  use mls_rs_core::crypto::HpkePublicKey;
22  
23  /// Application specific identifier.
24  ///
25  /// A custom application level identifier that can be optionally stored
26  /// within the `leaf_node_extensions` of a group [Member](crate::group::Member).
27  #[cfg_attr(
28      all(feature = "ffi", not(test)),
29      safer_ffi_gen::ffi_type(clone, opaque)
30  )]
31  #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
32  pub struct ApplicationIdExt {
33      /// Application level identifier presented by this extension.
34      #[mls_codec(with = "mls_rs_codec::byte_vec")]
35      pub identifier: Vec<u8>,
36  }
37  
38  impl Debug for ApplicationIdExt {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result39      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40          f.debug_struct("ApplicationIdExt")
41              .field(
42                  "identifier",
43                  &mls_rs_core::debug::pretty_bytes(&self.identifier),
44              )
45              .finish()
46      }
47  }
48  
49  #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
50  impl ApplicationIdExt {
51      /// Create a new application level identifier extension.
new(identifier: Vec<u8>) -> Self52      pub fn new(identifier: Vec<u8>) -> Self {
53          ApplicationIdExt { identifier }
54      }
55  
56      /// Get the application level identifier presented by this extension.
57      #[cfg(feature = "ffi")]
identifier(&self) -> &[u8]58      pub fn identifier(&self) -> &[u8] {
59          &self.identifier
60      }
61  }
62  
63  impl MlsCodecExtension for ApplicationIdExt {
extension_type() -> ExtensionType64      fn extension_type() -> ExtensionType {
65          ExtensionType::APPLICATION_ID
66      }
67  }
68  
69  /// Representation of an MLS ratchet tree.
70  ///
71  /// Used to provide new members
72  /// a copy of the current group state in-band.
73  #[cfg_attr(
74      all(feature = "ffi", not(test)),
75      safer_ffi_gen::ffi_type(clone, opaque)
76  )]
77  #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
78  pub struct RatchetTreeExt {
79      pub tree_data: ExportedTree<'static>,
80  }
81  
82  #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
83  impl RatchetTreeExt {
84      /// Required custom extension types.
85      #[cfg(feature = "ffi")]
tree_data(&self) -> &ExportedTree<'static>86      pub fn tree_data(&self) -> &ExportedTree<'static> {
87          &self.tree_data
88      }
89  }
90  
91  impl MlsCodecExtension for RatchetTreeExt {
extension_type() -> ExtensionType92      fn extension_type() -> ExtensionType {
93          ExtensionType::RATCHET_TREE
94      }
95  }
96  
97  /// Require members to have certain capabilities.
98  ///
99  /// Used within a
100  /// [Group Context Extensions Proposal](crate::group::proposal::Proposal)
101  /// in order to require that all current and future members of a group MUST
102  /// support specific extensions, proposals, or credentials.
103  ///
104  /// # Warning
105  ///
106  /// Extension, proposal, and credential types defined by the MLS RFC and
107  /// provided are considered required by default and should NOT be used
108  /// within this extension.
109  #[cfg_attr(
110      all(feature = "ffi", not(test)),
111      safer_ffi_gen::ffi_type(clone, opaque)
112  )]
113  #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)]
114  pub struct RequiredCapabilitiesExt {
115      pub extensions: Vec<ExtensionType>,
116      pub proposals: Vec<ProposalType>,
117      pub credentials: Vec<CredentialType>,
118  }
119  
120  #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
121  impl RequiredCapabilitiesExt {
122      /// Create a required capabilities extension.
new( extensions: Vec<ExtensionType>, proposals: Vec<ProposalType>, credentials: Vec<CredentialType>, ) -> Self123      pub fn new(
124          extensions: Vec<ExtensionType>,
125          proposals: Vec<ProposalType>,
126          credentials: Vec<CredentialType>,
127      ) -> Self {
128          Self {
129              extensions,
130              proposals,
131              credentials,
132          }
133      }
134  
135      /// Required custom extension types.
136      #[cfg(feature = "ffi")]
extensions(&self) -> &[ExtensionType]137      pub fn extensions(&self) -> &[ExtensionType] {
138          &self.extensions
139      }
140  
141      /// Required custom proposal types.
142      #[cfg(feature = "ffi")]
proposals(&self) -> &[ProposalType]143      pub fn proposals(&self) -> &[ProposalType] {
144          &self.proposals
145      }
146  
147      /// Required custom credential types.
148      #[cfg(feature = "ffi")]
credentials(&self) -> &[CredentialType]149      pub fn credentials(&self) -> &[CredentialType] {
150          &self.credentials
151      }
152  }
153  
154  impl MlsCodecExtension for RequiredCapabilitiesExt {
extension_type() -> ExtensionType155      fn extension_type() -> ExtensionType {
156          ExtensionType::REQUIRED_CAPABILITIES
157      }
158  }
159  
160  /// External public key used for [External Commits](crate::Client::commit_external).
161  ///
162  /// This proposal type is optionally provided as part of a
163  /// [Group Info](crate::group::Group::group_info_message).
164  #[cfg_attr(
165      all(feature = "ffi", not(test)),
166      safer_ffi_gen::ffi_type(clone, opaque)
167  )]
168  #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
169  pub struct ExternalPubExt {
170      /// Public key to be used for an external commit.
171      #[mls_codec(with = "mls_rs_codec::byte_vec")]
172      pub external_pub: HpkePublicKey,
173  }
174  
175  #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
176  impl ExternalPubExt {
177      /// Get the public key to be used for an external commit.
178      #[cfg(feature = "ffi")]
external_pub(&self) -> &HpkePublicKey179      pub fn external_pub(&self) -> &HpkePublicKey {
180          &self.external_pub
181      }
182  }
183  
184  impl MlsCodecExtension for ExternalPubExt {
extension_type() -> ExtensionType185      fn extension_type() -> ExtensionType {
186          ExtensionType::EXTERNAL_PUB
187      }
188  }
189  
190  /// Enable proposals by an [ExternalClient](crate::external_client::ExternalClient).
191  #[cfg(feature = "by_ref_proposal")]
192  #[cfg_attr(
193      all(feature = "ffi", not(test)),
194      safer_ffi_gen::ffi_type(clone, opaque)
195  )]
196  #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
197  #[non_exhaustive]
198  pub struct ExternalSendersExt {
199      pub allowed_senders: Vec<SigningIdentity>,
200  }
201  
202  #[cfg(feature = "by_ref_proposal")]
203  #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
204  impl ExternalSendersExt {
new(allowed_senders: Vec<SigningIdentity>) -> Self205      pub fn new(allowed_senders: Vec<SigningIdentity>) -> Self {
206          Self { allowed_senders }
207      }
208  
209      #[cfg(feature = "ffi")]
allowed_senders(&self) -> &[SigningIdentity]210      pub fn allowed_senders(&self) -> &[SigningIdentity] {
211          &self.allowed_senders
212      }
213  
214      #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_all<I: IdentityProvider>( &self, provider: &I, timestamp: Option<MlsTime>, group_context_extensions: &ExtensionList, ) -> Result<(), I::Error>215      pub(crate) async fn verify_all<I: IdentityProvider>(
216          &self,
217          provider: &I,
218          timestamp: Option<MlsTime>,
219          group_context_extensions: &ExtensionList,
220      ) -> Result<(), I::Error> {
221          for id in self.allowed_senders.iter() {
222              provider
223                  .validate_external_sender(id, timestamp, Some(group_context_extensions))
224                  .await?;
225          }
226  
227          Ok(())
228      }
229  }
230  
231  #[cfg(feature = "by_ref_proposal")]
232  impl MlsCodecExtension for ExternalSendersExt {
extension_type() -> ExtensionType233      fn extension_type() -> ExtensionType {
234          ExtensionType::EXTERNAL_SENDERS
235      }
236  }
237  
238  #[cfg(test)]
239  mod tests {
240      use super::*;
241  
242      use crate::tree_kem::node::NodeVec;
243      #[cfg(feature = "by_ref_proposal")]
244      use crate::{
245          client::test_utils::TEST_CIPHER_SUITE, identity::test_utils::get_test_signing_identity,
246      };
247  
248      use mls_rs_core::extension::MlsExtension;
249  
250      use mls_rs_core::identity::BasicCredential;
251  
252      use alloc::vec;
253  
254      #[cfg(target_arch = "wasm32")]
255      use wasm_bindgen_test::wasm_bindgen_test as test;
256  
257      #[test]
test_application_id_extension()258      fn test_application_id_extension() {
259          let test_id = vec![0u8; 32];
260          let test_extension = ApplicationIdExt {
261              identifier: test_id.clone(),
262          };
263  
264          let as_extension = test_extension.into_extension().unwrap();
265  
266          assert_eq!(as_extension.extension_type, ExtensionType::APPLICATION_ID);
267  
268          let restored = ApplicationIdExt::from_extension(&as_extension).unwrap();
269          assert_eq!(restored.identifier, test_id);
270      }
271  
272      #[test]
test_ratchet_tree()273      fn test_ratchet_tree() {
274          let ext = RatchetTreeExt {
275              tree_data: ExportedTree::new(NodeVec::from(vec![None, None])),
276          };
277  
278          let as_extension = ext.clone().into_extension().unwrap();
279          assert_eq!(as_extension.extension_type, ExtensionType::RATCHET_TREE);
280  
281          let restored = RatchetTreeExt::from_extension(&as_extension).unwrap();
282          assert_eq!(ext, restored)
283      }
284  
285      #[test]
test_required_capabilities()286      fn test_required_capabilities() {
287          let ext = RequiredCapabilitiesExt {
288              extensions: vec![0.into(), 1.into()],
289              proposals: vec![42.into(), 43.into()],
290              credentials: vec![BasicCredential::credential_type()],
291          };
292  
293          let as_extension = ext.clone().into_extension().unwrap();
294  
295          assert_eq!(
296              as_extension.extension_type,
297              ExtensionType::REQUIRED_CAPABILITIES
298          );
299  
300          let restored = RequiredCapabilitiesExt::from_extension(&as_extension).unwrap();
301          assert_eq!(ext, restored)
302      }
303  
304      #[cfg(feature = "by_ref_proposal")]
305      #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_external_senders()306      async fn test_external_senders() {
307          let identity = get_test_signing_identity(TEST_CIPHER_SUITE, &[1]).await.0;
308          let ext = ExternalSendersExt::new(vec![identity]);
309  
310          let as_extension = ext.clone().into_extension().unwrap();
311  
312          assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_SENDERS);
313  
314          let restored = ExternalSendersExt::from_extension(&as_extension).unwrap();
315          assert_eq!(ext, restored)
316      }
317  
318      #[test]
test_external_pub()319      fn test_external_pub() {
320          let ext = ExternalPubExt {
321              external_pub: vec![0, 1, 2, 3].into(),
322          };
323  
324          let as_extension = ext.clone().into_extension().unwrap();
325          assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_PUB);
326  
327          let restored = ExternalPubExt::from_extension(&as_extension).unwrap();
328          assert_eq!(ext, restored)
329      }
330  }
331