Files
TwitchDropsMiner/websocket.py

396 lines
16 KiB
Python

from __future__ import annotations
import json
import asyncio
import logging
from time import time
from contextlib import suppress
from typing import Any, Literal, TYPE_CHECKING
import aiohttp
from translate import _
from exceptions import MinerException, WebsocketClosed
from constants import PING_INTERVAL, PING_TIMEOUT, MAX_WEBSOCKETS, WS_TOPICS_LIMIT
from utils import (
CHARS_ASCII,
task_wrapper,
create_nonce,
json_minify,
format_traceback,
AwaitableValue,
ExponentialBackoff,
)
if TYPE_CHECKING:
from collections import abc
from twitch import Twitch
from gui import WebsocketStatus
from constants import JsonType, WebsocketTopic
WSMsgType = aiohttp.WSMsgType
logger = logging.getLogger("TwitchDrops")
ws_logger = logging.getLogger("TwitchDrops.websocket")
class Websocket:
def __init__(self, pool: WebsocketPool, index: int):
self._pool: WebsocketPool = pool
self._twitch: Twitch = pool._twitch
self._ws_gui: WebsocketStatus = self._twitch.gui.websockets
self._state_lock = asyncio.Lock()
# websocket index
self._idx: int = index
# current websocket connection
self._ws: AwaitableValue[aiohttp.ClientWebSocketResponse] = AwaitableValue()
# set when the websocket needs to be closed or reconnect
self._closed = asyncio.Event()
self._reconnect_requested = asyncio.Event()
# set when the topics changed
self._topics_changed = asyncio.Event()
# ping timestamps
self._next_ping: float = time()
self._max_pong: float = self._next_ping + PING_TIMEOUT.total_seconds()
# main task, responsible for receiving messages, sending them, and websocket ping
self._handle_task: asyncio.Task[None] | None = None
# topics stuff
self.topics: dict[str, WebsocketTopic] = {}
self._submitted: set[WebsocketTopic] = set()
# notify GUI
self.set_status(_("gui", "websocket", "disconnected"))
@property
def connected(self) -> bool:
return self._ws.has_value()
def wait_until_connected(self):
return self._ws.wait()
def set_status(self, status: str | None = None, refresh_topics: bool = False):
self._twitch.gui.websockets.update(
self._idx, status=status, topics=(len(self.topics) if refresh_topics else None)
)
def request_reconnect(self):
# reset our ping interval, so we send a PING after reconnect right away
self._next_ping = time()
self._reconnect_requested.set()
async def start(self):
async with self._state_lock:
self.start_nowait()
await self.wait_until_connected()
def start_nowait(self):
if self._handle_task is None or self._handle_task.done():
self._handle_task = asyncio.create_task(self._handle())
async def stop(self, *, remove: bool = False):
async with self._state_lock:
if self._closed.is_set():
return
self._closed.set()
ws = self._ws.get_with_default(None)
if ws is not None:
self.set_status(_("gui", "websocket", "disconnecting"))
await ws.close()
if self._handle_task is not None:
with suppress(asyncio.TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(self._handle_task, timeout=2)
self._handle_task = None
if remove:
self.topics.clear()
self._topics_changed.set()
self._twitch.gui.websockets.remove(self._idx)
def stop_nowait(self, *, remove: bool = False):
# weird syntax but that's what we get for using a decorator for this
# return type of 'task_wrapper' is a coro, so we need to instance it for the task
asyncio.create_task(task_wrapper(self.stop)(remove=remove))
async def _backoff_connect(
self, ws_url: str, **kwargs
) -> abc.AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
session = await self._twitch.get_session()
backoff = ExponentialBackoff(**kwargs)
if self._twitch.settings.proxy:
proxy = self._twitch.settings.proxy
else:
proxy = None
for delay in backoff:
try:
async with session.ws_connect(ws_url, proxy=proxy) as websocket:
yield websocket
backoff.reset()
except (
asyncio.TimeoutError,
aiohttp.ClientResponseError,
aiohttp.ClientConnectionError,
):
ws_logger.info(
f"Websocket[{self._idx}] connection problem (sleep: {round(delay)}s)"
)
await asyncio.sleep(delay)
except RuntimeError:
ws_logger.warning(
f"Websocket[{self._idx}] exiting backoff connect loop "
"because session is closed (RuntimeError)"
)
break
@task_wrapper(critical=True)
async def _handle(self):
# ensure we're logged in before connecting
self.set_status(_("gui", "websocket", "initializing"))
await self._twitch.wait_until_login()
self.set_status(_("gui", "websocket", "connecting"))
ws_logger.info(f"Websocket[{self._idx}] connecting...")
self._closed.clear()
# Connect/Reconnect loop
async for websocket in self._backoff_connect(
"wss://pubsub-edge.twitch.tv/v1", maximum=3*60 # 3 minutes maximum backoff time
):
self._ws.set(websocket)
self._reconnect_requested.clear()
# NOTE: _topics_changed doesn't start set,
# because there's no initial topics we can sub to right away
self.set_status(_("gui", "websocket", "connected"))
ws_logger.info(f"Websocket[{self._idx}] connected.")
try:
try:
while not self._reconnect_requested.is_set():
await self._handle_ping()
await self._handle_topics()
await self._handle_recv()
finally:
self._ws.clear()
self._submitted.clear()
# set _topics_changed to let the next WS connection resub to the topics
self._topics_changed.set()
# A reconnect was requested
except WebsocketClosed as exc:
if exc.received:
# server closed the connection, not us - reconnect
ws_logger.warning(
f"Websocket[{self._idx}] closed unexpectedly: {websocket.close_code}"
)
elif self._closed.is_set():
# we closed it - exit
ws_logger.info(f"Websocket[{self._idx}] stopped.")
self.set_status(_("gui", "websocket", "disconnected"))
return
except Exception:
ws_logger.exception(f"Exception in Websocket[{self._idx}]")
self.set_status(_("gui", "websocket", "reconnecting"))
ws_logger.warning(f"Websocket[{self._idx}] reconnecting...")
async def _handle_ping(self):
now = time()
if now >= self._next_ping:
self._next_ping = now + PING_INTERVAL.total_seconds()
self._max_pong = now + PING_TIMEOUT.total_seconds() # wait for a PONG for up to 10s
await self.send({"type": "PING"})
elif now >= self._max_pong:
# it's been more than 10s and there was no PONG
ws_logger.warning(f"Websocket[{self._idx}] didn't receive a PONG, reconnecting...")
self.request_reconnect()
async def _handle_topics(self):
if not self._topics_changed.is_set():
# nothing to do
return
self._topics_changed.clear()
self.set_status(refresh_topics=True)
auth_state = await self._twitch.get_auth()
current: set[WebsocketTopic] = set(self.topics.values())
# handle removed topics
removed = self._submitted.difference(current)
if removed:
topics_list = list(map(str, removed))
ws_logger.debug(f"Websocket[{self._idx}]: Removing topics: {', '.join(topics_list)}")
await self.send(
{
"type": "UNLISTEN",
"data": {
"topics": topics_list,
"auth_token": auth_state.access_token,
}
}
)
self._submitted.difference_update(removed)
# handle added topics
added = current.difference(self._submitted)
if added:
topics_list = list(map(str, added))
ws_logger.debug(f"Websocket[{self._idx}]: Adding topics: {', '.join(topics_list)}")
await self.send(
{
"type": "LISTEN",
"data": {
"topics": topics_list,
"auth_token": auth_state.access_token,
}
}
)
self._submitted.update(added)
async def _gather_recv(self, messages: list[JsonType], timeout: float = 0.5):
"""
Gather incoming messages over the timeout specified.
Note that there's no return value - this modifies `messages` in-place.
"""
ws = self._ws.get_with_default(None)
assert ws is not None
while True:
raw_message: aiohttp.WSMessage = await ws.receive(timeout=timeout)
ws_logger.debug(f"Websocket[{self._idx}] received: {raw_message}")
if raw_message.type is WSMsgType.TEXT:
message: JsonType = json.loads(raw_message.data)
messages.append(message)
elif raw_message.type is WSMsgType.CLOSE:
raise WebsocketClosed(received=True)
elif raw_message.type is WSMsgType.CLOSED:
raise WebsocketClosed(received=False)
elif raw_message.type is WSMsgType.CLOSING:
pass # skip these
elif raw_message.type is WSMsgType.ERROR:
ws_logger.error(
f"Websocket[{self._idx}] error: {format_traceback(raw_message.data)}"
)
raise WebsocketClosed()
else:
ws_logger.error(f"Websocket[{self._idx}] error: Unknown message: {raw_message}")
def _handle_message(self, message):
# request the assigned topic to process the response
topic = self.topics.get(message["data"]["topic"])
if topic is not None:
# use a task to not block the websocket
asyncio.create_task(topic(json.loads(message["data"]["message"])))
async def _handle_recv(self):
"""
Handle receiving messages from the websocket.
"""
# listen over 0.5s for incoming messages
messages: list[JsonType] = []
with suppress(asyncio.TimeoutError):
await self._gather_recv(messages, timeout=0.5)
# process them
for message in messages:
msg_type = message["type"]
if msg_type == "MESSAGE":
self._handle_message(message)
elif msg_type == "PONG":
# move the timestamp to something much later
self._max_pong = self._next_ping
elif msg_type == "RESPONSE":
# no special handling for these (for now)
pass
elif msg_type == "RECONNECT":
# We've received a reconnect request
ws_logger.warning(f"Websocket[{self._idx}] requested reconnect.")
self.request_reconnect()
else:
ws_logger.warning(f"Websocket[{self._idx}] received unknown payload: {message}")
def add_topics(self, topics_set: set[WebsocketTopic]):
changed: bool = False
while topics_set and len(self.topics) < WS_TOPICS_LIMIT:
topic = topics_set.pop()
self.topics[str(topic)] = topic
changed = True
if changed:
self._topics_changed.set()
def remove_topics(self, topics_set: set[str]):
existing = topics_set.intersection(self.topics.keys())
if not existing:
# nothing to remove from here
return
topics_set.difference_update(existing)
for topic in existing:
del self.topics[topic]
self._topics_changed.set()
async def send(self, message: JsonType):
ws = self._ws.get_with_default(None)
assert ws is not None
if message["type"] != "PING":
message["nonce"] = create_nonce(CHARS_ASCII, 30)
await ws.send_json(message, dumps=json_minify)
ws_logger.debug(f"Websocket[{self._idx}] sent: {message}")
class WebsocketPool:
def __init__(self, twitch: Twitch):
self._twitch: Twitch = twitch
self._running = asyncio.Event()
self.websockets: list[Websocket] = []
@property
def running(self) -> bool:
return self._running.is_set()
def wait_until_connected(self) -> abc.Coroutine[Any, Any, Literal[True]]:
return self._running.wait()
async def start(self):
self._running.set()
await asyncio.gather(*(ws.start() for ws in self.websockets))
async def stop(self, *, clear_topics: bool = False):
self._running.clear()
await asyncio.gather(*(ws.stop(remove=clear_topics) for ws in self.websockets))
def add_topics(self, topics: abc.Iterable[WebsocketTopic]):
# ensure no topics end up duplicated
topics_set = set(topics)
if not topics_set:
# nothing to add
return
topics_set.difference_update(*(ws.topics.values() for ws in self.websockets))
if not topics_set:
# none left to add
return
for ws_idx in range(MAX_WEBSOCKETS):
if ws_idx < len(self.websockets):
# just read it back
ws = self.websockets[ws_idx]
else:
# create new
ws = Websocket(self, ws_idx)
if self.running:
ws.start_nowait()
self.websockets.append(ws)
# ask websocket to take any topics it can - this modifies the set in-place
ws.add_topics(topics_set)
# see if there's any leftover topics for the next websocket connection
if not topics_set:
return
# if we're here, there were leftover topics after filling up all websockets
raise MinerException("Maximum topics limit has been reached")
def remove_topics(self, topics: abc.Iterable[str]):
topics_set = set(topics)
if not topics_set:
# nothing to remove
return
for ws in self.websockets:
ws.remove_topics(topics_set)
# count up all the topics - if we happen to have more websockets connected than needed,
# stop the last one and recycle topics from it - repeat until we have enough
recycled_topics: list[WebsocketTopic] = []
while True:
count = sum(len(ws.topics) for ws in self.websockets)
if count <= (len(self.websockets) - 1) * WS_TOPICS_LIMIT:
ws = self.websockets.pop()
recycled_topics.extend(ws.topics.values())
ws.stop_nowait(remove=True)
else:
break
if recycled_topics:
self.add_topics(recycled_topics)