diff --git a/Cargo.lock b/Cargo.lock index 2443bd1..a9b35e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,7 +72,9 @@ dependencies = [ "anyhow", "cove-core", "futures", + "rand", "serde_json", + "thiserror", "tokio", "tokio-stream", "tokio-tungstenite", diff --git a/cove-core/src/id.rs b/cove-core/src/id.rs index 95112c2..db51f32 100644 --- a/cove-core/src/id.rs +++ b/cove-core/src/id.rs @@ -8,7 +8,7 @@ use crate::macros::id_alias; // TODO Use base64 representation instead -#[derive(Debug, Clone, Copy, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)] struct Id(#[serde(with = "hex")] [u8; 32]); impl Id { diff --git a/cove-core/src/macros.rs b/cove-core/src/macros.rs index d1a58ad..882062a 100644 --- a/cove-core/src/macros.rs +++ b/cove-core/src/macros.rs @@ -3,7 +3,7 @@ macro_rules! id_alias { ($name:ident) => { - #[derive(Debug, Clone, Copy, Deserialize, Serialize)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)] pub struct $name(Id); impl $name { diff --git a/cove-core/src/message.rs b/cove-core/src/message.rs index 8fdbd79..b3ff412 100644 --- a/cove-core/src/message.rs +++ b/cove-core/src/message.rs @@ -4,7 +4,7 @@ use crate::{Identity, MessageId}; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Message { - pub time: i128, + pub time: u128, pub pred: MessageId, pub parent: Option, pub identity: Identity, diff --git a/cove-server/Cargo.toml b/cove-server/Cargo.toml index 70fe4c2..719a893 100644 --- a/cove-server/Cargo.toml +++ b/cove-server/Cargo.toml @@ -7,7 +7,9 @@ edition = "2021" anyhow = "1.0.53" cove-core = { path = "../cove-core" } futures = "0.3.21" +rand = "0.8.4" serde_json = "1.0.78" +thiserror = "1.0.30" tokio = { version = "1.16.1", features = ["full"] } tokio-stream = "0.1.8" tokio-tungstenite = "0.16.1" diff --git a/cove-server/src/conn.rs b/cove-server/src/conn.rs new file mode 100644 index 0000000..91cd71e --- /dev/null +++ b/cove-server/src/conn.rs @@ -0,0 +1,156 @@ +use std::result; +use std::sync::Arc; +use std::time::Duration; + +use cove_core::packets::Packet; +use futures::stream::{SplitSink, SplitStream}; +use futures::StreamExt; +use rand::Rng; +use tokio::net::TcpStream; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use tokio::sync::Mutex; +use tokio::try_join; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_tungstenite::tungstenite::{self, Message}; +use tokio_tungstenite::WebSocketStream; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("WS error: {0}")] + Ws(#[from] tungstenite::Error), + #[error("MPSC error: {0}")] + Mpsc(#[from] mpsc::error::SendError), + #[error("Serde error: {0}")] + Serde(#[from] serde_json::Error), + #[error("client did not pong")] + NoPong, + #[error("illegal binary packet")] + IllegalBinaryPacket, +} + +pub type Result = result::Result; + +#[derive(Clone)] +pub struct ConnTx { + tx: UnboundedSender, +} + +impl ConnTx { + pub fn send(&self, packet: &Packet) -> Result<()> { + let msg = Message::Text(serde_json::to_string(packet)?); + self.tx.send(msg)?; + Ok(()) + } +} + +pub struct ConnRx { + ws_rx: SplitStream>, + last_ping_payload: Arc>>, +} + +impl ConnRx { + pub async fn recv(&mut self) -> Result> { + loop { + let msg = match self.ws_rx.next().await { + None => return Ok(None), + Some(msg) => msg?, + }; + let str = match msg { + Message::Text(str) => str, + Message::Pong(payload) => { + *self.last_ping_payload.lock().await = payload; + continue; + } + Message::Ping(_) => { + // Tungstenite automatically replies to pings + continue; + } + Message::Binary(_) => return Err(Error::IllegalBinaryPacket), + Message::Close(_) => return Ok(None), + }; + let packet = serde_json::from_str(&str)?; + return Ok(Some(packet)); + } + } +} + +pub struct ConnMaintenance { + // Shoveling packets into the WS connection + rx: UnboundedReceiver, + ws_tx: SplitSink, Message>, + // Pinging and ponging + tx: UnboundedSender, + ping_delay: Duration, + last_ping_payload: Arc>>, +} + +impl ConnMaintenance { + pub async fn perform(self) -> Result<()> { + let result = try_join!( + Self::shovel(self.rx, self.ws_tx), + Self::ping_pong(self.tx, self.ping_delay, self.last_ping_payload) + ); + result.map(|_| ()) + } + + async fn shovel( + rx: UnboundedReceiver, + ws_tx: SplitSink, Message>, + ) -> Result<()> { + UnboundedReceiverStream::new(rx) + .map(Ok) + .forward(ws_tx) + .await?; + Ok(()) + } + + async fn ping_pong( + tx: UnboundedSender, + ping_delay: Duration, + last_ping_payload: Arc>>, + ) -> Result<()> { + let mut payload = [0u8; 8]; + + rand::thread_rng().fill(&mut payload); + // debug!("Sending first ping with payload {:?}", payload); + tx.send(Message::Ping(payload.to_vec()))?; + tokio::time::sleep(ping_delay).await; + + loop { + let last_payload = last_ping_payload.lock().await; + if (&payload as &[u8]) != (&last_payload as &[u8]) { + // warn!("Invalid ping payload, client probably dead"); + return Err(Error::NoPong); + } + + rand::thread_rng().fill(&mut payload); + // debug!("Sending ping with payload {:?}", payload); + tx.send(Message::Ping(payload.to_vec()))?; + tokio::time::sleep(ping_delay).await; + } + } +} + +pub async fn new( + stream: TcpStream, + ping_delay: Duration, +) -> Result<(ConnTx, ConnRx, ConnMaintenance)> { + let (ws_tx, ws_rx) = tokio_tungstenite::accept_async(stream).await?.split(); + let (tx, rx) = mpsc::unbounded_channel(); + let last_ping_payload = Arc::new(Mutex::new(vec![])); + + let conn_tx = ConnTx { tx: tx.clone() }; + let conn_rx = ConnRx { + ws_rx, + last_ping_payload: last_ping_payload.clone(), + }; + let conn_maintenance = ConnMaintenance { + ws_tx, + rx, + tx, + ping_delay, + last_ping_payload, + }; + + Ok((conn_tx, conn_rx, conn_maintenance)) +} diff --git a/cove-server/src/main.rs b/cove-server/src/main.rs index 1a2c814..89c8656 100644 --- a/cove-server/src/main.rs +++ b/cove-server/src/main.rs @@ -1,12 +1,17 @@ +mod conn; + use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; use anyhow::anyhow; use cove_core::packets::{Cmd, HelloRpl, Packet, Rpl}; use cove_core::{Identity, MessageId, Session, SessionId}; use futures::stream::{SplitSink, SplitStream}; use futures::{future, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; +use rand::prelude::ThreadRng; +use rand::Rng; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::{self, UnboundedSender}; use tokio::sync::{self, Mutex, RwLock}; @@ -26,18 +31,40 @@ struct Room { last_timestamp: u128, } +impl Room { + fn new() -> Self { + Self { + clients: HashMap::new(), + last_message: MessageId::of(&format!("{}", rand::thread_rng().gen::())), + last_timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("executed after 1970") + .as_millis(), + } + } +} + #[derive(Debug, Clone)] struct Server { - rooms: Arc>>>>, + rooms: Arc>>>>, } impl Server { fn new() -> Self { Self { - rooms: Arc::new(RwLock::new(HashMap::new())), + rooms: Arc::new(Mutex::new(HashMap::new())), } } + async fn room(&self, name: String) -> Arc> { + self.rooms + .lock() + .await + .entry(name) + .or_insert_with(|| Arc::new(Mutex::new(Room::new()))) + .clone() + } + async fn recv(rx: &mut SplitStream>) -> anyhow::Result { loop { let msg = rx.next().await.ok_or(anyhow!("connection closed"))??; @@ -106,7 +133,8 @@ impl Server { &self, tx: &mut SplitSink, TkMessage>, rx: &mut SplitStream>, - ) -> anyhow::Result<(String, String, Identity, u64)> { + ) -> anyhow::Result<(String, Session, u64)> { + // TODO Allow multiple Hello commands until the first succeeds let packet = Self::recv(rx).await?; let (id, cmd) = match packet { Packet::Cmd { @@ -127,18 +155,38 @@ impl Server { Self::send(tx, &Packet::rpl(id, HelloRpl::InvalidNick { reason })).await?; return Err(anyhow!("invalid identity")); } - let identity = Identity::of(&cmd.identity); - Ok((cmd.room, cmd.nick, identity, id)) + let session = Session { + id: SessionId::of(&format!("{}", rand::thread_rng().gen::())), + nick: cmd.nick, + identity: Identity::of(&cmd.identity), + }; + Ok((cmd.room, session, id)) } async fn on_conn(self, stream: TcpStream) { + // TODO Ping-pong starting from the beginning (not just after hello) println!("Connection from {}", stream.peer_addr().unwrap()); let stream = tokio_tungstenite::accept_async(stream).await.unwrap(); let (mut tx, mut rx) = stream.split(); - let (room, nick, identity, id) = match self.greet(&mut tx, &mut rx).await { + let (room, session, id) = match self.greet(&mut tx, &mut rx).await { Ok(info) => info, Err(_) => return, }; + let room = self.room(room).await; + let (packets, client_rx) = mpsc::unbounded_channel(); + { + let mut room = room.lock().await; + packets.send(Packet::rpl( + id, + HelloRpl::Success { + you: session.clone(), + others: room.clients.values().map(|c| c.session.clone()).collect(), + last_message: room.last_message, + }, + )); + let client = Client { session, packets }; + room.clients.insert(client.session.id, client); + } todo!() } }