From da1d23646a8774dd812b8627b6bd07d636bb550d Mon Sep 17 00:00:00 2001 From: Joscha Date: Sun, 31 Dec 2023 20:16:16 +0100 Subject: [PATCH] Migrate euph vault to respect domain --- cove/src/vault.rs | 2 +- cove/src/vault/euph.rs | 358 +++++++++++++++++++++++++---------------- 2 files changed, 221 insertions(+), 139 deletions(-) diff --git a/cove/src/vault.rs b/cove/src/vault.rs index 7a7e4ba..6861901 100644 --- a/cove/src/vault.rs +++ b/cove/src/vault.rs @@ -9,7 +9,7 @@ use rusqlite::Connection; use vault::tokio::TokioVault; use vault::Action; -pub use self::euph::{EuphRoomVault, EuphVault}; +pub use self::euph::{EuphRoomVault, EuphVault, RoomIdentifier}; #[derive(Debug, Clone)] pub struct Vault { diff --git a/cove/src/vault/euph.rs b/cove/src/vault/euph.rs index e9f363e..6f23261 100644 --- a/cove/src/vault/euph.rs +++ b/cove/src/vault/euph.rs @@ -1,5 +1,5 @@ -use std::mem; use std::str::FromStr; +use std::{fmt, mem}; use async_trait::async_trait; use cookie::{Cookie, CookieJar}; @@ -12,10 +12,6 @@ use vault::Action; use crate::euph::SmallMessage; use crate::store::{MsgStore, Path, Tree}; -/////////////////// -// Wrapper types // -/////////////////// - /// Wrapper for [`Snowflake`] that implements useful rusqlite traits. struct WSnowflake(Snowflake); @@ -50,6 +46,24 @@ impl FromSql for WTime { } } +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct RoomIdentifier { + pub domain: String, + pub name: String, +} + +impl fmt::Display for RoomIdentifier { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "&{}@{}", self.name, self.domain) + } +} + +impl RoomIdentifier { + pub fn new(domain: String, name: String) -> Self { + Self { domain, name } + } +} + /////////////// // EuphVault // /////////////// @@ -68,10 +82,10 @@ impl EuphVault { &self.vault } - pub fn room(&self, name: String) -> EuphRoomVault { + pub fn room(&self, room: RoomIdentifier) -> EuphRoomVault { EuphRoomVault { vault: self.clone(), - room: name, + room, } } } @@ -97,9 +111,9 @@ macro_rules! euph_vault_actions { } euph_vault_actions! { - GetCookies : cookies() -> CookieJar; - SetCookies : set_cookies(cookies: CookieJar) -> (); - GetRooms : rooms() -> Vec; + GetCookies : cookies(domain: String) -> CookieJar; + SetCookies : set_cookies(domain: String, cookies: CookieJar) -> (); + GetRooms : rooms() -> Vec; } impl Action for GetCookies { @@ -112,9 +126,10 @@ impl Action for GetCookies { " SELECT cookie FROM euph_cookies + WHERE domain = ? ", )? - .query_map([], |row| { + .query_map([self.domain], |row| { let cookie_str: String = row.get(0)?; Ok(Cookie::from_str(&cookie_str).expect("cookie in db is valid")) })? @@ -137,16 +152,21 @@ impl Action for SetCookies { // Since euphoria sets all cookies on every response, we can just delete // all previous cookies. - tx.execute_batch("DELETE FROM euph_cookies")?; + tx.execute( + " + DELETE FROM euph_cookies + WHERE domain = ?", + [&self.domain], + )?; let mut insert_cookie = tx.prepare( " - INSERT INTO euph_cookies (cookie) - VALUES (?) + INSERT INTO euph_cookies (domain, cookie) + VALUES (?, ?) ", )?; for cookie in self.cookies.iter() { - insert_cookie.execute([format!("{cookie}")])?; + insert_cookie.execute([self.domain, format!("{cookie}")])?; } drop(insert_cookie); @@ -156,17 +176,22 @@ impl Action for SetCookies { } impl Action for GetRooms { - type Output = Vec; + type Output = Vec; type Error = rusqlite::Error; fn run(self, conn: &mut Connection) -> Result { conn.prepare( " - SELECT room + SELECT room, domain FROM euph_rooms ", )? - .query_map([], |row| row.get(0))? + .query_map([], |row| { + Ok(RoomIdentifier { + domain: row.get(0)?, + name: row.get(1)?, + }) + })? .collect::>() } } @@ -178,7 +203,7 @@ impl Action for GetRooms { #[derive(Debug, Clone)] pub struct EuphRoomVault { vault: EuphVault, - room: String, + room: RoomIdentifier, } impl EuphRoomVault { @@ -186,7 +211,7 @@ impl EuphRoomVault { &self.vault } - pub fn room(&self) -> &str { + pub fn room(&self) -> &RoomIdentifier { &self.room } } @@ -197,7 +222,7 @@ macro_rules! euph_room_vault_actions { )* ) => { $( struct $struct { - room: String, + room: RoomIdentifier, $( $arg: $arg_ty, )* } )* @@ -253,12 +278,16 @@ impl Action for Join { fn run(self, conn: &mut Connection) -> Result { conn.execute( " - INSERT INTO euph_rooms (room, first_joined, last_joined) - VALUES (:room, :time, :time) - ON CONFLICT (room) DO UPDATE + INSERT INTO euph_rooms (domain, room, first_joined, last_joined) + VALUES (:domain, :room, :time, :time) + ON CONFLICT (domain, room) DO UPDATE SET last_joined = :time ", - named_params! {":room": self.room, ":time": WTime(self.time)}, + named_params! { + ":domain": self.room.domain, + ":room": self.room.name, + ":time": WTime(self.time), + }, )?; Ok(()) } @@ -272,9 +301,10 @@ impl Action for Delete { conn.execute( " DELETE FROM euph_rooms - WHERE room = ? + WHERE domain = ? + AND room = ? ", - [&self.room], + [&self.room.domain, &self.room.name], )?; Ok(()) } @@ -282,29 +312,33 @@ impl Action for Delete { fn insert_msgs( tx: &Transaction<'_>, - room: &str, + room: &RoomIdentifier, own_user_id: &Option, msgs: Vec, ) -> rusqlite::Result<()> { let mut insert_msg = tx.prepare( " INSERT INTO euph_msgs ( - room, id, parent, previous_edit_id, time, content, encryption_key_id, edited, deleted, truncated, + domain, room, + id, parent, previous_edit_id, time, content, encryption_key_id, edited, deleted, truncated, user_id, name, server_id, server_era, session_id, is_staff, is_manager, client_address, real_client_address, seen ) VALUES ( - :room, :id, :parent, :previous_edit_id, :time, :content, :encryption_key_id, :edited, :deleted, :truncated, + :domain, :room, + :id, :parent, :previous_edit_id, :time, :content, :encryption_key_id, :edited, :deleted, :truncated, :user_id, :name, :server_id, :server_era, :session_id, :is_staff, :is_manager, :client_address, :real_client_address, (:user_id == :own_user_id OR EXISTS( SELECT 1 FROM euph_rooms - WHERE room = :room + WHERE domain = :domain + AND room = :room AND :time < first_joined )) ) - ON CONFLICT (room, id) DO UPDATE + ON CONFLICT (domain, room, id) DO UPDATE SET + domain = :domain, room = :room, id = :id, parent = :parent, @@ -331,7 +365,8 @@ fn insert_msgs( let own_user_id = own_user_id.as_ref().map(|u| &u.0); for msg in msgs { insert_msg.execute(named_params! { - ":room": room, + ":domain": room.domain, + ":room": room.name, ":id": WSnowflake(msg.id.0), ":parent": msg.parent.map(|id| WSnowflake(id.0)), ":previous_edit_id": msg.previous_edit_id.map(WSnowflake), @@ -359,7 +394,7 @@ fn insert_msgs( fn add_span( tx: &Transaction<'_>, - room: &str, + room: &RoomIdentifier, start: Option, end: Option, ) -> rusqlite::Result<()> { @@ -369,10 +404,11 @@ fn add_span( " SELECT start, end FROM euph_spans - WHERE room = ? + WHERE domain = ? + AND room = ? ", )? - .query_map([room], |row| { + .query_map([&room.domain, &room.name], |row| { let start = row.get::<_, Option>(0)?.map(|s| MessageId(s.0)); let end = row.get::<_, Option>(1)?.map(|s| MessageId(s.0)); Ok((start, end)) @@ -412,21 +448,23 @@ fn add_span( tx.execute( " DELETE FROM euph_spans - WHERE room = ? + WHERE domain = ? + AND room = ? ", - [room], + [&room.domain, &room.name], )?; // Re-insert combined spans for the room let mut stmt = tx.prepare( " - INSERT INTO euph_spans (room, start, end) - VALUES (?, ?, ?) + INSERT INTO euph_spans (domain, room, start, end) + VALUES (?, ?, ?, ?) ", )?; for (start, end) in result { stmt.execute(params![ - room, + room.domain, + room.name, start.map(|id| WSnowflake(id.0)), end.map(|id| WSnowflake(id.0)) ])?; @@ -485,12 +523,13 @@ impl Action for GetLastSpan { " SELECT start, end FROM euph_spans - WHERE room = ? + WHERE domain = ? + AND room = ? ORDER BY start DESC LIMIT 1 ", )? - .query_row([self.room], |row| { + .query_row([&self.room.domain, &self.room.name], |row| { Ok(( row.get::<_, Option>(0)?.map(|s| MessageId(s.0)), row.get::<_, Option>(1)?.map(|s| MessageId(s.0)), @@ -510,12 +549,12 @@ impl Action for GetPath { .prepare( " WITH RECURSIVE - path (room, id) AS ( - VALUES (?, ?) + path (domain, room, id) AS ( + VALUES (?, ?, ?) UNION - SELECT room, parent + SELECT domain, room, parent FROM euph_msgs - JOIN path USING (room, id) + JOIN path USING (domain, room, id) ) SELECT id FROM path @@ -523,9 +562,10 @@ impl Action for GetPath { ORDER BY id ASC ", )? - .query_map(params![self.room, WSnowflake(self.id.0)], |row| { - row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) - })? + .query_map( + params![self.room.domain, self.room.name, WSnowflake(self.id.0)], + |row| row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)), + )? .collect::>()?; Ok(Path::new(path)) } @@ -541,10 +581,11 @@ impl Action for GetMsg { " SELECT id, parent, time, name, content, seen FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND id = ? ", - params![self.room, WSnowflake(self.id.0)], + params![self.room.domain, self.room.name, WSnowflake(self.id.0)], |row| { Ok(SmallMessage { id: MessageId(row.get::<_, WSnowflake>(0)?.0), @@ -572,36 +613,40 @@ impl Action for GetFullMsg { id, parent, previous_edit_id, time, content, encryption_key_id, edited, deleted, truncated, user_id, name, server_id, server_era, session_id, is_staff, is_manager, client_address, real_client_address FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND id = ? " )?; let msg = query - .query_row(params![self.room, WSnowflake(self.id.0)], |row| { - Ok(Message { - id: MessageId(row.get::<_, WSnowflake>(0)?.0), - parent: row.get::<_, Option>(1)?.map(|s| MessageId(s.0)), - previous_edit_id: row.get::<_, Option>(2)?.map(|s| s.0), - time: row.get::<_, WTime>(3)?.0, - content: row.get(4)?, - encryption_key_id: row.get(5)?, - edited: row.get::<_, Option>(6)?.map(|t| t.0), - deleted: row.get::<_, Option>(7)?.map(|t| t.0), - truncated: row.get(8)?, - sender: SessionView { - id: UserId(row.get(9)?), - name: row.get(10)?, - server_id: row.get(11)?, - server_era: row.get(12)?, - session_id: SessionId(row.get(13)?), - is_staff: row.get(14)?, - is_manager: row.get(15)?, - client_address: row.get(16)?, - real_client_address: row.get(17)?, - }, - }) - }) + .query_row( + params![self.room.domain, self.room.name, WSnowflake(self.id.0)], + |row| { + Ok(Message { + id: MessageId(row.get::<_, WSnowflake>(0)?.0), + parent: row.get::<_, Option>(1)?.map(|s| MessageId(s.0)), + previous_edit_id: row.get::<_, Option>(2)?.map(|s| s.0), + time: row.get::<_, WTime>(3)?.0, + content: row.get(4)?, + encryption_key_id: row.get(5)?, + edited: row.get::<_, Option>(6)?.map(|t| t.0), + deleted: row.get::<_, Option>(7)?.map(|t| t.0), + truncated: row.get(8)?, + sender: SessionView { + id: UserId(row.get(9)?), + name: row.get(10)?, + server_id: row.get(11)?, + server_era: row.get(12)?, + session_id: SessionId(row.get(13)?), + is_staff: row.get(14)?, + is_manager: row.get(15)?, + client_address: row.get(16)?, + real_client_address: row.get(17)?, + }, + }) + }, + ) .optional()?; Ok(msg) } @@ -616,31 +661,35 @@ impl Action for GetTree { .prepare( " WITH RECURSIVE - tree (room, id) AS ( - VALUES (?, ?) + tree (domain, room, id) AS ( + VALUES (?, ?, ?) UNION - SELECT euph_msgs.room, euph_msgs.id + SELECT euph_msgs.domain, euph_msgs.room, euph_msgs.id FROM euph_msgs JOIN tree - ON tree.room = euph_msgs.room + ON tree.domain = euph_msgs.domain + AND tree.room = euph_msgs.room AND tree.id = euph_msgs.parent ) SELECT id, parent, time, name, content, seen FROM euph_msgs - JOIN tree USING (room, id) + JOIN tree USING (domain, room, id) ORDER BY id ASC ", )? - .query_map(params![self.room, WSnowflake(self.root_id.0)], |row| { - Ok(SmallMessage { - id: MessageId(row.get::<_, WSnowflake>(0)?.0), - parent: row.get::<_, Option>(1)?.map(|s| MessageId(s.0)), - time: row.get::<_, WTime>(2)?.0, - nick: row.get(3)?, - content: row.get(4)?, - seen: row.get(5)?, - }) - })? + .query_map( + params![self.room.domain, self.room.name, WSnowflake(self.root_id.0)], + |row| { + Ok(SmallMessage { + id: MessageId(row.get::<_, WSnowflake>(0)?.0), + parent: row.get::<_, Option>(1)?.map(|s| MessageId(s.0)), + time: row.get::<_, WTime>(2)?.0, + nick: row.get(3)?, + content: row.get(4)?, + seen: row.get(5)?, + }) + }, + )? .collect::>()?; Ok(Tree::new(self.root_id, msgs)) } @@ -656,12 +705,13 @@ impl Action for GetFirstRootId { " SELECT id FROM euph_trees - WHERE room = ? + WHERE domain = ? + AND room = ? ORDER BY id ASC LIMIT 1 ", )? - .query_row([self.room], |row| { + .query_row([&self.room.domain, &self.room.name], |row| { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; @@ -679,12 +729,13 @@ impl Action for GetLastRootId { " SELECT id FROM euph_trees - WHERE room = ? + WHERE domain = ? + AND room = ? ORDER BY id DESC LIMIT 1 ", )? - .query_row([self.room], |row| { + .query_row([&self.room.domain, &self.room.name], |row| { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; @@ -702,15 +753,17 @@ impl Action for GetPrevRootId { " SELECT id FROM euph_trees - WHERE room = ? + WHERE domain = ? + AND room = ? AND id < ? ORDER BY id DESC LIMIT 1 ", )? - .query_row(params![self.room, WSnowflake(self.root_id.0)], |row| { - row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) - }) + .query_row( + params![self.room.domain, self.room.name, WSnowflake(self.root_id.0)], + |row| row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)), + ) .optional()?; Ok(root_id) } @@ -726,15 +779,17 @@ impl Action for GetNextRootId { " SELECT id FROM euph_trees - WHERE room = ? + WHERE domain = ? + AND room = ? AND id > ? ORDER BY id ASC LIMIT 1 ", )? - .query_row(params![self.room, WSnowflake(self.root_id.0)], |row| { - row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) - }) + .query_row( + params![self.room.domain, self.room.name, WSnowflake(self.root_id.0)], + |row| row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)), + ) .optional()?; Ok(root_id) } @@ -750,12 +805,13 @@ impl Action for GetOldestMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? ORDER BY id ASC LIMIT 1 ", )? - .query_row([self.room], |row| { + .query_row([&self.room.domain, &self.room.name], |row| { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; @@ -773,12 +829,13 @@ impl Action for GetNewestMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? ORDER BY id DESC LIMIT 1 ", )? - .query_row([self.room], |row| { + .query_row([&self.room.domain, &self.room.name], |row| { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; @@ -796,15 +853,17 @@ impl Action for GetOlderMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND id < ? ORDER BY id DESC LIMIT 1 ", )? - .query_row(params![self.room, WSnowflake(self.id.0)], |row| { - row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) - }) + .query_row( + params![self.room.domain, self.room.name, WSnowflake(self.id.0)], + |row| row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)), + ) .optional()?; Ok(msg_id) } @@ -819,15 +878,17 @@ impl Action for GetNewerMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND id > ? ORDER BY id ASC LIMIT 1 ", )? - .query_row(params![self.room, WSnowflake(self.id.0)], |row| { - row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) - }) + .query_row( + params![self.room.domain, self.room.name, WSnowflake(self.id.0)], + |row| row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)), + ) .optional()?; Ok(msg_id) } @@ -843,13 +904,14 @@ impl Action for GetOldestUnseenMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND NOT seen ORDER BY id ASC LIMIT 1 ", )? - .query_row([self.room], |row| { + .query_row([&self.room.domain, &self.room.name], |row| { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; @@ -867,13 +929,14 @@ impl Action for GetNewestUnseenMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND NOT seen ORDER BY id DESC LIMIT 1 ", )? - .query_row([self.room], |row| { + .query_row([&self.room.domain, &self.room.name], |row| { row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) }) .optional()?; @@ -891,16 +954,18 @@ impl Action for GetOlderUnseenMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND NOT seen AND id < ? ORDER BY id DESC LIMIT 1 ", )? - .query_row(params![self.room, WSnowflake(self.id.0)], |row| { - row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) - }) + .query_row( + params![self.room.domain, self.room.name, WSnowflake(self.id.0)], + |row| row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)), + ) .optional()?; Ok(msg_id) } @@ -916,16 +981,18 @@ impl Action for GetNewerUnseenMsgId { " SELECT id FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND NOT seen AND id > ? ORDER BY id ASC LIMIT 1 ", )? - .query_row(params![self.room, WSnowflake(self.id.0)], |row| { - row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)) - }) + .query_row( + params![self.room.domain, self.room.name, WSnowflake(self.id.0)], + |row| row.get::<_, WSnowflake>(0).map(|s| MessageId(s.0)), + ) .optional()?; Ok(msg_id) } @@ -941,10 +1008,11 @@ impl Action for GetUnseenMsgsCount { " SELECT amount FROM euph_unseen_counts - WHERE room = ? + WHERE domain = ? + AND room = ? ", )? - .query_row(params![self.room], |row| row.get(0)) + .query_row(params![self.room.domain, self.room.name], |row| row.get(0)) .optional()? .unwrap_or(0); Ok(amount) @@ -960,10 +1028,16 @@ impl Action for SetSeen { " UPDATE euph_msgs SET seen = :seen - WHERE room = :room + WHERE domain = :domain + AND room = :room AND id = :id ", - named_params! { ":room": self.room, ":id": WSnowflake(self.id.0), ":seen": self.seen }, + named_params! { + ":domain": self.room.domain, + ":room": self.room.name, + ":id": WSnowflake(self.id.0), + ":seen": self.seen, + }, )?; Ok(()) } @@ -978,11 +1052,17 @@ impl Action for SetOlderSeen { " UPDATE euph_msgs SET seen = :seen - WHERE room = :room + WHERE domain = :domain + AND room = :room AND id <= :id AND seen != :seen ", - named_params! { ":room": self.room, ":id": WSnowflake(self.id.0), ":seen": self.seen }, + named_params! { + ":domain": self.room.domain, + ":room": self.room.name, + ":id": WSnowflake(self.id.0), + ":seen": self.seen, + }, )?; Ok(()) } @@ -1024,12 +1104,13 @@ impl Action for GetChunkAfter { id, parent, previous_edit_id, time, content, encryption_key_id, edited, deleted, truncated, user_id, name, server_id, server_era, session_id, is_staff, is_manager, client_address, real_client_address FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? AND id > ? ORDER BY id ASC LIMIT ? ")? - .query_map(params![self.room, WSnowflake(id.0), self.amount], row2msg)? + .query_map(params![self.room.domain, self.room.name, WSnowflake(id.0), self.amount], row2msg)? .collect::>()? } else { conn.prepare(" @@ -1037,11 +1118,12 @@ impl Action for GetChunkAfter { id, parent, previous_edit_id, time, content, encryption_key_id, edited, deleted, truncated, user_id, name, server_id, server_era, session_id, is_staff, is_manager, client_address, real_client_address FROM euph_msgs - WHERE room = ? + WHERE domain = ? + AND room = ? ORDER BY id ASC LIMIT ? ")? - .query_map(params![self.room, self.amount], row2msg)? + .query_map(params![self.room.domain, self.room.name, self.amount], row2msg)? .collect::>()? };