from __future__ import annotations import json import asyncio import logging from time import time from contextlib import suppress from typing import Any, Literal, TYPE_CHECKING import aiohttp from translate import _ from exceptions import MinerException, WebsocketClosed from constants import PING_INTERVAL, PING_TIMEOUT, MAX_WEBSOCKETS, WS_TOPICS_LIMIT from utils import ( CHARS_ASCII, task_wrapper, create_nonce, json_minify, format_traceback, AwaitableValue, ExponentialBackoff, ) if TYPE_CHECKING: from collections import abc from twitch import Twitch from gui import WebsocketStatus from constants import JsonType, WebsocketTopic WSMsgType = aiohttp.WSMsgType logger = logging.getLogger("TwitchDrops") ws_logger = logging.getLogger("TwitchDrops.websocket") class Websocket: def __init__(self, pool: WebsocketPool, index: int): self._pool: WebsocketPool = pool self._twitch: Twitch = pool._twitch self._ws_gui: WebsocketStatus = self._twitch.gui.websockets self._state_lock = asyncio.Lock() # websocket index self._idx: int = index # current websocket connection self._ws: AwaitableValue[aiohttp.ClientWebSocketResponse] = AwaitableValue() # set when the websocket needs to be closed or reconnect self._closed = asyncio.Event() self._reconnect_requested = asyncio.Event() # set when the topics changed self._topics_changed = asyncio.Event() # ping timestamps self._next_ping: float = time() self._max_pong: float = self._next_ping + PING_TIMEOUT.total_seconds() # main task, responsible for receiving messages, sending them, and websocket ping self._handle_task: asyncio.Task[None] | None = None # topics stuff self.topics: dict[str, WebsocketTopic] = {} self._submitted: set[WebsocketTopic] = set() # notify GUI self.set_status("Disconnected") @property def connected(self) -> bool: return self._ws.has_value() def wait_until_connected(self): return self._ws.wait() def set_status(self, status: str | None = None, refresh_topics: bool = False): self._twitch.gui.websockets.update( self._idx, status=status, topics=(len(self.topics) if refresh_topics else None) ) def request_reconnect(self): # reset our ping interval, so we send a PING after reconnect right away self._next_ping = time() self._reconnect_requested.set() async def start(self): async with self._state_lock: self.start_nowait() await self.wait_until_connected() def start_nowait(self): if self._handle_task is None or self._handle_task.done(): self._handle_task = asyncio.create_task(self._handle()) async def stop(self, *, remove: bool = False): async with self._state_lock: if self._closed.is_set(): return self._closed.set() ws = self._ws.get_with_default(None) if ws is not None: self.set_status(_("gui", "websocket", "disconnecting")) await ws.close() if self._handle_task is not None: with suppress(asyncio.TimeoutError, asyncio.CancelledError): await asyncio.wait_for(self._handle_task, timeout=2) self._handle_task = None if remove: self.topics.clear() self._topics_changed.set() self._twitch.gui.websockets.remove(self._idx) def stop_nowait(self, *, remove: bool = False): # weird syntax but that's what we get for using a decorator for this # return type of 'task_wrapper' is a coro, so we need to instance it for the task asyncio.create_task(task_wrapper(self.stop)(remove=remove)) async def _backoff_connect( self, ws_url: str, **kwargs ) -> abc.AsyncGenerator[aiohttp.ClientWebSocketResponse, None]: session = await self._twitch.get_session() backoff = ExponentialBackoff(**kwargs) if self._twitch.settings.proxy: proxy = self._twitch.settings.proxy else: proxy = None for delay in backoff: try: async with session.ws_connect(ws_url, ssl=True, proxy=proxy) as websocket: yield websocket backoff.reset() except ( asyncio.TimeoutError, aiohttp.ClientResponseError, aiohttp.ClientConnectionError, ): ws_logger.info( f"Websocket[{self._idx}] connection problem (sleep: {round(delay)}s)", exc_info=True, ) await asyncio.sleep(delay) except RuntimeError: ws_logger.warning( f"Websocket[{self._idx}] exiting backoff connect loop " "because session is closed (RuntimeError)" ) break @task_wrapper async def _handle(self): # ensure we're logged in before connecting self.set_status(_("gui", "websocket", "initializing")) await self._twitch.wait_until_login() self.set_status(_("gui", "websocket", "connecting")) ws_logger.info(f"Websocket[{self._idx}] connecting...") self._closed.clear() # Connect/Reconnect loop async for websocket in self._backoff_connect( "wss://pubsub-edge.twitch.tv/v1", maximum=3*60 # 3 minutes maximum backoff time ): self._ws.set(websocket) self._reconnect_requested.clear() # NOTE: _topics_changed doesn't start set, # because there's no initial topics we can sub to right away self.set_status(_("gui", "websocket", "connected")) ws_logger.info(f"Websocket[{self._idx}] connected.") try: try: while not self._reconnect_requested.is_set(): await self._handle_ping() await self._handle_topics() await self._handle_recv() finally: self._ws.clear() self._submitted.clear() # set _topics_changed to let the next WS connection resub to the topics self._topics_changed.set() # A reconnect was requested except WebsocketClosed as exc: if exc.received: # server closed the connection, not us - reconnect ws_logger.warning( f"Websocket[{self._idx}] closed unexpectedly: {websocket.close_code}" ) elif self._closed.is_set(): # we closed it - exit ws_logger.info(f"Websocket[{self._idx}] stopped.") self.set_status(_("gui", "websocket", "disconnected")) return except Exception: ws_logger.exception(f"Exception in Websocket[{self._idx}]") self.set_status(_("gui", "websocket", "reconnecting")) ws_logger.warning(f"Websocket[{self._idx}] reconnecting...") async def _handle_ping(self): now = time() if now >= self._next_ping: self._next_ping = now + PING_INTERVAL.total_seconds() self._max_pong = now + PING_TIMEOUT.total_seconds() # wait for a PONG for up to 10s await self.send({"type": "PING"}) elif now >= self._max_pong: # it's been more than 10s and there was no PONG ws_logger.warning(f"Websocket[{self._idx}] didn't receive a PONG, reconnecting...") self.request_reconnect() async def _handle_topics(self): if not self._topics_changed.is_set(): # nothing to do return self._topics_changed.clear() self.set_status(refresh_topics=True) auth_state = await self._twitch.get_auth() current: set[WebsocketTopic] = set(self.topics.values()) # handle removed topics removed = self._submitted.difference(current) if removed: topics_list = list(map(str, removed)) ws_logger.debug(f"Websocket[{self._idx}]: Removing topics: {', '.join(topics_list)}") await self.send( { "type": "UNLISTEN", "data": { "topics": topics_list, "auth_token": auth_state.access_token, } } ) self._submitted.difference_update(removed) # handle added topics added = current.difference(self._submitted) if added: topics_list = list(map(str, added)) ws_logger.debug(f"Websocket[{self._idx}]: Adding topics: {', '.join(topics_list)}") await self.send( { "type": "LISTEN", "data": { "topics": topics_list, "auth_token": auth_state.access_token, } } ) self._submitted.update(added) async def _gather_recv(self, messages: list[JsonType], timeout: float = 0.5): """ Gather incoming messages over the timeout specified. Note that there's no return value - this modifies `messages` in-place. """ ws = self._ws.get_with_default(None) assert ws is not None while True: raw_message: aiohttp.WSMessage = await ws.receive(timeout=timeout) ws_logger.debug(f"Websocket[{self._idx}] received: {raw_message}") if raw_message.type is WSMsgType.TEXT: message: JsonType = json.loads(raw_message.data) messages.append(message) elif raw_message.type is WSMsgType.CLOSE: raise WebsocketClosed(received=True) elif raw_message.type is WSMsgType.CLOSED: raise WebsocketClosed(received=False) elif raw_message.type is WSMsgType.CLOSING: pass # skip these elif raw_message.type is WSMsgType.ERROR: ws_logger.error( f"Websocket[{self._idx}] error: {format_traceback(raw_message.data)}" ) raise WebsocketClosed() else: ws_logger.error(f"Websocket[{self._idx}] error: Unknown message: {raw_message}") def _handle_message(self, message): # request the assigned topic to process the response topic = self.topics.get(message["data"]["topic"]) if topic is not None: # use a task to not block the websocket asyncio.create_task(topic(json.loads(message["data"]["message"]))) async def _handle_recv(self): """ Handle receiving messages from the websocket. """ # listen over 0.5s for incoming messages messages: list[JsonType] = [] with suppress(asyncio.TimeoutError): await self._gather_recv(messages, timeout=0.5) # process them for message in messages: msg_type = message["type"] if msg_type == "MESSAGE": self._handle_message(message) elif msg_type == "PONG": # move the timestamp to something much later self._max_pong = self._next_ping elif msg_type == "RESPONSE": # no special handling for these (for now) pass elif msg_type == "RECONNECT": # We've received a reconnect request ws_logger.warning(f"Websocket[{self._idx}] requested reconnect.") self.request_reconnect() else: ws_logger.warning(f"Websocket[{self._idx}] received unknown payload: {message}") def add_topics(self, topics_set: set[WebsocketTopic]): changed: bool = False while topics_set and len(self.topics) < WS_TOPICS_LIMIT: topic = topics_set.pop() self.topics[str(topic)] = topic changed = True if changed: self._topics_changed.set() def remove_topics(self, topics_set: set[str]): existing = topics_set.intersection(self.topics.keys()) if not existing: # nothing to remove from here return topics_set.difference_update(existing) for topic in existing: del self.topics[topic] self._topics_changed.set() async def send(self, message: JsonType): ws = self._ws.get_with_default(None) assert ws is not None if message["type"] != "PING": message["nonce"] = create_nonce(CHARS_ASCII, 30) await ws.send_json(message, dumps=json_minify) ws_logger.debug(f"Websocket[{self._idx}] sent: {message}") class WebsocketPool: def __init__(self, twitch: Twitch): self._twitch: Twitch = twitch self._running = asyncio.Event() self.websockets: list[Websocket] = [] @property def running(self) -> bool: return self._running.is_set() def wait_until_connected(self) -> abc.Coroutine[Any, Any, Literal[True]]: return self._running.wait() async def start(self): self._running.set() await asyncio.gather(*(ws.start() for ws in self.websockets)) async def stop(self, *, clear_topics: bool = False): self._running.clear() await asyncio.gather(*(ws.stop(remove=clear_topics) for ws in self.websockets)) def add_topics(self, topics: abc.Iterable[WebsocketTopic]): # ensure no topics end up duplicated topics_set = set(topics) if not topics_set: # nothing to add return topics_set.difference_update(*(ws.topics.values() for ws in self.websockets)) if not topics_set: # none left to add return for ws_idx in range(MAX_WEBSOCKETS): if ws_idx < len(self.websockets): # just read it back ws = self.websockets[ws_idx] else: # create new ws = Websocket(self, ws_idx) if self.running: ws.start_nowait() self.websockets.append(ws) # ask websocket to take any topics it can - this modifies the set in-place ws.add_topics(topics_set) # see if there's any leftover topics for the next websocket connection if not topics_set: return # if we're here, there were leftover topics after filling up all websockets raise MinerException("Maximum topics limit has been reached") def remove_topics(self, topics: abc.Iterable[str]): topics_set = set(topics) if not topics_set: # nothing to remove return for ws in self.websockets: ws.remove_topics(topics_set) # count up all the topics - if we happen to have more websockets connected than needed, # stop the last one and recycle topics from it - repeat until we have enough recycled_topics: list[WebsocketTopic] = [] while True: count = sum(len(ws.topics) for ws in self.websockets) if count <= (len(self.websockets) - 1) * WS_TOPICS_LIMIT: ws = self.websockets.pop() recycled_topics.extend(ws.topics.values()) ws.stop_nowait(remove=True) else: break if recycled_topics: self.add_topics(recycled_topics)