From 17d08438d1ee97129c74f90cad3e0012819d734c Mon Sep 17 00:00:00 2001 From: Joscha Date: Thu, 3 Mar 2022 00:49:17 +0100 Subject: [PATCH] Wrap state in Arc> instead of whole connection --- cove-tui/src/cove/conn.rs | 116 +++++++++++++++----------------------- 1 file changed, 44 insertions(+), 72 deletions(-) diff --git a/cove-tui/src/cove/conn.rs b/cove-tui/src/cove/conn.rs index 7dd3195..e800d81 100644 --- a/cove-tui/src/cove/conn.rs +++ b/cove-tui/src/cove/conn.rs @@ -8,7 +8,7 @@ use cove_core::packets::{ RoomRpl, Rpl, SendNtf, SendRpl, WhoRpl, }; use cove_core::{Session, SessionId}; -use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use tokio::sync::mpsc::UnboundedSender; use tokio::sync::Mutex; use crate::replies::{self, Replies}; @@ -144,36 +144,21 @@ impl State { } } +#[derive(Clone)] pub struct CoveConn { - state: State, + state: Arc>, ev_tx: UnboundedSender, } impl CoveConn { - pub fn state(&self) -> &State { - &self.state - } - - pub fn state_mut(&mut self) -> &mut State { - &mut self.state - } - - pub fn connected(&self) -> Option<&Connected> { - self.state.connected() - } - - pub fn connected_mut(&mut self) -> Option<&mut Connected> { - self.state.connected_mut() - } - - async fn cmd(conn: &Mutex, cmd: C) -> Result + async fn cmd(&self, cmd: C) -> Result where C: Into, Rpl: TryInto, { let pending_reply = { - let mut conn = conn.lock().await; - let mut connected = conn.connected_mut().ok_or(Error::NotConnected)?; + let mut state = self.state.lock().await; + let mut connected = state.connected_mut().ok_or(Error::NotConnected)?; let id = connected.next_id; connected.next_id += 1; @@ -193,13 +178,13 @@ impl CoveConn { /// /// This method is intended to be called whenever a CoveConn user suspects /// identification to be necessary. It has little overhead. - pub async fn identify(conn: Arc>, nick: &str, identity: &str) { + pub async fn identify(&self, nick: &str, identity: &str) { { - let mut conn = conn.lock().await; - if let Some(connected) = conn.connected_mut() { + let mut state = self.state.lock().await; + if let Some(connected) = state.connected_mut() { if let Status::IdRequired(_) = connected.status { connected.status = Status::Identifying; - conn.ev_tx.send(Event::StateChanged); + self.ev_tx.send(Event::StateChanged); } else { return; } @@ -208,13 +193,15 @@ impl CoveConn { } } + let conn = self.clone(); let nick = nick.to_string(); let identity = identity.to_string(); tokio::spawn(async move { // There's no need for a second locking block, or for us to see the // result of this command. CoveConnMt::run will set the connection's // status as appropriate. - Self::cmd::(&conn, IdentifyCmd { nick, identity }).await + conn.cmd::(IdentifyCmd { nick, identity }) + .await }); } } @@ -224,8 +211,7 @@ pub struct CoveConnMt { url: String, room: String, timeout: Duration, - conn: Arc>, - ev_tx: UnboundedSender, + conn: CoveConn, } impl CoveConnMt { @@ -233,23 +219,23 @@ impl CoveConnMt { let (tx, rx, mt) = match Self::connect(&self.url, self.timeout).await { Ok(conn) => conn, Err(e) => { - *self.conn.lock().await.state_mut() = State::Stopped; - self.ev_tx.send(Event::StateChanged); + *self.conn.state.lock().await = State::Stopped; + self.conn.ev_tx.send(Event::StateChanged); return Err(Error::CouldNotConnect(e)); } }; - *self.conn.lock().await.state_mut() = State::Connected(Connected::new(tx, self.timeout)); - self.ev_tx.send(Event::StateChanged); + *self.conn.state.lock().await = State::Connected(Connected::new(tx, self.timeout)); + self.conn.ev_tx.send(Event::StateChanged); tokio::spawn(Self::join_room(self.conn.clone(), self.room)); let result = tokio::select! { - result = Self::recv(&self.conn, &self.ev_tx, rx) => result, + result = Self::recv(&self.conn, rx) => result, _ = mt.perform() => Err(Error::MaintenanceAborted), }; - *self.conn.lock().await.state_mut() = State::Stopped; - self.ev_tx.send(Event::StateChanged); + *self.conn.state.lock().await = State::Stopped; + self.conn.ev_tx.send(Event::StateChanged); result } @@ -263,36 +249,27 @@ impl CoveConnMt { Ok(conn) } - async fn join_room(conn: Arc>, name: String) -> Result<(), Error> { - let reply: RoomRpl = CoveConn::cmd(&conn, RoomCmd { name }).await?; + async fn join_room(conn: CoveConn, name: String) -> Result<(), Error> { + let reply: RoomRpl = conn.cmd(RoomCmd { name }).await?; Ok(()) } - async fn recv( - conn: &Mutex, - ev_tx: &UnboundedSender, - mut rx: ConnRx, - ) -> Result<(), Error> { + async fn recv(conn: &CoveConn, mut rx: ConnRx) -> Result<(), Error> { while let Some(packet) = rx.recv().await? { match packet { Packet::Cmd { id, cmd } => { // Ignore commands as the server doesn't send them. } - Packet::Rpl { id, rpl } => Self::on_rpl(&conn, &ev_tx, id, rpl).await?, - Packet::Ntf { ntf } => Self::on_ntf(&conn, &ev_tx, ntf).await?, + Packet::Rpl { id, rpl } => Self::on_rpl(conn, id, rpl).await?, + Packet::Ntf { ntf } => Self::on_ntf(conn, ntf).await?, } } Ok(()) } - async fn on_rpl( - conn: &Mutex, - ev_tx: &UnboundedSender, - id: u64, - rpl: Rpl, - ) -> Result<(), Error> { - let mut conn = conn.lock().await; - let connected = match conn.connected_mut() { + async fn on_rpl(conn: &CoveConn, id: u64, rpl: Rpl) -> Result<(), Error> { + let mut state = conn.state.lock().await; + let connected = match state.connected_mut() { Some(connected) => connected, None => return Ok(()), }; @@ -300,18 +277,18 @@ impl CoveConnMt { match &rpl { Rpl::Room(RoomRpl::Success) => { connected.status = Status::IdRequired(None); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } Rpl::Room(RoomRpl::InvalidRoom { reason }) => { return Err(Error::InvalidRoom(reason.clone())) } Rpl::Identify(IdentifyRpl::Success { you, others, .. }) => { connected.status = Status::Present(Present::new(you, others)); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } Rpl::Identify(IdentifyRpl::InvalidNick { reason }) => { connected.status = Status::IdRequired(Some(reason.clone())); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } Rpl::Identify(IdentifyRpl::InvalidIdentity { reason }) => { return Err(Error::InvalidIdentity(reason.clone())) @@ -319,7 +296,7 @@ impl CoveConnMt { Rpl::Nick(NickRpl::Success { you }) => { if let Some(present) = connected.status.present_mut() { present.update_session(you); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } } Rpl::Nick(NickRpl::InvalidNick { reason }) => {} @@ -330,7 +307,7 @@ impl CoveConnMt { Rpl::Who(WhoRpl { you, others }) => { if let Some(present) = connected.status.present_mut() { present.update(you, others); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } } } @@ -340,13 +317,9 @@ impl CoveConnMt { Ok(()) } - async fn on_ntf( - conn: &Mutex, - ev_tx: &UnboundedSender, - ntf: Ntf, - ) -> Result<(), Error> { - let mut conn = conn.lock().await; - let connected = match conn.connected_mut() { + async fn on_ntf(conn: &CoveConn, ntf: Ntf) -> Result<(), Error> { + let mut state = conn.state.lock().await; + let connected = match state.connected_mut() { Some(connected) => connected, None => return Ok(()), }; @@ -355,19 +328,19 @@ impl CoveConnMt { Ntf::Join(JoinNtf { who }) => { if let Some(present) = connected.status.present_mut() { present.join(who); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } } Ntf::Nick(NickNtf { who }) => { if let Some(present) = connected.status.present_mut() { present.nick(who); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } } Ntf::Part(PartNtf { who }) => { if let Some(present) = connected.status.present_mut() { present.part(who); - ev_tx.send(Event::StateChanged); + conn.ev_tx.send(Event::StateChanged); } } Ntf::Send(SendNtf { message }) => { @@ -384,17 +357,16 @@ pub async fn new( room: String, timeout: Duration, ev_tx: UnboundedSender, -) -> (Arc>, CoveConnMt) { - let conn = Arc::new(Mutex::new(CoveConn { - state: State::Connecting, +) -> (CoveConn, CoveConnMt) { + let conn = CoveConn { + state: Arc::new(Mutex::new(State::Connecting)), ev_tx: ev_tx.clone(), - })); + }; let mt = CoveConnMt { url, room, timeout, conn, - ev_tx, }; (mt.conn.clone(), mt) }