Start rewrite
This commit is contained in:
parent
04f4c3a8b6
commit
6b65bef5e0
5 changed files with 354 additions and 776 deletions
|
|
@ -4,120 +4,38 @@ import logging
|
|||
import socket
|
||||
import websockets
|
||||
|
||||
from .exceptions import ConnectionClosed
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
__all__ = ["Connection"]
|
||||
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, url, packet_hook, cookie=None, ping_timeout=10, ping_delay=30):
|
||||
def __init__(self, url, packet_callback, disconnect_callback, cookiejar=None, ping_timeout=10, ping_delay=30, reconnect_attempts=10):
|
||||
self.url = url
|
||||
self.packet_hook = packet_hook
|
||||
self.cookie = cookie
|
||||
self.packet_callback = packet_callback
|
||||
self.disconnect_callback = disconnect_callback
|
||||
self.cookiejar = cookiejar
|
||||
self.ping_timeout = ping_timeout # how long to wait for websocket ping reply
|
||||
self.ping_delay = ping_delay # how long to wait between pings
|
||||
|
||||
self.reconnect_attempts = reconnect_attempts
|
||||
|
||||
self._ws = None
|
||||
self._pid = 0 # successive packet ids
|
||||
self._spawned_tasks = set()
|
||||
#self._spawned_tasks = set()
|
||||
self._pending_responses = {}
|
||||
|
||||
self._runtask = None
|
||||
self._pingtask = None # pings
|
||||
|
||||
async def connect(self, max_tries=10, delay=60):
|
||||
"""
|
||||
success = await connect(max_tries=10, delay=60)
|
||||
|
||||
Attempt to connect to a room.
|
||||
Returns the task listening for packets, or None if the attempt failed.
|
||||
|
||||
max_tries - maximum number of reconnect attempts before stopping
|
||||
delay - time (in seconds) between reconnect attempts
|
||||
"""
|
||||
|
||||
logger.debug(f"Attempting to connect, max_tries={max_tries}")
|
||||
|
||||
await self.stop()
|
||||
|
||||
logger.debug(f"Stopped previously running things.")
|
||||
|
||||
for tries_left in reversed(range(max_tries)):
|
||||
logger.info(f"Attempting to connect, {tries_left} tries left.")
|
||||
|
||||
try:
|
||||
self._ws = await websockets.connect(self.url, max_size=None)
|
||||
except (websockets.InvalidURI, websockets.InvalidHandshake, socket.gaierror):
|
||||
self._ws = None
|
||||
if tries_left > 0:
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
self._runtask = asyncio.ensure_future(self._run())
|
||||
self._pingtask = asyncio.ensure_future(self._ping())
|
||||
logger.debug(f"Started run and ping tasks")
|
||||
|
||||
return self._runtask
|
||||
|
||||
async def _run(self):
|
||||
"""
|
||||
Listen for packets and deal with them accordingly.
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
await self._handle_next_message()
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
self._clean_up_futures()
|
||||
self._clean_up_tasks()
|
||||
|
||||
try:
|
||||
await self._ws.close() # just to make sure
|
||||
except:
|
||||
pass # errors are not useful here
|
||||
|
||||
self._pingtask.cancel()
|
||||
await self._pingtask # should stop now that the ws is closed
|
||||
self._ws = None
|
||||
|
||||
async def _ping(self):
|
||||
"""
|
||||
Periodically ping the server to detect a timeout.
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
logger.debug("Pinging...")
|
||||
wait_for_reply = await self._ws.ping()
|
||||
await asyncio.wait_for(wait_for_reply, self.ping_timeout)
|
||||
logger.debug("Pinged!")
|
||||
await asyncio.sleep(self.ping_delay)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Ping timed out.")
|
||||
await self._ws.close()
|
||||
break
|
||||
except (websockets.ConnectionClosed, ConnectionResetError, asyncio.CancelledError):
|
||||
return
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Close websocket connection and wait for running task to stop.
|
||||
"""
|
||||
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass # errors not useful here
|
||||
|
||||
if self._runtask:
|
||||
await self._runtask
|
||||
|
||||
self._stopped = False
|
||||
self._pingtask = None
|
||||
self._runtask = asyncio.create_task(self._run())
|
||||
# ... aaand the connection is started.
|
||||
|
||||
async def send(self, ptype, data=None, await_response=True):
|
||||
if not self._ws:
|
||||
raise asyncio.CancelledError
|
||||
|
||||
raise exceptions.ConnectionClosed
|
||||
#raise asyncio.CancelledError
|
||||
|
||||
pid = str(self._new_pid())
|
||||
packet = {
|
||||
"type": ptype,
|
||||
|
|
@ -125,62 +43,157 @@ class Connection:
|
|||
}
|
||||
if data:
|
||||
packet["data"] = data
|
||||
|
||||
|
||||
if await_response:
|
||||
wait_for = self._wait_for_response(pid)
|
||||
|
||||
|
||||
logging.debug(f"Currently used websocket at self._ws: {self._ws}")
|
||||
await self._ws.send(json.dumps(packet, separators=(',', ':'))) # minimum size
|
||||
|
||||
|
||||
if await_response:
|
||||
await wait_for
|
||||
return wait_for.result()
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Close websocket connection and wait for running task to stop.
|
||||
|
||||
No connection function are to be called after calling stop().
|
||||
This means that stop() can only be called once.
|
||||
"""
|
||||
|
||||
self._stopped = True
|
||||
if self._ws:
|
||||
await self._ws.close() # _run() does the cleaning up now.
|
||||
await self._runtask
|
||||
|
||||
async def _connect(self, tries):
|
||||
"""
|
||||
Attempt to connect to a room.
|
||||
If the Connection is already connected, it attempts to reconnect.
|
||||
|
||||
Returns True on success, False on failure.
|
||||
|
||||
If tries is None, connect retries infinitely.
|
||||
The delay between connection attempts doubles every attempt (starts with 1s).
|
||||
"""
|
||||
|
||||
# Assumes _disconnect() has already been called in _run()
|
||||
|
||||
delay = 1 # seconds
|
||||
while True:
|
||||
try:
|
||||
if self._cookiejar:
|
||||
cookies = [("Cookie", cookie) for cookie in self._cookiejar.sniff()]
|
||||
self._ws = await websockets.connect(self.url, max_size=None, extra_headers=cookies)
|
||||
else:
|
||||
self._ws = await websockets.connect(self.url, max_size=None)
|
||||
except (websockets.InvalidHandshake, socket.gaierror): # not websockets.InvalidURI
|
||||
self._ws = None
|
||||
|
||||
if tries is not None:
|
||||
tries -= 1
|
||||
if tries <= 0:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
if self._cookiejar:
|
||||
for set_cookie in self._ws.response_headers.get_all("Set-Cookie"):
|
||||
self._cookiejar.bake(set_cookie)
|
||||
|
||||
self._pingtask = asyncio.create_task(self._ping())
|
||||
|
||||
return True
|
||||
|
||||
async def _disconnect(self):
|
||||
"""
|
||||
Disconnect and clean up all "residue", such as:
|
||||
- close existing websocket connection
|
||||
- cancel all pending response futures with a ConnectionClosed exception
|
||||
- reset package ID counter
|
||||
- make sure the ping task has finished
|
||||
"""
|
||||
|
||||
# stop ping task
|
||||
if self._pingtask:
|
||||
self._pingtask.cancel()
|
||||
await self._pingtask
|
||||
self._pingtask = None
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
self._pid = 0
|
||||
|
||||
# clean up pending response futures
|
||||
for _, future in self._pending_responses.items():
|
||||
logger.debug(f"Cancelling future with ConnectionClosed: {future}")
|
||||
future.set_exception(exceptions.ConnectionClosed("No server response"))
|
||||
self._pending_responses = {}
|
||||
|
||||
async def _run(self):
|
||||
"""
|
||||
Listen for packets and deal with them accordingly.
|
||||
"""
|
||||
|
||||
while not self._stopped:
|
||||
self._connect(self.reconnect_attempts)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await self._handle_next_message()
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
await self._disconnect() # disconnect and clean up
|
||||
|
||||
async def _ping(self):
|
||||
"""
|
||||
Periodically ping the server to detect a timeout.
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Pinging...")
|
||||
wait_for_reply = await self._ws.ping()
|
||||
await asyncio.wait_for(wait_for_reply, self.ping_timeout)
|
||||
logger.debug("Pinged!")
|
||||
await asyncio.sleep(self.ping_delay)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Ping timed out.")
|
||||
await self._ws.close() # trigger a reconnect attempt
|
||||
except (websockets.ConnectionClosed, ConnectionResetError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
def _new_pid(self):
|
||||
self._pid += 1
|
||||
return self._pid
|
||||
|
||||
|
||||
async def _handle_next_message(self):
|
||||
response = await self._ws.recv()
|
||||
task = asyncio.ensure_future(self._handle_json(response))
|
||||
self._track_task(task) # will be cancelled when the connection is closed
|
||||
|
||||
def _clean_up_futures(self):
|
||||
for pid, future in self._pending_responses.items():
|
||||
logger.debug(f"Cancelling future: {future}")
|
||||
future.cancel()
|
||||
self._pending_responses = {}
|
||||
|
||||
def _clean_up_tasks(self):
|
||||
for task in self._spawned_tasks:
|
||||
if not task.done():
|
||||
logger.debug(f"Cancelling task: {task}")
|
||||
task.cancel()
|
||||
else:
|
||||
logger.debug(f"Task already done: {task}")
|
||||
logger.debug(f"Exception: {task.exception()}")
|
||||
self._spawned_tasks = set()
|
||||
|
||||
async def _handle_json(self, text):
|
||||
packet = json.loads(text)
|
||||
|
||||
|
||||
# Deal with pending responses
|
||||
pid = packet.get("id", None)
|
||||
future = self._pending_responses.pop(pid, None)
|
||||
if future:
|
||||
future.set_result(packet)
|
||||
|
||||
|
||||
ptype = packet.get("type")
|
||||
data = packet.get("data", None)
|
||||
error = packet.get("error", None)
|
||||
if packet.get("throttled", False):
|
||||
throttled = packet.get("throttled_reason")
|
||||
else:
|
||||
throttled = None
|
||||
|
||||
# Pass packet onto room
|
||||
await self.packet_hook(packet)
|
||||
|
||||
def _track_task(self, task):
|
||||
self._spawned_tasks.add(task)
|
||||
|
||||
# only keep running tasks
|
||||
self._spawned_tasks = {task for task in self._spawned_tasks if not task.done()}
|
||||
|
||||
asyncio.create_task(self.packet_callback(ptype, data, error, throttled))
|
||||
|
||||
def _wait_for_response(self, pid):
|
||||
future = asyncio.Future()
|
||||
self._pending_responses[pid] = future
|
||||
|
||||
return future
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue