from __future__ import annotations import json import random import string import asyncio import logging from time import time from functools import wraps from contextlib import suppress from typing import Any, Optional, List, Dict, Set, Iterable, TYPE_CHECKING from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.client import WebSocketClientProtocol, connect as websocket_connect from inventory import TimedDrop from exceptions import MinerException from constants import ( JsonType, WebsocketTopic, DEBUG_RAW, WEBSOCKET_URL, PING_INTERVAL, MAX_WEBSOCKETS, WS_TOPICS_LIMIT, ) if TYPE_CHECKING: from twitch import Twitch logger = logging.getLogger("TwitchDrops") NONCE_CHARS = string.ascii_letters + string.digits def create_nonce(length: int = 30) -> str: return ''.join(random.choices(NONCE_CHARS, k=length)) def task_wrapper(afunc): @wraps(afunc) async def wrapper(self: Websocket, *args, **kwargs): try: await afunc(self, *args, **kwargs) except Exception: logger.exception("Exception in websocket task") raise # raise up to the wrapping task return wrapper class Websocket: def __init__(self, pool: WebsocketPool, index: int): self._pool = pool self._twitch = pool._twitch # websocket index self._idx: int = index # current websocket connection self._ws: WebSocketClientProtocol # set when there's an active websocket connection self._connected_flag = asyncio.Event() # set when the websocket needs to reconnect self._reconnect_requested = asyncio.Event() # set when the topics changed self._topics_changed = asyncio.Event() # ping timestamps self._next_ping = time() self._max_pong = time() # main task, responsible for receiving messages, sending them, and websocket ping self._handle_task: Optional[asyncio.Task[Any]] = None # topics stuff self.topics: Dict[str, WebsocketTopic] = {} self._submitted: Set[WebsocketTopic] = set() @property def connected(self) -> bool: return self._connected_flag.is_set() def wait_until_connected(self): return self._connected_flag.wait() def request_reconnect(self): self._reconnect_requested.set() async def start(self): if self.connected: return if self._handle_task is None: self._handle_task = asyncio.create_task(self._handle()) else: self.request_reconnect() await self.wait_until_connected() async def stop(self): if self._ws is not None: await self._ws.close() if self._handle_task is not None: await self._handle_task self._handle_task = None def start_nowait(self): if self.connected: return if self._handle_task is None: self._handle_task = asyncio.create_task(self._handle()) else: self.request_reconnect() def stop_nowait(self): if self._ws is not None: asyncio.create_task(self._ws.close()) self._handle_task = None @task_wrapper async def _handle(self): # ensure we're logged in before connecting await self._twitch.wait_until_login() logger.info("Connecting to Websocket") # Connect/Reconnect loop async for websocket in websocket_connect(WEBSOCKET_URL, ssl=True, ping_interval=None): websocket.BACKOFF_MAX = 3 * 60 # type: ignore # 3 minutes self._ws = websocket try: try: self._reconnect_requested.clear() self._connected_flag.set() while not self._reconnect_requested.is_set(): await self._handle_ping() await self._handle_topics() await self._handle_recv() finally: self._submitted.clear() self._connected_flag.clear() # A reconnect was requested continue except ConnectionClosed as exc: if isinstance(exc, ConnectionClosedOK): if exc.rcvd_then_sent: # server closed the connection, not us - reconnect logger.warning("Server Disconnected - Reconnecting") continue # we closed it - exit return # otherwise, reconnect logger.warning("Websocket Closed - Reconnecting") continue except Exception: logger.exception("Exception in Websocket - Reconnecting") continue async def _handle_ping(self): now = time() if now >= self._next_ping: self._next_ping = now + PING_INTERVAL.total_seconds() self._max_pong = now + 10 # wait for a PONG for up to 10s await self.send({"type": "PING"}) elif now >= self._max_pong: # it's been more than 10s and there was no PONG self.request_reconnect() async def _handle_topics(self): if not self._topics_changed.is_set(): # nothing to do return current: Set[WebsocketTopic] = set(self.topics.values()) # handle removed topics removed = self._submitted.difference(current) if removed: topics_list = list(map(str, removed)) logger.debug(f"Websocket[{self._idx}]: Removing topics: {', '.join(topics_list)}") await self.send( { "type": "UNLISTEN", "data": { "topics": topics_list, "auth_token": self._twitch._access_token, } } ) self._submitted.difference_update(removed) # handle added topics added = current.difference(self._submitted) if added: topics_list = list(map(str, added)) logger.debug(f"Websocket[{self._idx}]: Adding topics: {', '.join(topics_list)}") await self.send( { "type": "LISTEN", "data": { "topics": topics_list, "auth_token": self._twitch._access_token, } } ) self._submitted.update(added) self._topics_changed.clear() async def _gather_recv(self, messages: List[JsonType]): """ Gather incoming messages over the timeout specified. Note that there's no return value - this modifies `messages` in-place. """ while True: raw_message = await self._ws.recv() message = json.loads(raw_message) logger.log(DEBUG_RAW, f"Websocket received: {message}") messages.append(message) def _handle_message(self, message): # request the assigned topic to process the response topic_process = self.topics.get(message["data"]["topic"]) if topic_process is not None: # use a task to not block the websocket asyncio.create_task(topic_process(json.loads(message["data"]["message"]))) async def _handle_recv(self): """ Handle receiving messages from the websocket. """ # listen over 0.5s for incoming messages messages: List[JsonType] = [] with suppress(asyncio.TimeoutError): await asyncio.wait_for(self._gather_recv(messages), timeout=0.5) # process them for message in messages: msg_type = message["type"] if msg_type == "MESSAGE": self._handle_message(message) elif msg_type == "PONG": # move the timestamp to something much later self._max_pong = self._next_ping elif msg_type == "RESPONSE": # no special handling for these (for now) pass elif msg_type == "RECONNECT": # We've received a reconnect request logger.warning("Received a Websocket Reconnect Request") self.request_reconnect() else: logger.error(f"Received unknown websocket payload: {message}") def add_topics(self, topics_set: Set[WebsocketTopic]): while topics_set and len(self.topics) < WS_TOPICS_LIMIT: topic = topics_set.pop() self.topics[str(topic)] = topic self._topics_changed.set() def remove_topics(self, topics_set: Set[WebsocketTopic]): existing = topics_set.intersection(self.topics.values()) if not existing: # nothing to remove from here return topics_set.difference_update(existing) for topic in existing: del self.topics[str(topic)] self._topics_changed.set() async def send(self, message: JsonType): if self._ws is None: return if message["type"] != "PING": message["nonce"] = create_nonce() await self._ws.send(json.dumps(message, separators=(',', ':'))) logger.log(DEBUG_RAW, f"Websocket sent: {message}") class WebsocketPool: def __init__(self, twitch: Twitch): self._twitch: Twitch = twitch self._running = asyncio.Event() self.websockets: List[Websocket] = [] @property def running(self) -> bool: return self._running.is_set() def wait_until_connected(self): return self._running.wait() async def start(self): self._running.set() await self._twitch.wait_until_login() # Add default topics assert self._twitch._user_id is not None user_id = self._twitch._user_id self.add_topics([ WebsocketTopic("User", "Drops", user_id, self.process_drops), WebsocketTopic("User", "CommunityPoints", user_id, self.process_points), ]) await asyncio.gather(*(ws.start() for ws in self.websockets)) async def stop(self): self._running.clear() await asyncio.gather(*(ws.stop() for ws in self.websockets)) def add_topics(self, topics: Iterable[WebsocketTopic]): # ensure no topics end up duplicated topics_set = set(topics) if not topics_set: # nothing to add return topics_set.difference_update(ws.topics.values() for ws in self.websockets) if not topics_set: # none left to add return for ws_idx in range(MAX_WEBSOCKETS): if ws_idx < len(self.websockets): # just read it back ws = self.websockets[ws_idx] else: # create new ws = Websocket(self, len(self.websockets)) if self._running: ws.start_nowait() self.websockets.append(ws) # ask websocket to take any topics it can - this modifies the set in-place ws.add_topics(topics_set) # see if there's any leftover topics for the next websocket connection if not topics_set: return # if we're here, there were leftover topics after filling up all websockets raise MinerException("Maximum topics limit has been reached") def remove_topics(self, topics: Iterable[WebsocketTopic]): topics_set = set(topics) if not topics_set: # nothing to remove return for ws in self.websockets: ws.remove_topics(topics_set) # count up all the topics - if we happen to have more websockets connected than needed, # stop the last one and recycle topics from it - repeat until we have enough topics = [] while True: count = sum(len(ws.topics) for ws in self.websockets) if count <= (len(self.websockets) - 1) * WS_TOPICS_LIMIT: ws = self.websockets.pop() topics.extend(ws.topics.values()) ws.stop_nowait() else: break if topics: self.add_topics(topics) @task_wrapper async def process_drops(self, message: JsonType): drop_id = message["data"]["drop_id"] drop: Optional[TimedDrop] = None for campaign in self._twitch.inventory: drop = campaign.get_drop(drop_id) if drop is not None: break else: logger.error(f"Drop with ID of {drop_id} not found!") return drop.update(message) msg_type = message["type"] if msg_type == "drop-progress": print( f"Drop: {drop.rewards_text()}: {drop.progress:6.1%} " f"({drop.remaining_minutes} minutes remaining)" ) elif msg_type == "drop-claim": campaign = drop.campaign await drop.claim() print( f"Claimed drop: {drop.rewards_text()} " f"({campaign.claimed_drops}/{campaign.total_drops})" ) if campaign.remaining_drops == 0: self._twitch.reevaluate_campaigns() @task_wrapper async def process_points(self, message: JsonType): msg_type = message["type"] if msg_type == "points-earned": points = message["data"]["point_gain"]["total_points"] balance = message["data"]["balance"]["balance"] print(f"Earned points for watching: {points:3}, total: {balance}") elif msg_type == "claim-available": claim_data = message["data"]["claim"] points = claim_data["point_gain"]["total_points"] await self._twitch.claim_points(claim_data["channel_id"], claim_data["id"]) print(f"Claimed bonus points: {points}")