1 use super::{Codec, DecodeBuf, Decoder, Encoder};
2 use crate::codec::EncodeBuf;
3 use crate::{Code, Status};
4 use prost::Message;
5 use std::marker::PhantomData;
6 
7 /// A [`Codec`] that implements `application/grpc+proto` via the prost library..
8 #[derive(Debug, Clone)]
9 pub struct ProstCodec<T, U> {
10     _pd: PhantomData<(T, U)>,
11 }
12 
13 impl<T, U> Default for ProstCodec<T, U> {
default() -> Self14     fn default() -> Self {
15         Self { _pd: PhantomData }
16     }
17 }
18 
19 impl<T, U> Codec for ProstCodec<T, U>
20 where
21     T: Message + Send + 'static,
22     U: Message + Default + Send + 'static,
23 {
24     type Encode = T;
25     type Decode = U;
26 
27     type Encoder = ProstEncoder<T>;
28     type Decoder = ProstDecoder<U>;
29 
encoder(&mut self) -> Self::Encoder30     fn encoder(&mut self) -> Self::Encoder {
31         ProstEncoder(PhantomData)
32     }
33 
decoder(&mut self) -> Self::Decoder34     fn decoder(&mut self) -> Self::Decoder {
35         ProstDecoder(PhantomData)
36     }
37 }
38 
39 /// A [`Encoder`] that knows how to encode `T`.
40 #[derive(Debug, Clone, Default)]
41 pub struct ProstEncoder<T>(PhantomData<T>);
42 
43 impl<T: Message> Encoder for ProstEncoder<T> {
44     type Item = T;
45     type Error = Status;
46 
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>47     fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
48         item.encode(buf)
49             .expect("Message only errors if not enough space");
50 
51         Ok(())
52     }
53 }
54 
55 /// A [`Decoder`] that knows how to decode `U`.
56 #[derive(Debug, Clone, Default)]
57 pub struct ProstDecoder<U>(PhantomData<U>);
58 
59 impl<U: Message + Default> Decoder for ProstDecoder<U> {
60     type Item = U;
61     type Error = Status;
62 
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>63     fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
64         let item = Message::decode(buf)
65             .map(Option::Some)
66             .map_err(from_decode_error)?;
67 
68         Ok(item)
69     }
70 }
71 
from_decode_error(error: prost::DecodeError) -> crate::Status72 fn from_decode_error(error: prost::DecodeError) -> crate::Status {
73     // Map Protobuf parse errors to an INTERNAL status code, as per
74     // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
75     Status::new(Code::Internal, error.to_string())
76 }
77 
78 #[cfg(test)]
79 mod tests {
80     use crate::codec::compression::SingleMessageCompressionOverride;
81     use crate::codec::{
82         encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
83     };
84     use crate::{Code, Status};
85     use bytes::{Buf, BufMut, BytesMut};
86     use http_body::Body;
87 
88     const LEN: usize = 10000;
89     // The maximum uncompressed size in bytes for a message. Set to 2MB.
90     const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
91 
92     #[tokio::test]
decode()93     async fn decode() {
94         let decoder = MockDecoder::default();
95 
96         let msg = vec![0u8; LEN];
97 
98         let mut buf = BytesMut::new();
99 
100         buf.reserve(msg.len() + HEADER_SIZE);
101         buf.put_u8(0);
102         buf.put_u32(msg.len() as u32);
103 
104         buf.put(&msg[..]);
105 
106         let body = body::MockBody::new(&buf[..], 10005, 0);
107 
108         let mut stream = Streaming::new_request(decoder, body, None, None);
109 
110         let mut i = 0usize;
111         while let Some(output_msg) = stream.message().await.unwrap() {
112             assert_eq!(output_msg.len(), msg.len());
113             i += 1;
114         }
115         assert_eq!(i, 1);
116     }
117 
118     #[tokio::test]
decode_max_message_size_exceeded()119     async fn decode_max_message_size_exceeded() {
120         let decoder = MockDecoder::default();
121 
122         let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
123 
124         let mut buf = BytesMut::new();
125 
126         buf.reserve(msg.len() + HEADER_SIZE);
127         buf.put_u8(0);
128         buf.put_u32(msg.len() as u32);
129 
130         buf.put(&msg[..]);
131 
132         let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
133 
134         let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
135 
136         let actual = stream.message().await.unwrap_err();
137 
138         let expected = Status::new(
139             Code::OutOfRange,
140             format!(
141                 "Error, message length too large: found {} bytes, the limit is: {} bytes",
142                 msg.len(),
143                 MAX_MESSAGE_SIZE
144             ),
145         );
146 
147         assert_eq!(actual.code(), expected.code());
148         assert_eq!(actual.message(), expected.message());
149     }
150 
151     #[tokio::test]
encode()152     async fn encode() {
153         let encoder = MockEncoder::default();
154 
155         let msg = Vec::from(&[0u8; 1024][..]);
156 
157         let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
158         let source = tokio_stream::iter(messages);
159 
160         let body = encode_server(
161             encoder,
162             source,
163             None,
164             SingleMessageCompressionOverride::default(),
165             None,
166         );
167 
168         tokio::pin!(body);
169 
170         while let Some(r) = body.data().await {
171             r.unwrap();
172         }
173     }
174 
175     #[tokio::test]
encode_max_message_size_exceeded()176     async fn encode_max_message_size_exceeded() {
177         let encoder = MockEncoder::default();
178 
179         let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
180 
181         let messages = std::iter::once(Ok::<_, Status>(msg));
182         let source = tokio_stream::iter(messages);
183 
184         let body = encode_server(
185             encoder,
186             source,
187             None,
188             SingleMessageCompressionOverride::default(),
189             Some(MAX_MESSAGE_SIZE),
190         );
191 
192         tokio::pin!(body);
193 
194         assert!(body.data().await.is_none());
195         assert_eq!(
196             body.trailers()
197                 .await
198                 .expect("no error polling trailers")
199                 .expect("some trailers")
200                 .get("grpc-status")
201                 .expect("grpc-status header"),
202             "11"
203         );
204         assert!(body.is_end_stream());
205     }
206 
207     // skip on windows because CI stumbles over our 4GB allocation
208     #[cfg(not(target_family = "windows"))]
209     #[tokio::test]
encode_too_big()210     async fn encode_too_big() {
211         let encoder = MockEncoder::default();
212 
213         let msg = vec![0u8; u32::MAX as usize + 1];
214 
215         let messages = std::iter::once(Ok::<_, Status>(msg));
216         let source = tokio_stream::iter(messages);
217 
218         let body = encode_server(
219             encoder,
220             source,
221             None,
222             SingleMessageCompressionOverride::default(),
223             Some(usize::MAX),
224         );
225 
226         tokio::pin!(body);
227 
228         assert!(body.data().await.is_none());
229         assert_eq!(
230             body.trailers()
231                 .await
232                 .expect("no error polling trailers")
233                 .expect("some trailers")
234                 .get("grpc-status")
235                 .expect("grpc-status header"),
236             "8"
237         );
238         assert!(body.is_end_stream());
239     }
240 
241     #[derive(Debug, Clone, Default)]
242     struct MockEncoder;
243 
244     impl Encoder for MockEncoder {
245         type Item = Vec<u8>;
246         type Error = Status;
247 
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>248         fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
249             buf.put(&item[..]);
250             Ok(())
251         }
252     }
253 
254     #[derive(Debug, Clone, Default)]
255     struct MockDecoder;
256 
257     impl Decoder for MockDecoder {
258         type Item = Vec<u8>;
259         type Error = Status;
260 
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>261         fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
262             let out = Vec::from(buf.chunk());
263             buf.advance(LEN);
264             Ok(Some(out))
265         }
266     }
267 
268     mod body {
269         use crate::Status;
270         use bytes::Bytes;
271         use http_body::Body;
272         use std::{
273             pin::Pin,
274             task::{Context, Poll},
275         };
276 
277         #[derive(Debug)]
278         pub(super) struct MockBody {
279             data: Bytes,
280 
281             // the size of the partial message to send
282             partial_len: usize,
283 
284             // the number of times we've sent
285             count: usize,
286         }
287 
288         impl MockBody {
new(b: &[u8], partial_len: usize, count: usize) -> Self289             pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
290                 MockBody {
291                     data: Bytes::copy_from_slice(b),
292                     partial_len,
293                     count,
294                 }
295             }
296         }
297 
298         impl Body for MockBody {
299             type Data = Bytes;
300             type Error = Status;
301 
poll_data( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>302             fn poll_data(
303                 mut self: Pin<&mut Self>,
304                 cx: &mut Context<'_>,
305             ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
306                 // every other call to poll_data returns data
307                 let should_send = self.count % 2 == 0;
308                 let data_len = self.data.len();
309                 let partial_len = self.partial_len;
310                 let count = self.count;
311                 if data_len > 0 {
312                     let result = if should_send {
313                         let response =
314                             self.data
315                                 .split_to(if count == 0 { partial_len } else { data_len });
316                         Poll::Ready(Some(Ok(response)))
317                     } else {
318                         cx.waker().wake_by_ref();
319                         Poll::Pending
320                     };
321                     // make some fake progress
322                     self.count += 1;
323                     result
324                 } else {
325                     Poll::Ready(None)
326                 }
327             }
328 
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>>329             fn poll_trailers(
330                 self: Pin<&mut Self>,
331                 _cx: &mut Context<'_>,
332             ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
333                 Poll::Ready(Ok(None))
334             }
335         }
336     }
337 }
338