From bb2dadb8629a3064c1d2bc3fc637acf694b52720 Mon Sep 17 00:00:00 2001 From: Joscha Date: Tue, 31 Jul 2018 15:18:57 +0000 Subject: [PATCH] Add database module --- yaboli/__init__.py | 2 ++ yaboli/database.py | 31 +++++++++++++++++++++++++++++++ yaboli/utils.py | 7 ++++++- 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 yaboli/database.py diff --git a/yaboli/__init__.py b/yaboli/__init__.py index 30fe578..89eccb2 100644 --- a/yaboli/__init__.py +++ b/yaboli/__init__.py @@ -1,6 +1,7 @@ from .bot import * from .cookiejar import * from .connection import * +from .database import * from .exceptions import * from .room import * from .utils import * @@ -9,6 +10,7 @@ __all__ = ( bot.__all__ + connection.__all__ + cookiejar.__all__ + + database.__all__ + exceptions.__all__ + room.__all__ + utils.__all__ diff --git a/yaboli/database.py b/yaboli/database.py new file mode 100644 index 0000000..36e72d4 --- /dev/null +++ b/yaboli/database.py @@ -0,0 +1,31 @@ +import asyncio +import sqlite3 + +from .utils import * + + +__all__ = ["Database", "operation"] + + +def operation(func): + async def wrapper(self, *args, **kwargs): + async with self as db: + return await asyncify(func, db, *args, **kwargs) + return wrapper + +class Database: + def __init__(self, database): + self._connection = sqlite3.connect(database, check_same_thread=False) + self._lock = asyncio.Lock() + + self.initialize(self._connection) + + def initialize(self, db): + pass + + async def __aenter__(self, *args, **kwargs): + await self._lock.__aenter__(*args, **kwargs) + return self._connection + + async def __aexit__(self, *args, **kwargs): + return await self._lock.__aexit__(*args, **kwargs) diff --git a/yaboli/utils.py b/yaboli/utils.py index e810998..d1f9ccb 100644 --- a/yaboli/utils.py +++ b/yaboli/utils.py @@ -1,10 +1,11 @@ import asyncio import logging import time +import functools logger = logging.getLogger(__name__) __all__ = [ - "parallel", + "parallel", "asyncify", "mention", "mention_reduced", "similar", "format_time", "format_time_delta", "Session", "Listing", "Message", @@ -14,6 +15,10 @@ __all__ = [ # alias for parallel message sending parallel = asyncio.ensure_future +async def asyncify(func, *args, **kwargs): + func_with_args = functools.partial(func, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func_with_args) + def mention(nick): return "".join(c for c in nick if c not in ".!?;&<'\"" and not c.isspace())