Remove instance id generics

Instead, all instance ids are now usize (like message ids). This allows
me to enforce the fact that no two instances of a Bot must have the same
id by generating the ids in the Bot.

Reusing the same id for multiple instances that send their events to the
same place can lead to race conditions depending on how events are
handled. For example, the old instance might still be shutting down
while the new instance is already connected to a room, leading to an
InstanceEvent::Stopped from the old instance that seemingly applies to
the new instance.
This commit is contained in:
Joscha 2024-12-27 14:42:35 +01:00
parent 17ff660ab2
commit 8377695529
4 changed files with 68 additions and 83 deletions

View file

@ -77,7 +77,7 @@ async fn run() -> anyhow::Result<()> {
.instance("test") .instance("test")
.with_username("examplebot"); .with_username("examplebot");
bot.add_instance((), config); bot.add_instance(config);
while let Some(event) = bot.recv().await { while let Some(event) = bot.recv().await {
if let BotEvent::Packet { conn, packet, .. } = event { if let BotEvent::Packet { conn, packet, .. } = event {

View file

@ -77,7 +77,7 @@ async fn run() -> anyhow::Result<()> {
.with_username("examplebot"); .with_username("examplebot");
let (event_tx, mut event_rx) = mpsc::channel(10); let (event_tx, mut event_rx) = mpsc::channel(10);
let _instance = Instance::new((), config, event_tx); // Don't drop or instance stops let _instance = Instance::new(0, config, event_tx); // Don't drop or instance stops
while let Some(event) = event_rx.recv().await { while let Some(event) = event_rx.recv().await {
if let InstanceEvent::Packet { conn, packet, .. } = event { if let InstanceEvent::Packet { conn, packet, .. } = event {

View file

@ -1,6 +1,5 @@
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt, hash,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
@ -13,39 +12,39 @@ use tokio::sync::mpsc;
use crate::{BotConfig, Instance, InstanceConfig, InstanceEvent}; use crate::{BotConfig, Instance, InstanceConfig, InstanceEvent};
#[derive(Debug)] #[derive(Debug)]
pub enum BotEvent<I> { pub enum BotEvent {
Started { Started {
instance: Instance<I>, instance: Instance,
}, },
Connecting { Connecting {
instance: Instance<I>, instance: Instance,
}, },
Connected { Connected {
instance: Instance<I>, instance: Instance,
conn: ClientConnHandle, conn: ClientConnHandle,
state: State, state: State,
}, },
Joined { Joined {
instance: Instance<I>, instance: Instance,
conn: ClientConnHandle, conn: ClientConnHandle,
state: State, state: State,
}, },
Packet { Packet {
instance: Instance<I>, instance: Instance,
conn: ClientConnHandle, conn: ClientConnHandle,
state: State, state: State,
packet: ParsedPacket, packet: ParsedPacket,
}, },
Disconnected { Disconnected {
instance: Instance<I>, instance: Instance,
}, },
Stopped { Stopped {
instance: Instance<I>, instance: Instance,
}, },
} }
impl<I> BotEvent<I> { impl BotEvent {
fn from_instance_event(instance: Instance<I>, event: InstanceEvent<I>) -> Self { fn from_instance_event(instance: Instance, event: InstanceEvent) -> Self {
match event { match event {
InstanceEvent::Started { id: _ } => Self::Started { instance }, InstanceEvent::Started { id: _ } => Self::Started { instance },
InstanceEvent::Connecting { id: _ } => Self::Connecting { instance }, InstanceEvent::Connecting { id: _ } => Self::Connecting { instance },
@ -75,7 +74,7 @@ impl<I> BotEvent<I> {
} }
} }
pub fn instance(&self) -> &Instance<I> { pub fn instance(&self) -> &Instance {
match self { match self {
Self::Started { instance } => instance, Self::Started { instance } => instance,
Self::Connecting { instance, .. } => instance, Self::Connecting { instance, .. } => instance,
@ -88,14 +87,15 @@ impl<I> BotEvent<I> {
} }
} }
pub struct Bot<I> { pub struct Bot {
config: BotConfig, config: BotConfig,
instances: Arc<RwLock<HashMap<I, Instance<I>>>>, next_id: usize,
event_tx: mpsc::Sender<InstanceEvent<I>>, instances: Arc<RwLock<HashMap<usize, Instance>>>,
event_rx: mpsc::Receiver<InstanceEvent<I>>, event_tx: mpsc::Sender<InstanceEvent>,
event_rx: mpsc::Receiver<InstanceEvent>,
} }
impl<I> Bot<I> { impl Bot {
pub fn new() -> Self { pub fn new() -> Self {
Self::new_with_config(BotConfig::default()) Self::new_with_config(BotConfig::default())
} }
@ -104,6 +104,7 @@ impl<I> Bot<I> {
let (event_tx, event_rx) = mpsc::channel(10); let (event_tx, event_rx) = mpsc::channel(10);
Self { Self {
config, config,
next_id: 0,
instances: Arc::new(RwLock::new(HashMap::new())), instances: Arc::new(RwLock::new(HashMap::new())),
event_tx, event_tx,
event_rx, event_rx,
@ -115,30 +116,24 @@ impl<I> Bot<I> {
guard.retain(|_, v| !v.stopped()); guard.retain(|_, v| !v.stopped());
} }
pub fn get_instances(&self) -> Vec<Instance<I>> pub fn get_instances(&self) -> Vec<Instance> {
where
I: Clone,
{
self.instances.read().unwrap().values().cloned().collect() self.instances.read().unwrap().values().cloned().collect()
} }
pub fn add_instance(&self, id: I, config: InstanceConfig) pub fn add_instance(&mut self, config: InstanceConfig) -> Instance {
where let id = self.next_id;
I: Clone + fmt::Debug + Send + 'static + Eq + hash::Hash, self.next_id += 1;
{
let mut guard = self.instances.write().unwrap(); let mut guard = self.instances.write().unwrap();
assert!(!guard.contains_key(&id));
if guard.contains_key(&id) { let instance = Instance::new(id, config, self.event_tx.clone());
return; guard.insert(id, instance.clone());
instance
} }
guard.insert(id.clone(), Instance::new(id, config, self.event_tx.clone())); pub async fn recv(&mut self) -> Option<BotEvent> {
}
pub async fn recv(&mut self) -> Option<BotEvent<I>>
where
I: Clone + Eq + hash::Hash,
{
// We hold exactly one sender. If no other senders exist, then all // We hold exactly one sender. If no other senders exist, then all
// instances are dead and we'll never receive any more events unless we // instances are dead and we'll never receive any more events unless we
// return and allow the user to add more instances again. // return and allow the user to add more instances again.
@ -159,7 +154,7 @@ impl<I> Bot<I> {
// own one sender, this can't happen. // own one sender, this can't happen.
let event = event.expect("event channel should never close since we own a sender"); let event = event.expect("event channel should never close since we own a sender");
if let Some(instance) = self.instances.read().unwrap().get(event.id()) { if let Some(instance) = self.instances.read().unwrap().get(&event.id()) {
return Some(BotEvent::from_instance_event(instance.clone(), event)); return Some(BotEvent::from_instance_event(instance.clone(), event));
} }
} }
@ -168,7 +163,7 @@ impl<I> Bot<I> {
} }
} }
impl<I> Default for Bot<I> { impl Default for Bot {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }

View file

@ -72,63 +72,63 @@ enum Command {
} }
#[derive(Debug)] #[derive(Debug)]
pub enum InstanceEvent<I> { pub enum InstanceEvent {
Started { Started {
id: I, id: usize,
}, },
Connecting { Connecting {
id: I, id: usize,
}, },
Connected { Connected {
id: I, id: usize,
conn: ClientConnHandle, conn: ClientConnHandle,
state: State, state: State,
}, },
Joined { Joined {
id: I, id: usize,
conn: ClientConnHandle, conn: ClientConnHandle,
state: State, state: State,
}, },
Packet { Packet {
id: I, id: usize,
conn: ClientConnHandle, conn: ClientConnHandle,
state: State, state: State,
packet: ParsedPacket, packet: ParsedPacket,
}, },
Disconnected { Disconnected {
id: I, id: usize,
}, },
Stopped { Stopped {
id: I, id: usize,
}, },
} }
impl<I> InstanceEvent<I> { impl InstanceEvent {
pub fn id(&self) -> &I { pub fn id(&self) -> usize {
match self { match self {
Self::Started { id } => id, Self::Started { id } => *id,
Self::Connecting { id } => id, Self::Connecting { id } => *id,
Self::Connected { id, .. } => id, Self::Connected { id, .. } => *id,
Self::Joined { id, .. } => id, Self::Joined { id, .. } => *id,
Self::Packet { id, .. } => id, Self::Packet { id, .. } => *id,
Self::Disconnected { id } => id, Self::Disconnected { id } => *id,
Self::Stopped { id } => id, Self::Stopped { id } => *id,
} }
} }
} }
struct InstanceTask<I> { struct InstanceTask {
id: I, id: usize,
config: InstanceConfig, config: InstanceConfig,
cmd_rx: mpsc::Receiver<Command>, cmd_rx: mpsc::Receiver<Command>,
event_tx: mpsc::Sender<InstanceEvent<I>>, event_tx: mpsc::Sender<InstanceEvent>,
attempts: usize, attempts: usize,
never_joined: bool, never_joined: bool,
} }
impl<I: Clone + fmt::Debug> InstanceTask<I> { impl InstanceTask {
fn get_cookies(&self) -> Option<HeaderValue> { fn get_cookies(&self) -> Option<HeaderValue> {
self.config self.config
.server .server
@ -173,7 +173,7 @@ impl<I: Clone + fmt::Debug> InstanceTask<I> {
let _ = self let _ = self
.event_tx .event_tx
.send(InstanceEvent::Joined { .send(InstanceEvent::Joined {
id: self.id.clone(), id: self.id,
conn: conn.handle(), conn: conn.handle(),
state: conn.state().clone(), state: conn.state().clone(),
}) })
@ -184,7 +184,7 @@ impl<I: Clone + fmt::Debug> InstanceTask<I> {
let _ = self let _ = self
.event_tx .event_tx
.send(InstanceEvent::Packet { .send(InstanceEvent::Packet {
id: self.id.clone(), id: self.id,
conn: conn.handle(), conn: conn.handle(),
state: conn.state().clone(), state: conn.state().clone(),
packet: packet.clone(), packet: packet.clone(),
@ -261,9 +261,7 @@ impl<I: Clone + fmt::Debug> InstanceTask<I> {
let _ = self let _ = self
.event_tx .event_tx
.send(InstanceEvent::Connecting { .send(InstanceEvent::Connecting { id: self.id })
id: self.id.clone(),
})
.await; .await;
let mut conn = match self.connect().await { let mut conn = match self.connect().await {
@ -283,7 +281,7 @@ impl<I: Clone + fmt::Debug> InstanceTask<I> {
let _ = self let _ = self
.event_tx .event_tx
.send(InstanceEvent::Connected { .send(InstanceEvent::Connected {
id: self.id.clone(), id: self.id,
conn: conn.handle(), conn: conn.handle(),
state: conn.state().clone(), state: conn.state().clone(),
}) })
@ -307,9 +305,7 @@ impl<I: Clone + fmt::Debug> InstanceTask<I> {
let _ = self let _ = self
.event_tx .event_tx
.send(InstanceEvent::Disconnected { .send(InstanceEvent::Disconnected { id: self.id })
id: self.id.clone(),
})
.await; .await;
result result
@ -318,9 +314,7 @@ impl<I: Clone + fmt::Debug> InstanceTask<I> {
async fn run(mut self) { async fn run(mut self) {
let _ = self let _ = self
.event_tx .event_tx
.send(InstanceEvent::Started { .send(InstanceEvent::Started { id: self.id })
id: self.id.clone(),
})
.await; .await;
loop { loop {
@ -334,20 +328,18 @@ impl<I: Clone + fmt::Debug> InstanceTask<I> {
let _ = self let _ = self
.event_tx .event_tx
.send(InstanceEvent::Stopped { .send(InstanceEvent::Stopped { id: self.id })
id: self.id.clone(),
})
.await; .await;
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Instance<I> { pub struct Instance {
id: I, id: usize,
cmd_tx: mpsc::Sender<Command>, cmd_tx: mpsc::Sender<Command>,
} }
impl<I: fmt::Debug> fmt::Debug for Instance<I> { impl fmt::Debug for Instance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Instance") f.debug_struct("Instance")
.field("id", &self.id) .field("id", &self.id)
@ -355,12 +347,12 @@ impl<I: fmt::Debug> fmt::Debug for Instance<I> {
} }
} }
impl<I: Clone + fmt::Debug + Send + 'static> Instance<I> { impl Instance {
pub fn new(id: I, config: InstanceConfig, event_tx: mpsc::Sender<InstanceEvent<I>>) -> Self { pub fn new(id: usize, config: InstanceConfig, event_tx: mpsc::Sender<InstanceEvent>) -> Self {
let (cmd_tx, cmd_rx) = mpsc::channel(config.server.cmd_channel_bufsize); let (cmd_tx, cmd_rx) = mpsc::channel(config.server.cmd_channel_bufsize);
let task = InstanceTask { let task = InstanceTask {
id: id.clone(), id,
config, config,
attempts: 0, attempts: 0,
never_joined: false, never_joined: false,
@ -372,11 +364,9 @@ impl<I: Clone + fmt::Debug + Send + 'static> Instance<I> {
Self { id, cmd_tx } Self { id, cmd_tx }
} }
}
impl<I> Instance<I> { pub fn id(&self) -> usize {
pub fn id(&self) -> &I { self.id
&self.id
} }
pub fn stopped(&self) -> bool { pub fn stopped(&self) -> bool {