Wrap state in Arc<Mutex<_>> instead of whole connection

This commit is contained in:
Joscha 2022-03-03 00:49:17 +01:00
parent dfb5ade023
commit 17d08438d1

View file

@ -8,7 +8,7 @@ use cove_core::packets::{
RoomRpl, Rpl, SendNtf, SendRpl, WhoRpl, RoomRpl, Rpl, SendNtf, SendRpl, WhoRpl,
}; };
use cove_core::{Session, SessionId}; use cove_core::{Session, SessionId};
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::replies::{self, Replies}; use crate::replies::{self, Replies};
@ -144,36 +144,21 @@ impl State {
} }
} }
#[derive(Clone)]
pub struct CoveConn { pub struct CoveConn {
state: State, state: Arc<Mutex<State>>,
ev_tx: UnboundedSender<Event>, ev_tx: UnboundedSender<Event>,
} }
impl CoveConn { impl CoveConn {
pub fn state(&self) -> &State { async fn cmd<C, R>(&self, cmd: C) -> Result<R, Error>
&self.state
}
pub fn state_mut(&mut self) -> &mut State {
&mut self.state
}
pub fn connected(&self) -> Option<&Connected> {
self.state.connected()
}
pub fn connected_mut(&mut self) -> Option<&mut Connected> {
self.state.connected_mut()
}
async fn cmd<C, R>(conn: &Mutex<Self>, cmd: C) -> Result<R, Error>
where where
C: Into<Cmd>, C: Into<Cmd>,
Rpl: TryInto<R>, Rpl: TryInto<R>,
{ {
let pending_reply = { let pending_reply = {
let mut conn = conn.lock().await; let mut state = self.state.lock().await;
let mut connected = conn.connected_mut().ok_or(Error::NotConnected)?; let mut connected = state.connected_mut().ok_or(Error::NotConnected)?;
let id = connected.next_id; let id = connected.next_id;
connected.next_id += 1; connected.next_id += 1;
@ -193,13 +178,13 @@ impl CoveConn {
/// ///
/// This method is intended to be called whenever a CoveConn user suspects /// This method is intended to be called whenever a CoveConn user suspects
/// identification to be necessary. It has little overhead. /// identification to be necessary. It has little overhead.
pub async fn identify(conn: Arc<Mutex<Self>>, nick: &str, identity: &str) { pub async fn identify(&self, nick: &str, identity: &str) {
{ {
let mut conn = conn.lock().await; let mut state = self.state.lock().await;
if let Some(connected) = conn.connected_mut() { if let Some(connected) = state.connected_mut() {
if let Status::IdRequired(_) = connected.status { if let Status::IdRequired(_) = connected.status {
connected.status = Status::Identifying; connected.status = Status::Identifying;
conn.ev_tx.send(Event::StateChanged); self.ev_tx.send(Event::StateChanged);
} else { } else {
return; return;
} }
@ -208,13 +193,15 @@ impl CoveConn {
} }
} }
let conn = self.clone();
let nick = nick.to_string(); let nick = nick.to_string();
let identity = identity.to_string(); let identity = identity.to_string();
tokio::spawn(async move { tokio::spawn(async move {
// There's no need for a second locking block, or for us to see the // There's no need for a second locking block, or for us to see the
// result of this command. CoveConnMt::run will set the connection's // result of this command. CoveConnMt::run will set the connection's
// status as appropriate. // status as appropriate.
Self::cmd::<IdentifyCmd, IdentifyRpl>(&conn, IdentifyCmd { nick, identity }).await conn.cmd::<IdentifyCmd, IdentifyRpl>(IdentifyCmd { nick, identity })
.await
}); });
} }
} }
@ -224,8 +211,7 @@ pub struct CoveConnMt {
url: String, url: String,
room: String, room: String,
timeout: Duration, timeout: Duration,
conn: Arc<Mutex<CoveConn>>, conn: CoveConn,
ev_tx: UnboundedSender<Event>,
} }
impl CoveConnMt { impl CoveConnMt {
@ -233,23 +219,23 @@ impl CoveConnMt {
let (tx, rx, mt) = match Self::connect(&self.url, self.timeout).await { let (tx, rx, mt) = match Self::connect(&self.url, self.timeout).await {
Ok(conn) => conn, Ok(conn) => conn,
Err(e) => { Err(e) => {
*self.conn.lock().await.state_mut() = State::Stopped; *self.conn.state.lock().await = State::Stopped;
self.ev_tx.send(Event::StateChanged); self.conn.ev_tx.send(Event::StateChanged);
return Err(Error::CouldNotConnect(e)); return Err(Error::CouldNotConnect(e));
} }
}; };
*self.conn.lock().await.state_mut() = State::Connected(Connected::new(tx, self.timeout)); *self.conn.state.lock().await = State::Connected(Connected::new(tx, self.timeout));
self.ev_tx.send(Event::StateChanged); self.conn.ev_tx.send(Event::StateChanged);
tokio::spawn(Self::join_room(self.conn.clone(), self.room)); tokio::spawn(Self::join_room(self.conn.clone(), self.room));
let result = tokio::select! { let result = tokio::select! {
result = Self::recv(&self.conn, &self.ev_tx, rx) => result, result = Self::recv(&self.conn, rx) => result,
_ = mt.perform() => Err(Error::MaintenanceAborted), _ = mt.perform() => Err(Error::MaintenanceAborted),
}; };
*self.conn.lock().await.state_mut() = State::Stopped; *self.conn.state.lock().await = State::Stopped;
self.ev_tx.send(Event::StateChanged); self.conn.ev_tx.send(Event::StateChanged);
result result
} }
@ -263,36 +249,27 @@ impl CoveConnMt {
Ok(conn) Ok(conn)
} }
async fn join_room(conn: Arc<Mutex<CoveConn>>, name: String) -> Result<(), Error> { async fn join_room(conn: CoveConn, name: String) -> Result<(), Error> {
let reply: RoomRpl = CoveConn::cmd(&conn, RoomCmd { name }).await?; let reply: RoomRpl = conn.cmd(RoomCmd { name }).await?;
Ok(()) Ok(())
} }
async fn recv( async fn recv(conn: &CoveConn, mut rx: ConnRx) -> Result<(), Error> {
conn: &Mutex<CoveConn>,
ev_tx: &UnboundedSender<Event>,
mut rx: ConnRx,
) -> Result<(), Error> {
while let Some(packet) = rx.recv().await? { while let Some(packet) = rx.recv().await? {
match packet { match packet {
Packet::Cmd { id, cmd } => { Packet::Cmd { id, cmd } => {
// Ignore commands as the server doesn't send them. // Ignore commands as the server doesn't send them.
} }
Packet::Rpl { id, rpl } => Self::on_rpl(&conn, &ev_tx, id, rpl).await?, Packet::Rpl { id, rpl } => Self::on_rpl(conn, id, rpl).await?,
Packet::Ntf { ntf } => Self::on_ntf(&conn, &ev_tx, ntf).await?, Packet::Ntf { ntf } => Self::on_ntf(conn, ntf).await?,
} }
} }
Ok(()) Ok(())
} }
async fn on_rpl( async fn on_rpl(conn: &CoveConn, id: u64, rpl: Rpl) -> Result<(), Error> {
conn: &Mutex<CoveConn>, let mut state = conn.state.lock().await;
ev_tx: &UnboundedSender<Event>, let connected = match state.connected_mut() {
id: u64,
rpl: Rpl,
) -> Result<(), Error> {
let mut conn = conn.lock().await;
let connected = match conn.connected_mut() {
Some(connected) => connected, Some(connected) => connected,
None => return Ok(()), None => return Ok(()),
}; };
@ -300,18 +277,18 @@ impl CoveConnMt {
match &rpl { match &rpl {
Rpl::Room(RoomRpl::Success) => { Rpl::Room(RoomRpl::Success) => {
connected.status = Status::IdRequired(None); connected.status = Status::IdRequired(None);
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
Rpl::Room(RoomRpl::InvalidRoom { reason }) => { Rpl::Room(RoomRpl::InvalidRoom { reason }) => {
return Err(Error::InvalidRoom(reason.clone())) return Err(Error::InvalidRoom(reason.clone()))
} }
Rpl::Identify(IdentifyRpl::Success { you, others, .. }) => { Rpl::Identify(IdentifyRpl::Success { you, others, .. }) => {
connected.status = Status::Present(Present::new(you, others)); connected.status = Status::Present(Present::new(you, others));
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
Rpl::Identify(IdentifyRpl::InvalidNick { reason }) => { Rpl::Identify(IdentifyRpl::InvalidNick { reason }) => {
connected.status = Status::IdRequired(Some(reason.clone())); connected.status = Status::IdRequired(Some(reason.clone()));
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
Rpl::Identify(IdentifyRpl::InvalidIdentity { reason }) => { Rpl::Identify(IdentifyRpl::InvalidIdentity { reason }) => {
return Err(Error::InvalidIdentity(reason.clone())) return Err(Error::InvalidIdentity(reason.clone()))
@ -319,7 +296,7 @@ impl CoveConnMt {
Rpl::Nick(NickRpl::Success { you }) => { Rpl::Nick(NickRpl::Success { you }) => {
if let Some(present) = connected.status.present_mut() { if let Some(present) = connected.status.present_mut() {
present.update_session(you); present.update_session(you);
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
} }
Rpl::Nick(NickRpl::InvalidNick { reason }) => {} Rpl::Nick(NickRpl::InvalidNick { reason }) => {}
@ -330,7 +307,7 @@ impl CoveConnMt {
Rpl::Who(WhoRpl { you, others }) => { Rpl::Who(WhoRpl { you, others }) => {
if let Some(present) = connected.status.present_mut() { if let Some(present) = connected.status.present_mut() {
present.update(you, others); present.update(you, others);
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
} }
} }
@ -340,13 +317,9 @@ impl CoveConnMt {
Ok(()) Ok(())
} }
async fn on_ntf( async fn on_ntf(conn: &CoveConn, ntf: Ntf) -> Result<(), Error> {
conn: &Mutex<CoveConn>, let mut state = conn.state.lock().await;
ev_tx: &UnboundedSender<Event>, let connected = match state.connected_mut() {
ntf: Ntf,
) -> Result<(), Error> {
let mut conn = conn.lock().await;
let connected = match conn.connected_mut() {
Some(connected) => connected, Some(connected) => connected,
None => return Ok(()), None => return Ok(()),
}; };
@ -355,19 +328,19 @@ impl CoveConnMt {
Ntf::Join(JoinNtf { who }) => { Ntf::Join(JoinNtf { who }) => {
if let Some(present) = connected.status.present_mut() { if let Some(present) = connected.status.present_mut() {
present.join(who); present.join(who);
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
} }
Ntf::Nick(NickNtf { who }) => { Ntf::Nick(NickNtf { who }) => {
if let Some(present) = connected.status.present_mut() { if let Some(present) = connected.status.present_mut() {
present.nick(who); present.nick(who);
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
} }
Ntf::Part(PartNtf { who }) => { Ntf::Part(PartNtf { who }) => {
if let Some(present) = connected.status.present_mut() { if let Some(present) = connected.status.present_mut() {
present.part(who); present.part(who);
ev_tx.send(Event::StateChanged); conn.ev_tx.send(Event::StateChanged);
} }
} }
Ntf::Send(SendNtf { message }) => { Ntf::Send(SendNtf { message }) => {
@ -384,17 +357,16 @@ pub async fn new(
room: String, room: String,
timeout: Duration, timeout: Duration,
ev_tx: UnboundedSender<Event>, ev_tx: UnboundedSender<Event>,
) -> (Arc<Mutex<CoveConn>>, CoveConnMt) { ) -> (CoveConn, CoveConnMt) {
let conn = Arc::new(Mutex::new(CoveConn { let conn = CoveConn {
state: State::Connecting, state: Arc::new(Mutex::new(State::Connecting)),
ev_tx: ev_tx.clone(), ev_tx: ev_tx.clone(),
})); };
let mt = CoveConnMt { let mt = CoveConnMt {
url, url,
room, room,
timeout, timeout,
conn, conn,
ev_tx,
}; };
(mt.conn.clone(), mt) (mt.conn.clone(), mt)
} }