use futures::{channel::{mpsc, oneshot}, prelude::*};
use libp2p::{
core::{
connection::{ConnectionId, ListenerId},
ConnectedPoint, Multiaddr, PeerId,
},
request_response::{
RequestResponse, RequestResponseCodec, RequestResponseConfig, RequestResponseEvent,
RequestResponseMessage, ResponseChannel, ProtocolSupport
},
swarm::{
protocols_handler::multi::MultiHandler, NetworkBehaviour, NetworkBehaviourAction,
PollParameters, ProtocolsHandler,
},
};
use std::{
borrow::Cow, collections::{hash_map::Entry, HashMap}, convert::TryFrom as _, io, iter,
pin::Pin, task::{Context, Poll}, time::Duration,
};
pub use libp2p::request_response::{InboundFailure, OutboundFailure, RequestId};
#[derive(Debug, Clone)]
pub struct ProtocolConfig {
pub name: Cow<'static, str>,
pub max_request_size: u64,
pub max_response_size: u64,
pub request_timeout: Duration,
pub inbound_queue: Option<mpsc::Sender<IncomingRequest>>,
}
#[derive(Debug)]
pub struct IncomingRequest {
pub peer: PeerId,
pub payload: Vec<u8>,
pub pending_response: oneshot::Sender<Vec<u8>>,
}
#[derive(Debug)]
pub enum Event {
InboundRequest {
peer: PeerId,
protocol: Cow<'static, str>,
result: Result<Duration, ResponseFailure>,
},
RequestFinished {
request_id: RequestId,
result: Result<Vec<u8>, RequestFailure>,
},
}
pub struct RequestResponsesBehaviour {
protocols: HashMap<
Cow<'static, str>,
(RequestResponse<GenericCodec>, Option<mpsc::Sender<IncomingRequest>>)
>,
pending_responses: stream::FuturesUnordered<
Pin<Box<dyn Future<Output = RequestProcessingOutcome> + Send>>
>,
}
enum RequestProcessingOutcome {
Response {
protocol: Cow<'static, str>,
inner_channel: ResponseChannel<Result<Vec<u8>, ()>>,
response: Vec<u8>,
},
Busy {
peer: PeerId,
protocol: Cow<'static, str>,
},
}
impl RequestResponsesBehaviour {
pub fn new(list: impl Iterator<Item = ProtocolConfig>) -> Result<Self, RegisterError> {
let mut protocols = HashMap::new();
for protocol in list {
let mut cfg = RequestResponseConfig::default();
cfg.set_connection_keep_alive(Duration::from_secs(10));
cfg.set_request_timeout(protocol.request_timeout);
let protocol_support = if protocol.inbound_queue.is_some() {
ProtocolSupport::Full
} else {
ProtocolSupport::Outbound
};
let rq_rp = RequestResponse::new(GenericCodec {
max_request_size: protocol.max_request_size,
max_response_size: protocol.max_response_size,
}, iter::once((protocol.name.as_bytes().to_vec(), protocol_support)), cfg);
match protocols.entry(protocol.name) {
Entry::Vacant(e) => e.insert((rq_rp, protocol.inbound_queue)),
Entry::Occupied(e) =>
return Err(RegisterError::DuplicateProtocol(e.key().clone())),
};
}
Ok(Self {
protocols,
pending_responses: stream::FuturesUnordered::new(),
})
}
pub fn send_request(&mut self, target: &PeerId, protocol: &str, request: Vec<u8>)
-> Result<RequestId, SendRequestError>
{
if let Some((protocol, _)) = self.protocols.get_mut(protocol) {
if protocol.is_connected(target) {
Ok(protocol.send_request(target, request))
} else {
Err(SendRequestError::NotConnected)
}
} else {
Err(SendRequestError::UnknownProtocol)
}
}
}
impl NetworkBehaviour for RequestResponsesBehaviour {
type ProtocolsHandler = MultiHandler<
String,
<RequestResponse<GenericCodec> as NetworkBehaviour>::ProtocolsHandler,
>;
type OutEvent = Event;
fn new_handler(&mut self) -> Self::ProtocolsHandler {
let iter = self.protocols.iter_mut()
.map(|(p, (r, _))| (p.to_string(), NetworkBehaviour::new_handler(r)));
MultiHandler::try_from_iter(iter)
.expect("Protocols are in a HashMap and there can be at most one handler per \
protocol name, which is the only possible error; qed")
}
fn addresses_of_peer(&mut self, _: &PeerId) -> Vec<Multiaddr> {
Vec::new()
}
fn inject_connection_established(
&mut self,
peer_id: &PeerId,
conn: &ConnectionId,
endpoint: &ConnectedPoint,
) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_connection_established(p, peer_id, conn, endpoint)
}
}
fn inject_connected(&mut self, peer_id: &PeerId) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_connected(p, peer_id)
}
}
fn inject_connection_closed(&mut self, peer_id: &PeerId, conn: &ConnectionId, endpoint: &ConnectedPoint) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_connection_closed(p, peer_id, conn, endpoint)
}
}
fn inject_disconnected(&mut self, peer_id: &PeerId) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_disconnected(p, peer_id)
}
}
fn inject_addr_reach_failure(
&mut self,
peer_id: Option<&PeerId>,
addr: &Multiaddr,
error: &dyn std::error::Error
) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_addr_reach_failure(p, peer_id, addr, error)
}
}
fn inject_event(
&mut self,
peer_id: PeerId,
connection: ConnectionId,
(p_name, event): <Self::ProtocolsHandler as ProtocolsHandler>::OutEvent,
) {
if let Some((proto, _)) = self.protocols.get_mut(&*p_name) {
return proto.inject_event(peer_id, connection, event)
}
log::warn!(target: "sub-libp2p",
"inject_node_event: no request-response instance registered for protocol {:?}",
p_name)
}
fn inject_new_external_addr(&mut self, addr: &Multiaddr) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_new_external_addr(p, addr)
}
}
fn inject_expired_listen_addr(&mut self, addr: &Multiaddr) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_expired_listen_addr(p, addr)
}
}
fn inject_dial_failure(&mut self, peer_id: &PeerId) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_dial_failure(p, peer_id)
}
}
fn inject_new_listen_addr(&mut self, addr: &Multiaddr) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_new_listen_addr(p, addr)
}
}
fn inject_listener_error(&mut self, id: ListenerId, err: &(dyn std::error::Error + 'static)) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_listener_error(p, id, err)
}
}
fn inject_listener_closed(&mut self, id: ListenerId, reason: Result<(), &io::Error>) {
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::inject_listener_closed(p, id, reason)
}
}
fn poll(
&mut self,
cx: &mut Context,
params: &mut impl PollParameters,
) -> Poll<
NetworkBehaviourAction<
<Self::ProtocolsHandler as ProtocolsHandler>::InEvent,
Self::OutEvent,
>,
> {
'poll_all: loop {
while let Poll::Ready(Some(result)) = self.pending_responses.poll_next_unpin(cx) {
match result {
RequestProcessingOutcome::Response {
protocol, inner_channel, response
} => {
if let Some((protocol, _)) = self.protocols.get_mut(&*protocol) {
protocol.send_response(inner_channel, Ok(response));
}
}
RequestProcessingOutcome::Busy { peer, protocol } => {
let out = Event::InboundRequest {
peer,
protocol,
result: Err(ResponseFailure::Busy),
};
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(out));
}
}
}
for (protocol, (behaviour, resp_builder)) in &mut self.protocols {
while let Poll::Ready(ev) = behaviour.poll(cx, params) {
let ev = match ev {
NetworkBehaviourAction::GenerateEvent(ev) => ev,
NetworkBehaviourAction::DialAddress { address } => {
log::error!("The request-response isn't supposed to start dialing peers");
return Poll::Ready(NetworkBehaviourAction::DialAddress { address })
}
NetworkBehaviourAction::DialPeer { peer_id, condition } => {
log::error!("The request-response isn't supposed to start dialing peers");
return Poll::Ready(NetworkBehaviourAction::DialPeer {
peer_id,
condition,
})
}
NetworkBehaviourAction::NotifyHandler {
peer_id,
handler,
event,
} => {
return Poll::Ready(NetworkBehaviourAction::NotifyHandler {
peer_id,
handler,
event: ((*protocol).to_string(), event),
})
}
NetworkBehaviourAction::ReportObservedAddr { address } => {
return Poll::Ready(NetworkBehaviourAction::ReportObservedAddr {
address,
})
}
};
match ev {
RequestResponseEvent::Message {
peer,
message: RequestResponseMessage::Request { request, channel, .. },
} => {
let (tx, rx) = oneshot::channel();
if let Some(resp_builder) = resp_builder {
let _ = resp_builder.try_send(IncomingRequest {
peer: peer.clone(),
payload: request,
pending_response: tx,
});
}
let protocol = protocol.clone();
self.pending_responses.push(Box::pin(async move {
if let Ok(response) = rx.await {
RequestProcessingOutcome::Response {
protocol, inner_channel: channel, response
}
} else {
RequestProcessingOutcome::Busy { peer, protocol }
}
}));
continue 'poll_all;
}
RequestResponseEvent::Message {
message:
RequestResponseMessage::Response {
request_id,
response,
},
..
} => {
let out = Event::RequestFinished {
request_id,
result: response.map_err(|()| RequestFailure::Refused),
};
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(out));
}
RequestResponseEvent::OutboundFailure {
request_id,
error,
..
} => {
let out = Event::RequestFinished {
request_id,
result: Err(RequestFailure::Network(error)),
};
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(out));
}
RequestResponseEvent::InboundFailure { peer, error, .. } => {
let out = Event::InboundRequest {
peer,
protocol: protocol.clone(),
result: Err(ResponseFailure::Network(error)),
};
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(out));
}
};
}
}
break Poll::Pending;
}
}
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum RegisterError {
DuplicateProtocol(#[error(ignore)] Cow<'static, str>),
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum SendRequestError {
NotConnected,
UnknownProtocol,
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum RequestFailure {
Refused,
#[display(fmt = "Problem on the network")]
Network(#[error(ignore)] OutboundFailure),
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum ResponseFailure {
Busy,
#[display(fmt = "Problem on the network")]
Network(#[error(ignore)] InboundFailure),
}
#[derive(Debug, Clone)]
#[doc(hidden)]
pub struct GenericCodec {
max_request_size: u64,
max_response_size: u64,
}
#[async_trait::async_trait]
impl RequestResponseCodec for GenericCodec {
type Protocol = Vec<u8>;
type Request = Vec<u8>;
type Response = Result<Vec<u8>, ()>;
async fn read_request<T>(
&mut self,
_: &Self::Protocol,
mut io: &mut T,
) -> io::Result<Self::Request>
where
T: AsyncRead + Unpin + Send,
{
let length = unsigned_varint::aio::read_usize(&mut io).await
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
if length > usize::try_from(self.max_request_size).unwrap_or(usize::max_value()) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Request size exceeds limit: {} > {}", length, self.max_request_size)
));
}
let mut buffer = vec![0; length];
io.read_exact(&mut buffer).await?;
Ok(buffer)
}
async fn read_response<T>(
&mut self,
_: &Self::Protocol,
mut io: &mut T,
) -> io::Result<Self::Response>
where
T: AsyncRead + Unpin + Send,
{
let length = match unsigned_varint::aio::read_usize(&mut io).await {
Ok(l) => l,
Err(unsigned_varint::io::ReadError::Io(err))
if matches!(err.kind(), io::ErrorKind::UnexpectedEof) =>
{
return Ok(Err(()));
}
Err(err) => return Err(io::Error::new(io::ErrorKind::InvalidInput, err)),
};
if length > usize::try_from(self.max_response_size).unwrap_or(usize::max_value()) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Response size exceeds limit: {} > {}", length, self.max_response_size)
));
}
let mut buffer = vec![0; length];
io.read_exact(&mut buffer).await?;
Ok(Ok(buffer))
}
async fn write_request<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
req: Self::Request,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
{
let mut buffer = unsigned_varint::encode::usize_buffer();
io.write_all(unsigned_varint::encode::usize(req.len(), &mut buffer)).await?;
}
io.write_all(&req).await?;
io.close().await?;
Ok(())
}
async fn write_response<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
res: Self::Response,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
if let Ok(res) = res {
{
let mut buffer = unsigned_varint::encode::usize_buffer();
io.write_all(unsigned_varint::encode::usize(res.len(), &mut buffer)).await?;
}
io.write_all(&res).await?;
}
io.close().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use futures::{channel::mpsc, prelude::*};
use libp2p::identity::Keypair;
use libp2p::Multiaddr;
use libp2p::core::upgrade;
use libp2p::core::transport::{Transport, MemoryTransport};
use libp2p::noise;
use libp2p::swarm::{Swarm, SwarmEvent};
use std::{iter, time::Duration};
#[test]
fn basic_request_response_works() {
let protocol_name = "/test/req-rep/1";
let mut swarms = (0..2)
.map(|_| {
let keypair = Keypair::generate_ed25519();
let noise_keys = noise::Keypair::<noise::X25519Spec>::new()
.into_authentic(&keypair)
.unwrap();
let transport = MemoryTransport
.upgrade(upgrade::Version::V1)
.authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated())
.multiplex(libp2p::yamux::Config::default());
let behaviour = {
let (tx, mut rx) = mpsc::channel(64);
let b = super::RequestResponsesBehaviour::new(iter::once(super::ProtocolConfig {
name: From::from(protocol_name),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
})).unwrap();
async_std::task::spawn(async move {
while let Some(rq) = rx.next().await {
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(b"this is a response".to_vec());
}
});
b
};
let mut swarm = Swarm::new(transport, behaviour, keypair.public().into_peer_id());
let listen_addr: Multiaddr = format!("/memory/{}", rand::random::<u64>()).parse().unwrap();
Swarm::listen_on(&mut swarm, listen_addr.clone()).unwrap();
(swarm, listen_addr)
})
.collect::<Vec<_>>();
{
let dial_addr = swarms[1].1.clone();
Swarm::dial_addr(&mut swarms[0].0, dial_addr).unwrap();
}
async_std::task::spawn({
let (mut swarm, _) = swarms.remove(0);
async move {
loop {
match swarm.next_event().await {
SwarmEvent::Behaviour(super::Event::InboundRequest { result, .. }) => {
assert!(result.is_ok());
break
},
_ => {}
}
}
}
});
let (mut swarm, _) = swarms.remove(0);
async_std::task::block_on(async move {
let mut sent_request_id = None;
loop {
match swarm.next_event().await {
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
let id = swarm.send_request(
&peer_id,
protocol_name,
b"this is a request".to_vec()
).unwrap();
assert!(sent_request_id.is_none());
sent_request_id = Some(id);
}
SwarmEvent::Behaviour(super::Event::RequestFinished {
request_id,
result,
}) => {
assert_eq!(Some(request_id), sent_request_id);
let result = result.unwrap();
assert_eq!(result, b"this is a response");
break;
}
_ => {}
}
}
});
}
#[test]
fn max_response_size_exceeded() {
let protocol_name = "/test/req-rep/1";
let mut swarms = (0..2)
.map(|_| {
let keypair = Keypair::generate_ed25519();
let noise_keys = noise::Keypair::<noise::X25519Spec>::new()
.into_authentic(&keypair)
.unwrap();
let transport = MemoryTransport
.upgrade(upgrade::Version::V1)
.authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated())
.multiplex(libp2p::yamux::Config::default());
let behaviour = {
let (tx, mut rx) = mpsc::channel(64);
let b = super::RequestResponsesBehaviour::new(iter::once(super::ProtocolConfig {
name: From::from(protocol_name),
max_request_size: 1024,
max_response_size: 8,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
})).unwrap();
async_std::task::spawn(async move {
while let Some(rq) = rx.next().await {
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(b"this response exceeds the limit".to_vec());
}
});
b
};
let mut swarm = Swarm::new(transport, behaviour, keypair.public().into_peer_id());
let listen_addr: Multiaddr = format!("/memory/{}", rand::random::<u64>()).parse().unwrap();
Swarm::listen_on(&mut swarm, listen_addr.clone()).unwrap();
(swarm, listen_addr)
})
.collect::<Vec<_>>();
{
let dial_addr = swarms[1].1.clone();
Swarm::dial_addr(&mut swarms[0].0, dial_addr).unwrap();
}
async_std::task::spawn({
let (mut swarm, _) = swarms.remove(0);
async move {
loop {
match swarm.next_event().await {
SwarmEvent::Behaviour(super::Event::InboundRequest { result, .. }) => {
assert!(result.is_ok());
break
},
_ => {}
}
}
}
});
let (mut swarm, _) = swarms.remove(0);
async_std::task::block_on(async move {
let mut sent_request_id = None;
loop {
match swarm.next_event().await {
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
let id = swarm.send_request(
&peer_id,
protocol_name,
b"this is a request".to_vec()
).unwrap();
assert!(sent_request_id.is_none());
sent_request_id = Some(id);
}
SwarmEvent::Behaviour(super::Event::RequestFinished {
request_id,
result,
}) => {
assert_eq!(Some(request_id), sent_request_id);
match result {
Err(super::RequestFailure::Network(super::OutboundFailure::ConnectionClosed)) => {},
_ => panic!()
}
break;
}
_ => {}
}
}
});
}
}