From 80cf5a8a618aa962a103b9b334d130fc78fe552c Mon Sep 17 00:00:00 2001 From: DevilXD Date: Sat, 17 Sep 2022 15:17:02 +0200 Subject: [PATCH] Simplify some of the websocket start and stop logic --- websocket.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/websocket.py b/websocket.py index d3a2e7d..560aa47 100644 --- a/websocket.py +++ b/websocket.py @@ -80,18 +80,14 @@ class Websocket: async def start(self): async with self._state_lock: - if not self.connected: - self.start_nowait() - await self.wait_until_connected() + self.start_nowait() + await self.wait_until_connected() def start_nowait(self): - if not self.connected: - if self._handle_task is None: - self._handle_task = asyncio.create_task(self._handle()) - elif self._handle_task.done(): - ws_logger.error(f"Detected zombified handle task: {self._handle_task!r}") + if self._handle_task is None or self._handle_task.done(): + self._handle_task = asyncio.create_task(self._handle()) - async def stop(self): + async def stop(self, *, remove: bool = False): async with self._state_lock: if self._closed.is_set(): return @@ -100,19 +96,17 @@ class Websocket: if ws is not None: self.set_status(_("gui", "websocket", "disconnecting")) await ws.close() + if self._handle_task is not None: + with suppress(asyncio.TimeoutError, asyncio.CancelledError): + await asyncio.wait_for(self._handle_task, timeout=2) + self._handle_task = None + if remove: + self._twitch.gui.websockets.remove(self._idx) - def stop_nowait(self): + def stop_nowait(self, *, remove: bool = False): # weird syntax but that's what we get for using a decorator for this # return type of 'task_wrapper' is a coro, so we need to instance it for the task - asyncio.create_task(task_wrapper(self.stop)()) - - def stop_and_remove(self): - # this stops the websocket, and then removes it from the gui list - @task_wrapper - async def remover(): - await self.stop() - self._twitch.gui.websockets.remove(self._idx) - asyncio.create_task(remover()) + asyncio.create_task(task_wrapper(self.stop)(remove=remove)) async def _backoff_connect( self, ws_url: str, **kwargs @@ -392,7 +386,7 @@ class WebsocketPool: if count <= (len(self.websockets) - 1) * WS_TOPICS_LIMIT: ws = self.websockets.pop() recycled_topics.extend(ws.topics.values()) - ws.stop_and_remove() + ws.stop_nowait(remove=True) else: break if recycled_topics: