mod control;
mod stream;
use crate::{
Config,
DEFAULT_CREDIT,
WindowUpdateMode,
error::ConnectionError,
frame::{self, Frame},
frame::header::{self, CONNECTION_ID, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate},
pause::Pausable
};
use futures::{
channel::{mpsc, oneshot},
future::{self, Either},
prelude::*,
stream::{Fuse, FusedStream}
};
use nohash_hasher::IntMap;
use std::{fmt, sync::Arc, task::{Context, Poll}};
pub use control::Control;
pub use stream::{Packet, State, Stream};
const MAX_COMMAND_BACKLOG: usize = 32;
type Result<T> = std::result::Result<T, ConnectionError>;
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum Mode {
Client,
Server
}
#[derive(Clone, Copy)]
pub(crate) struct Id(u32);
impl Id {
pub(crate) fn random() -> Self {
Id(rand::random())
}
}
impl fmt::Debug for Id {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:08x}", self.0)
}
}
impl fmt::Display for Id {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:08x}", self.0)
}
}
pub struct Connection<T> {
id: Id,
mode: Mode,
config: Arc<Config>,
socket: Fuse<frame::Io<T>>,
next_id: u32,
streams: IntMap<StreamId, Stream>,
control_sender: mpsc::Sender<ControlCommand>,
control_receiver: Pausable<mpsc::Receiver<ControlCommand>>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
garbage: Vec<StreamId>,
shutdown: Shutdown,
is_closed: bool
}
#[derive(Debug)]
pub(crate) enum ControlCommand {
OpenStream(oneshot::Sender<Result<Stream>>),
CloseConnection(oneshot::Sender<()>)
}
#[derive(Debug)]
pub(crate) enum StreamCommand {
SendFrame(Frame<Either<Data, WindowUpdate>>),
CloseStream { id: StreamId, ack: bool }
}
#[derive(Debug)]
enum Action {
None,
New(Stream, Option<Frame<WindowUpdate>>),
Update(Frame<WindowUpdate>),
Ping(Frame<Ping>),
Reset(Frame<Data>),
Terminate(Frame<GoAway>)
}
#[derive(Debug)]
enum Shutdown {
NotStarted,
InProgress(oneshot::Sender<()>),
Complete
}
impl Shutdown {
fn has_not_started(&self) -> bool {
if let Shutdown::NotStarted = self {
true
} else {
false
}
}
fn is_in_progress(&self) -> bool {
if let Shutdown::InProgress(_) = self {
true
} else {
false
}
}
fn is_complete(&self) -> bool {
if let Shutdown::Complete = self {
true
} else {
false
}
}
}
impl<T> fmt::Debug for Connection<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Connection")
.field("id", &self.id)
.field("mode", &self.mode)
.field("streams", &self.streams.len())
.field("next_id", &self.next_id)
.field("is_closed", &self.is_closed)
.finish()
}
}
impl<T> fmt::Display for Connection<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(Connection {} {:?} (streams {}))", self.id, self.mode, self.streams.len())
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
pub fn new(socket: T, cfg: Config, mode: Mode) -> Self {
let id = Id::random();
log::debug!("new connection: {} ({:?})", id, mode);
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let (control_sender, control_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
Connection {
id,
mode,
config: Arc::new(cfg),
socket,
streams: IntMap::default(),
control_sender,
control_receiver: Pausable::new(control_receiver),
stream_sender,
stream_receiver,
next_id: match mode {
Mode::Client => 1,
Mode::Server => 2
},
garbage: Vec::new(),
shutdown: Shutdown::NotStarted,
is_closed: false
}
}
pub fn control(&self) -> Control {
Control::new(self.control_sender.clone())
}
pub async fn next_stream(&mut self) -> Result<Option<Stream>> {
if self.is_closed {
log::debug!("{}: connection is closed", self.id);
return Ok(None)
}
let result = self.next().await;
if let Ok(Some(_)) = result {
return result
}
self.is_closed = true;
if !self.control_receiver.stream().is_terminated() {
self.control_receiver.stream().close();
self.control_receiver.unpause();
while let Some(cmd) = self.control_receiver.next().await {
match cmd {
ControlCommand::OpenStream(reply) => {
let _ = reply.send(Err(ConnectionError::Closed));
}
ControlCommand::CloseConnection(reply) => {
let _ = reply.send(());
}
}
}
}
self.drop_all_streams();
if !self.stream_receiver.is_terminated() {
self.stream_receiver.close();
while let Some(_cmd) = self.stream_receiver.next().await {
}
}
if let Err(ConnectionError::Closed) = result {
return Ok(None)
}
result
}
async fn next(&mut self) -> Result<Option<Stream>> {
loop {
self.garbage_collect().await?;
let mut num_terminated = 0;
let mut next_inbound_frame =
if self.socket.is_terminated() {
num_terminated += 1;
Either::Left(future::pending())
} else {
Either::Right(self.socket.try_next().err_into())
};
let mut next_stream_command =
if self.stream_receiver.is_terminated() {
num_terminated += 1;
Either::Left(future::pending())
} else {
Either::Right(self.stream_receiver.next())
};
let mut next_control_command =
if self.control_receiver.is_terminated() {
num_terminated += 1;
Either::Left(future::pending())
} else {
Either::Right(self.control_receiver.next())
};
if num_terminated == 3 {
log::debug!("{}: socket and channels are terminated", self.id);
return Err(ConnectionError::Closed)
}
let next_item =
future::poll_fn(move |cx: &mut Context| {
let a = next_stream_command.poll_unpin(cx);
let b = next_control_command.poll_unpin(cx);
let c = next_inbound_frame.poll_unpin(cx);
if a.is_pending() && b.is_pending() && c.is_pending() {
return Poll::Pending
}
Poll::Ready((a, b, c))
});
let (stream_command, control_command, inbound_frame) = next_item.await;
if let Poll::Ready(cmd) = control_command {
self.on_control_command(cmd).await?
}
if let Poll::Ready(cmd) = stream_command {
self.on_stream_command(cmd).await?
}
if let Poll::Ready(frame) = inbound_frame {
if let Some(stream) = self.on_frame(frame).await? {
self.socket.get_mut().flush().await.or(Err(ConnectionError::Closed))?;
return Ok(Some(stream))
}
}
self.socket.get_mut().flush().await.or(Err(ConnectionError::Closed))?
}
}
async fn on_control_command(&mut self, cmd: Option<ControlCommand>) -> Result<()> {
match cmd {
Some(ControlCommand::OpenStream(reply)) => {
if self.shutdown.is_complete() {
let _ = reply.send(Err(ConnectionError::Closed));
return Ok(())
}
if self.streams.len() >= self.config.max_num_streams {
log::error!("{}: maximum number of streams reached", self.id);
let _ = reply.send(Err(ConnectionError::TooManyStreams));
return Ok(())
}
log::trace!("{}: creating new outbound stream", self.id);
let id = self.next_stream_id()?;
let extra_credit = self.config.receive_window - DEFAULT_CREDIT;
if extra_credit > 0 {
let mut frame = Frame::window_update(id, extra_credit);
frame.header_mut().syn();
log::trace!("{}: sending initial {}", self.id, frame.header());
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
}
let stream = {
let config = self.config.clone();
let sender = self.stream_sender.clone();
let window = self.config.receive_window;
let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender);
if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}
stream
};
if reply.send(Ok(stream.clone())).is_ok() {
log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream);
} else {
log::debug!("{}: open stream {} has been cancelled", self.id, id);
if extra_credit > 0 {
let mut header = Header::data(id, 0);
header.rst();
let frame = Frame::new(header);
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
}
}
}
Some(ControlCommand::CloseConnection(reply)) => {
if self.shutdown.is_complete() {
let _ = reply.send(());
return Ok(())
}
debug_assert!(self.shutdown.has_not_started());
self.shutdown = Shutdown::InProgress(reply);
log::trace!("{}: shutting down connection", self.id);
self.control_receiver.pause();
self.stream_receiver.close()
}
None => {
debug_assert!(self.shutdown.is_complete());
self.socket.get_mut().close().await.or(Err(ConnectionError::Closed))?;
return Err(ConnectionError::Closed)
}
}
Ok(())
}
async fn on_stream_command(&mut self, cmd: Option<StreamCommand>) -> Result<()> {
match cmd {
Some(StreamCommand::SendFrame(frame)) => {
log::trace!("{}: sending: {}", self.id, frame.header());
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
}
Some(StreamCommand::CloseStream { id, ack }) => {
log::trace!("{}: closing stream {} of {}", self.id, id, self);
let mut header = Header::data(id, 0);
header.fin();
if ack { header.ack() }
let frame = Frame::new(header);
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?
}
None => {
debug_assert!(self.shutdown.is_in_progress());
log::debug!("{}: closing {}", self.id, self);
let frame = Frame::term();
self.socket.get_mut().send(&frame).await.or(Err(ConnectionError::Closed))?;
let shutdown = std::mem::replace(&mut self.shutdown, Shutdown::Complete);
if let Shutdown::InProgress(tx) = shutdown {
let _ = tx.send(());
}
debug_assert!(self.control_receiver.is_paused());
self.control_receiver.unpause();
self.control_receiver.stream().close()
}
}
Ok(())
}
async fn on_frame(&mut self, frame: Result<Option<Frame<()>>>) -> Result<Option<Stream>> {
match frame {
Ok(Some(frame)) => {
log::trace!("{}: received: {}", self.id, frame.header());
let action = match frame.header().tag() {
Tag::Data => self.on_data(frame.into_data()),
Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()),
Tag::Ping => self.on_ping(&frame.into_ping()),
Tag::GoAway => return Err(ConnectionError::Closed)
};
match action {
Action::None => {}
Action::New(stream, update) => {
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
if let Some(f) = update {
log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
}
return Ok(Some(stream))
}
Action::Update(f) => {
log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
}
Action::Ping(f) => {
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
}
Action::Reset(f) => {
log::trace!("{}/{}: sending reset", self.id, f.header().stream_id());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
}
Action::Terminate(f) => {
log::trace!("{}: sending term", self.id);
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
}
}
Ok(None)
}
Ok(None) => {
log::debug!("{}: socket eof", self.id);
Err(ConnectionError::Closed)
}
Err(e) if e.io_kind() == Some(std::io::ErrorKind::ConnectionReset) => {
log::debug!("{}: connection reset", self.id);
Err(ConnectionError::Closed)
}
Err(e) => {
log::error!("{}: socket error: {}", self.id, e);
Err(e)
}
}
}
fn on_data(&mut self, frame: Frame<Data>) -> Action {
let stream_id = frame.header().stream_id();
if frame.header().flags().contains(header::RST) {
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.shared();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
}
return Action::None
}
let is_finish = frame.header().flags().contains(header::FIN);
if frame.header().flags().contains(header::SYN) {
if !self.is_valid_remote_id(stream_id, Tag::Data) {
log::error!("{}: invalid stream id {}", self.id, stream_id);
return Action::Terminate(Frame::protocol_error())
}
if frame.body().len() > DEFAULT_CREDIT as usize {
log::error!("{}/{}: 1st body of stream exceeds default credit", self.id, stream_id);
return Action::Terminate(Frame::protocol_error())
}
if self.streams.contains_key(&stream_id) {
log::error!("{}/{}: stream already exists", self.id, stream_id);
return Action::Terminate(Frame::protocol_error())
}
if self.streams.len() == self.config.max_num_streams {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::internal_error())
}
let mut stream = {
let config = self.config.clone();
let credit = DEFAULT_CREDIT;
let sender = self.stream_sender.clone();
Stream::new(stream_id, self.id, config, credit, credit, sender)
};
let window_update;
{
let mut shared = stream.shared();
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
shared.window = shared.window.saturating_sub(frame.body_len());
shared.buffer.push(frame.into_body());
if !is_finish
&& shared.window == 0
&& self.config.window_update_mode == WindowUpdateMode::OnReceive
{
shared.window = self.config.receive_window;
let mut frame = Frame::window_update(stream_id, self.config.receive_window);
frame.header_mut().ack();
window_update = Some(frame)
} else {
window_update = None
}
}
if window_update.is_none() {
stream.set_flag(stream::Flag::Ack)
}
self.streams.insert(stream_id, stream.clone());
return Action::New(stream, window_update)
}
if let Some(stream) = self.streams.get_mut(&stream_id) {
let mut shared = stream.shared();
if frame.body().len() > shared.window as usize {
log::error!("{}/{}: frame body larger than window of stream", self.id, stream_id);
return Action::Terminate(Frame::protocol_error())
}
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
let max_buffer_size = self.config.max_buffer_size;
if shared.buffer.len().map(move |n| n >= max_buffer_size).unwrap_or(true) {
log::error!("{}/{}: buffer of stream grows beyond limit", self.id, stream_id);
let mut header = Header::data(stream_id, 0);
header.rst();
return Action::Reset(Frame::new(header))
}
shared.window = shared.window.saturating_sub(frame.body_len());
shared.buffer.push(frame.into_body());
if let Some(w) = shared.reader.take() {
w.wake()
}
if !is_finish
&& shared.window == 0
&& self.config.window_update_mode == WindowUpdateMode::OnReceive
{
shared.window = self.config.receive_window;
let frame = Frame::window_update(stream_id, self.config.receive_window);
return Action::Update(frame)
}
} else if !is_finish {
log::debug!("{}/{}: data for unknown stream, ignoring", self.id, stream_id);
}
Action::None
}
fn on_window_update(&mut self, frame: &Frame<WindowUpdate>) -> Action {
let stream_id = frame.header().stream_id();
if frame.header().flags().contains(header::RST) {
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.shared();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
}
return Action::None
}
let is_finish = frame.header().flags().contains(header::FIN);
if frame.header().flags().contains(header::SYN) {
if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) {
log::error!("{}: invalid stream id {}", self.id, stream_id);
return Action::Terminate(Frame::protocol_error())
}
if self.streams.contains_key(&stream_id) {
log::error!("{}/{}: stream already exists", self.id, stream_id);
return Action::Terminate(Frame::protocol_error())
}
if self.streams.len() == self.config.max_num_streams {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::protocol_error())
}
let stream = {
let credit = frame.header().credit() + DEFAULT_CREDIT;
let config = self.config.clone();
let sender = self.stream_sender.clone();
let mut stream = Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender);
stream.set_flag(stream::Flag::Ack);
stream
};
if is_finish {
stream.shared().update_state(self.id, stream_id, State::RecvClosed);
}
self.streams.insert(stream_id, stream.clone());
return Action::New(stream, None)
}
if let Some(stream) = self.streams.get_mut(&stream_id) {
let mut shared = stream.shared();
shared.credit += frame.header().credit();
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
if let Some(w) = shared.writer.take() {
w.wake()
}
} else if !is_finish {
log::debug!("{}/{}: window update for unknown stream", self.id, stream_id);
}
Action::None
}
fn on_ping(&mut self, frame: &Frame<Ping>) -> Action {
let stream_id = frame.header().stream_id();
if frame.header().flags().contains(header::ACK) {
return Action::None
}
if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) {
let mut hdr = Header::ping(frame.header().nonce());
hdr.ack();
return Action::Ping(Frame::new(hdr))
}
log::debug!("{}/{}: ping for unknown stream", self.id, stream_id);
Action::None
}
fn next_stream_id(&mut self) -> Result<StreamId> {
let proposed = StreamId::new(self.next_id);
self.next_id = self.next_id.checked_add(2).ok_or(ConnectionError::NoMoreStreamIds)?;
match self.mode {
Mode::Client => assert!(proposed.is_client()),
Mode::Server => assert!(proposed.is_server())
}
Ok(proposed)
}
fn is_valid_remote_id(&self, id: StreamId, tag: Tag) -> bool {
if tag == Tag::Ping || tag == Tag::GoAway {
return id.is_session()
}
match self.mode {
Mode::Client => id.is_server(),
Mode::Server => id.is_client()
}
}
async fn garbage_collect(&mut self) -> Result<()> {
let conn_id = self.id;
let win_update_mode = self.config.window_update_mode;
for stream in self.streams.values_mut() {
if stream.strong_count() > 1 {
continue
}
log::trace!("{}: removing dropped {}", conn_id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(conn_id, stream_id, State::Closed) {
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
State::SendClosed =>
if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
None
}
State::Closed => None
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}: sending: {}", self.id, f.header());
self.socket.get_mut().send(&f).await.or(Err(ConnectionError::Closed))?
}
self.garbage.push(stream_id)
}
for id in self.garbage.drain(..) {
self.streams.remove(&id);
}
Ok(())
}
}
impl<T> Connection<T> {
fn drop_all_streams(&mut self) {
for (id, s) in self.streams.drain() {
let mut shared = s.shared();
shared.update_state(self.id, id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
}
}
}
impl<T> Drop for Connection<T> {
fn drop(&mut self) {
self.drop_all_streams()
}
}
pub fn into_stream<T>(c: Connection<T>) -> impl futures::stream::Stream<Item = Result<Stream>>
where
T: AsyncRead + AsyncWrite + Unpin
{
futures::stream::unfold(c, |mut c| async {
match c.next_stream().await {
Ok(None) => None,
Ok(Some(stream)) => Some((Ok(stream), c)),
Err(e) => Some((Err(e), c))
}
})
}