1 use super::{Codec, DecodeBuf, Decoder, Encoder};
2 use crate::codec::EncodeBuf;
3 use crate::{Code, Status};
4 use prost::Message;
5 use std::marker::PhantomData;
6
7 /// A [`Codec`] that implements `application/grpc+proto` via the prost library..
8 #[derive(Debug, Clone)]
9 pub struct ProstCodec<T, U> {
10 _pd: PhantomData<(T, U)>,
11 }
12
13 impl<T, U> Default for ProstCodec<T, U> {
default() -> Self14 fn default() -> Self {
15 Self { _pd: PhantomData }
16 }
17 }
18
19 impl<T, U> Codec for ProstCodec<T, U>
20 where
21 T: Message + Send + 'static,
22 U: Message + Default + Send + 'static,
23 {
24 type Encode = T;
25 type Decode = U;
26
27 type Encoder = ProstEncoder<T>;
28 type Decoder = ProstDecoder<U>;
29
encoder(&mut self) -> Self::Encoder30 fn encoder(&mut self) -> Self::Encoder {
31 ProstEncoder(PhantomData)
32 }
33
decoder(&mut self) -> Self::Decoder34 fn decoder(&mut self) -> Self::Decoder {
35 ProstDecoder(PhantomData)
36 }
37 }
38
39 /// A [`Encoder`] that knows how to encode `T`.
40 #[derive(Debug, Clone, Default)]
41 pub struct ProstEncoder<T>(PhantomData<T>);
42
43 impl<T: Message> Encoder for ProstEncoder<T> {
44 type Item = T;
45 type Error = Status;
46
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>47 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
48 item.encode(buf)
49 .expect("Message only errors if not enough space");
50
51 Ok(())
52 }
53 }
54
55 /// A [`Decoder`] that knows how to decode `U`.
56 #[derive(Debug, Clone, Default)]
57 pub struct ProstDecoder<U>(PhantomData<U>);
58
59 impl<U: Message + Default> Decoder for ProstDecoder<U> {
60 type Item = U;
61 type Error = Status;
62
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>63 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
64 let item = Message::decode(buf)
65 .map(Option::Some)
66 .map_err(from_decode_error)?;
67
68 Ok(item)
69 }
70 }
71
from_decode_error(error: prost::DecodeError) -> crate::Status72 fn from_decode_error(error: prost::DecodeError) -> crate::Status {
73 // Map Protobuf parse errors to an INTERNAL status code, as per
74 // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
75 Status::new(Code::Internal, error.to_string())
76 }
77
78 #[cfg(test)]
79 mod tests {
80 use crate::codec::compression::SingleMessageCompressionOverride;
81 use crate::codec::{
82 encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
83 };
84 use crate::{Code, Status};
85 use bytes::{Buf, BufMut, BytesMut};
86 use http_body::Body;
87
88 const LEN: usize = 10000;
89 // The maximum uncompressed size in bytes for a message. Set to 2MB.
90 const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
91
92 #[tokio::test]
decode()93 async fn decode() {
94 let decoder = MockDecoder::default();
95
96 let msg = vec![0u8; LEN];
97
98 let mut buf = BytesMut::new();
99
100 buf.reserve(msg.len() + HEADER_SIZE);
101 buf.put_u8(0);
102 buf.put_u32(msg.len() as u32);
103
104 buf.put(&msg[..]);
105
106 let body = body::MockBody::new(&buf[..], 10005, 0);
107
108 let mut stream = Streaming::new_request(decoder, body, None, None);
109
110 let mut i = 0usize;
111 while let Some(output_msg) = stream.message().await.unwrap() {
112 assert_eq!(output_msg.len(), msg.len());
113 i += 1;
114 }
115 assert_eq!(i, 1);
116 }
117
118 #[tokio::test]
decode_max_message_size_exceeded()119 async fn decode_max_message_size_exceeded() {
120 let decoder = MockDecoder::default();
121
122 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
123
124 let mut buf = BytesMut::new();
125
126 buf.reserve(msg.len() + HEADER_SIZE);
127 buf.put_u8(0);
128 buf.put_u32(msg.len() as u32);
129
130 buf.put(&msg[..]);
131
132 let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
133
134 let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
135
136 let actual = stream.message().await.unwrap_err();
137
138 let expected = Status::new(
139 Code::OutOfRange,
140 format!(
141 "Error, message length too large: found {} bytes, the limit is: {} bytes",
142 msg.len(),
143 MAX_MESSAGE_SIZE
144 ),
145 );
146
147 assert_eq!(actual.code(), expected.code());
148 assert_eq!(actual.message(), expected.message());
149 }
150
151 #[tokio::test]
encode()152 async fn encode() {
153 let encoder = MockEncoder::default();
154
155 let msg = Vec::from(&[0u8; 1024][..]);
156
157 let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
158 let source = tokio_stream::iter(messages);
159
160 let body = encode_server(
161 encoder,
162 source,
163 None,
164 SingleMessageCompressionOverride::default(),
165 None,
166 );
167
168 tokio::pin!(body);
169
170 while let Some(r) = body.data().await {
171 r.unwrap();
172 }
173 }
174
175 #[tokio::test]
encode_max_message_size_exceeded()176 async fn encode_max_message_size_exceeded() {
177 let encoder = MockEncoder::default();
178
179 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
180
181 let messages = std::iter::once(Ok::<_, Status>(msg));
182 let source = tokio_stream::iter(messages);
183
184 let body = encode_server(
185 encoder,
186 source,
187 None,
188 SingleMessageCompressionOverride::default(),
189 Some(MAX_MESSAGE_SIZE),
190 );
191
192 tokio::pin!(body);
193
194 assert!(body.data().await.is_none());
195 assert_eq!(
196 body.trailers()
197 .await
198 .expect("no error polling trailers")
199 .expect("some trailers")
200 .get("grpc-status")
201 .expect("grpc-status header"),
202 "11"
203 );
204 assert!(body.is_end_stream());
205 }
206
207 // skip on windows because CI stumbles over our 4GB allocation
208 #[cfg(not(target_family = "windows"))]
209 #[tokio::test]
encode_too_big()210 async fn encode_too_big() {
211 let encoder = MockEncoder::default();
212
213 let msg = vec![0u8; u32::MAX as usize + 1];
214
215 let messages = std::iter::once(Ok::<_, Status>(msg));
216 let source = tokio_stream::iter(messages);
217
218 let body = encode_server(
219 encoder,
220 source,
221 None,
222 SingleMessageCompressionOverride::default(),
223 Some(usize::MAX),
224 );
225
226 tokio::pin!(body);
227
228 assert!(body.data().await.is_none());
229 assert_eq!(
230 body.trailers()
231 .await
232 .expect("no error polling trailers")
233 .expect("some trailers")
234 .get("grpc-status")
235 .expect("grpc-status header"),
236 "8"
237 );
238 assert!(body.is_end_stream());
239 }
240
241 #[derive(Debug, Clone, Default)]
242 struct MockEncoder;
243
244 impl Encoder for MockEncoder {
245 type Item = Vec<u8>;
246 type Error = Status;
247
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>248 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
249 buf.put(&item[..]);
250 Ok(())
251 }
252 }
253
254 #[derive(Debug, Clone, Default)]
255 struct MockDecoder;
256
257 impl Decoder for MockDecoder {
258 type Item = Vec<u8>;
259 type Error = Status;
260
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>261 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
262 let out = Vec::from(buf.chunk());
263 buf.advance(LEN);
264 Ok(Some(out))
265 }
266 }
267
268 mod body {
269 use crate::Status;
270 use bytes::Bytes;
271 use http_body::Body;
272 use std::{
273 pin::Pin,
274 task::{Context, Poll},
275 };
276
277 #[derive(Debug)]
278 pub(super) struct MockBody {
279 data: Bytes,
280
281 // the size of the partial message to send
282 partial_len: usize,
283
284 // the number of times we've sent
285 count: usize,
286 }
287
288 impl MockBody {
new(b: &[u8], partial_len: usize, count: usize) -> Self289 pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
290 MockBody {
291 data: Bytes::copy_from_slice(b),
292 partial_len,
293 count,
294 }
295 }
296 }
297
298 impl Body for MockBody {
299 type Data = Bytes;
300 type Error = Status;
301
poll_data( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>302 fn poll_data(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context<'_>,
305 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
306 // every other call to poll_data returns data
307 let should_send = self.count % 2 == 0;
308 let data_len = self.data.len();
309 let partial_len = self.partial_len;
310 let count = self.count;
311 if data_len > 0 {
312 let result = if should_send {
313 let response =
314 self.data
315 .split_to(if count == 0 { partial_len } else { data_len });
316 Poll::Ready(Some(Ok(response)))
317 } else {
318 cx.waker().wake_by_ref();
319 Poll::Pending
320 };
321 // make some fake progress
322 self.count += 1;
323 result
324 } else {
325 Poll::Ready(None)
326 }
327 }
328
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>>329 fn poll_trailers(
330 self: Pin<&mut Self>,
331 _cx: &mut Context<'_>,
332 ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
333 Poll::Ready(Ok(None))
334 }
335 }
336 }
337 }
338