mod codec;
use std::{cmp, iter, mem, pin::Pin, task::Context, task::Poll};
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::sync::Arc;
use std::task::Waker;
use bytes::Bytes;
use libp2p_core::{
Endpoint,
StreamMuxer,
muxing::StreamMuxerEvent,
upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
};
use log::{debug, trace};
use parking_lot::Mutex;
use fnv::FnvHashSet;
use futures::{prelude::*, future, ready, stream::Fuse};
use futures::task::{ArcWake, waker_ref};
use futures_codec::Framed;
#[derive(Debug, Clone)]
pub struct MplexConfig {
max_substreams: usize,
max_buffer_len: usize,
max_buffer_behaviour: MaxBufferBehaviour,
split_send_size: usize,
}
impl MplexConfig {
pub fn new() -> MplexConfig {
Default::default()
}
pub fn max_substreams(&mut self, max: usize) -> &mut Self {
self.max_substreams = max;
self
}
pub fn max_buffer_len(&mut self, max: usize) -> &mut Self {
self.max_buffer_len = max;
self
}
pub fn max_buffer_len_behaviour(&mut self, behaviour: MaxBufferBehaviour) -> &mut Self {
self.max_buffer_behaviour = behaviour;
self
}
pub fn split_send_size(&mut self, size: usize) -> &mut Self {
let size = cmp::min(size, codec::MAX_FRAME_SIZE);
self.split_send_size = size;
self
}
fn upgrade<C>(self, i: C) -> Multiplex<C>
where
C: AsyncRead + AsyncWrite + Unpin
{
let max_buffer_len = self.max_buffer_len;
Multiplex {
inner: Mutex::new(MultiplexInner {
error: Ok(()),
inner: Framed::new(i, codec::Codec::new()).fuse(),
config: self,
buffer: Vec::with_capacity(cmp::min(max_buffer_len, 512)),
opened_substreams: Default::default(),
next_outbound_stream_id: 0,
notifier_read: Arc::new(Notifier {
to_wake: Mutex::new(Default::default()),
}),
notifier_write: Arc::new(Notifier {
to_wake: Mutex::new(Default::default()),
}),
is_shutdown: false,
})
}
}
}
impl Default for MplexConfig {
fn default() -> MplexConfig {
MplexConfig {
max_substreams: 128,
max_buffer_len: 4096,
max_buffer_behaviour: MaxBufferBehaviour::CloseAll,
split_send_size: 1024,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum MaxBufferBehaviour {
CloseAll,
Block,
}
impl UpgradeInfo for MplexConfig {
type Info = &'static [u8];
type InfoIter = iter::Once<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
iter::once(b"/mplex/6.7.0")
}
}
impl<C> InboundUpgrade<C> for MplexConfig
where
C: AsyncRead + AsyncWrite + Unpin,
{
type Output = Multiplex<C>;
type Error = IoError;
type Future = future::Ready<Result<Self::Output, IoError>>;
fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future {
future::ready(Ok(self.upgrade(socket)))
}
}
impl<C> OutboundUpgrade<C> for MplexConfig
where
C: AsyncRead + AsyncWrite + Unpin,
{
type Output = Multiplex<C>;
type Error = IoError;
type Future = future::Ready<Result<Self::Output, IoError>>;
fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future {
future::ready(Ok(self.upgrade(socket)))
}
}
pub struct Multiplex<C> {
inner: Mutex<MultiplexInner<C>>,
}
struct MultiplexInner<C> {
error: Result<(), IoError>,
inner: Fuse<Framed<C, codec::Codec>>,
config: MplexConfig,
buffer: Vec<codec::Elem>,
opened_substreams: FnvHashSet<(u32, Endpoint)>,
next_outbound_stream_id: u32,
notifier_read: Arc<Notifier>,
notifier_write: Arc<Notifier>,
is_shutdown: bool,
}
struct Notifier {
to_wake: Mutex<Vec<Waker>>,
}
impl Notifier {
fn insert(&self, waker: &Waker) {
let mut to_wake = self.to_wake.lock();
if to_wake.iter().all(|w| !w.will_wake(waker)) {
to_wake.push(waker.clone());
}
}
}
impl ArcWake for Notifier {
fn wake_by_ref(arc_self: &Arc<Self>) {
let wakers = mem::replace(&mut *arc_self.to_wake.lock(), Default::default());
for waker in wakers {
waker.wake();
}
}
}
fn next_match<C, F, O>(inner: &mut MultiplexInner<C>, cx: &mut Context<'_>, mut filter: F) -> Poll<Result<O, IoError>>
where C: AsyncRead + AsyncWrite + Unpin,
F: FnMut(&codec::Elem) -> Option<O>,
{
if let Err(ref err) = inner.error {
return Poll::Ready(Err(IoError::new(err.kind(), err.to_string())));
}
if let Some((offset, out)) = inner.buffer.iter().enumerate().filter_map(|(n, v)| filter(v).map(|v| (n, v))).next() {
if inner.buffer.len() == inner.config.max_buffer_len {
ArcWake::wake_by_ref(&inner.notifier_read);
}
inner.buffer.remove(offset);
return Poll::Ready(Ok(out));
}
loop {
debug_assert!(inner.buffer.len() <= inner.config.max_buffer_len);
if inner.buffer.len() == inner.config.max_buffer_len {
debug!("Reached mplex maximum buffer length");
match inner.config.max_buffer_behaviour {
MaxBufferBehaviour::CloseAll => {
inner.error = Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length"));
return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length")));
},
MaxBufferBehaviour::Block => {
inner.notifier_read.insert(cx.waker());
return Poll::Pending
},
}
}
inner.notifier_read.insert(cx.waker());
let elem = match Stream::poll_next(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_read))) {
Poll::Ready(Some(Ok(item))) => item,
Poll::Ready(None) => return Poll::Ready(Err(IoErrorKind::BrokenPipe.into())),
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Err(err))) => {
let err2 = IoError::new(err.kind(), err.to_string());
inner.error = Err(err);
return Poll::Ready(Err(err2));
},
};
trace!("Received message: {:?}", elem);
match elem {
codec::Elem::Open { substream_id } => {
if !inner.opened_substreams.insert((substream_id, Endpoint::Listener)) {
debug!("Received open message for substream {} which was already open", substream_id)
}
}
codec::Elem::Close { substream_id, endpoint, .. } | codec::Elem::Reset { substream_id, endpoint, .. } => {
inner.opened_substreams.remove(&(substream_id, !endpoint));
}
_ => ()
}
if let Some(out) = filter(&elem) {
return Poll::Ready(Ok(out));
} else {
let endpoint = elem.endpoint().unwrap_or(Endpoint::Dialer);
if inner.opened_substreams.contains(&(elem.substream_id(), !endpoint)) || elem.is_open_msg() {
inner.buffer.push(elem);
} else if !elem.is_close_or_reset_msg() {
debug!("Ignored message {:?} because the substream wasn't open", elem);
}
}
}
}
fn poll_send<C>(inner: &mut MultiplexInner<C>, cx: &mut Context<'_>, elem: codec::Elem) -> Poll<Result<(), IoError>>
where C: AsyncRead + AsyncWrite + Unpin
{
ensure_no_error_no_close(inner)?;
inner.notifier_write.insert(cx.waker());
match Sink::poll_ready(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write))) {
Poll::Ready(Ok(())) => {
match Sink::start_send(Pin::new(&mut inner.inner), elem) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(err))
}
},
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
inner.error = Err(IoError::new(err.kind(), err.to_string()));
Poll::Ready(Err(err))
}
}
}
fn ensure_no_error_no_close<C>(inner: &mut MultiplexInner<C>) -> Result<(), IoError>
where
C: AsyncRead + AsyncWrite + Unpin
{
if inner.is_shutdown {
return Err(IoError::new(IoErrorKind::Other, "connection is shut down"))
}
if let Err(ref e) = inner.error {
return Err(IoError::new(e.kind(), e.to_string()))
}
Ok(())
}
impl<C> StreamMuxer for Multiplex<C>
where C: AsyncRead + AsyncWrite + Unpin
{
type Substream = Substream;
type OutboundSubstream = OutboundSubstream;
type Error = IoError;
fn poll_event(&self, cx: &mut Context<'_>) -> Poll<Result<StreamMuxerEvent<Self::Substream>, IoError>> {
let mut inner = self.inner.lock();
if inner.opened_substreams.len() >= inner.config.max_substreams {
debug!("Refused substream; reached maximum number of substreams {}", inner.config.max_substreams);
return Poll::Ready(Err(IoError::new(IoErrorKind::ConnectionRefused,
"exceeded maximum number of open substreams")));
}
let num = ready!(next_match(&mut inner, cx, |elem| {
match elem {
codec::Elem::Open { substream_id } => Some(*substream_id),
_ => None,
}
}));
let num = match num {
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};
debug!("Successfully opened inbound substream {}", num);
Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(Substream {
current_data: Bytes::new(),
num,
endpoint: Endpoint::Listener,
local_open: true,
remote_open: true,
})))
}
fn open_outbound(&self) -> Self::OutboundSubstream {
let mut inner = self.inner.lock();
let substream_id = {
let n = inner.next_outbound_stream_id;
inner.next_outbound_stream_id = inner.next_outbound_stream_id.checked_add(1)
.expect("Mplex substream ID overflowed");
n
};
inner.opened_substreams.insert((substream_id, Endpoint::Dialer));
OutboundSubstream {
num: substream_id,
state: OutboundSubstreamState::SendElem(codec::Elem::Open { substream_id }),
}
}
fn poll_outbound(&self, cx: &mut Context<'_>, substream: &mut Self::OutboundSubstream) -> Poll<Result<Self::Substream, IoError>> {
loop {
let mut inner = self.inner.lock();
let polling = match substream.state {
OutboundSubstreamState::SendElem(ref elem) => {
poll_send(&mut inner, cx, elem.clone())
},
OutboundSubstreamState::Flush => {
ensure_no_error_no_close(&mut inner)?;
let inner = &mut *inner;
inner.notifier_write.insert(cx.waker());
Sink::poll_flush(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write)))
},
OutboundSubstreamState::Done => {
panic!("Polling outbound substream after it's been succesfully open");
},
};
match polling {
Poll::Ready(Ok(())) => (),
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => {
debug!("Failed to open outbound substream {}", substream.num);
inner.buffer.retain(|elem| {
elem.substream_id() != substream.num || elem.endpoint() == Some(Endpoint::Dialer)
});
inner.error = Err(IoError::new(err.kind(), err.to_string()));
return Poll::Ready(Err(err));
},
};
drop(inner);
match substream.state {
OutboundSubstreamState::SendElem(_) => {
substream.state = OutboundSubstreamState::Flush;
},
OutboundSubstreamState::Flush => {
debug!("Successfully opened outbound substream {}", substream.num);
substream.state = OutboundSubstreamState::Done;
return Poll::Ready(Ok(Substream {
num: substream.num,
current_data: Bytes::new(),
endpoint: Endpoint::Dialer,
local_open: true,
remote_open: true,
}));
},
OutboundSubstreamState::Done => unreachable!(),
}
}
}
fn destroy_outbound(&self, _substream: Self::OutboundSubstream) {
}
fn read_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream, buf: &mut [u8]) -> Poll<Result<usize, IoError>> {
loop {
if !substream.current_data.is_empty() {
let len = cmp::min(substream.current_data.len(), buf.len());
buf[..len].copy_from_slice(&substream.current_data.split_to(len));
return Poll::Ready(Ok(len));
}
if !substream.remote_open {
return Poll::Ready(Ok(0));
}
let mut inner = self.inner.lock();
let next_data_poll = next_match(&mut inner, cx, |elem| {
match elem {
codec::Elem::Data { substream_id, endpoint, data, .. }
if *substream_id == substream.num && *endpoint != substream.endpoint =>
{
Some(Some(data.clone()))
}
codec::Elem::Close { substream_id, endpoint }
if *substream_id == substream.num && *endpoint != substream.endpoint =>
{
Some(None)
}
_ => None
}
});
match next_data_poll {
Poll::Ready(Ok(Some(data))) => substream.current_data = data,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Ready(Ok(None)) => {
substream.remote_open = false;
return Poll::Ready(Ok(0));
},
Poll::Pending => {
if inner.opened_substreams.contains(&(substream.num, substream.endpoint)) {
return Poll::Pending
} else {
return Poll::Ready(Ok(0))
}
},
}
}
}
fn write_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream, buf: &[u8]) -> Poll<Result<usize, IoError>> {
if !substream.local_open {
return Poll::Ready(Err(IoErrorKind::BrokenPipe.into()));
}
let mut inner = self.inner.lock();
let to_write = cmp::min(buf.len(), inner.config.split_send_size);
let elem = codec::Elem::Data {
substream_id: substream.num,
data: Bytes::copy_from_slice(&buf[..to_write]),
endpoint: substream.endpoint,
};
match poll_send(&mut inner, cx, elem) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(to_write)),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending,
}
}
fn flush_substream(&self, cx: &mut Context<'_>, _substream: &mut Self::Substream) -> Poll<Result<(), IoError>> {
let mut inner = self.inner.lock();
ensure_no_error_no_close(&mut inner)?;
let inner = &mut *inner;
inner.notifier_write.insert(cx.waker());
let result = Sink::poll_flush(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write)));
if let Poll::Ready(Err(err)) = &result {
inner.error = Err(IoError::new(err.kind(), err.to_string()));
}
result
}
fn shutdown_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream) -> Poll<Result<(), IoError>> {
if !sub.local_open {
return Poll::Ready(Ok(()));
}
let elem = codec::Elem::Close {
substream_id: sub.num,
endpoint: sub.endpoint,
};
let mut inner = self.inner.lock();
let result = poll_send(&mut inner, cx, elem);
if let Poll::Ready(Ok(())) = result {
sub.local_open = false;
}
result
}
fn destroy_substream(&self, sub: Self::Substream) {
self.inner.lock().buffer.retain(|elem| {
elem.substream_id() != sub.num || elem.endpoint() == Some(sub.endpoint)
})
}
fn close(&self, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
let inner = &mut *self.inner.lock();
if inner.is_shutdown {
return Poll::Ready(Ok(()))
}
if let Err(ref e) = inner.error {
return Poll::Ready(Err(IoError::new(e.kind(), e.to_string())))
}
inner.notifier_write.insert(cx.waker());
match Sink::poll_close(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write))) {
Poll::Ready(Ok(())) => {
inner.is_shutdown = true;
Poll::Ready(Ok(()))
}
Poll::Ready(Err(err)) => {
inner.error = Err(IoError::new(err.kind(), err.to_string()));
Poll::Ready(Err(err))
}
Poll::Pending => Poll::Pending,
}
}
fn flush_all(&self, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
let inner = &mut *self.inner.lock();
if inner.is_shutdown {
return Poll::Ready(Ok(()))
}
if let Err(ref e) = inner.error {
return Poll::Ready(Err(IoError::new(e.kind(), e.to_string())))
}
inner.notifier_write.insert(cx.waker());
let result = Sink::poll_flush(Pin::new(&mut inner.inner), &mut Context::from_waker(&waker_ref(&inner.notifier_write)));
if let Poll::Ready(Err(err)) = &result {
inner.error = Err(IoError::new(err.kind(), err.to_string()));
}
result
}
}
pub struct OutboundSubstream {
num: u32,
state: OutboundSubstreamState,
}
enum OutboundSubstreamState {
SendElem(codec::Elem),
Flush,
Done,
}
pub struct Substream {
num: u32,
current_data: Bytes,
endpoint: Endpoint,
local_open: bool,
remote_open: bool,
}