use std::{ fmt, future::Future, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// [`Service`] returned by the [`map_future`] combinator. /// /// [`map_future`]: crate::util::ServiceExt::map_future #[derive(Clone)] pub struct MapFuture { inner: S, f: F, } impl MapFuture { /// Creates a new [`MapFuture`] service. pub fn new(inner: S, f: F) -> Self { Self { inner, f } } /// Returns a new [`Layer`] that produces [`MapFuture`] services. /// /// This is a convenience function that simply calls [`MapFutureLayer::new`]. /// /// [`Layer`]: tower_layer::Layer pub fn layer(f: F) -> MapFutureLayer { MapFutureLayer::new(f) } /// Get a reference to the inner service pub fn get_ref(&self) -> &S { &self.inner } /// Get a mutable reference to the inner service pub fn get_mut(&mut self) -> &mut S { &mut self.inner } /// Consume `self`, returning the inner service pub fn into_inner(self) -> S { self.inner } } impl Service for MapFuture where S: Service, F: FnMut(S::Future) -> Fut, E: From, Fut: Future>, { type Response = T; type Error = E; type Future = Fut; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx).map_err(From::from) } fn call(&mut self, req: R) -> Self::Future { (self.f)(self.inner.call(req)) } } impl fmt::Debug for MapFuture where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapFuture") .field("inner", &self.inner) .field("f", &format_args!("{}", std::any::type_name::())) .finish() } } /// A [`Layer`] that produces a [`MapFuture`] service. /// /// [`Layer`]: tower_layer::Layer #[derive(Clone)] pub struct MapFutureLayer { f: F, } impl MapFutureLayer { /// Creates a new [`MapFutureLayer`] layer. pub fn new(f: F) -> Self { Self { f } } } impl Layer for MapFutureLayer where F: Clone, { type Service = MapFuture; fn layer(&self, inner: S) -> Self::Service { MapFuture::new(inner, self.f.clone()) } } impl fmt::Debug for MapFutureLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapFutureLayer") .field("f", &format_args!("{}", std::any::type_name::())) .finish() } }