diff --git a/database.py b/database.py deleted file mode 100644 index 7f9eb37..0000000 --- a/database.py +++ /dev/null @@ -1,86 +0,0 @@ -import asyncio -from functools import wraps -import sqlite3 -import threading - -__all__ = ["Database"] - - -def shielded(afunc): - #@wraps(afunc) - async def wrapper(*args, **kwargs): - return await asyncio.shield(afunc(*args, **kwargs)) - return wrapper - -class PooledConnection: - def __init__(self, pool): - self._pool = pool - - self.connection = None - - async def open(self): - self.connection = await self._pool._request() - - async def close(self): - conn = self.connection - self.connection = None - await self._pool._return(conn) - - async def __aenter__(self): - await self.open() - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.close() - -class Pool: - def __init__(self, filename, size=10): - self.filename = filename - self.size = size - - self._available_connections = asyncio.Queue() - - for i in range(size): - conn = sqlite3.connect(self.filename, check_same_thread=False) - self._available_connections.put_nowait(conn) - - def connection(self): - return PooledConnection(self) - - async def _request(self): - return await self._available_connections.get() - - async def _return(self, conn): - await self._available_connections.put(conn) - -class Database: - def __init__(self, filename, pool_size=10, event_loop=None): - self._filename = filename - self._pool = Pool(filename, size=pool_size) - self._loop = event_loop or asyncio.get_event_loop() - - def operation(func): - @wraps(func) - @shielded - async def wrapper(self, *args, **kwargs): - async with self._pool.connection() as conn: - return await self._run_in_thread(func, conn.connection, *args, **kwargs) - return wrapper - - @staticmethod - def _target_function(loop, future, func, *args, **kwargs): - result = None - try: - result = func(*args, **kwargs) - finally: - loop.call_soon_threadsafe(future.set_result, result) - - async def _run_in_thread(self, func, *args, **kwargs): - finished = asyncio.Future() - target_args = (self._loop, finished, func, *args) - - thread = threading.Thread(target=self._target_function, args=target_args, kwargs=kwargs) - thread.start() - - await finished - return finished.result() diff --git a/plusone.py b/plusone.py index 12cb8dd..7cc313c 100644 --- a/plusone.py +++ b/plusone.py @@ -6,46 +6,38 @@ import yaboli from yaboli.utils import * from join_rooms import join_rooms # List of rooms kept in separate file, which is .gitignore'd -import database - # Turn all debugging on asyncio.get_event_loop().set_debug(True) logging.getLogger("asyncio").setLevel(logging.INFO) logging.getLogger("yaboli").setLevel(logging.DEBUG) -class PointDB(database.Database): - @database.Database.operation - def initialize(conn): - cur = conn.cursor() - cur.execute(( - "CREATE TABLE IF NOT EXISTS Points (" - "nick TEXT UNIQUE NOT NULL," - "points INTEGER" +class PointDB(yaboli.Database): + def initialize(self, db): # called automatically + db.execute(( + "CREATE TABLE IF NOT EXISTS Points ( " + "nick TEXT UNIQUE NOT NULL, " + "points INTEGER " ")" )) - conn.commit() + db.commit() - @database.Database.operation - def add_point(conn, nick): + @yaboli.operation + def add_point(db, nick): nick = mention_reduced(nick) - cur = conn.cursor() + cur = db.cursor() cur.execute("INSERT OR IGNORE INTO Points (nick, points) VALUES (?, 0)", (nick,)) cur.execute("UPDATE Points SET points=points+1 WHERE nick=?", (nick,)) - conn.commit() + db.commit() - @database.Database.operation - def points_of(conn, nick): + @yaboli.operation + def points_of(db, nick): nick = mention_reduced(nick) - cur = conn.cursor() - cur.execute("SELECT points FROM Points WHERE nick=?", (nick,)) + cur = db.execute("SELECT points FROM Points WHERE nick=?", (nick,)) res = cur.fetchone() - if res is not None: - return res[0] - else: - return 0 + return res[0] if res is not None else 0 PLUSONE_RE = r"(\+1|:\+1:|:bronze(!\?|\?!)?:)\s*(.*)" @@ -57,8 +49,7 @@ class PlusOne(yaboli.Bot): """ async def on_created(self, room): - room.pointsdb = PointDB(f"points-{room.roomname}.db") - await room.pointsdb.initialize() + room.pointdb = PointDB(f"points-{room.roomname}.db") async def on_send(self, room, message): ping_text = ":bronze!?:" @@ -88,7 +79,7 @@ class PlusOne(yaboli.Bot): async def command_points(self, room, message, argstr): args = self.parse_args(argstr) if not args: - points = await room.pointsdb.points_of(message.sender.nick) + points = await room.pointdb.points_of(message.sender.nick) await room.send( f"You have {points} point{'s' if points != 1 else ''}.", message.mid @@ -100,7 +91,7 @@ class PlusOne(yaboli.Bot): nick = arg[1:] else: nick = arg - points = await room.pointsdb.points_of(nick) + points = await room.pointdb.points_of(nick) response.append(f"{mention(nick)} has {points} point{'' if points == 1 else 's'}.") await room.send("\n".join(response), message.mid) @@ -122,7 +113,7 @@ class PlusOne(yaboli.Bot): elif similar(nick, message.sender.nick): await room.send("Don't +1 yourself, that's not how things work.", message.mid) else: - await room.pointsdb.add_point(nick) + await room.pointdb.add_point(nick) await room.send(f"Point for user {mention(nick)} registered.", message.mid) def main():