Start rewrite

This commit is contained in:
Joscha 2018-07-25 16:02:38 +00:00
parent 04f4c3a8b6
commit 6b65bef5e0
5 changed files with 354 additions and 776 deletions

View file

@ -4,120 +4,38 @@ import logging
import socket
import websockets
from .exceptions import ConnectionClosed
logger = logging.getLogger(__name__)
__all__ = ["Connection"]
class Connection:
def __init__(self, url, packet_hook, cookie=None, ping_timeout=10, ping_delay=30):
def __init__(self, url, packet_callback, disconnect_callback, cookiejar=None, ping_timeout=10, ping_delay=30, reconnect_attempts=10):
self.url = url
self.packet_hook = packet_hook
self.cookie = cookie
self.packet_callback = packet_callback
self.disconnect_callback = disconnect_callback
self.cookiejar = cookiejar
self.ping_timeout = ping_timeout # how long to wait for websocket ping reply
self.ping_delay = ping_delay # how long to wait between pings
self.reconnect_attempts = reconnect_attempts
self._ws = None
self._pid = 0 # successive packet ids
self._spawned_tasks = set()
#self._spawned_tasks = set()
self._pending_responses = {}
self._runtask = None
self._pingtask = None # pings
async def connect(self, max_tries=10, delay=60):
"""
success = await connect(max_tries=10, delay=60)
Attempt to connect to a room.
Returns the task listening for packets, or None if the attempt failed.
max_tries - maximum number of reconnect attempts before stopping
delay - time (in seconds) between reconnect attempts
"""
logger.debug(f"Attempting to connect, max_tries={max_tries}")
await self.stop()
logger.debug(f"Stopped previously running things.")
for tries_left in reversed(range(max_tries)):
logger.info(f"Attempting to connect, {tries_left} tries left.")
try:
self._ws = await websockets.connect(self.url, max_size=None)
except (websockets.InvalidURI, websockets.InvalidHandshake, socket.gaierror):
self._ws = None
if tries_left > 0:
await asyncio.sleep(delay)
else:
self._runtask = asyncio.ensure_future(self._run())
self._pingtask = asyncio.ensure_future(self._ping())
logger.debug(f"Started run and ping tasks")
return self._runtask
async def _run(self):
"""
Listen for packets and deal with them accordingly.
"""
try:
while True:
await self._handle_next_message()
except websockets.ConnectionClosed:
pass
finally:
self._clean_up_futures()
self._clean_up_tasks()
try:
await self._ws.close() # just to make sure
except:
pass # errors are not useful here
self._pingtask.cancel()
await self._pingtask # should stop now that the ws is closed
self._ws = None
async def _ping(self):
"""
Periodically ping the server to detect a timeout.
"""
while True:
try:
logger.debug("Pinging...")
wait_for_reply = await self._ws.ping()
await asyncio.wait_for(wait_for_reply, self.ping_timeout)
logger.debug("Pinged!")
await asyncio.sleep(self.ping_delay)
except asyncio.TimeoutError:
logger.warning("Ping timed out.")
await self._ws.close()
break
except (websockets.ConnectionClosed, ConnectionResetError, asyncio.CancelledError):
return
async def stop(self):
"""
Close websocket connection and wait for running task to stop.
"""
if self._ws:
try:
await self._ws.close()
except:
pass # errors not useful here
if self._runtask:
await self._runtask
self._stopped = False
self._pingtask = None
self._runtask = asyncio.create_task(self._run())
# ... aaand the connection is started.
async def send(self, ptype, data=None, await_response=True):
if not self._ws:
raise asyncio.CancelledError
raise exceptions.ConnectionClosed
#raise asyncio.CancelledError
pid = str(self._new_pid())
packet = {
"type": ptype,
@ -125,62 +43,157 @@ class Connection:
}
if data:
packet["data"] = data
if await_response:
wait_for = self._wait_for_response(pid)
logging.debug(f"Currently used websocket at self._ws: {self._ws}")
await self._ws.send(json.dumps(packet, separators=(',', ':'))) # minimum size
if await_response:
await wait_for
return wait_for.result()
async def stop(self):
"""
Close websocket connection and wait for running task to stop.
No connection function are to be called after calling stop().
This means that stop() can only be called once.
"""
self._stopped = True
if self._ws:
await self._ws.close() # _run() does the cleaning up now.
await self._runtask
async def _connect(self, tries):
"""
Attempt to connect to a room.
If the Connection is already connected, it attempts to reconnect.
Returns True on success, False on failure.
If tries is None, connect retries infinitely.
The delay between connection attempts doubles every attempt (starts with 1s).
"""
# Assumes _disconnect() has already been called in _run()
delay = 1 # seconds
while True:
try:
if self._cookiejar:
cookies = [("Cookie", cookie) for cookie in self._cookiejar.sniff()]
self._ws = await websockets.connect(self.url, max_size=None, extra_headers=cookies)
else:
self._ws = await websockets.connect(self.url, max_size=None)
except (websockets.InvalidHandshake, socket.gaierror): # not websockets.InvalidURI
self._ws = None
if tries is not None:
tries -= 1
if tries <= 0:
return False
await asyncio.sleep(delay)
delay *= 2
else:
if self._cookiejar:
for set_cookie in self._ws.response_headers.get_all("Set-Cookie"):
self._cookiejar.bake(set_cookie)
self._pingtask = asyncio.create_task(self._ping())
return True
async def _disconnect(self):
"""
Disconnect and clean up all "residue", such as:
- close existing websocket connection
- cancel all pending response futures with a ConnectionClosed exception
- reset package ID counter
- make sure the ping task has finished
"""
# stop ping task
if self._pingtask:
self._pingtask.cancel()
await self._pingtask
self._pingtask = None
if self._ws:
await self._ws.close()
self._ws = None
self._pid = 0
# clean up pending response futures
for _, future in self._pending_responses.items():
logger.debug(f"Cancelling future with ConnectionClosed: {future}")
future.set_exception(exceptions.ConnectionClosed("No server response"))
self._pending_responses = {}
async def _run(self):
"""
Listen for packets and deal with them accordingly.
"""
while not self._stopped:
self._connect(self.reconnect_attempts)
try:
while True:
await self._handle_next_message()
except websockets.ConnectionClosed:
pass
finally:
await self._disconnect() # disconnect and clean up
async def _ping(self):
"""
Periodically ping the server to detect a timeout.
"""
try:
while True:
logger.debug("Pinging...")
wait_for_reply = await self._ws.ping()
await asyncio.wait_for(wait_for_reply, self.ping_timeout)
logger.debug("Pinged!")
await asyncio.sleep(self.ping_delay)
except asyncio.TimeoutError:
logger.warning("Ping timed out.")
await self._ws.close() # trigger a reconnect attempt
except (websockets.ConnectionClosed, ConnectionResetError, asyncio.CancelledError):
pass
def _new_pid(self):
self._pid += 1
return self._pid
async def _handle_next_message(self):
response = await self._ws.recv()
task = asyncio.ensure_future(self._handle_json(response))
self._track_task(task) # will be cancelled when the connection is closed
def _clean_up_futures(self):
for pid, future in self._pending_responses.items():
logger.debug(f"Cancelling future: {future}")
future.cancel()
self._pending_responses = {}
def _clean_up_tasks(self):
for task in self._spawned_tasks:
if not task.done():
logger.debug(f"Cancelling task: {task}")
task.cancel()
else:
logger.debug(f"Task already done: {task}")
logger.debug(f"Exception: {task.exception()}")
self._spawned_tasks = set()
async def _handle_json(self, text):
packet = json.loads(text)
# Deal with pending responses
pid = packet.get("id", None)
future = self._pending_responses.pop(pid, None)
if future:
future.set_result(packet)
ptype = packet.get("type")
data = packet.get("data", None)
error = packet.get("error", None)
if packet.get("throttled", False):
throttled = packet.get("throttled_reason")
else:
throttled = None
# Pass packet onto room
await self.packet_hook(packet)
def _track_task(self, task):
self._spawned_tasks.add(task)
# only keep running tasks
self._spawned_tasks = {task for task in self._spawned_tasks if not task.done()}
asyncio.create_task(self.packet_callback(ptype, data, error, throttled))
def _wait_for_response(self, pid):
future = asyncio.Future()
self._pending_responses[pid] = future
return future