From 4dcc021f7306f55a1c0a1829247596899bbdca6f Mon Sep 17 00:00:00 2001 From: Joscha Date: Tue, 24 Jan 2023 14:12:50 +0100 Subject: [PATCH] Add bot::command::clap --- Cargo.toml | 10 ++- src/bot/command.rs | 19 +++-- src/bot/command/clap.rs | 166 ++++++++++++++++++++++++++++++++++++++++ src/bot/commands.rs | 1 + 4 files changed, 189 insertions(+), 7 deletions(-) create mode 100644 src/bot/command/clap.rs diff --git a/Cargo.toml b/Cargo.toml index 8aaca88..ec9de74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,10 +4,10 @@ version = "0.2.0" edition = "2021" [features] -bot = ["dep:cookie"] +bot = ["dep:async-trait", "dep:clap", "dep:cookie"] [dependencies] -async-trait = "0.1.63" +async-trait = { version = "0.1.63", optional = true } cookie = { version = "0.16.2", optional = true } futures-util = { version = "0.3.25", default-features = false, features = ["sink"] } log = "0.4.17" @@ -18,6 +18,12 @@ tokio = { version = "1.23.0", features = ["time", "sync", "macros", "rt"] } tokio-stream = "0.1.11" tokio-tungstenite = { version = "0.18.0", features = ["rustls-tls-native-roots"] } +[dependencies.clap] +version = "4.1.3" +optional = true +default-features = false +features = ["std"] + [dev-dependencies] # For example bot tokio = { version = "1.23.0", features = ["rt-multi-thread"] } diff --git a/src/bot/command.rs b/src/bot/command.rs index 5b9f749..749cdac 100644 --- a/src/bot/command.rs +++ b/src/bot/command.rs @@ -1,8 +1,14 @@ +mod clap; + +use std::future::Future; + use async_trait::async_trait; use crate::api::{self, Message, MessageId}; use crate::conn::{self, ConnTx, Joined}; +pub use self::clap::{Clap, ClapCommand}; + use super::instance::InstanceConfig; #[derive(Clone, Copy, PartialEq, Eq)] @@ -38,6 +44,7 @@ impl Kind { } pub struct Context { + pub name: String, pub kind: Kind, pub config: InstanceConfig, pub conn_tx: ConnTx, @@ -45,24 +52,26 @@ pub struct Context { } impl Context { - pub async fn send(&self, content: S) -> Result { + pub fn send(&self, content: S) -> impl Future> { let cmd = api::Send { content: content.to_string(), parent: None, }; - Ok(self.conn_tx.send(cmd).await?.0) + let reply = self.conn_tx.send(cmd); + async move { reply.await.map(|r| r.0) } } - pub async fn reply( + pub fn reply( &self, parent: MessageId, content: S, - ) -> Result { + ) -> impl Future> { let cmd = api::Send { content: content.to_string(), parent: Some(parent), }; - Ok(self.conn_tx.send(cmd).await?.0) + let reply = self.conn_tx.send(cmd); + async move { reply.await.map(|r| r.0) } } } diff --git a/src/bot/command/clap.rs b/src/bot/command/clap.rs new file mode 100644 index 0000000..975f2c6 --- /dev/null +++ b/src/bot/command/clap.rs @@ -0,0 +1,166 @@ +use async_trait::async_trait; +use clap::{CommandFactory, Parser}; + +use crate::api::Message; + +use super::{Command, Context}; + +#[async_trait] +pub trait ClapCommand { + type Args; + + async fn execute(&self, args: Self::Args, msg: &Message, ctx: &Context, bot: &mut B); +} + +/// Parse bash-like quoted arguments separated by whitespace. +/// +/// Outside of quotes, the backslash either escapes the next character or forms +/// an escape sequence. \n is a newline, \r a carriage return and \t a tab. +/// TODO Escape sequences +/// +/// Special characters like the backslash and whitespace can also be quoted +/// using double quotes. Within double quotes, \" escapes a double quote and \\ +/// escapes a backslash. Other occurrences of \ have no special meaning. +fn parse_quoted_args(text: &str) -> Result, &'static str> { + let mut args = vec![]; + let mut arg = String::new(); + let mut arg_exists = false; + + let mut quoted = false; + let mut escaped = false; + for c in text.chars() { + if quoted { + match c { + '\\' if escaped => { + arg.push('\\'); + escaped = false; + } + '"' if escaped => { + arg.push('"'); + escaped = false; + } + c if escaped => { + arg.push('\\'); + arg.push(c); + escaped = false; + } + '\\' => escaped = true, + '"' => quoted = false, + c => arg.push(c), + } + } else { + match c { + c if escaped => { + arg.push(c); + arg_exists = true; + escaped = false; + } + c if c.is_whitespace() => { + if arg_exists { + args.push(arg); + arg = String::new(); + arg_exists = false; + } + } + '\\' => escaped = true, + '"' => { + quoted = true; + arg_exists = true; + } + c => { + arg.push(c); + arg_exists = true; + } + } + } + } + + if quoted { + return Err("Unclosed trailing quote"); + } + if escaped { + return Err("Unfinished trailing escape"); + } + + if arg_exists { + args.push(arg); + } + + Ok(args) +} + +pub struct Clap(pub C); + +#[async_trait] +impl Command for Clap +where + B: Send, + C: ClapCommand + Send + Sync, + C::Args: Parser + Send, +{ + fn description(&self) -> Option { + C::Args::command().get_about().map(|s| format!("{s}")) + } + + async fn execute(&self, arg: &str, msg: &Message, ctx: &Context, bot: &mut B) { + let mut args = match parse_quoted_args(arg) { + Ok(args) => args, + Err(err) => { + let _ = ctx.reply(msg.id, err); + return; + } + }; + + args.insert(0, ctx.kind.usage(&ctx.name, &ctx.joined.session.name)); + + let args = match C::Args::try_parse_from(args) { + Ok(args) => args, + Err(err) => { + let _ = ctx.reply(msg.id, format!("{}", err.render())); + return; + } + }; + + self.0.execute(args, msg, ctx, bot).await + } +} + +#[cfg(test)] +mod test { + use super::parse_quoted_args; + + fn assert_quoted(raw: &str, parsed: &[&str]) { + let parsed = parsed.iter().map(|s| s.to_string()).collect(); + assert_eq!(parse_quoted_args(raw), Ok(parsed)) + } + + #[test] + fn test_parse_quoted_args() { + assert_quoted("foo bar baz", &["foo", "bar", "baz"]); + assert_quoted(" foo bar baz ", &["foo", "bar", "baz"]); + assert_quoted("foo\\ ba\"r ba\"z", &["foo bar baz"]); + assert_quoted( + "It's a nice day, isn't it?", + &["It's", "a", "nice", "day,", "isn't", "it?"], + ); + + // Trailing whitespace + assert_quoted("a ", &["a"]); + assert_quoted("a\\ ", &["a "]); + assert_quoted("a\\ ", &["a "]); + + // Zero-length arguments + assert_quoted("a \"\" b \"\"", &["a", "", "b", ""]); + assert_quoted("a \"\" b \"\" ", &["a", "", "b", ""]); + + // Backslashes in quotes + assert_quoted("\"a \\b \\\" \\\\\"", &["a \\b \" \\"]); + + // Unclosed quotes and unfinished escapes + assert!(parse_quoted_args("foo 'bar \"baz").is_err()); + assert!(parse_quoted_args("foo \"bar baz").is_err()); + assert!(parse_quoted_args("foo \"bar 'baz").is_err()); + assert!(parse_quoted_args("foo \\").is_err()); + assert!(parse_quoted_args("foo 'bar\\").is_err()); + } +} diff --git a/src/bot/commands.rs b/src/bot/commands.rs index cef0a3a..c33ff07 100644 --- a/src/bot/commands.rs +++ b/src/bot/commands.rs @@ -166,6 +166,7 @@ impl Commands { }; let mut ctx = Context { + name: cmd_name.to_string(), kind: Kind::Global, config: config.clone(), conn_tx: snapshot.conn_tx.clone(),