1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 /// The example shows how how to create an MLS extension implementing an access control policy
6 /// based on the concept of users, similar to
7 /// https://bifurcation.github.io/ietf-mimi-protocol/draft-ralston-mimi-protocol.html.
8 ///
9 /// A user, e.g. "[email protected]", owns zero or more MLS members, e.g. Bob's tablet and PC.
10 /// Users do not have MLS cryptographic state, while MLS members do. At any point in time,
11 /// the MLS group has a fixed set of users and for each user, zero or more MLS members they
12 /// own. Each user also has a role, e.g. a regular user or moderator (which may possibly change
13 /// over time).
14 ///
15 /// The goal is to implement the following rule:
16 /// 1. Each MLS member belongs to a user in the group.
17 ///
18 /// To this end, we implement the following:
19 /// * A GroupContext extension containing the current list of users. MLS guarantees agreement
20 /// on the list.
21 /// * An AddUser proposal that modifies the user list.
22 /// * An MLS credential type for MLS members with the owning user's public key and signature.
23 /// When MLS members join using MLS Add proposals, the signature is verified.
24 /// * Proposal validation rules that enforce 1. above.
25 ///
26 use assert_matches::assert_matches;
27 use mls_rs::{
28 client_builder::{MlsConfig, PaddingMode},
29 error::MlsError,
30 group::{
31 proposal::{MlsCustomProposal, Proposal},
32 Roster, Sender,
33 },
34 mls_rules::{
35 CommitDirection, CommitOptions, CommitSource, EncryptionOptions, ProposalBundle,
36 ProposalSource,
37 },
38 CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, IdentityProvider,
39 MlsRules,
40 };
41 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
42 use mls_rs_core::{
43 crypto::{SignaturePublicKey, SignatureSecretKey},
44 error::IntoAnyError,
45 extension::{ExtensionError, ExtensionType, MlsCodecExtension},
46 group::ProposalType,
47 identity::{Credential, CredentialType, CustomCredential, MlsCredential, SigningIdentity},
48 time::MlsTime,
49 };
50
51 use std::fmt::Display;
52
53 const CIPHER_SUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
54
55 const ROSTER_EXTENSION_V1: ExtensionType = ExtensionType::new(65000);
56 const ADD_USER_PROPOSAL_V1: ProposalType = ProposalType::new(65001);
57 const CREDENTIAL_V1: CredentialType = CredentialType::new(65002);
58
crypto() -> impl CryptoProvider + Clone59 fn crypto() -> impl CryptoProvider + Clone {
60 mls_rs_crypto_openssl::OpensslCryptoProvider::new()
61 }
62
cipher_suite() -> impl CipherSuiteProvider63 fn cipher_suite() -> impl CipherSuiteProvider {
64 crypto().cipher_suite_provider(CIPHER_SUITE).unwrap()
65 }
66
67 #[derive(MlsSize, MlsDecode, MlsEncode)]
68 #[repr(u8)]
69 enum UserRole {
70 Regular = 1u8,
71 Moderator = 2u8,
72 }
73
74 #[derive(MlsSize, MlsDecode, MlsEncode)]
75 struct UserCredential {
76 name: String,
77 role: UserRole,
78 public_key: SignaturePublicKey,
79 }
80
81 #[derive(MlsSize, MlsDecode, MlsEncode)]
82 struct MemberCredential {
83 name: String,
84 user_public_key: SignaturePublicKey, // Identifies the user
85 signature: Vec<u8>,
86 }
87
88 #[derive(MlsSize, MlsEncode)]
89 struct MemberCredentialTBS<'a> {
90 name: &'a str,
91 user_public_key: &'a SignaturePublicKey,
92 public_key: &'a SignaturePublicKey,
93 }
94
95 /// The roster will be stored in the custom RosterExtension, an extension in the MLS GroupContext
96 #[derive(MlsSize, MlsDecode, MlsEncode)]
97 struct RosterExtension {
98 roster: Vec<UserCredential>,
99 }
100
101 impl MlsCodecExtension for RosterExtension {
extension_type() -> ExtensionType102 fn extension_type() -> ExtensionType {
103 ROSTER_EXTENSION_V1
104 }
105 }
106
107 /// The custom AddUser proposal will be used to update the RosterExtension
108 #[derive(MlsSize, MlsDecode, MlsEncode)]
109 struct AddUserProposal {
110 new_user: UserCredential,
111 }
112
113 impl MlsCustomProposal for AddUserProposal {
proposal_type() -> ProposalType114 fn proposal_type() -> ProposalType {
115 ADD_USER_PROPOSAL_V1
116 }
117 }
118
119 /// MlsRules tell MLS how to handle our custom proposal
120 #[derive(Debug, Clone, Copy)]
121 struct CustomMlsRules;
122
123 impl MlsRules for CustomMlsRules {
124 type Error = CustomError;
125
filter_proposals( &self, _: CommitDirection, _: CommitSource, _: &Roster, extension_list: &ExtensionList, mut proposals: ProposalBundle, ) -> Result<ProposalBundle, Self::Error>126 fn filter_proposals(
127 &self,
128 _: CommitDirection,
129 _: CommitSource,
130 _: &Roster,
131 extension_list: &ExtensionList,
132 mut proposals: ProposalBundle,
133 ) -> Result<ProposalBundle, Self::Error> {
134 // Find our extension
135 let mut roster: RosterExtension =
136 extension_list.get_as().ok().flatten().ok_or(CustomError)?;
137
138 // Find AddUser proposals
139 let add_user_proposals = proposals
140 .custom_proposals()
141 .iter()
142 .filter(|p| p.proposal.proposal_type() == ADD_USER_PROPOSAL_V1);
143
144 for add_user_info in add_user_proposals {
145 let add_user = AddUserProposal::from_custom_proposal(&add_user_info.proposal)?;
146
147 // Eventually we should check for duplicates
148 roster.roster.push(add_user.new_user);
149 }
150
151 // Issue GroupContextExtensions proposal to modify our roster (eventually we don't have to do this if there were no AddUser proposals)
152 let mut new_extensions = extension_list.clone();
153 new_extensions.set_from(roster)?;
154 let gce_proposal = Proposal::GroupContextExtensions(new_extensions);
155 proposals.add(gce_proposal, Sender::Member(0), ProposalSource::Local);
156
157 Ok(proposals)
158 }
159
commit_options( &self, _: &Roster, _: &ExtensionList, _: &ProposalBundle, ) -> Result<CommitOptions, Self::Error>160 fn commit_options(
161 &self,
162 _: &Roster,
163 _: &ExtensionList,
164 _: &ProposalBundle,
165 ) -> Result<CommitOptions, Self::Error> {
166 Ok(CommitOptions::new())
167 }
168
encryption_options( &self, _: &Roster, _: &ExtensionList, ) -> Result<EncryptionOptions, Self::Error>169 fn encryption_options(
170 &self,
171 _: &Roster,
172 _: &ExtensionList,
173 ) -> Result<EncryptionOptions, Self::Error> {
174 Ok(EncryptionOptions::new(false, PaddingMode::None))
175 }
176 }
177
178 // The IdentityProvider will tell MLS how to validate members' identities. We will use custom identity
179 // type to store our User structs.
180 impl MlsCredential for MemberCredential {
181 type Error = CustomError;
182
credential_type() -> CredentialType183 fn credential_type() -> CredentialType {
184 CREDENTIAL_V1
185 }
186
into_credential(self) -> Result<Credential, Self::Error>187 fn into_credential(self) -> Result<Credential, Self::Error> {
188 Ok(Credential::Custom(CustomCredential::new(
189 Self::credential_type(),
190 self.mls_encode_to_vec()?,
191 )))
192 }
193 }
194
195 #[derive(Debug, Clone, Copy)]
196 struct CustomIdentityProvider;
197
198 impl IdentityProvider for CustomIdentityProvider {
199 type Error = CustomError;
200
validate_member( &self, signing_identity: &SigningIdentity, _: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> Result<(), Self::Error>201 fn validate_member(
202 &self,
203 signing_identity: &SigningIdentity,
204 _: Option<MlsTime>,
205 extensions: Option<&ExtensionList>,
206 ) -> Result<(), Self::Error> {
207 let Some(extensions) = extensions else {
208 return Ok(());
209 };
210
211 let roster = extensions
212 .get_as::<RosterExtension>()
213 .ok()
214 .flatten()
215 .ok_or(CustomError)?;
216
217 // Retrieve the MemberCredential from the MLS credential
218 let Credential::Custom(custom) = &signing_identity.credential else {
219 return Err(CustomError);
220 };
221
222 if custom.credential_type != CREDENTIAL_V1 {
223 return Err(CustomError);
224 }
225
226 let member = MemberCredential::mls_decode(&mut &*custom.data)?;
227
228 // Validate the MemberCredential
229
230 let tbs = MemberCredentialTBS {
231 name: &member.name,
232 user_public_key: &member.user_public_key,
233 public_key: &signing_identity.signature_key,
234 }
235 .mls_encode_to_vec()?;
236
237 cipher_suite()
238 .verify(&member.user_public_key, &member.signature, &tbs)
239 .map_err(|_| CustomError)?;
240
241 let user_in_roster = roster
242 .roster
243 .iter()
244 .any(|u| u.public_key == member.user_public_key);
245
246 if !user_in_roster {
247 return Err(CustomError);
248 }
249
250 Ok(())
251 }
252
identity( &self, signing_identity: &SigningIdentity, _: &ExtensionList, ) -> Result<Vec<u8>, Self::Error>253 fn identity(
254 &self,
255 signing_identity: &SigningIdentity,
256 _: &ExtensionList,
257 ) -> Result<Vec<u8>, Self::Error> {
258 Ok(signing_identity.mls_encode_to_vec()?)
259 }
260
supported_types(&self) -> Vec<CredentialType>261 fn supported_types(&self) -> Vec<CredentialType> {
262 vec![CREDENTIAL_V1]
263 }
264
valid_successor( &self, _: &SigningIdentity, _: &SigningIdentity, _: &ExtensionList, ) -> Result<bool, Self::Error>265 fn valid_successor(
266 &self,
267 _: &SigningIdentity,
268 _: &SigningIdentity,
269 _: &ExtensionList,
270 ) -> Result<bool, Self::Error> {
271 Ok(true)
272 }
273
validate_external_sender( &self, _: &SigningIdentity, _: Option<MlsTime>, _: Option<&ExtensionList>, ) -> Result<(), Self::Error>274 fn validate_external_sender(
275 &self,
276 _: &SigningIdentity,
277 _: Option<MlsTime>,
278 _: Option<&ExtensionList>,
279 ) -> Result<(), Self::Error> {
280 Ok(())
281 }
282 }
283
284 // Convenience structs to create users and members
285
286 struct User {
287 credential: UserCredential,
288 signer: SignatureSecretKey,
289 }
290
291 impl User {
new(name: &str, role: UserRole) -> Result<Self, CustomError>292 fn new(name: &str, role: UserRole) -> Result<Self, CustomError> {
293 let (signer, public_key) = cipher_suite()
294 .signature_key_generate()
295 .map_err(|_| CustomError)?;
296
297 let credential = UserCredential {
298 name: name.into(),
299 role,
300 public_key,
301 };
302
303 Ok(Self { credential, signer })
304 }
305 }
306
307 struct Member {
308 credential: MemberCredential,
309 public_key: SignaturePublicKey,
310 signer: SignatureSecretKey,
311 }
312
313 impl Member {
new(name: &str, user: &User) -> Result<Self, CustomError>314 fn new(name: &str, user: &User) -> Result<Self, CustomError> {
315 let (signer, public_key) = cipher_suite()
316 .signature_key_generate()
317 .map_err(|_| CustomError)?;
318
319 let tbs = MemberCredentialTBS {
320 name,
321 user_public_key: &user.credential.public_key,
322 public_key: &public_key,
323 }
324 .mls_encode_to_vec()?;
325
326 let signature = cipher_suite()
327 .sign(&user.signer, &tbs)
328 .map_err(|_| CustomError)?;
329
330 let credential = MemberCredential {
331 name: name.into(),
332 user_public_key: user.credential.public_key.clone(),
333 signature,
334 };
335
336 Ok(Self {
337 credential,
338 signer,
339 public_key,
340 })
341 }
342 }
343
344 // Set up Client to use our custom providers
make_client(member: Member) -> Result<Client<impl MlsConfig>, CustomError>345 fn make_client(member: Member) -> Result<Client<impl MlsConfig>, CustomError> {
346 let mls_credential = member.credential.into_credential()?;
347 let signing_identity = SigningIdentity::new(mls_credential, member.public_key);
348
349 Ok(Client::builder()
350 .identity_provider(CustomIdentityProvider)
351 .mls_rules(CustomMlsRules)
352 .custom_proposal_type(ADD_USER_PROPOSAL_V1)
353 .extension_type(ROSTER_EXTENSION_V1)
354 .crypto_provider(crypto())
355 .signing_identity(signing_identity, member.signer, CIPHER_SUITE)
356 .build())
357 }
358
main() -> Result<(), CustomError>359 fn main() -> Result<(), CustomError> {
360 let alice = User::new("alice", UserRole::Moderator)?;
361 let bob = User::new("bob", UserRole::Regular)?;
362
363 let alice_tablet = Member::new("alice tablet", &alice)?;
364 let alice_pc = Member::new("alice pc", &alice)?;
365 let bob_tablet = Member::new("bob tablet", &bob)?;
366
367 // Alice creates the group with our RosterExtension containing her user
368 let mut context_extensions = ExtensionList::new();
369 let roster = vec![alice.credential];
370 context_extensions.set_from(RosterExtension { roster })?;
371
372 let mut alice_tablet_group = make_client(alice_tablet)?.create_group(context_extensions)?;
373
374 // Alice can add her other device
375 let alice_pc_client = make_client(alice_pc)?;
376 let key_package = alice_pc_client.generate_key_package_message()?;
377
378 let welcome = alice_tablet_group
379 .commit_builder()
380 .add_member(key_package)?
381 .build()?
382 .welcome_messages
383 .remove(0);
384
385 alice_tablet_group.apply_pending_commit()?;
386 let (mut alice_pc_group, _) = alice_pc_client.join_group(None, &welcome)?;
387
388 // Alice cannot add bob's devices yet
389 let bob_tablet_client = make_client(bob_tablet)?;
390 let key_package = bob_tablet_client.generate_key_package_message()?;
391
392 let res = alice_tablet_group
393 .commit_builder()
394 .add_member(key_package.clone())?
395 .build();
396
397 assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
398
399 // Alice can add bob's user and device
400 let add_bob = AddUserProposal {
401 new_user: bob.credential,
402 };
403
404 let commit = alice_tablet_group
405 .commit_builder()
406 .custom_proposal(add_bob.to_custom_proposal()?)
407 .add_member(key_package)?
408 .build()?;
409
410 bob_tablet_client.join_group(None, &commit.welcome_messages[0])?;
411 alice_tablet_group.apply_pending_commit()?;
412 alice_pc_group.process_incoming_message(commit.commit_message)?;
413
414 Ok(())
415 }
416
417 #[derive(Debug, thiserror::Error)]
418 struct CustomError;
419
420 impl IntoAnyError for CustomError {
into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self>421 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
422 Ok(Box::new(self))
423 }
424 }
425
426 impl Display for CustomError {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result427 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428 f.write_str("Custom Error")
429 }
430 }
431
432 impl From<MlsError> for CustomError {
from(_: MlsError) -> Self433 fn from(_: MlsError) -> Self {
434 Self
435 }
436 }
437
438 impl From<mls_rs_codec::Error> for CustomError {
from(_: mls_rs_codec::Error) -> Self439 fn from(_: mls_rs_codec::Error) -> Self {
440 Self
441 }
442 }
443
444 impl From<ExtensionError> for CustomError {
from(_: ExtensionError) -> Self445 fn from(_: ExtensionError) -> Self {
446 Self
447 }
448 }
449