diff --git a/cove-core/src/packets.rs b/cove-core/src/packets.rs index 7b27f08..0e179fc 100644 --- a/cove-core/src/packets.rs +++ b/cove-core/src/packets.rs @@ -51,7 +51,6 @@ pub struct SendCmd { #[serde(tag = "type")] pub enum SendRpl { Success { message: Message }, - InvalidNick { reason: String }, InvalidContent { reason: String }, } diff --git a/cove-server/src/main.rs b/cove-server/src/main.rs index 89c8656..1be03c9 100644 --- a/cove-server/src/main.rs +++ b/cove-server/src/main.rs @@ -1,27 +1,28 @@ mod conn; +mod util; +use std::any; use std::collections::HashMap; -use std::hash::Hash; use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use anyhow::anyhow; -use cove_core::packets::{Cmd, HelloRpl, Packet, Rpl}; -use cove_core::{Identity, MessageId, Session, SessionId}; -use futures::stream::{SplitSink, SplitStream}; -use futures::{future, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; -use rand::prelude::ThreadRng; +use conn::{ConnMaintenance, ConnRx, ConnTx}; +use cove_core::packets::{ + Cmd, HelloCmd, HelloRpl, JoinNtf, NickCmd, NickNtf, NickRpl, Ntf, Packet, PartNtf, SendCmd, + SendNtf, SendRpl, WhoCmd, +}; +use cove_core::{Identity, Message, MessageId, Session, SessionId}; use rand::Rng; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc::{self, UnboundedSender}; -use tokio::sync::{self, Mutex, RwLock}; -use tokio_tungstenite::tungstenite::Message as TkMessage; -use tokio_tungstenite::WebSocketStream; +use tokio::sync::Mutex; +use tokio_tungstenite::tungstenite::http::header::LAST_MODIFIED; +use util::timestamp; -#[derive(Debug)] +#[derive(Debug, Clone)] struct Client { session: Session, - packets: UnboundedSender, + send: ConnTx, } #[derive(Debug)] @@ -36,12 +37,157 @@ impl Room { Self { clients: HashMap::new(), last_message: MessageId::of(&format!("{}", rand::thread_rng().gen::())), - last_timestamp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("executed after 1970") - .as_millis(), + last_timestamp: util::timestamp(), } } + + fn client(&self, id: SessionId) -> &Client { + self.clients.get(&id).expect("invalid session id") + } + + fn client_mut(&mut self, id: SessionId) -> &mut Client { + self.clients.get_mut(&id).expect("invalid session id") + } + + fn notify_all(&self, packet: &Packet) { + for client in self.clients.values() { + let _ = client.send.send(packet); + } + } + + fn notify_except(&self, id: SessionId, packet: &Packet) { + for client in self.clients.values() { + if client.session.id != id { + let _ = client.send.send(packet); + } + } + } + + fn part(&mut self, id: SessionId) { + let client = self.clients.remove(&id).expect("invalid session id"); + + self.notify_all(&Packet::ntf(PartNtf { + who: client.session, + })); + } + + fn join(&mut self, client: Client) { + self.notify_all(&Packet::ntf(JoinNtf { + who: client.session.clone(), + })); + + self.clients.insert(client.session.id, client); + } + + fn nick(&mut self, id: SessionId, nick: String) { + let who = { + let client = self.client_mut(id); + client.session.nick = nick; + client.session.clone() + }; + + self.notify_except(id, &Packet::ntf(NickNtf { who })) + } + + fn send(&mut self, id: SessionId, parent: Option, content: String) -> Message { + let client = &self.clients[&id]; + + self.last_timestamp = util::timestamp_after(self.last_timestamp); + + let message = Message { + time: self.last_timestamp, + pred: self.last_message, + parent, + identity: client.session.identity, + nick: client.session.nick.clone(), + content, + }; + + self.notify_except( + id, + &Packet::ntf(SendNtf { + message: message.clone(), + }), + ); + + message + } +} + +#[derive(Debug)] +struct ServerSession { + tx: ConnTx, + rx: ConnRx, + room: Arc>, + session: Session, +} + +impl ServerSession { + fn new(tx: ConnTx, rx: ConnRx, room: Arc>, session: Session) -> Self { + Self { + tx, + rx, + room, + session, + } + } + + async fn handle_nick(&mut self, id: u64, cmd: NickCmd) -> anyhow::Result<()> { + if let Some(reason) = util::check_nick(&cmd.nick) { + self.tx + .send(&Packet::rpl(id, NickRpl::InvalidNick { reason }))?; + return Ok(()); + } + + self.session.nick = cmd.nick.clone(); + self.tx.send(&Packet::rpl(id, NickRpl::Success))?; + self.room.lock().await.nick(self.session.id, cmd.nick); + + Ok(()) + } + + async fn handle_send(&mut self, id: u64, cmd: SendCmd) -> anyhow::Result<()> { + if let Some(reason) = util::check_content(&cmd.content) { + self.tx + .send(&Packet::rpl(id, SendRpl::InvalidContent { reason }))?; + return Ok(()); + } + + let message = self + .room + .lock() + .await + .send(self.session.id, cmd.parent, cmd.content); + + self.tx + .send(&Packet::rpl(id, SendRpl::Success { message }))?; + + Ok(()) + } + + async fn handle_who(&mut self, id: u64, cmd: WhoCmd) -> anyhow::Result<()> { + todo!() + } + + async fn handle_packet(&mut self, packet: Packet) -> anyhow::Result<()> { + match packet { + Packet::Cmd { id, cmd } => match cmd { + Cmd::Hello(_) => Err(anyhow!("unexpected Hello cmd")), + Cmd::Nick(cmd) => self.handle_nick(id, cmd).await, + Cmd::Send(cmd) => self.handle_send(id, cmd).await, + Cmd::Who(cmd) => self.handle_who(id, cmd).await, + }, + Packet::Rpl { .. } => Err(anyhow!("unexpected rpl")), + Packet::Ntf { .. } => Err(anyhow!("unexpected ntf")), + } + } + + async fn run(&mut self) -> anyhow::Result<()> { + while let Some(packet) = self.rx.recv().await? { + self.handle_packet(packet).await?; + } + Ok(()) + } } #[derive(Debug, Clone)] @@ -65,129 +211,105 @@ impl Server { .clone() } - async fn recv(rx: &mut SplitStream>) -> anyhow::Result { - loop { - let msg = rx.next().await.ok_or(anyhow!("connection closed"))??; - let str = match msg { - TkMessage::Text(str) => str, - TkMessage::Ping(_) | TkMessage::Pong(_) => continue, - TkMessage::Binary(_) => return Err(anyhow!("invalid binary packet")), - TkMessage::Close(_) => return Err(anyhow!("connection closed")), - }; - break Ok(serde_json::from_str(&str)?); - } - } - - async fn send( - tx: &mut SplitSink, TkMessage>, - packet: &Packet, - ) -> anyhow::Result<()> { - let str = serde_json::to_string(packet).expect("serialisable packet"); - let msg = TkMessage::Text(str); - tx.feed(msg).await?; - tx.flush().await?; - Ok(()) - } - - fn check_room(room: &str) -> Option { - if !room.is_empty() { - return Some("is empty".to_string()); - } - if !room.is_ascii() { - return Some("contains non-ascii characters".to_string()); - } - if room.len() > 1024 { - return Some("contains more than 1024 characters".to_string()); - } - if !room - .chars() - .all(|c| c == '-' || c == '.' || ('a'..='z').contains(&c)) - { - return Some("must only contain a-z, '-' and '_'".to_string()); - } - None - } - - fn check_nick(nick: &str) -> Option { - if !nick.is_empty() { - return Some("is empty".to_string()); - } - if !nick.trim().is_empty() { - return Some("contains only whitespace".to_string()); - } - let nick = nick.trim(); - if nick.chars().count() > 1024 { - return Some("contains more than 1024 characters".to_string()); - } - None - } - - fn check_identity(identity: &str) -> Option { - if identity.chars().count() > 32768 { - return Some("contains more than 32768 characters".to_string()); - } - None - } - - async fn greet( + async fn handle_hello( &self, - tx: &mut SplitSink, TkMessage>, - rx: &mut SplitStream>, - ) -> anyhow::Result<(String, Session, u64)> { - // TODO Allow multiple Hello commands until the first succeeds - let packet = Self::recv(rx).await?; - let (id, cmd) = match packet { - Packet::Cmd { - id, - cmd: Cmd::Hello(cmd), - } => (id, cmd), - _ => return Err(anyhow!("not a hello command")), - }; - if let Some(reason) = Self::check_room(&cmd.room) { - Self::send(tx, &Packet::rpl(id, HelloRpl::InvalidRoom { reason })).await?; - return Err(anyhow!("invalid room")); + tx: &ConnTx, + id: u64, + cmd: HelloCmd, + ) -> anyhow::Result> { + if let Some(reason) = util::check_room(&cmd.room) { + tx.send(&Packet::rpl(id, HelloRpl::InvalidRoom { reason }))?; + return Ok(None); } - if let Some(reason) = Self::check_nick(&cmd.nick) { - Self::send(tx, &Packet::rpl(id, HelloRpl::InvalidNick { reason })).await?; - return Err(anyhow!("invalid nick")); + if let Some(reason) = util::check_nick(&cmd.nick) { + tx.send(&Packet::rpl(id, HelloRpl::InvalidNick { reason }))?; + return Ok(None); } - if let Some(reason) = Self::check_identity(&cmd.identity) { - Self::send(tx, &Packet::rpl(id, HelloRpl::InvalidNick { reason })).await?; - return Err(anyhow!("invalid identity")); + if let Some(reason) = util::check_identity(&cmd.identity) { + tx.send(&Packet::rpl(id, HelloRpl::InvalidIdentity { reason }))?; + return Ok(None); } + let session = Session { id: SessionId::of(&format!("{}", rand::thread_rng().gen::())), nick: cmd.nick, identity: Identity::of(&cmd.identity), }; - Ok((cmd.room, session, id)) + + Ok(Some((cmd.room, session))) } - async fn on_conn(self, stream: TcpStream) { + async fn greet(&self, tx: ConnTx, mut rx: ConnRx) -> anyhow::Result { + let (id, room, session) = loop { + let (id, cmd) = match rx.recv().await? { + Some(Packet::Cmd { + id, + cmd: Cmd::Hello(cmd), + }) => (id, cmd), + _ => return Err(anyhow!("not a Hello command")), + }; + + if let Some((room, session)) = self.handle_hello(&tx, id, cmd).await? { + break (id, room, session); + } + }; + + let room = self.room(room).await; + + { + let mut room = room.lock().await; + + let you = session.clone(); + let others = room + .clients + .values() + .map(|client| client.session.clone()) + .collect::>(); + let last_message = room.last_message; + + tx.send(&Packet::rpl( + id, + HelloRpl::Success { + you, + others, + last_message, + }, + ))?; + + room.join(Client { + session: session.clone(), + send: tx.clone(), + }); + } + + Ok(ServerSession { + tx, + rx, + room, + session, + }) + } + async fn greet_and_run(&self, tx: ConnTx, rx: ConnRx) -> anyhow::Result<()> { + let mut session = self.greet(tx, rx).await?; + let result = session.run().await; + session.room.lock().await.part(session.session.id); + result + } + + /// Wrapper for [`ConnMaintenance::perform`] so it returns an + /// [`anyhow::Result`]. + async fn maintain(maintenance: ConnMaintenance) -> anyhow::Result<()> { + maintenance.perform().await?; + Ok(()) + } + + async fn on_conn(self, stream: TcpStream) -> anyhow::Result<()> { // TODO Ping-pong starting from the beginning (not just after hello) println!("Connection from {}", stream.peer_addr().unwrap()); let stream = tokio_tungstenite::accept_async(stream).await.unwrap(); - let (mut tx, mut rx) = stream.split(); - let (room, session, id) = match self.greet(&mut tx, &mut rx).await { - Ok(info) => info, - Err(_) => return, - }; - let room = self.room(room).await; - let (packets, client_rx) = mpsc::unbounded_channel(); - { - let mut room = room.lock().await; - packets.send(Packet::rpl( - id, - HelloRpl::Success { - you: session.clone(), - others: room.clients.values().map(|c| c.session.clone()).collect(), - last_message: room.last_message, - }, - )); - let client = Client { session, packets }; - room.clients.insert(client.session.id, client); - } - todo!() + let (tx, rx, maintenance) = conn::new(stream, Duration::from_secs(10))?; + tokio::try_join!(self.greet_and_run(tx, rx), Self::maintain(maintenance))?; + Ok(()) } } diff --git a/cove-server/src/util.rs b/cove-server/src/util.rs new file mode 100644 index 0000000..3d48c41 --- /dev/null +++ b/cove-server/src/util.rs @@ -0,0 +1,60 @@ +use std::cmp; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn timestamp() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("executed after 1970") + .as_millis() +} + +pub fn timestamp_after(previous: u128) -> u128 { + cmp::max(timestamp(), previous + 1) +} + +pub fn check_room(room: &str) -> Option { + if !room.is_empty() { + return Some("is empty".to_string()); + } + if !room.is_ascii() { + return Some("contains non-ascii characters".to_string()); + } + if room.len() > 1024 { + return Some("contains more than 1024 characters".to_string()); + } + if !room + .chars() + .all(|c| c == '-' || c == '.' || ('a'..='z').contains(&c)) + { + return Some("must only contain a-z, '-' and '_'".to_string()); + } + None +} + +pub fn check_nick(nick: &str) -> Option { + if !nick.is_empty() { + return Some("is empty".to_string()); + } + if !nick.trim().is_empty() { + return Some("contains only whitespace".to_string()); + } + let nick = nick.trim(); + if nick.chars().count() > 1024 { + return Some("contains more than 1024 characters".to_string()); + } + None +} + +pub fn check_identity(identity: &str) -> Option { + if identity.chars().count() > 32 * 1024 { + return Some("contains more than 32768 characters".to_string()); + } + None +} + +pub fn check_content(content: &str) -> Option { + if content.chars().count() > 128 * 1024 { + return Some("contains more than 131072 characters".to_string()); + } + None +}