diff --git a/Cargo.lock b/Cargo.lock index c95586d..20b9f53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -203,6 +203,7 @@ dependencies = [ "toss", "unicode-segmentation", "unicode-width", + "vault", ] [[package]] @@ -1392,6 +1393,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "vault" +version = "0.1.0" +source = "git+https://github.com/Garmelon/vault.git?tag=v0.1.0#028c72cac4e84bfbbf9fb03b15acb59989a31df9" +dependencies = [ + "rusqlite", + "tokio", +] + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index db7ea53..4fd5a56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,3 +47,11 @@ rev = "0d59116012a51516a821991e2969b1cf4779770f" # [patch."https://github.com/Garmelon/toss.git"] # toss = { path = "../toss/" } + +[dependencies.vault] +git = "https://github.com/Garmelon/vault.git" +tag = "v0.1.0" +features = ["tokio"] + +# [patch."https://github.com/Garmelon/vault.git"] +# vault = { path = "../vault/" } diff --git a/src/euph/room.rs b/src/euph/room.rs index 8eeb8e5..c26685f 100644 --- a/src/euph/room.rs +++ b/src/euph/room.rs @@ -14,7 +14,7 @@ use log::{debug, error, info, warn}; use tokio::select; use tokio::sync::oneshot; -use crate::macros::ok_or_return; +use crate::macros::{logging_unwrap, ok_or_return}; use crate::vault::EuphRoomVault; const LOG_INTERVAL: Duration = Duration::from_secs(10); @@ -93,7 +93,7 @@ impl Room { self.state.conn_tx().ok_or(Error::NotConnected) } - pub fn handle_event(&mut self, event: Event) { + pub async fn handle_event(&mut self, event: Event) { match event { Event::Connecting(_) => { self.state = State::Connecting; @@ -121,11 +121,11 @@ impl Room { let cookies = &*self.instance.config().server.cookies; let cookies = cookies.lock().unwrap().clone(); - self.vault.vault().set_cookies(cookies); + logging_unwrap!(self.vault.vault().set_cookies(cookies).await); } Event::Packet(_, packet, Snapshot { conn_tx, state }) => { self.state = State::Connected(conn_tx, state); - self.on_packet(packet); + self.on_packet(packet).await; } Event::Disconnected(_) => { self.state = State::Disconnected; @@ -173,7 +173,7 @@ impl Room { } async fn request_logs(vault: &EuphRoomVault, conn_tx: &ConnTx) { - let before = match vault.last_span().await { + let before = match logging_unwrap!(vault.last_span().await) { Some((None, _)) => return, // Already at top of room history Some((Some(before), _)) => Some(before), None => None, @@ -203,7 +203,7 @@ impl Room { } } - fn on_packet(&mut self, packet: ParsedPacket) { + async fn on_packet(&mut self, packet: ParsedPacket) { let instance_name = &self.instance.config().name; let data = ok_or_return!(&packet.content); match data { @@ -238,26 +238,39 @@ impl Room { Data::SendEvent(SendEvent(msg)) => { let own_user_id = self.own_user_id(); if let Some(last_msg_id) = &mut self.last_msg_id { - self.vault - .add_msg(Box::new(msg.clone()), *last_msg_id, own_user_id); + logging_unwrap!( + self.vault + .add_msg(Box::new(msg.clone()), *last_msg_id, own_user_id) + .await + ); *last_msg_id = Some(msg.id); } } Data::SnapshotEvent(d) => { info!("{instance_name}: successfully joined"); - self.vault.join(Time::now()); + logging_unwrap!(self.vault.join(Time::now()).await); self.last_msg_id = Some(d.log.last().map(|m| m.id)); - self.vault.add_msgs(d.log.clone(), None, self.own_user_id()); + logging_unwrap!( + self.vault + .add_msgs(d.log.clone(), None, self.own_user_id()) + .await + ); } Data::LogReply(d) => { - self.vault - .add_msgs(d.log.clone(), d.before, self.own_user_id()); + logging_unwrap!( + self.vault + .add_msgs(d.log.clone(), d.before, self.own_user_id()) + .await + ); } Data::SendReply(SendReply(msg)) => { let own_user_id = self.own_user_id(); if let Some(last_msg_id) = &mut self.last_msg_id { - self.vault - .add_msg(Box::new(msg.clone()), *last_msg_id, own_user_id); + logging_unwrap!( + self.vault + .add_msg(Box::new(msg.clone()), *last_msg_id, own_user_id) + .await + ); *last_msg_id = Some(msg.id); } } diff --git a/src/export.rs b/src/export.rs index 15db9b7..545f48b 100644 --- a/src/export.rs +++ b/src/export.rs @@ -85,7 +85,7 @@ pub async fn export(vault: &EuphVault, mut args: Args) -> anyhow::Result<()> { } let rooms = if args.all { - let mut rooms = vault.rooms().await; + let mut rooms = vault.rooms().await?; rooms.sort_unstable(); rooms } else { diff --git a/src/export/json.rs b/src/export/json.rs index 8921229..258b7bd 100644 --- a/src/export/json.rs +++ b/src/export/json.rs @@ -10,7 +10,7 @@ pub async fn export(vault: &EuphRoomVault, file: &mut W) -> anyhow::Re let mut total = 0; let mut offset = 0; loop { - let messages = vault.chunk_at_offset(CHUNK_SIZE, offset).await; + let messages = vault.chunk_at_offset(CHUNK_SIZE, offset).await?; offset += messages.len(); if messages.is_empty() { @@ -42,7 +42,7 @@ pub async fn export_stream(vault: &EuphRoomVault, file: &mut W) -> any let mut total = 0; let mut offset = 0; loop { - let messages = vault.chunk_at_offset(CHUNK_SIZE, offset).await; + let messages = vault.chunk_at_offset(CHUNK_SIZE, offset).await?; offset += messages.len(); if messages.is_empty() { diff --git a/src/export/text.rs b/src/export/text.rs index 0cdd138..bb3cfa1 100644 --- a/src/export/text.rs +++ b/src/export/text.rs @@ -16,11 +16,11 @@ const TIME_EMPTY: &str = " "; pub async fn export(vault: &EuphRoomVault, out: &mut W) -> anyhow::Result<()> { let mut exported_trees = 0; let mut exported_msgs = 0; - let mut root_id = vault.first_root_id().await; + let mut root_id = vault.first_root_id().await?; while let Some(some_root_id) = root_id { - let tree = vault.tree(some_root_id).await; + let tree = vault.tree(some_root_id).await?; write_tree(out, &tree, some_root_id, 0)?; - root_id = vault.next_root_id(some_root_id).await; + root_id = vault.next_root_id(some_root_id).await?; exported_trees += 1; exported_msgs += tree.len(); diff --git a/src/macros.rs b/src/macros.rs index 36372d7..3e03b07 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -29,3 +29,17 @@ macro_rules! ok_or_return { }; } pub(crate) use ok_or_return; + +// TODO Get rid of this macro as much as possible +macro_rules! logging_unwrap { + ($e:expr) => { + match $e { + Ok(value) => value, + Err(err) => { + log::error!("{err}"); + panic!("{err}"); + } + } + }; +} +pub(crate) use logging_unwrap; diff --git a/src/main.rs b/src/main.rs index 1c6071b..6023c76 100644 --- a/src/main.rs +++ b/src/main.rs @@ -149,11 +149,11 @@ async fn main() -> anyhow::Result<()> { Command::Gc => { eprintln!("Cleaning up and compacting vault"); eprintln!("This may take a while..."); - vault.gc().await; + vault.gc().await?; } Command::ClearCookies => { eprintln!("Clearing cookies"); - vault.euph().set_cookies(CookieJar::new()); + vault.euph().set_cookies(CookieJar::new()).await?; } } diff --git a/src/ui.rs b/src/ui.rs index 8ff3e13..5d9007e 100644 --- a/src/ui.rs +++ b/src/ui.rs @@ -10,7 +10,6 @@ use std::io; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; -use log::error; use parking_lot::FairMutex; use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; @@ -19,7 +18,7 @@ use toss::terminal::Terminal; use crate::config::Config; use crate::logger::{LogMsg, Logger}; -use crate::macros::{ok_or_return, some_or_return}; +use crate::macros::{logging_unwrap, ok_or_return, some_or_return}; use crate::vault::Vault; pub use self::chat::ChatMsg; @@ -231,7 +230,7 @@ impl Ui { .await } UiEvent::Euph(event) => { - if self.rooms.handle_euph_event(event) { + if self.rooms.handle_euph_event(event).await { EventHandleResult::Redraw } else { EventHandleResult::Continue @@ -288,17 +287,11 @@ impl Ui { .await } Mode::Log => { - let reaction = match self + let reaction = self .log_chat .handle_input_event(terminal, crossterm_lock, &event, false) - .await - { - Ok(reaction) => reaction, - Err(err) => { - error!("{err}"); - panic!("{err}"); - } - }; + .await; + let reaction = logging_unwrap!(reaction); reaction.handled() } }; diff --git a/src/ui/chat/tree.rs b/src/ui/chat/tree.rs index 9325d3e..ad8e6ab 100644 --- a/src/ui/chat/tree.rs +++ b/src/ui/chat/tree.rs @@ -10,12 +10,12 @@ use std::fmt; use std::sync::Arc; use async_trait::async_trait; -use log::error; use parking_lot::FairMutex; use tokio::sync::Mutex; use toss::frame::{Frame, Pos, Size}; use toss::terminal::Terminal; +use crate::macros::logging_unwrap; use crate::store::{Msg, MsgStore}; use crate::ui::input::{key, InputEvent, KeyBindingsList}; use crate::ui::util; @@ -439,13 +439,7 @@ where async fn render(self: Box, frame: &mut Frame) { let mut guard = self.inner.lock().await; - let blocks = match guard.relayout(self.nick, self.focused, frame).await { - Ok(blocks) => blocks, - Err(err) => { - error!("{err}"); - panic!("{err}"); - } - }; + let blocks = logging_unwrap!(guard.relayout(self.nick, self.focused, frame).await); let size = frame.size(); for block in blocks.into_blocks().blocks { diff --git a/src/ui/euph/room.rs b/src/ui/euph/room.rs index 4572cf6..59a113a 100644 --- a/src/ui/euph/room.rs +++ b/src/ui/euph/room.rs @@ -5,7 +5,6 @@ use crossterm::style::{ContentStyle, Stylize}; use euphoxide::api::{Data, Message, MessageId, PacketType, SessionId}; use euphoxide::bot::instance::{Event, ServerConfig}; use euphoxide::conn::{self, Joined, Joining, SessionInfo}; -use log::error; use parking_lot::FairMutex; use tokio::sync::oneshot::error::TryRecvError; use tokio::sync::{mpsc, oneshot}; @@ -14,6 +13,7 @@ use toss::terminal::Terminal; use crate::config; use crate::euph; +use crate::macros::logging_unwrap; use crate::ui::chat::{ChatState, Reaction}; use crate::ui::input::{key, InputEvent, KeyBindingsList}; use crate::ui::widgets::border::Border; @@ -143,7 +143,7 @@ impl EuphRoom { } pub async fn unseen_msgs_count(&self) -> usize { - self.vault().unseen_msgs_count().await + logging_unwrap!(self.vault().unseen_msgs_count().await) } async fn stabilize_pseudo_msg(&mut self) { @@ -327,17 +327,11 @@ impl EuphRoom { Some(euph::State::Connected(_, conn::State::Joined(_))) ); - let reaction = match self + let reaction = self .chat .handle_input_event(terminal, crossterm_lock, event, can_compose) - .await - { - Ok(reaction) => reaction, - Err(err) => { - error!("{err}"); - panic!("{err}"); - } - }; + .await; + let reaction = logging_unwrap!(reaction); match reaction { Reaction::NotHandled => {} @@ -434,7 +428,7 @@ impl EuphRoom { match event { key!('i') => { if let Some(id) = self.chat.cursor().await { - if let Some(msg) = self.vault().full_msg(id).await { + if let Some(msg) = logging_unwrap!(self.vault().full_msg(id).await) { self.state = State::InspectMessage(msg); } } @@ -442,7 +436,7 @@ impl EuphRoom { } key!('I') => { if let Some(id) = self.chat.cursor().await { - if let Some(msg) = self.vault().msg(id).await { + if let Some(msg) = logging_unwrap!(self.vault().msg(id).await) { self.state = State::Links(LinksState::new(&msg.content)); } } @@ -679,7 +673,7 @@ impl EuphRoom { } } - pub fn handle_event(&mut self, event: Event) -> bool { + pub async fn handle_event(&mut self, event: Event) -> bool { let handled = if self.room.is_some() { if let Event::Packet(_, packet, _) = &event { match &packet.content { @@ -694,7 +688,7 @@ impl EuphRoom { }; if let Some(room) = &mut self.room { - room.handle_event(event); + room.handle_event(event).await; } handled diff --git a/src/ui/rooms.rs b/src/ui/rooms.rs index 07ef09a..56dac95 100644 --- a/src/ui/rooms.rs +++ b/src/ui/rooms.rs @@ -13,6 +13,7 @@ use toss::terminal::Terminal; use crate::config::{Config, RoomsSortOrder}; use crate::euph; +use crate::macros::logging_unwrap; use crate::vault::Vault; use super::euph::room::EuphRoom; @@ -69,8 +70,8 @@ impl Rooms { vault: Vault, ui_event_tx: mpsc::UnboundedSender, ) -> Self { - let euph_server_config = - ServerConfig::default().cookies(Arc::new(Mutex::new(vault.euph().cookies().await))); + let cookies = logging_unwrap!(vault.euph().cookies().await); + let euph_server_config = ServerConfig::default().cookies(Arc::new(Mutex::new(cookies))); let mut result = Self { config, @@ -112,13 +113,8 @@ impl Rooms { /// - failed connection attempts, or /// - rooms that were deleted from the db. async fn stabilize_rooms(&mut self) { - let mut rooms_set = self - .vault - .euph() - .rooms() - .await - .into_iter() - .collect::>(); + let rooms = logging_unwrap!(self.vault.euph().rooms().await); + let mut rooms_set = rooms.into_iter().collect::>(); // Prevent room that is currently being shown from being removed. This // could otherwise happen when connecting to a room that doesn't exist. @@ -533,7 +529,7 @@ impl Rooms { } key!(Enter) if editor.text() == *name => { self.euph_rooms.remove(name); - self.vault.euph().delete(name.clone()); + logging_unwrap!(self.vault.euph().room(name.clone()).delete().await); self.state = State::ShowList; return true; } @@ -548,10 +544,10 @@ impl Rooms { false } - pub fn handle_euph_event(&mut self, event: Event) -> bool { + pub async fn handle_euph_event(&mut self, event: Event) -> bool { let instance_name = event.config().name.clone(); let room = self.get_or_insert_room(instance_name.clone()); - let handled = room.handle_event(event); + let handled = room.handle_event(event).await; let room_visible = match &self.state { State::ShowRoom(name) => *name == instance_name, diff --git a/src/vault.rs b/src/vault.rs index 9352e7d..4f49e45 100644 --- a/src/vault.rs +++ b/src/vault.rs @@ -2,43 +2,42 @@ mod euph; mod migrate; mod prepare; +use std::fs; use std::path::Path; -use std::{fs, thread}; -use log::error; use rusqlite::Connection; -use tokio::sync::{mpsc, oneshot}; +use vault::tokio::TokioVault; +use vault::Action; -use self::euph::EuphRequest; pub use self::euph::{EuphRoomVault, EuphVault}; -enum Request { - Close(oneshot::Sender<()>), - Gc(oneshot::Sender<()>), - Euph(EuphRequest), -} - #[derive(Debug, Clone)] pub struct Vault { - tx: mpsc::UnboundedSender, + tokio_vault: TokioVault, ephemeral: bool, } +struct GcAction; + +impl Action for GcAction { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + conn.execute_batch("ANALYZE; VACUUM;") + } +} + impl Vault { pub fn ephemeral(&self) -> bool { self.ephemeral } pub async fn close(&self) { - let (tx, rx) = oneshot::channel(); - let _ = self.tx.send(Request::Close(tx)); - let _ = rx.await; + self.tokio_vault.stop().await; } - pub async fn gc(&self) { - let (tx, rx) = oneshot::channel(); - let _ = self.tx.send(Request::Gc(tx)); - let _ = rx.await; + pub async fn gc(&self) -> vault::tokio::Result<()> { + self.tokio_vault.execute(GcAction).await } pub fn euph(&self) -> EuphVault { @@ -46,47 +45,17 @@ impl Vault { } } -fn run(mut conn: Connection, mut rx: mpsc::UnboundedReceiver) { - while let Some(request) = rx.blocking_recv() { - match request { - Request::Close(tx) => { - eprintln!("Closing vault"); - if let Err(e) = conn.execute_batch("PRAGMA optimize") { - error!("{e}"); - } - // Ensure `Vault::close` exits only after the sqlite connection - // has been closed properly. - drop(conn); - drop(tx); - break; - } - Request::Gc(tx) => { - if let Err(e) = conn.execute_batch("ANALYZE; VACUUM;") { - error!("{e}"); - } - drop(tx); - } - Request::Euph(r) => { - if let Err(e) = r.perform(&mut conn) { - error!("{e}"); - } - } - } - } -} - -fn launch_from_connection(mut conn: Connection, ephemeral: bool) -> rusqlite::Result { +fn launch_from_connection(conn: Connection, ephemeral: bool) -> rusqlite::Result { conn.pragma_update(None, "foreign_keys", true)?; conn.pragma_update(None, "trusted_schema", false)?; eprintln!("Opening vault"); - migrate::migrate(&mut conn)?; - prepare::prepare(&mut conn)?; - - let (tx, rx) = mpsc::unbounded_channel(); - thread::spawn(move || run(conn, rx)); - Ok(Vault { tx, ephemeral }) + let tokio_vault = TokioVault::launch_and_prepare(conn, &migrate::MIGRATIONS, prepare::prepare)?; + Ok(Vault { + tokio_vault, + ephemeral, + }) } pub fn launch(path: &Path) -> rusqlite::Result { diff --git a/src/vault/euph.rs b/src/vault/euph.rs index 9fd5fdc..ff32754 100644 --- a/src/vault/euph.rs +++ b/src/vault/euph.rs @@ -1,4 +1,3 @@ -use std::convert::Infallible; use std::mem; use std::str::FromStr; @@ -8,11 +7,15 @@ use euphoxide::api::{Message, MessageId, SessionId, SessionView, Snowflake, Time use rusqlite::types::{FromSql, FromSqlError, ToSqlOutput, Value, ValueRef}; use rusqlite::{named_params, params, Connection, OptionalExtension, ToSql, Transaction}; use time::OffsetDateTime; -use tokio::sync::oneshot; +use vault::Action; use crate::euph::SmallMessage; use crate::store::{MsgStore, Path, Tree}; +/////////////////// +// Wrapper types // +/////////////////// + /// Wrapper for [`Snowflake`] that implements useful rusqlite traits. struct WSnowflake(Snowflake); @@ -47,6 +50,10 @@ impl FromSql for WTime { } } +/////////////// +// EuphVault // +/////////////// + #[derive(Debug, Clone)] pub struct EuphVault { vault: super::Vault, @@ -69,216 +76,36 @@ impl EuphVault { } } -#[derive(Debug, Clone)] -pub struct EuphRoomVault { - vault: EuphVault, - room: String, -} - -impl EuphRoomVault { - pub fn vault(&self) -> &EuphVault { - &self.vault - } - - pub fn room(&self) -> &str { - &self.room - } -} - -#[async_trait] -impl MsgStore for EuphRoomVault { - type Error = Infallible; - - async fn path(&self, id: &MessageId) -> Result, Self::Error> { - Ok(self.path(*id).await) - } - - async fn msg(&self, id: &MessageId) -> Result, Self::Error> { - Ok(self.msg(*id).await) - } - - async fn tree(&self, root_id: &MessageId) -> Result, Self::Error> { - Ok(self.tree(*root_id).await) - } - - async fn first_root_id(&self) -> Result, Self::Error> { - Ok(self.first_root_id().await) - } - - async fn last_root_id(&self) -> Result, Self::Error> { - Ok(self.last_root_id().await) - } - - async fn prev_root_id(&self, root_id: &MessageId) -> Result, Self::Error> { - Ok(self.prev_root_id(*root_id).await) - } - - async fn next_root_id(&self, root_id: &MessageId) -> Result, Self::Error> { - Ok(self.next_root_id(*root_id).await) - } - - async fn oldest_msg_id(&self) -> Result, Self::Error> { - Ok(self.oldest_msg_id().await) - } - - async fn newest_msg_id(&self) -> Result, Self::Error> { - Ok(self.newest_msg_id().await) - } - - async fn older_msg_id(&self, id: &MessageId) -> Result, Self::Error> { - Ok(self.older_msg_id(*id).await) - } - - async fn newer_msg_id(&self, id: &MessageId) -> Result, Self::Error> { - Ok(self.newer_msg_id(*id).await) - } - - async fn oldest_unseen_msg_id(&self) -> Result, Self::Error> { - Ok(self.oldest_unseen_msg_id().await) - } - - async fn newest_unseen_msg_id(&self) -> Result, Self::Error> { - Ok(self.newest_unseen_msg_id().await) - } - - async fn older_unseen_msg_id(&self, id: &MessageId) -> Result, Self::Error> { - Ok(self.older_unseen_msg_id(*id).await) - } - - async fn newer_unseen_msg_id(&self, id: &MessageId) -> Result, Self::Error> { - Ok(self.newer_unseen_msg_id(*id).await) - } - - async fn unseen_msgs_count(&self) -> Result { - Ok(self.unseen_msgs_count().await) - } - - async fn set_seen(&self, id: &MessageId, seen: bool) -> Result<(), Self::Error> { - self.set_seen(*id, seen); - Ok(()) - } - - async fn set_older_seen(&self, id: &MessageId, seen: bool) -> Result<(), Self::Error> { - self.set_older_seen(*id, seen); - Ok(()) - } -} - -trait Request { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()>; -} - -macro_rules! requests_vault_fn { - ( $var:ident : $fn:ident( $( $arg:ident : $ty:ty ),* ) ) => { - pub fn $fn(&self $( , $arg: $ty )* ) { - let request = EuphRequest::$var($var { $( $arg, )* }); - let _ = self.vault.tx.send(super::Request::Euph(request)); - } - }; - ( $var:ident : $fn:ident( $( $arg:ident : $ty:ty ),* ) -> $res:ty ) => { - pub async fn $fn(&self $( , $arg: $ty )* ) -> $res { - let (tx, rx) = oneshot::channel(); - let request = EuphRequest::$var($var { - $( $arg, )* - result: tx, - }); - let _ = self.vault.tx.send(super::Request::Euph(request)); - rx.await.unwrap() - } - }; -} - -// This doesn't match the type of the `room` argument because that's apparently -// impossible to match to `String`. See also the readme of -// https://github.com/danielhenrymantilla/rust-defile for a description of this -// phenomenon and some examples. -macro_rules! requests_room_vault_fn { - ( $fn:ident ( room: $mustbestring:ty $( , $arg:ident : $ty:ty )* ) ) => { - pub fn $fn(&self $( , $arg: $ty )* ) { - self.vault.$fn(self.room.clone() $( , $arg )* ); - } - }; - ( $fn:ident ( room: $mustbestring:ty $( , $arg:ident : $ty:ty )* ) -> $res:ty ) => { - pub async fn $fn(&self $( , $arg: $ty )* ) -> $res { - self.vault.$fn(self.room.clone() $( , $arg )* ).await - } - }; - ( $( $tt:tt )* ) => { }; -} - -macro_rules! requests { +macro_rules! euph_vault_actions { ( $( - $var:ident : $fn:ident ( $( $arg:ident : $ty:ty ),* ) $( -> $res:ty )? ; + $struct:ident : $fn:ident ( $( $arg:ident : $arg_ty:ty ),* ) -> $res:ty ; )* ) => { $( - pub(super) struct $var { - $( $arg: $ty, )* - $( result: oneshot::Sender<$res>, )? + struct $struct { + $( $arg: $arg_ty, )* } )* - pub(super) enum EuphRequest { - $( $var($var), )* - } - - impl EuphRequest { - pub(super) fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - match self { - $( Self::$var(request) => request.perform(conn), )* - } - } - } - - #[allow(dead_code)] impl EuphVault { - $( requests_vault_fn!($var : $fn( $( $arg: $ty ),* ) $( -> $res )? ); )* - } - - #[allow(dead_code)] - impl EuphRoomVault { - $( requests_room_vault_fn!($fn( $( $arg: $ty ),* ) $( -> $res )? ); )* + $( + pub async fn $fn(&self, $( $arg: $arg_ty, )* ) -> vault::tokio::Result<$res> { + self.vault.tokio_vault.execute($struct { $( $arg, )* }).await + } + )* } }; } -requests! { - // Cookies +euph_vault_actions! { GetCookies : cookies() -> CookieJar; - SetCookies : set_cookies(cookies: CookieJar); - - // Rooms + SetCookies : set_cookies(cookies: CookieJar) -> (); GetRooms : rooms() -> Vec; - Join : join(room: String, time: Time); - Delete : delete(room: String); - - // Message - AddMsg : add_msg(room: String, msg: Box, prev_msg_id: Option, own_user_id: Option); - AddMsgs : add_msgs(room: String, msgs: Vec, next_msg_id: Option, own_user_id: Option); - GetLastSpan : last_span(room: String) -> Option<(Option, Option)>; - GetPath : path(room: String, id: MessageId) -> Path; - GetMsg : msg(room: String, id: MessageId) -> Option; - GetFullMsg : full_msg(room: String, id: MessageId) -> Option; - GetTree : tree(room: String, root_id: MessageId) -> Tree; - GetFirstRootId : first_root_id(room: String) -> Option; - GetLastRootId : last_root_id(room: String) -> Option; - GetPrevRootId : prev_root_id(room: String, root_id: MessageId) -> Option; - GetNextRootId : next_root_id(room: String, root_id: MessageId) -> Option; - GetOldestMsgId : oldest_msg_id(room: String) -> Option; - GetNewestMsgId : newest_msg_id(room: String) -> Option; - GetOlderMsgId : older_msg_id(room: String, id: MessageId) -> Option; - GetNewerMsgId : newer_msg_id(room: String, id: MessageId) -> Option; - GetOldestUnseenMsgId : oldest_unseen_msg_id(room: String) -> Option; - GetNewestUnseenMsgId : newest_unseen_msg_id(room: String) -> Option; - GetOlderUnseenMsgId : older_unseen_msg_id(room: String, id: MessageId) -> Option; - GetNewerUnseenMsgId : newer_unseen_msg_id(room: String, id: MessageId) -> Option; - GetUnseenMsgsCount : unseen_msgs_count(room: String) -> usize; - SetSeen : set_seen(room: String, id: MessageId, seen: bool); - SetOlderSeen : set_older_seen(room: String, id: MessageId, seen: bool); - GetChunkAtOffset : chunk_at_offset(room: String, amount: usize, offset: usize) -> Vec; } -impl Request for GetCookies { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetCookies { + type Result = CookieJar; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let cookies = conn .prepare( " @@ -296,14 +123,14 @@ impl Request for GetCookies { for cookie in cookies { cookie_jar.add_original(cookie); } - - let _ = self.result.send(cookie_jar); - Ok(()) + Ok(cookie_jar) } } -impl Request for SetCookies { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for SetCookies { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let tx = conn.transaction()?; // Since euphoria sets all cookies on every response, we can just delete @@ -326,24 +153,100 @@ impl Request for SetCookies { } } -impl Request for GetRooms { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let rooms = conn - .prepare( - " +impl Action for GetRooms { + type Result = Vec; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + conn.prepare( + " SELECT room FROM euph_rooms ", - )? - .query_map([], |row| row.get(0))? - .collect::>()?; - let _ = self.result.send(rooms); - Ok(()) + )? + .query_map([], |row| row.get(0))? + .collect::>() } } -impl Request for Join { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +/////////////////// +// EuphRoomVault // +/////////////////// + +#[derive(Debug, Clone)] +pub struct EuphRoomVault { + vault: EuphVault, + room: String, +} + +impl EuphRoomVault { + pub fn vault(&self) -> &EuphVault { + &self.vault + } + + pub fn room(&self) -> &str { + &self.room + } +} + +macro_rules! euph_room_vault_actions { + ( $( + $struct:ident : $fn:ident ( $( $arg:ident : $arg_ty:ty ),* ) -> $res:ty ; + )* ) => { + $( + struct $struct { + room: String, + $( $arg: $arg_ty, )* + } + )* + + impl EuphRoomVault { + $( + pub async fn $fn(&self, $( $arg: $arg_ty, )* ) -> vault::tokio::Result<$res> { + self.vault.vault.tokio_vault.execute($struct { + room: self.room.clone(), + $( $arg, )* + }).await + } + )* + } + }; +} + +euph_room_vault_actions! { + // Room + Join : join(time: Time) -> (); + Delete : delete() -> (); + + // Message + AddMsg : add_msg(msg: Box, prev_msg_id: Option, own_user_id: Option) -> (); + AddMsgs : add_msgs(msgs: Vec, next_msg_id: Option, own_user_id: Option) -> (); + GetLastSpan : last_span() -> Option<(Option, Option)>; + GetPath : path(id: MessageId) -> Path; + GetMsg : msg(id: MessageId) -> Option; + GetFullMsg : full_msg(id: MessageId) -> Option; + GetTree : tree(root_id: MessageId) -> Tree; + GetFirstRootId : first_root_id() -> Option; + GetLastRootId : last_root_id() -> Option; + GetPrevRootId : prev_root_id(root_id: MessageId) -> Option; + GetNextRootId : next_root_id(root_id: MessageId) -> Option; + GetOldestMsgId : oldest_msg_id() -> Option; + GetNewestMsgId : newest_msg_id() -> Option; + GetOlderMsgId : older_msg_id(id: MessageId) -> Option; + GetNewerMsgId : newer_msg_id(id: MessageId) -> Option; + GetOldestUnseenMsgId : oldest_unseen_msg_id() -> Option; + GetNewestUnseenMsgId : newest_unseen_msg_id() -> Option; + GetOlderUnseenMsgId : older_unseen_msg_id(id: MessageId) -> Option; + GetNewerUnseenMsgId : newer_unseen_msg_id(id: MessageId) -> Option; + GetUnseenMsgsCount : unseen_msgs_count() -> usize; + SetSeen : set_seen(id: MessageId, seen: bool) -> (); + SetOlderSeen : set_older_seen(id: MessageId, seen: bool) -> (); + GetChunkAtOffset : chunk_at_offset(amount: usize, offset: usize) -> Vec; +} + +impl Action for Join { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { conn.execute( " INSERT INTO euph_rooms (room, first_joined, last_joined) @@ -357,8 +260,10 @@ impl Request for Join { } } -impl Request for Delete { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for Delete { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { conn.execute( " DELETE FROM euph_rooms @@ -525,8 +430,10 @@ fn add_span( Ok(()) } -impl Request for AddMsg { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for AddMsg { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let tx = conn.transaction()?; let end = self.msg.id; @@ -538,8 +445,10 @@ impl Request for AddMsg { } } -impl Request for AddMsgs { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for AddMsgs { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let tx = conn.transaction()?; if self.msgs.is_empty() { @@ -559,8 +468,10 @@ impl Request for AddMsgs { } } -impl Request for GetLastSpan { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetLastSpan { + type Result = Option<(Option, Option)>; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let span = conn .prepare( " @@ -578,13 +489,14 @@ impl Request for GetLastSpan { )) }) .optional()?; - let _ = self.result.send(span); - Ok(()) + Ok(span) } } -impl Request for GetPath { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetPath { + type Result = Path; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let path = conn .prepare( " @@ -606,14 +518,14 @@ impl Request for GetPath { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) })? .collect::>()?; - let path = Path::new(path); - let _ = self.result.send(path); - Ok(()) + Ok(Path::new(path)) } } -impl Request for GetMsg { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetMsg { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let msg = conn .query_row( " @@ -635,13 +547,14 @@ impl Request for GetMsg { }, ) .optional()?; - let _ = self.result.send(msg); - Ok(()) + Ok(msg) } } -impl Request for GetFullMsg { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetFullMsg { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let mut query = conn.prepare( " SELECT @@ -679,13 +592,14 @@ impl Request for GetFullMsg { }) }) .optional()?; - let _ = self.result.send(msg); - Ok(()) + Ok(msg) } } -impl Request for GetTree { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetTree { + type Result = Tree; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let msgs = conn .prepare( " @@ -716,15 +630,15 @@ impl Request for GetTree { }) })? .collect::>()?; - let tree = Tree::new(self.root_id, msgs); - let _ = self.result.send(tree); - Ok(()) + Ok(Tree::new(self.root_id, msgs)) } } -impl Request for GetFirstRootId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetFirstRootId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let root_id = conn .prepare( " SELECT id @@ -738,14 +652,15 @@ impl Request for GetFirstRootId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(root_id) } } -impl Request for GetLastRootId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetLastRootId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let root_id = conn .prepare( " SELECT id @@ -759,14 +674,15 @@ impl Request for GetLastRootId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(root_id) } } -impl Request for GetPrevRootId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetPrevRootId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let root_id = conn .prepare( " SELECT id @@ -781,14 +697,15 @@ impl Request for GetPrevRootId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(root_id) } } -impl Request for GetNextRootId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetNextRootId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let root_id = conn .prepare( " SELECT id @@ -803,14 +720,15 @@ impl Request for GetNextRootId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(root_id) } } -impl Request for GetOldestMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetOldestMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -824,14 +742,15 @@ impl Request for GetOldestMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetNewestMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetNewestMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -845,14 +764,15 @@ impl Request for GetNewestMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetOlderMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetOlderMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -867,13 +787,14 @@ impl Request for GetOlderMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetNewerMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetNewerMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -888,14 +809,15 @@ impl Request for GetNewerMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetOldestUnseenMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetOldestUnseenMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -910,14 +832,15 @@ impl Request for GetOldestUnseenMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetNewestUnseenMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetNewestUnseenMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -932,14 +855,15 @@ impl Request for GetNewestUnseenMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetOlderUnseenMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetOlderUnseenMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -955,14 +879,15 @@ impl Request for GetOlderUnseenMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetNewerUnseenMsgId { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { - let tree = conn +impl Action for GetNewerUnseenMsgId { + type Result = Option; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { + let msg_id = conn .prepare( " SELECT id @@ -978,13 +903,14 @@ impl Request for GetNewerUnseenMsgId { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; - let _ = self.result.send(tree); - Ok(()) + Ok(msg_id) } } -impl Request for GetUnseenMsgsCount { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetUnseenMsgsCount { + type Result = usize; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let amount = conn .prepare( " @@ -996,13 +922,14 @@ impl Request for GetUnseenMsgsCount { .query_row(params![self.room], |row| row.get(0)) .optional()? .unwrap_or(0); - let _ = self.result.send(amount); - Ok(()) + Ok(amount) } } -impl Request for SetSeen { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for SetSeen { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { conn.execute( " UPDATE euph_msgs @@ -1016,8 +943,10 @@ impl Request for SetSeen { } } -impl Request for SetOlderSeen { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for SetOlderSeen { + type Result = (); + + fn run(self, conn: &mut Connection) -> rusqlite::Result { conn.execute( " UPDATE euph_msgs @@ -1032,8 +961,10 @@ impl Request for SetOlderSeen { } } -impl Request for GetChunkAtOffset { - fn perform(self, conn: &mut Connection) -> rusqlite::Result<()> { +impl Action for GetChunkAtOffset { + type Result = Vec; + + fn run(self, conn: &mut Connection) -> rusqlite::Result { let mut query = conn.prepare( " SELECT @@ -1073,7 +1004,83 @@ impl Request for GetChunkAtOffset { }) })? .collect::>()?; - let _ = self.result.send(messages); - Ok(()) + Ok(messages) + } +} + +#[async_trait] +impl MsgStore for EuphRoomVault { + type Error = vault::tokio::Error; + + async fn path(&self, id: &MessageId) -> Result, Self::Error> { + self.path(*id).await + } + + async fn msg(&self, id: &MessageId) -> Result, Self::Error> { + self.msg(*id).await + } + + async fn tree(&self, root_id: &MessageId) -> Result, Self::Error> { + self.tree(*root_id).await + } + + async fn first_root_id(&self) -> Result, Self::Error> { + self.first_root_id().await + } + + async fn last_root_id(&self) -> Result, Self::Error> { + self.last_root_id().await + } + + async fn prev_root_id(&self, root_id: &MessageId) -> Result, Self::Error> { + self.prev_root_id(*root_id).await + } + + async fn next_root_id(&self, root_id: &MessageId) -> Result, Self::Error> { + self.next_root_id(*root_id).await + } + + async fn oldest_msg_id(&self) -> Result, Self::Error> { + self.oldest_msg_id().await + } + + async fn newest_msg_id(&self) -> Result, Self::Error> { + self.newest_msg_id().await + } + + async fn older_msg_id(&self, id: &MessageId) -> Result, Self::Error> { + self.older_msg_id(*id).await + } + + async fn newer_msg_id(&self, id: &MessageId) -> Result, Self::Error> { + self.newer_msg_id(*id).await + } + + async fn oldest_unseen_msg_id(&self) -> Result, Self::Error> { + self.oldest_unseen_msg_id().await + } + + async fn newest_unseen_msg_id(&self) -> Result, Self::Error> { + self.newest_unseen_msg_id().await + } + + async fn older_unseen_msg_id(&self, id: &MessageId) -> Result, Self::Error> { + self.older_unseen_msg_id(*id).await + } + + async fn newer_unseen_msg_id(&self, id: &MessageId) -> Result, Self::Error> { + self.newer_unseen_msg_id(*id).await + } + + async fn unseen_msgs_count(&self) -> Result { + self.unseen_msgs_count().await + } + + async fn set_seen(&self, id: &MessageId, seen: bool) -> Result<(), Self::Error> { + self.set_seen(*id, seen).await + } + + async fn set_older_seen(&self, id: &MessageId, seen: bool) -> Result<(), Self::Error> { + self.set_older_seen(*id, seen).await } } diff --git a/src/vault/migrate.rs b/src/vault/migrate.rs index 8cd42c6..e5d16da 100644 --- a/src/vault/migrate.rs +++ b/src/vault/migrate.rs @@ -1,25 +1,10 @@ -use rusqlite::{Connection, Transaction}; +use rusqlite::Transaction; +use vault::Migration; -pub fn migrate(conn: &mut Connection) -> rusqlite::Result<()> { - let mut tx = conn.transaction()?; +pub const MIGRATIONS: [Migration; 2] = [m1, m2]; - let user_version: usize = - tx.query_row("SELECT * FROM pragma_user_version", [], |r| r.get(0))?; - - let total = MIGRATIONS.len(); - assert!(user_version <= total, "malformed database schema"); - for (i, migration) in MIGRATIONS.iter().enumerate().skip(user_version) { - eprintln!("Migrating vault from {} to {} (out of {})", i, i + 1, total); - migration(&mut tx)?; - } - - tx.pragma_update(None, "user_version", total)?; - tx.commit() -} - -const MIGRATIONS: [fn(&mut Transaction<'_>) -> rusqlite::Result<()>; 2] = [m1, m2]; - -fn m1(tx: &mut Transaction<'_>) -> rusqlite::Result<()> { +fn m1(tx: &mut Transaction<'_>, nr: usize, total: usize) -> rusqlite::Result<()> { + eprintln!("Migrating vault from {} to {} (out of {total})", nr, nr + 1); tx.execute_batch( " CREATE TABLE euph_rooms ( @@ -81,7 +66,8 @@ fn m1(tx: &mut Transaction<'_>) -> rusqlite::Result<()> { ) } -fn m2(tx: &mut Transaction<'_>) -> rusqlite::Result<()> { +fn m2(tx: &mut Transaction<'_>, nr: usize, total: usize) -> rusqlite::Result<()> { + eprintln!("Migrating vault from {} to {} (out of {total})", nr, nr + 1); tx.execute_batch( " ALTER TABLE euph_msgs