diff --git a/websocket.py b/websocket.py index ff186e9..7e73960 100644 --- a/websocket.py +++ b/websocket.py @@ -119,8 +119,8 @@ class Websocket: for delay in backoff: try: async with session.ws_connect(ws_url, ssl=True, proxy=proxy) as websocket: - backoff.reset() yield websocket + backoff.reset() except aiohttp.ClientConnectionError: ws_logger.info( f"Websocket[{self._idx}] connection error (sleep: {delay:.3}s)", exc_info=True @@ -225,7 +225,7 @@ class Websocket: ) self._submitted.update(added) - async def _gather_recv(self, messages: list[JsonType]): + async def _gather_recv(self, messages: list[JsonType], timeout: float = 0.5): """ Gather incoming messages over the timeout specified. Note that there's no return value - this modifies `messages` in-place. @@ -233,12 +233,11 @@ class Websocket: ws = self._ws.get_with_default(None) assert ws is not None while True: - raw_message: aiohttp.WSMessage = await ws.receive() + raw_message: aiohttp.WSMessage = await ws.receive(timeout=timeout) ws_logger.debug(f"Websocket[{self._idx}] received: {raw_message}") if raw_message.type is WSMsgType.TEXT: message: JsonType = json.loads(raw_message.data) messages.append(message) - continue # shortcut to avoid checking all other elifs if not necessary elif raw_message.type is WSMsgType.CLOSE: raise WebsocketClosed(received=True) elif raw_message.type is WSMsgType.CLOSED: @@ -265,7 +264,7 @@ class Websocket: # listen over 0.5s for incoming messages messages: list[JsonType] = [] with suppress(asyncio.TimeoutError): - await asyncio.wait_for(self._gather_recv(messages), timeout=0.5) + await self._gather_recv(messages, timeout=0.5) # process them for message in messages: msg_type = message["type"]