1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 /// The example shows how how to create an MLS extension implementing an access control policy
6 /// based on the concept of users, similar to
7 /// https://bifurcation.github.io/ietf-mimi-protocol/draft-ralston-mimi-protocol.html.
8 ///
9 /// A user, e.g. "[email protected]", owns zero or more MLS members, e.g. Bob's tablet and PC.
10 /// Users do not have MLS cryptographic state, while MLS members do. At any point in time,
11 /// the MLS group has a fixed set of users and for each user, zero or more MLS members they
12 /// own. Each user also has a role, e.g. a regular user or moderator (which may possibly change
13 /// over time).
14 ///
15 /// The goal is to implement the following rule:
16 /// 1. Each MLS member belongs to a user in the group.
17 ///
18 /// To this end, we implement the following:
19 /// * A GroupContext extension containing the current list of users. MLS guarantees agreement
20 ///   on the list.
21 /// * An AddUser proposal that modifies the user list.
22 /// * An MLS credential type for MLS members with the owning user's public key and signature.
23 ///   When MLS members join using MLS Add proposals, the signature is verified.
24 /// * Proposal validation rules that enforce 1. above.
25 ///
26 use assert_matches::assert_matches;
27 use mls_rs::{
28     client_builder::{MlsConfig, PaddingMode},
29     error::MlsError,
30     group::{
31         proposal::{MlsCustomProposal, Proposal},
32         Roster, Sender,
33     },
34     mls_rules::{
35         CommitDirection, CommitOptions, CommitSource, EncryptionOptions, ProposalBundle,
36         ProposalSource,
37     },
38     CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, IdentityProvider,
39     MlsRules,
40 };
41 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
42 use mls_rs_core::{
43     crypto::{SignaturePublicKey, SignatureSecretKey},
44     error::IntoAnyError,
45     extension::{ExtensionError, ExtensionType, MlsCodecExtension},
46     group::ProposalType,
47     identity::{Credential, CredentialType, CustomCredential, MlsCredential, SigningIdentity},
48     time::MlsTime,
49 };
50 
51 use std::fmt::Display;
52 
53 const CIPHER_SUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
54 
55 const ROSTER_EXTENSION_V1: ExtensionType = ExtensionType::new(65000);
56 const ADD_USER_PROPOSAL_V1: ProposalType = ProposalType::new(65001);
57 const CREDENTIAL_V1: CredentialType = CredentialType::new(65002);
58 
crypto() -> impl CryptoProvider + Clone59 fn crypto() -> impl CryptoProvider + Clone {
60     mls_rs_crypto_openssl::OpensslCryptoProvider::new()
61 }
62 
cipher_suite() -> impl CipherSuiteProvider63 fn cipher_suite() -> impl CipherSuiteProvider {
64     crypto().cipher_suite_provider(CIPHER_SUITE).unwrap()
65 }
66 
67 #[derive(MlsSize, MlsDecode, MlsEncode)]
68 #[repr(u8)]
69 enum UserRole {
70     Regular = 1u8,
71     Moderator = 2u8,
72 }
73 
74 #[derive(MlsSize, MlsDecode, MlsEncode)]
75 struct UserCredential {
76     name: String,
77     role: UserRole,
78     public_key: SignaturePublicKey,
79 }
80 
81 #[derive(MlsSize, MlsDecode, MlsEncode)]
82 struct MemberCredential {
83     name: String,
84     user_public_key: SignaturePublicKey, // Identifies the user
85     signature: Vec<u8>,
86 }
87 
88 #[derive(MlsSize, MlsEncode)]
89 struct MemberCredentialTBS<'a> {
90     name: &'a str,
91     user_public_key: &'a SignaturePublicKey,
92     public_key: &'a SignaturePublicKey,
93 }
94 
95 /// The roster will be stored in the custom RosterExtension, an extension in the MLS GroupContext
96 #[derive(MlsSize, MlsDecode, MlsEncode)]
97 struct RosterExtension {
98     roster: Vec<UserCredential>,
99 }
100 
101 impl MlsCodecExtension for RosterExtension {
extension_type() -> ExtensionType102     fn extension_type() -> ExtensionType {
103         ROSTER_EXTENSION_V1
104     }
105 }
106 
107 /// The custom AddUser proposal will be used to update the RosterExtension
108 #[derive(MlsSize, MlsDecode, MlsEncode)]
109 struct AddUserProposal {
110     new_user: UserCredential,
111 }
112 
113 impl MlsCustomProposal for AddUserProposal {
proposal_type() -> ProposalType114     fn proposal_type() -> ProposalType {
115         ADD_USER_PROPOSAL_V1
116     }
117 }
118 
119 /// MlsRules tell MLS how to handle our custom proposal
120 #[derive(Debug, Clone, Copy)]
121 struct CustomMlsRules;
122 
123 impl MlsRules for CustomMlsRules {
124     type Error = CustomError;
125 
filter_proposals( &self, _: CommitDirection, _: CommitSource, _: &Roster, extension_list: &ExtensionList, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, Self::Error>126     fn filter_proposals(
127         &self,
128         _: CommitDirection,
129         _: CommitSource,
130         _: &Roster,
131         extension_list: &ExtensionList,
132         mut proposals: ProposalBundle,
133     ) -> Result<ProposalBundle, Self::Error> {
134         // Find our extension
135         let mut roster: RosterExtension =
136             extension_list.get_as().ok().flatten().ok_or(CustomError)?;
137 
138         // Find AddUser proposals
139         let add_user_proposals = proposals
140             .custom_proposals()
141             .iter()
142             .filter(|p| p.proposal.proposal_type() == ADD_USER_PROPOSAL_V1);
143 
144         for add_user_info in add_user_proposals {
145             let add_user = AddUserProposal::from_custom_proposal(&add_user_info.proposal)?;
146 
147             // Eventually we should check for duplicates
148             roster.roster.push(add_user.new_user);
149         }
150 
151         // Issue GroupContextExtensions proposal to modify our roster (eventually we don't have to do this if there were no AddUser proposals)
152         let mut new_extensions = extension_list.clone();
153         new_extensions.set_from(roster)?;
154         let gce_proposal = Proposal::GroupContextExtensions(new_extensions);
155         proposals.add(gce_proposal, Sender::Member(0), ProposalSource::Local);
156 
157         Ok(proposals)
158     }
159 
commit_options( &self, _: &Roster, _: &ExtensionList, _: &ProposalBundle, ) -> Result<CommitOptions, Self::Error>160     fn commit_options(
161         &self,
162         _: &Roster,
163         _: &ExtensionList,
164         _: &ProposalBundle,
165     ) -> Result<CommitOptions, Self::Error> {
166         Ok(CommitOptions::new())
167     }
168 
encryption_options( &self, _: &Roster, _: &ExtensionList, ) -> Result<EncryptionOptions, Self::Error>169     fn encryption_options(
170         &self,
171         _: &Roster,
172         _: &ExtensionList,
173     ) -> Result<EncryptionOptions, Self::Error> {
174         Ok(EncryptionOptions::new(false, PaddingMode::None))
175     }
176 }
177 
178 // The IdentityProvider will tell MLS how to validate members' identities. We will use custom identity
179 // type to store our User structs.
180 impl MlsCredential for MemberCredential {
181     type Error = CustomError;
182 
credential_type() -> CredentialType183     fn credential_type() -> CredentialType {
184         CREDENTIAL_V1
185     }
186 
into_credential(self) -> Result<Credential, Self::Error>187     fn into_credential(self) -> Result<Credential, Self::Error> {
188         Ok(Credential::Custom(CustomCredential::new(
189             Self::credential_type(),
190             self.mls_encode_to_vec()?,
191         )))
192     }
193 }
194 
195 #[derive(Debug, Clone, Copy)]
196 struct CustomIdentityProvider;
197 
198 impl IdentityProvider for CustomIdentityProvider {
199     type Error = CustomError;
200 
validate_member( &self, signing_identity: &SigningIdentity, _: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> Result<(), Self::Error>201     fn validate_member(
202         &self,
203         signing_identity: &SigningIdentity,
204         _: Option<MlsTime>,
205         extensions: Option<&ExtensionList>,
206     ) -> Result<(), Self::Error> {
207         let Some(extensions) = extensions else {
208             return Ok(());
209         };
210 
211         let roster = extensions
212             .get_as::<RosterExtension>()
213             .ok()
214             .flatten()
215             .ok_or(CustomError)?;
216 
217         // Retrieve the MemberCredential from the MLS credential
218         let Credential::Custom(custom) = &signing_identity.credential else {
219             return Err(CustomError);
220         };
221 
222         if custom.credential_type != CREDENTIAL_V1 {
223             return Err(CustomError);
224         }
225 
226         let member = MemberCredential::mls_decode(&mut &*custom.data)?;
227 
228         // Validate the MemberCredential
229 
230         let tbs = MemberCredentialTBS {
231             name: &member.name,
232             user_public_key: &member.user_public_key,
233             public_key: &signing_identity.signature_key,
234         }
235         .mls_encode_to_vec()?;
236 
237         cipher_suite()
238             .verify(&member.user_public_key, &member.signature, &tbs)
239             .map_err(|_| CustomError)?;
240 
241         let user_in_roster = roster
242             .roster
243             .iter()
244             .any(|u| u.public_key == member.user_public_key);
245 
246         if !user_in_roster {
247             return Err(CustomError);
248         }
249 
250         Ok(())
251     }
252 
identity( &self, signing_identity: &SigningIdentity, _: &ExtensionList, ) -> Result<Vec<u8>, Self::Error>253     fn identity(
254         &self,
255         signing_identity: &SigningIdentity,
256         _: &ExtensionList,
257     ) -> Result<Vec<u8>, Self::Error> {
258         Ok(signing_identity.mls_encode_to_vec()?)
259     }
260 
supported_types(&self) -> Vec<CredentialType>261     fn supported_types(&self) -> Vec<CredentialType> {
262         vec![CREDENTIAL_V1]
263     }
264 
valid_successor( &self, _: &SigningIdentity, _: &SigningIdentity, _: &ExtensionList, ) -> Result<bool, Self::Error>265     fn valid_successor(
266         &self,
267         _: &SigningIdentity,
268         _: &SigningIdentity,
269         _: &ExtensionList,
270     ) -> Result<bool, Self::Error> {
271         Ok(true)
272     }
273 
validate_external_sender( &self, _: &SigningIdentity, _: Option<MlsTime>, _: Option<&ExtensionList>, ) -> Result<(), Self::Error>274     fn validate_external_sender(
275         &self,
276         _: &SigningIdentity,
277         _: Option<MlsTime>,
278         _: Option<&ExtensionList>,
279     ) -> Result<(), Self::Error> {
280         Ok(())
281     }
282 }
283 
284 // Convenience structs to create users and members
285 
286 struct User {
287     credential: UserCredential,
288     signer: SignatureSecretKey,
289 }
290 
291 impl User {
new(name: &str, role: UserRole) -> Result<Self, CustomError>292     fn new(name: &str, role: UserRole) -> Result<Self, CustomError> {
293         let (signer, public_key) = cipher_suite()
294             .signature_key_generate()
295             .map_err(|_| CustomError)?;
296 
297         let credential = UserCredential {
298             name: name.into(),
299             role,
300             public_key,
301         };
302 
303         Ok(Self { credential, signer })
304     }
305 }
306 
307 struct Member {
308     credential: MemberCredential,
309     public_key: SignaturePublicKey,
310     signer: SignatureSecretKey,
311 }
312 
313 impl Member {
new(name: &str, user: &User) -> Result<Self, CustomError>314     fn new(name: &str, user: &User) -> Result<Self, CustomError> {
315         let (signer, public_key) = cipher_suite()
316             .signature_key_generate()
317             .map_err(|_| CustomError)?;
318 
319         let tbs = MemberCredentialTBS {
320             name,
321             user_public_key: &user.credential.public_key,
322             public_key: &public_key,
323         }
324         .mls_encode_to_vec()?;
325 
326         let signature = cipher_suite()
327             .sign(&user.signer, &tbs)
328             .map_err(|_| CustomError)?;
329 
330         let credential = MemberCredential {
331             name: name.into(),
332             user_public_key: user.credential.public_key.clone(),
333             signature,
334         };
335 
336         Ok(Self {
337             credential,
338             signer,
339             public_key,
340         })
341     }
342 }
343 
344 // Set up Client to use our custom providers
make_client(member: Member) -> Result<Client<impl MlsConfig>, CustomError>345 fn make_client(member: Member) -> Result<Client<impl MlsConfig>, CustomError> {
346     let mls_credential = member.credential.into_credential()?;
347     let signing_identity = SigningIdentity::new(mls_credential, member.public_key);
348 
349     Ok(Client::builder()
350         .identity_provider(CustomIdentityProvider)
351         .mls_rules(CustomMlsRules)
352         .custom_proposal_type(ADD_USER_PROPOSAL_V1)
353         .extension_type(ROSTER_EXTENSION_V1)
354         .crypto_provider(crypto())
355         .signing_identity(signing_identity, member.signer, CIPHER_SUITE)
356         .build())
357 }
358 
main() -> Result<(), CustomError>359 fn main() -> Result<(), CustomError> {
360     let alice = User::new("alice", UserRole::Moderator)?;
361     let bob = User::new("bob", UserRole::Regular)?;
362 
363     let alice_tablet = Member::new("alice tablet", &alice)?;
364     let alice_pc = Member::new("alice pc", &alice)?;
365     let bob_tablet = Member::new("bob tablet", &bob)?;
366 
367     // Alice creates the group with our RosterExtension containing her user
368     let mut context_extensions = ExtensionList::new();
369     let roster = vec![alice.credential];
370     context_extensions.set_from(RosterExtension { roster })?;
371 
372     let mut alice_tablet_group = make_client(alice_tablet)?.create_group(context_extensions)?;
373 
374     // Alice can add her other device
375     let alice_pc_client = make_client(alice_pc)?;
376     let key_package = alice_pc_client.generate_key_package_message()?;
377 
378     let welcome = alice_tablet_group
379         .commit_builder()
380         .add_member(key_package)?
381         .build()?
382         .welcome_messages
383         .remove(0);
384 
385     alice_tablet_group.apply_pending_commit()?;
386     let (mut alice_pc_group, _) = alice_pc_client.join_group(None, &welcome)?;
387 
388     // Alice cannot add bob's devices yet
389     let bob_tablet_client = make_client(bob_tablet)?;
390     let key_package = bob_tablet_client.generate_key_package_message()?;
391 
392     let res = alice_tablet_group
393         .commit_builder()
394         .add_member(key_package.clone())?
395         .build();
396 
397     assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
398 
399     // Alice can add bob's user and device
400     let add_bob = AddUserProposal {
401         new_user: bob.credential,
402     };
403 
404     let commit = alice_tablet_group
405         .commit_builder()
406         .custom_proposal(add_bob.to_custom_proposal()?)
407         .add_member(key_package)?
408         .build()?;
409 
410     bob_tablet_client.join_group(None, &commit.welcome_messages[0])?;
411     alice_tablet_group.apply_pending_commit()?;
412     alice_pc_group.process_incoming_message(commit.commit_message)?;
413 
414     Ok(())
415 }
416 
417 #[derive(Debug, thiserror::Error)]
418 struct CustomError;
419 
420 impl IntoAnyError for CustomError {
into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self>421     fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
422         Ok(Box::new(self))
423     }
424 }
425 
426 impl Display for CustomError {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result427     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428         f.write_str("Custom Error")
429     }
430 }
431 
432 impl From<MlsError> for CustomError {
from(_: MlsError) -> Self433     fn from(_: MlsError) -> Self {
434         Self
435     }
436 }
437 
438 impl From<mls_rs_codec::Error> for CustomError {
from(_: mls_rs_codec::Error) -> Self439     fn from(_: mls_rs_codec::Error) -> Self {
440         Self
441     }
442 }
443 
444 impl From<ExtensionError> for CustomError {
from(_: ExtensionError) -> Self445     fn from(_: ExtensionError) -> Self {
446         Self
447     }
448 }
449