Rework spawning and task structure
Still not working: See TestBot.py
This commit is contained in:
parent
34e1ae4b8f
commit
1c3b9d0a20
5 changed files with 252 additions and 93 deletions
63
TestBot.py
63
TestBot.py
|
|
@ -6,41 +6,36 @@ from yaboli.utils import *
|
|||
|
||||
#class TestBot(Bot):
|
||||
class TestBot(yaboli.Controller):
|
||||
def __init__(self, roomname):
|
||||
super().__init__(roomname)
|
||||
|
||||
async def on_snapshot(self, user_id, session_id, version, listing, log, nick=None,
|
||||
pm_with_nick=None, pm_with_user_id=None):
|
||||
await self.room.nick("TestBot")
|
||||
def __init__(self, nick):
|
||||
super().__init__(nick=nick)
|
||||
|
||||
async def on_send(self, message):
|
||||
await self.room.send("Hey, a message!", message.message_id)
|
||||
|
||||
async def on_join(self, session):
|
||||
if session.nick != "":
|
||||
await self.room.send(f"Hey, a @{mention(session.nick)}!")
|
||||
else:
|
||||
await self.room.send("Hey, a lurker!")
|
||||
|
||||
async def on_nick(self, session_id, user_id, from_nick, to_nick):
|
||||
if from_nick != "" and to_nick != "":
|
||||
if from_nick == to_nick:
|
||||
await self.room.send(f"You didn't even change your nick, @{mention(to_nick)} :(")
|
||||
else:
|
||||
await self.room.send(f"Bye @{mention(from_nick)}, hi @{mention(to_nick)}")
|
||||
elif from_nick != "":
|
||||
await self.room.send(f"Bye @{mention(from_nick)}? This message should never appear...")
|
||||
elif to_nick != "":
|
||||
await self.room.send(f"Hey, a @{mention(to_nick)}!")
|
||||
else:
|
||||
await self.room.send("I have no idea how you did that. This message should never appear...")
|
||||
|
||||
async def on_part(self, session):
|
||||
if session.nick != "":
|
||||
await self.room.send(f"Bye, you @{mention(session.nick)}!")
|
||||
else:
|
||||
await self.room.send("Bye, you lurker!")
|
||||
if message.content == "!spawnevil":
|
||||
bot = TestBot("TestSpawn")
|
||||
task, reason = await bot.connect("test")
|
||||
second = await self.room.send("We have " + ("a" if task else "no") + " task. Reason: " + reason, message.message_id)
|
||||
if task:
|
||||
await bot.stop()
|
||||
await self.room.send("Stopped." if task.done() else "Still running (!)", second.message_id)
|
||||
|
||||
await self.room.send("All's over now.", message.message_id)
|
||||
|
||||
elif message.content == "!tree":
|
||||
messages = [message]
|
||||
newmessages = []
|
||||
for i in range(2):
|
||||
for m in messages:
|
||||
for j in range(2):
|
||||
newm = await self.room.send(f"{m.content}.{j}", m.message_id)
|
||||
newmessages.append(newm)
|
||||
messages = newmessages
|
||||
newmessages = []
|
||||
|
||||
async def run_bot():
|
||||
bot = TestBot("TestSummoner")
|
||||
task, reason = await bot.connect("test")
|
||||
if task:
|
||||
await task
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = TestBot("test")
|
||||
asyncio.get_event_loop().run_until_complete(bot.run())
|
||||
asyncio.get_event_loop().run_until_complete(run_bot())
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
import logging
|
||||
#logging.basicConfig(level=logging.DEBUG)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
from .connection import *
|
||||
from .room import *
|
||||
from .controller import *
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import asyncio
|
||||
asyncio.get_event_loop().set_debug(True)
|
||||
|
||||
|
|
@ -17,37 +18,67 @@ class Connection:
|
|||
self.cookie = cookie
|
||||
self.packet_hook = packet_hook
|
||||
|
||||
self.stopped = False
|
||||
|
||||
self._ws = None
|
||||
self._pid = 0
|
||||
self._pid = 0 # successive packet ids
|
||||
self._spawned_tasks = set()
|
||||
self._pending_responses = {}
|
||||
#self._stopping = False
|
||||
self._runtask = None
|
||||
|
||||
async def run(self):
|
||||
self._ws = await websockets.connect(self.url, max_size=None)
|
||||
async def connect(self, max_tries=10, delay=60):
|
||||
"""
|
||||
success = await connect(max_tries=10, delay=60)
|
||||
|
||||
Attempt to connect to a room.
|
||||
Returns the task listening for packets, or None if the attempt failed.
|
||||
"""
|
||||
|
||||
await self.stop()
|
||||
|
||||
tries_left = max_tries
|
||||
while tries_left > 0:
|
||||
tries_left -= 1
|
||||
try:
|
||||
self._ws = await websockets.connect(self.url, max_size=None)
|
||||
except (websockets.InvalidURI, websockets.InvalidHandshake):
|
||||
self._ws = None
|
||||
if tries_left > 0:
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
self._runtask = asyncio.ensure_future(self._run())
|
||||
return self._runtask
|
||||
|
||||
async def _run(self):
|
||||
"""
|
||||
Listen for packets and deal with them accordingly.
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
response = await self._ws.recv()
|
||||
asyncio.ensure_future(self._handle_json(response))
|
||||
await self._handle_next_message()
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await self._ws.close() # just to make sure it's closed
|
||||
self._ws = None
|
||||
stopped = True
|
||||
self._clean_up_futures()
|
||||
self._clean_up_tasks()
|
||||
|
||||
for future in self._pending_responses:
|
||||
#future.set_error(ConnectionClosed)
|
||||
future.cancel()
|
||||
await self._ws.close() # just to make sure
|
||||
self._ws = None
|
||||
|
||||
async def stop(self):
|
||||
if not self.stopped and self._ws:
|
||||
"""
|
||||
Close websocket connection and wait for running task to stop.
|
||||
"""
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
|
||||
if self._runtask:
|
||||
await self._runtask
|
||||
|
||||
async def send(self, ptype, data=None, await_response=True):
|
||||
if self.stopped:
|
||||
raise ConnectionClosed
|
||||
if not self._ws:
|
||||
raise asyncio.CancelledError
|
||||
|
||||
pid = str(self._new_pid())
|
||||
packet = {
|
||||
|
|
@ -60,7 +91,8 @@ class Connection:
|
|||
if await_response:
|
||||
wait_for = self._wait_for_response(pid)
|
||||
|
||||
await self._ws.send(json.dumps(packet, separators=(',', ':')))
|
||||
logging.debug(f"Currently used websocket at self._ws: {self._ws}")
|
||||
await self._ws.send(json.dumps(packet, separators=(',', ':'))) # minimum size
|
||||
|
||||
if await_response:
|
||||
await wait_for
|
||||
|
|
@ -70,11 +102,32 @@ class Connection:
|
|||
self._pid += 1
|
||||
return self._pid
|
||||
|
||||
async def _handle_next_message(self):
|
||||
response = await self._ws.recv()
|
||||
task = asyncio.ensure_future(self._handle_json(response))
|
||||
self._track_task(task) # will be cancelled when the connection is closed
|
||||
|
||||
def _clean_up_futures(self):
|
||||
for pid, future in self._pending_responses.items():
|
||||
logger.debug(f"Cancelling future: {future}")
|
||||
future.cancel()
|
||||
self._pending_responses = {}
|
||||
|
||||
def _clean_up_tasks(self):
|
||||
for task in self._spawned_tasks:
|
||||
if not task.done():
|
||||
logger.debug(f"Cancelling task: {task}")
|
||||
task.cancel()
|
||||
else:
|
||||
logger.debug(f"Task already done: {task}")
|
||||
logger.debug(f"Exception: {task.exception()}")
|
||||
self._spawned_tasks = set()
|
||||
|
||||
async def _handle_json(self, text):
|
||||
packet = json.loads(text)
|
||||
|
||||
# Deal with pending responses
|
||||
pid = packet.get("id")
|
||||
pid = packet.get("id", None)
|
||||
future = self._pending_responses.pop(pid, None)
|
||||
if future:
|
||||
future.set_result(packet)
|
||||
|
|
@ -82,6 +135,20 @@ class Connection:
|
|||
# Pass packet onto room
|
||||
await self.packet_hook(packet)
|
||||
|
||||
def _track_task(self, task):
|
||||
self._spawned_tasks.add(task)
|
||||
|
||||
# only keep running tasks
|
||||
#tasks = set()
|
||||
#for task in self._spawned_tasks:
|
||||
#if not task.done():
|
||||
#logger.debug(f"Keeping task: {task}")
|
||||
#tasks.add(task)
|
||||
#else:
|
||||
#logger.debug(f"Deleting task: {task}")
|
||||
#self._spawned_tasks = tasks
|
||||
#self._spawned_tasks = {task for task in self._spawned_tasks if not task.done()} # TODO: Reenable
|
||||
|
||||
def _wait_for_response(self, pid):
|
||||
future = asyncio.Future()
|
||||
self._pending_responses[pid] = future
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from .room import Room
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
__all__ = ["Controller"]
|
||||
|
||||
|
||||
|
|
@ -24,50 +27,103 @@ class Controller:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, roomname, human=False, cookie=None):
|
||||
def __init__(self, nick, human=False, cookie=None):
|
||||
"""
|
||||
roomname - name of room to connect to
|
||||
human - whether the human flag should be set on connections
|
||||
cookie - cookie to use in HTTP request, if any
|
||||
"""
|
||||
|
||||
self.roomname = roomname
|
||||
self.nick = nick
|
||||
self.human = human
|
||||
self.cookie = cookie
|
||||
|
||||
self.roomname = "test"
|
||||
self.password = None
|
||||
|
||||
self.room = None
|
||||
self.running = True
|
||||
self._connect_result = None
|
||||
|
||||
async def run(self):
|
||||
await self.on_start()
|
||||
def _create_room(self, roomname):
|
||||
return Room(roomname, self, human=self.human, cookie=self.cookie)
|
||||
|
||||
def _set_connect_result(self, result):
|
||||
logger.debug(f"Attempting to set connect result to {result}")
|
||||
if self._connect_result and not self._connect_result.done():
|
||||
logger.debug(f"Setting connect result to {result}")
|
||||
self._connect_result.set_result(result)
|
||||
|
||||
async def connect(self, roomname, password=None, timeout=10):
|
||||
"""
|
||||
task, reason = await connect(roomname, password=None, timeout=10)
|
||||
|
||||
while self.running:
|
||||
self.room = Room(self.roomname, self, self.human, self.cookie)
|
||||
await self.room.run()
|
||||
Connect to a room and authenticate, if necessary.
|
||||
|
||||
roomname - name of the room to connect to
|
||||
password - password for the room, if needed
|
||||
timeout - wait this long for a reply from the server
|
||||
|
||||
Returns:
|
||||
task - the task running the bot, or None on failure
|
||||
reason - the reason for failure
|
||||
"no room" = could not establish connection, room doesn't exist
|
||||
"auth option" = can't authenticate with a password
|
||||
"no password" = password needed to connect to room
|
||||
"wrong password" = password given does not work
|
||||
"disconnected" = connection closed before client could access the room
|
||||
"success" = no failure
|
||||
"""
|
||||
|
||||
logger.info(f"Attempting to connect to &{roomname}")
|
||||
|
||||
# make sure nothing is running any more
|
||||
try:
|
||||
await self.stop()
|
||||
except asyncio.CancelledError:
|
||||
logger.error("Calling connect from the controller itself.")
|
||||
raise
|
||||
|
||||
self.password = password
|
||||
self.room = self._create_room(roomname)
|
||||
|
||||
# prepare for if connect() is successful
|
||||
self._connect_result = asyncio.Future()
|
||||
|
||||
# attempt to connect to the room
|
||||
task = await self.room.connect()
|
||||
if not task:
|
||||
logger.warn(f"Could not connect to &{roomname}.")
|
||||
self.room = None
|
||||
return None, "no room"
|
||||
|
||||
await self.on_end()
|
||||
# connection succeeded, now we need to know whether we can log in
|
||||
# wait for success/authentication/disconnect
|
||||
# TODO: add a timeout
|
||||
await self._connect_result
|
||||
result = self._connect_result.result()
|
||||
logger.debug(f"&{roomname}._connect_result: {result!r}")
|
||||
|
||||
# deal with result
|
||||
if result == "success":
|
||||
logger.info(f"Successfully connected to &{roomname}.")
|
||||
return task, result
|
||||
else: # not successful for some reason
|
||||
logger.warn(f"Could not join &{roomname}: {result!r}")
|
||||
await self.stop()
|
||||
return None, result
|
||||
|
||||
async def stop(self):
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self.room:
|
||||
logger.info(f"&{self.room.roomname}: Stopping")
|
||||
await self.room.stop()
|
||||
logger.debug(f"&{self.room.roomname}: Stopped. Deleting room")
|
||||
self.room = None
|
||||
|
||||
async def set_nick(self, nick):
|
||||
if nick != self.nick:
|
||||
_, _, _, to_nick = await self.room.nick(nick)
|
||||
|
||||
if self.room:
|
||||
await self.room.stop()
|
||||
|
||||
async def on_start(self):
|
||||
"""
|
||||
The first callback called when the controller is run.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def on_stop(self):
|
||||
"""
|
||||
The last callback called when the controller is run.
|
||||
"""
|
||||
|
||||
pass
|
||||
if to_nick != nick:
|
||||
logger.warn(f"&{self.room.roomname}: Could not set nick to {nick!r}, set to {to_nick!r} instead.")
|
||||
|
||||
async def on_connected(self):
|
||||
"""
|
||||
|
|
@ -77,7 +133,7 @@ class Controller:
|
|||
such as resetting the message history.
|
||||
"""
|
||||
|
||||
pass
|
||||
self._set_connect_result("success")
|
||||
|
||||
async def on_disconnected(self):
|
||||
"""
|
||||
|
|
@ -88,10 +144,18 @@ class Controller:
|
|||
Need to store information from old room?
|
||||
"""
|
||||
|
||||
pass
|
||||
logger.debug(f"on_disconnected: self.room is {self.room}")
|
||||
self._set_connect_result("disconnected")
|
||||
|
||||
async def on_bounce(self, reason=None, auth_options=None, agent_id=None, ip=None):
|
||||
pass
|
||||
async def on_bounce(self, reason=None, auth_options=[], agent_id=None, ip=None):
|
||||
if "passcode" not in auth_options:
|
||||
self._set_connect_result("auth option")
|
||||
elif self.password is None:
|
||||
self._set_connect_result("no password")
|
||||
else:
|
||||
success, reason = await self.room.auth("passcode", passcode=self.password)
|
||||
if not success:
|
||||
self._set_connect_result("wrong password")
|
||||
|
||||
async def on_disconnect(self, reason):
|
||||
pass
|
||||
|
|
@ -125,7 +189,8 @@ class Controller:
|
|||
"""
|
||||
Default implementation, refer to api.euphoria.io
|
||||
"""
|
||||
|
||||
|
||||
logger.debug(f"&{self.room.roomname}: Pong!")
|
||||
await self.room.ping_reply(ptime)
|
||||
|
||||
async def on_pm_initiate(self, from_id, from_nick, from_room, pm_id):
|
||||
|
|
@ -136,4 +201,5 @@ class Controller:
|
|||
|
||||
async def on_snapshot(self, user_id, session_id, version, listing, log, nick=None,
|
||||
pm_with_nick=None, pm_with_user_id=None):
|
||||
pass
|
||||
if nick != self.nick:
|
||||
await self.room.nick(self.nick)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from .connection import *
|
||||
from .utils import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
__all__ = ["Room"]
|
||||
|
||||
|
||||
|
|
@ -34,17 +36,40 @@ class Room:
|
|||
self._callbacks = {}
|
||||
self._add_callbacks()
|
||||
|
||||
self._stopping = False
|
||||
self._runtask = None
|
||||
|
||||
if human:
|
||||
url = self.HUMAN_FORMAT.format(self.roomname)
|
||||
else:
|
||||
url = self.ROOM_FORMAT.format(self.roomname)
|
||||
self._conn = Connection(url, self._handle_packet, self.cookie)
|
||||
|
||||
async def run(self):
|
||||
await self._conn.run()
|
||||
async def connect(self, max_tries=10, delay=60):
|
||||
task = await self._conn.connect(max_tries=1)
|
||||
if task:
|
||||
self._runtask = asyncio.ensure_future(self._run(task, max_tries=max_tries, delay=delay))
|
||||
return self._runtask
|
||||
|
||||
async def _run(self, task, max_tries=10, delay=60):
|
||||
while not self._stopping:
|
||||
await task
|
||||
await self.controller.on_disconnected()
|
||||
|
||||
task = await self._conn.connect(max_tries=max_tries, delay=delay)
|
||||
if not task:
|
||||
return
|
||||
|
||||
self.stopping = False
|
||||
|
||||
async def stop(self):
|
||||
self._stopping = True
|
||||
await self._conn.stop()
|
||||
|
||||
if self._runtask:
|
||||
await self._runtask
|
||||
|
||||
|
||||
|
||||
# CATEGORY: SESSION COMMANDS
|
||||
|
||||
|
|
@ -270,12 +295,11 @@ class Room:
|
|||
try:
|
||||
await callback(packet)
|
||||
except asyncio.CancelledError as e:
|
||||
# TODO: log error
|
||||
print("HEHEHEHEY, CANCELLEDERROR", e)
|
||||
pass
|
||||
logger.info(f"&{self.roomname}: Callback of type {ptype!r} cancelled.")
|
||||
|
||||
def _check_for_errors(self, packet):
|
||||
# TODO: log throttled
|
||||
if packet.get("throttled", False):
|
||||
logger.warn(f"&{self.roomname}: Throttled for reason: {packet.get('throttled_reason', 'no reason')!r}")
|
||||
|
||||
if "error" in packet:
|
||||
raise ResponseError(response.get("error"))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue