use crate::connection::rtt::Rtt;
use crate::frame::header::ACK;
use crate::{
chunks::Chunks,
connection::{self, rtt, StreamCommand},
frame::{
header::{Data, Header, StreamId, WindowUpdate},
Frame,
},
Config, DEFAULT_CREDIT,
};
use flow_control::FlowController;
use futures::{
channel::mpsc,
future::Either,
io::{AsyncRead, AsyncWrite},
ready, SinkExt,
};
use parking_lot::{Mutex, MutexGuard};
use std::{
fmt, io,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
mod flow_control;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum State {
Open {
acknowledged: bool,
},
SendClosed,
RecvClosed,
Closed,
}
impl State {
pub fn can_read(self) -> bool {
!matches!(self, State::RecvClosed | State::Closed)
}
pub fn can_write(self) -> bool {
!matches!(self, State::SendClosed | State::Closed)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum Flag {
None,
Syn,
Ack,
}
pub struct Stream {
id: StreamId,
conn: connection::Id,
config: Arc<Config>,
sender: mpsc::Sender<StreamCommand>,
flag: Flag,
shared: Arc<Mutex<Shared>>,
}
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Stream")
.field("id", &self.id.val())
.field("connection", &self.conn)
.finish()
}
}
impl fmt::Display for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(Stream {}/{})", self.conn, self.id.val())
}
}
impl Stream {
pub(crate) fn new_inbound(
id: StreamId,
conn: connection::Id,
config: Arc<Config>,
send_window: u32,
sender: mpsc::Sender<StreamCommand>,
rtt: rtt::Rtt,
accumulated_max_stream_windows: Arc<Mutex<usize>>,
) -> Self {
Self {
id,
conn,
config: config.clone(),
sender,
flag: Flag::Ack,
shared: Arc::new(Mutex::new(Shared::new(
DEFAULT_CREDIT,
send_window,
accumulated_max_stream_windows,
rtt,
config,
))),
}
}
pub(crate) fn new_outbound(
id: StreamId,
conn: connection::Id,
config: Arc<Config>,
sender: mpsc::Sender<StreamCommand>,
rtt: rtt::Rtt,
accumulated_max_stream_windows: Arc<Mutex<usize>>,
) -> Self {
Self {
id,
conn,
config: config.clone(),
sender,
flag: Flag::Syn,
shared: Arc::new(Mutex::new(Shared::new(
DEFAULT_CREDIT,
DEFAULT_CREDIT,
accumulated_max_stream_windows,
rtt,
config,
))),
}
}
pub fn id(&self) -> StreamId {
self.id
}
pub fn is_write_closed(&self) -> bool {
matches!(self.shared().state(), State::SendClosed)
}
pub fn is_closed(&self) -> bool {
matches!(self.shared().state(), State::Closed)
}
pub fn is_pending_ack(&self) -> bool {
self.shared().is_pending_ack()
}
pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
self.shared.lock()
}
pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
self.shared.clone()
}
fn write_zero_err(&self) -> io::Error {
let msg = format!("{}/{}: connection is closed", self.conn, self.id);
io::Error::new(io::ErrorKind::WriteZero, msg)
}
fn add_flag(&mut self, header: &mut Header<Either<Data, WindowUpdate>>) {
match self.flag {
Flag::None => (),
Flag::Syn => {
header.syn();
self.flag = Flag::None
}
Flag::Ack => {
header.ack();
self.flag = Flag::None
}
}
}
fn send_window_update(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
if !self.shared.lock().state.can_read() {
return Poll::Ready(Ok(()));
}
ready!(self
.sender
.poll_ready(cx)
.map_err(|_| self.write_zero_err())?);
let Some(credit) = self.shared.lock().next_window_update() else {
return Poll::Ready(Ok(()));
};
let mut frame = Frame::window_update(self.id, credit).right();
self.add_flag(frame.header_mut());
let cmd = StreamCommand::SendFrame(frame);
self.sender
.start_send(cmd)
.map_err(|_| self.write_zero_err())?;
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Packet(Vec<u8>);
impl AsRef<[u8]> for Packet {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl futures::stream::Stream for Stream {
type Item = io::Result<Packet>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if !self.config.read_after_close && self.sender.is_closed() {
return Poll::Ready(None);
}
match self.send_window_update(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
Poll::Pending => {}
}
let mut shared = self.shared();
if let Some(bytes) = shared.buffer.pop() {
let off = bytes.offset();
let mut vec = bytes.into_vec();
if off != 0 {
log::debug!(
"{}/{}: chunk has been partially consumed",
self.conn,
self.id
);
vec = vec.split_off(off)
}
return Poll::Ready(Some(Ok(Packet(vec))));
}
if !shared.state().can_read() {
log::debug!("{}/{}: eof", self.conn, self.id);
return Poll::Ready(None); }
shared.reader = Some(cx.waker().clone());
Poll::Pending
}
}
impl AsyncRead for Stream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if !self.config.read_after_close && self.sender.is_closed() {
return Poll::Ready(Ok(0));
}
match self.send_window_update(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {}
}
let mut shared = self.shared();
let mut n = 0;
while let Some(chunk) = shared.buffer.front_mut() {
if chunk.is_empty() {
shared.buffer.pop();
continue;
}
let k = std::cmp::min(chunk.len(), buf.len() - n);
buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
n += k;
chunk.advance(k);
if n == buf.len() {
break;
}
}
if n > 0 {
log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
return Poll::Ready(Ok(n));
}
if !shared.state().can_read() {
log::debug!("{}/{}: eof", self.conn, self.id);
return Poll::Ready(Ok(0)); }
shared.reader = Some(cx.waker().clone());
Poll::Pending
}
}
impl AsyncWrite for Stream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
ready!(self
.sender
.poll_ready(cx)
.map_err(|_| self.write_zero_err())?);
let body = {
let mut shared = self.shared();
if !shared.state().can_write() {
log::debug!("{}/{}: can no longer write", self.conn, self.id);
return Poll::Ready(Err(self.write_zero_err()));
}
if shared.send_window() == 0 {
log::trace!("{}/{}: no more credit left", self.conn, self.id);
shared.writer = Some(cx.waker().clone());
return Poll::Pending;
}
let k = std::cmp::min(
shared.send_window(),
buf.len().try_into().unwrap_or(u32::MAX),
);
let k = std::cmp::min(
k,
self.config.split_send_size.try_into().unwrap_or(u32::MAX),
);
shared.consume_send_window(k);
Vec::from(&buf[..k as usize])
};
let n = body.len();
let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
self.add_flag(frame.header_mut());
log::trace!("{}/{}: write {} bytes", self.conn, self.id, n);
if frame.header().flags().contains(ACK) {
self.shared()
.update_state(self.conn, self.id, State::Open { acknowledged: true });
}
let cmd = StreamCommand::SendFrame(frame);
self.sender
.start_send(cmd)
.map_err(|_| self.write_zero_err())?;
Poll::Ready(Ok(n))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.sender
.poll_flush_unpin(cx)
.map_err(|_| self.write_zero_err())
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
if self.is_closed() {
return Poll::Ready(Ok(()));
}
ready!(self
.sender
.poll_ready(cx)
.map_err(|_| self.write_zero_err())?);
let ack = if self.flag == Flag::Ack {
self.flag = Flag::None;
true
} else {
false
};
log::trace!("{}/{}: close", self.conn, self.id);
let cmd = StreamCommand::CloseStream { ack };
self.sender
.start_send(cmd)
.map_err(|_| self.write_zero_err())?;
self.shared()
.update_state(self.conn, self.id, State::SendClosed);
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub(crate) struct Shared {
state: State,
flow_controller: FlowController,
pub(crate) buffer: Chunks,
pub(crate) reader: Option<Waker>,
pub(crate) writer: Option<Waker>,
}
impl Shared {
fn new(
receive_window: u32,
send_window: u32,
accumulated_max_stream_windows: Arc<Mutex<usize>>,
rtt: Rtt,
config: Arc<Config>,
) -> Self {
Shared {
state: State::Open {
acknowledged: false,
},
flow_controller: FlowController::new(
receive_window,
send_window,
accumulated_max_stream_windows,
rtt,
config,
),
buffer: Chunks::new(),
reader: None,
writer: None,
}
}
pub(crate) fn state(&self) -> State {
self.state
}
pub(crate) fn update_state(
&mut self,
cid: connection::Id,
sid: StreamId,
next: State,
) -> State {
use self::State::*;
let current = self.state;
match (current, next) {
(Closed, _) => {}
(Open { .. }, _) => self.state = next,
(RecvClosed, Closed) => self.state = Closed,
(RecvClosed, Open { .. }) => {}
(RecvClosed, RecvClosed) => {}
(RecvClosed, SendClosed) => self.state = Closed,
(SendClosed, Closed) => self.state = Closed,
(SendClosed, Open { .. }) => {}
(SendClosed, RecvClosed) => self.state = Closed,
(SendClosed, SendClosed) => {}
}
log::trace!(
"{}/{}: update state: (from {:?} to {:?} -> {:?})",
cid,
sid,
current,
next,
self.state
);
current }
pub(crate) fn next_window_update(&mut self) -> Option<u32> {
self.flow_controller.next_window_update(self.buffer.len())
}
pub fn is_pending_ack(&self) -> bool {
matches!(
self.state(),
State::Open {
acknowledged: false
}
)
}
pub(crate) fn send_window(&self) -> u32 {
self.flow_controller.send_window()
}
pub(crate) fn consume_send_window(&mut self, i: u32) {
self.flow_controller.consume_send_window(i)
}
pub(crate) fn increase_send_window_by(&mut self, i: u32) {
self.flow_controller.increase_send_window_by(i)
}
pub(crate) fn receive_window(&self) -> u32 {
self.flow_controller.receive_window()
}
pub(crate) fn consume_receive_window(&mut self, i: u32) {
self.flow_controller.consume_receive_window(i)
}
}