1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use mls_rs::{
6     client_builder::MlsConfig,
7     error::MlsError,
8     external_client::{
9         builder::MlsConfig as ExternalMlsConfig, ExternalClient, ExternalReceivedMessage,
10         ExternalSnapshot,
11     },
12     group::{CachedProposal, ReceivedMessage},
13     identity::{
14         basic::{BasicCredential, BasicIdentityProvider},
15         SigningIdentity,
16     },
17     CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, MlsMessage,
18 };
19 use mls_rs_core::crypto::SignatureSecretKey;
20 
21 const CIPHERSUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
22 
cipher_suite_provider() -> impl CipherSuiteProvider23 fn cipher_suite_provider() -> impl CipherSuiteProvider {
24     crypto_provider()
25         .cipher_suite_provider(CIPHERSUITE)
26         .unwrap()
27 }
28 
crypto_provider() -> impl CryptoProvider + Clone29 fn crypto_provider() -> impl CryptoProvider + Clone {
30     mls_rs_crypto_openssl::OpensslCryptoProvider::default()
31 }
32 
33 #[derive(Default)]
34 struct BasicServer {
35     group_state: Vec<u8>,
36     cached_proposals: Vec<Vec<u8>>,
37     message_queue: Vec<Vec<u8>>,
38 }
39 
40 impl BasicServer {
41     // Client uploads group data after creating the group
create_group(group_info: &[u8]) -> Result<Self, MlsError>42     fn create_group(group_info: &[u8]) -> Result<Self, MlsError> {
43         let server = make_server();
44         let group_info = MlsMessage::from_bytes(group_info)?;
45 
46         let group = server.observe_group(group_info, None)?;
47 
48         Ok(Self {
49             group_state: group.snapshot().to_bytes()?,
50             ..Default::default()
51         })
52     }
53 
54     // Client uploads a proposal. This doesn't change the server's group state, so clients can
55     // upload prposals without synchronization (`cached_proposals` and `message_queue` collect
56     // all proposals in any order).
upload_proposal(&mut self, proposal: Vec<u8>) -> Result<(), MlsError>57     fn upload_proposal(&mut self, proposal: Vec<u8>) -> Result<(), MlsError> {
58         let server = make_server();
59         let group_state = ExternalSnapshot::from_bytes(&self.group_state)?;
60         let mut group = server.load_group(group_state)?;
61 
62         let proposal_msg = MlsMessage::from_bytes(&proposal)?;
63         let res = group.process_incoming_message(proposal_msg)?;
64 
65         let ExternalReceivedMessage::Proposal(proposal_desc) = res else {
66             panic!("expected proposal message!")
67         };
68 
69         self.cached_proposals
70             .push(proposal_desc.cached_proposal().to_bytes()?);
71 
72         self.message_queue.push(proposal);
73 
74         Ok(())
75     }
76 
77     // Client uploads a commit. This changes the server's group state, so in a real application,
78     // it must be synchronized. That is, only one `upload_commit` operation can succeed.
upload_commit(&mut self, commit: Vec<u8>) -> Result<(), MlsError>79     fn upload_commit(&mut self, commit: Vec<u8>) -> Result<(), MlsError> {
80         let server = make_server();
81         let group_state = ExternalSnapshot::from_bytes(&self.group_state)?;
82         let mut group = server.load_group(group_state)?;
83 
84         for p in &self.cached_proposals {
85             group.insert_proposal(CachedProposal::from_bytes(p)?);
86         }
87 
88         let commit_msg = MlsMessage::from_bytes(&commit)?;
89         let res = group.process_incoming_message(commit_msg)?;
90 
91         let ExternalReceivedMessage::Commit(_commit_desc) = res else {
92             panic!("expected commit message!")
93         };
94 
95         self.cached_proposals = Vec::new();
96         self.group_state = group.snapshot().to_bytes()?;
97         self.message_queue.push(commit);
98 
99         Ok(())
100     }
101 
download_messages(&self, i: usize) -> &[Vec<u8>]102     pub fn download_messages(&self, i: usize) -> &[Vec<u8>] {
103         &self.message_queue[i..]
104     }
105 }
106 
make_server() -> ExternalClient<impl ExternalMlsConfig>107 fn make_server() -> ExternalClient<impl ExternalMlsConfig> {
108     ExternalClient::builder()
109         .identity_provider(BasicIdentityProvider)
110         .crypto_provider(crypto_provider())
111         .build()
112 }
113 
make_client(name: &str) -> Result<Client<impl MlsConfig>, MlsError>114 fn make_client(name: &str) -> Result<Client<impl MlsConfig>, MlsError> {
115     let (secret, signing_identity) = make_identity(name);
116 
117     Ok(Client::builder()
118         .identity_provider(BasicIdentityProvider)
119         .crypto_provider(crypto_provider())
120         .signing_identity(signing_identity, secret, CIPHERSUITE)
121         .build())
122 }
123 
make_identity(name: &str) -> (SignatureSecretKey, SigningIdentity)124 fn make_identity(name: &str) -> (SignatureSecretKey, SigningIdentity) {
125     let cipher_suite = cipher_suite_provider();
126     let (secret, public) = cipher_suite.signature_key_generate().unwrap();
127 
128     // Create a basic credential for the session.
129     // NOTE: BasicCredential is for demonstration purposes and not recommended for production.
130     // X.509 credentials are recommended.
131     let basic_identity = BasicCredential::new(name.as_bytes().to_vec());
132     let identity = SigningIdentity::new(basic_identity.into_credential(), public);
133 
134     (secret, identity)
135 }
136 
main() -> Result<(), MlsError>137 fn main() -> Result<(), MlsError> {
138     // Create clients for Alice and Bob
139     let alice = make_client("alice")?;
140     let bob = make_client("bob")?;
141 
142     // Alice creates a group with bob
143     let mut alice_group = alice.create_group(ExtensionList::default())?;
144     let bob_key_package = bob.generate_key_package_message()?;
145 
146     let welcome = &alice_group
147         .commit_builder()
148         .add_member(bob_key_package)?
149         .build()?
150         .welcome_messages[0];
151 
152     let (mut bob_group, _) = bob.join_group(None, welcome)?;
153     alice_group.apply_pending_commit()?;
154 
155     // Server starts observing Alice's group
156     let group_info = alice_group.group_info_message(true)?.to_bytes()?;
157     let mut server = BasicServer::create_group(&group_info)?;
158 
159     // Bob uploads a proposal
160     let proposal = bob_group
161         .propose_group_context_extensions(ExtensionList::new(), Vec::new())?
162         .to_bytes()?;
163 
164     server.upload_proposal(proposal)?;
165 
166     // Alice downloads all messages and commits
167     for m in server.download_messages(0) {
168         alice_group.process_incoming_message(MlsMessage::from_bytes(m)?)?;
169     }
170 
171     let commit = alice_group
172         .commit(b"changing extensions".to_vec())?
173         .commit_message
174         .to_bytes()?;
175 
176     server.upload_commit(commit)?;
177 
178     // Alice waits for an ACK from the server and applies the commit
179     alice_group.apply_pending_commit()?;
180 
181     // Bob downloads the commit
182     let message = server.download_messages(1).first().unwrap();
183 
184     let res = bob_group.process_incoming_message(MlsMessage::from_bytes(message)?)?;
185 
186     let ReceivedMessage::Commit(commit_desc) = res else {
187         panic!("expected commit message")
188     };
189 
190     assert_eq!(&commit_desc.authenticated_data, b"changing extensions");
191 
192     Ok(())
193 }
194