mirror of
https://github.com/rangermix/TwitchDropsMiner.git
synced 2026-05-27 15:39:38 +00:00
373 lines
14 KiB
Python
373 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import random
|
|
import string
|
|
import asyncio
|
|
import logging
|
|
from time import time
|
|
from functools import wraps
|
|
from contextlib import suppress
|
|
from typing import Any, Optional, List, Dict, Set, Iterable, TYPE_CHECKING
|
|
|
|
from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
|
|
from websockets.client import WebSocketClientProtocol, connect as websocket_connect
|
|
|
|
from exceptions import MinerException
|
|
from constants import (
|
|
JsonType,
|
|
WebsocketTopic,
|
|
WEBSOCKET_URL,
|
|
PING_INTERVAL,
|
|
PING_TIMEOUT,
|
|
MAX_WEBSOCKETS,
|
|
WS_TOPICS_LIMIT,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from twitch import Twitch
|
|
|
|
|
|
logger = logging.getLogger("TwitchDrops")
|
|
ws_logger = logging.getLogger("TwitchDrops.websocket")
|
|
NONCE_CHARS = string.ascii_letters + string.digits
|
|
|
|
|
|
def create_nonce(length: int = 30) -> str:
|
|
return ''.join(random.choices(NONCE_CHARS, k=length))
|
|
|
|
|
|
def task_wrapper(afunc):
|
|
@wraps(afunc)
|
|
async def wrapper(self, *args, **kwargs):
|
|
try:
|
|
await afunc(self, *args, **kwargs)
|
|
except Exception:
|
|
logger.exception("Exception in task")
|
|
raise # raise up to the wrapping task
|
|
return wrapper
|
|
|
|
|
|
class Websocket:
|
|
def __init__(self, pool: WebsocketPool, index: int):
|
|
self._pool = pool
|
|
self._twitch = pool._twitch
|
|
# websocket index
|
|
self._idx: int = index
|
|
# current websocket connection
|
|
self._ws: Optional[WebSocketClientProtocol] = None
|
|
# set when there's an active websocket connection
|
|
self._connected_flag = asyncio.Event()
|
|
# set when the websocket needs to reconnect
|
|
self._reconnect_requested = asyncio.Event()
|
|
# set when the topics changed
|
|
self._topics_changed = asyncio.Event()
|
|
# ping timestamps
|
|
self._next_ping: float = time()
|
|
self._max_pong: float = self._next_ping + PING_TIMEOUT.total_seconds()
|
|
# main task, responsible for receiving messages, sending them, and websocket ping
|
|
self._handle_task: Optional[asyncio.Task[Any]] = None
|
|
# topics stuff
|
|
self.topics: Dict[str, WebsocketTopic] = {}
|
|
self._submitted: Set[WebsocketTopic] = set()
|
|
# notify GUI
|
|
self.set_status("Disconnected")
|
|
|
|
@property
|
|
def connected(self) -> bool:
|
|
return self._connected_flag.is_set()
|
|
|
|
def wait_until_connected(self):
|
|
return self._connected_flag.wait()
|
|
|
|
def set_status(self, status: Optional[str] = None, refresh_topics: bool = False):
|
|
kwargs: Dict[str, Any] = {}
|
|
if status is not None:
|
|
kwargs["status"] = status
|
|
if refresh_topics:
|
|
kwargs["topics"] = len(self.topics)
|
|
self._twitch.gui.websockets.update(self._idx, **kwargs)
|
|
|
|
def request_reconnect(self):
|
|
ws_logger.warning(f"Websocket[{self._idx}] requested reconnect.")
|
|
# reset our ping interval, so we send a PING after reconnect right away
|
|
self._next_ping = time()
|
|
self._reconnect_requested.set()
|
|
|
|
async def close(self):
|
|
self.set_status("Disconnecting...")
|
|
if self._ws is not None:
|
|
await self._ws.close()
|
|
|
|
async def start(self):
|
|
if self.connected:
|
|
return
|
|
if self._handle_task is None:
|
|
self._handle_task = asyncio.create_task(self._handle())
|
|
await self.wait_until_connected()
|
|
|
|
def start_nowait(self):
|
|
if self.connected:
|
|
return
|
|
if self._handle_task is None:
|
|
self._handle_task = asyncio.create_task(self._handle())
|
|
|
|
async def stop(self):
|
|
await self.close()
|
|
if self._handle_task is not None:
|
|
# this raises back any stray exceptions
|
|
await self._handle_task
|
|
self._handle_task = None
|
|
|
|
def stop_nowait(self):
|
|
asyncio.create_task(self.close())
|
|
# note: this detaches the handle task, so we have to assume it closes properly
|
|
self._handle_task = None
|
|
|
|
def remove(self):
|
|
# this stops the websocket, and then removes it from the gui list
|
|
async def remover():
|
|
await self.stop()
|
|
self._twitch.gui.websockets.remove(self._idx)
|
|
asyncio.create_task(remover())
|
|
|
|
@task_wrapper
|
|
async def _handle(self):
|
|
# ensure we're logged in before connecting
|
|
await self._twitch.wait_until_login()
|
|
self.set_status("Connecting...")
|
|
ws_logger.info(f"Websocket[{self._idx}] connecting...")
|
|
# Connect/Reconnect loop
|
|
async for websocket in websocket_connect(WEBSOCKET_URL, ssl=True, ping_interval=None):
|
|
websocket.BACKOFF_MAX = 3 * 60 # type: ignore # 3 minutes
|
|
self._ws = websocket
|
|
self.set_status("Connected")
|
|
try:
|
|
try:
|
|
self._reconnect_requested.clear()
|
|
self._connected_flag.set()
|
|
while not self._reconnect_requested.is_set():
|
|
await self._handle_ping()
|
|
await self._handle_topics()
|
|
await self._handle_recv()
|
|
finally:
|
|
self._submitted.clear()
|
|
self._connected_flag.clear()
|
|
# A reconnect was requested
|
|
except ConnectionClosed as exc:
|
|
if isinstance(exc, ConnectionClosedOK):
|
|
if exc.rcvd_then_sent:
|
|
# server closed the connection, not us - reconnect
|
|
ws_logger.warning(f"Websocket[{self._idx}] got disconnected.")
|
|
else:
|
|
# we closed it - exit
|
|
self._ws = None
|
|
ws_logger.info(f"Websocket[{self._idx}] stopped.")
|
|
self.set_status("Disconnected")
|
|
return
|
|
else:
|
|
if exc.rcvd is not None:
|
|
code = exc.rcvd.code
|
|
elif exc.sent is not None:
|
|
code = exc.sent.code
|
|
else:
|
|
code = -1
|
|
ws_logger.warning(f"Websocket[{self._idx}] closed unexpectedly: {code}")
|
|
except Exception:
|
|
ws_logger.exception(f"Exception in Websocket[{self._idx}]")
|
|
self.set_status("Reconnecting...")
|
|
ws_logger.warning(f"Websocket[{self._idx}] reconnecting...")
|
|
|
|
async def _handle_ping(self):
|
|
now = time()
|
|
if now >= self._next_ping:
|
|
self._next_ping = now + PING_INTERVAL.total_seconds()
|
|
self._max_pong = now + PING_TIMEOUT.total_seconds() # wait for a PONG for up to 10s
|
|
await self.send({"type": "PING"})
|
|
elif now >= self._max_pong:
|
|
# it's been more than 10s and there was no PONG
|
|
self.request_reconnect()
|
|
|
|
async def _handle_topics(self):
|
|
if not self._topics_changed.is_set():
|
|
# nothing to do
|
|
return
|
|
current: Set[WebsocketTopic] = set(self.topics.values())
|
|
# handle removed topics
|
|
removed = self._submitted.difference(current)
|
|
if removed:
|
|
topics_list = list(map(str, removed))
|
|
ws_logger.debug(f"Websocket[{self._idx}]: Removing topics: {', '.join(topics_list)}")
|
|
await self.send(
|
|
{
|
|
"type": "UNLISTEN",
|
|
"data": {
|
|
"topics": topics_list,
|
|
"auth_token": self._twitch._access_token,
|
|
}
|
|
}
|
|
)
|
|
self._submitted.difference_update(removed)
|
|
# handle added topics
|
|
added = current.difference(self._submitted)
|
|
if added:
|
|
topics_list = list(map(str, added))
|
|
ws_logger.debug(f"Websocket[{self._idx}]: Adding topics: {', '.join(topics_list)}")
|
|
await self.send(
|
|
{
|
|
"type": "LISTEN",
|
|
"data": {
|
|
"topics": topics_list,
|
|
"auth_token": self._twitch._access_token,
|
|
}
|
|
}
|
|
)
|
|
self._submitted.update(added)
|
|
self._topics_changed.clear()
|
|
|
|
async def _gather_recv(self, messages: List[JsonType]):
|
|
"""
|
|
Gather incoming messages over the timeout specified.
|
|
Note that there's no return value - this modifies `messages` in-place.
|
|
"""
|
|
assert self._ws is not None
|
|
while True:
|
|
raw_message = await self._ws.recv()
|
|
message = json.loads(raw_message)
|
|
ws_logger.debug(f"Websocket[{self._idx}] received: {message}")
|
|
messages.append(message)
|
|
|
|
def _handle_message(self, message):
|
|
# request the assigned topic to process the response
|
|
topic = self.topics.get(message["data"]["topic"])
|
|
if topic is not None:
|
|
# use a task to not block the websocket
|
|
asyncio.create_task(topic(json.loads(message["data"]["message"])))
|
|
|
|
async def _handle_recv(self):
|
|
"""
|
|
Handle receiving messages from the websocket.
|
|
"""
|
|
# listen over 0.5s for incoming messages
|
|
messages: List[JsonType] = []
|
|
with suppress(asyncio.TimeoutError):
|
|
await asyncio.wait_for(self._gather_recv(messages), timeout=0.5)
|
|
# process them
|
|
for message in messages:
|
|
msg_type = message["type"]
|
|
if msg_type == "MESSAGE":
|
|
self._handle_message(message)
|
|
elif msg_type == "PONG":
|
|
# move the timestamp to something much later
|
|
self._max_pong = self._next_ping
|
|
elif msg_type == "RESPONSE":
|
|
# no special handling for these (for now)
|
|
pass
|
|
elif msg_type == "RECONNECT":
|
|
# We've received a reconnect request
|
|
self.request_reconnect()
|
|
else:
|
|
ws_logger.warning(f"Websocket[{self._idx}] received unknown payload: {message}")
|
|
|
|
def add_topics(self, topics_set: Set[WebsocketTopic]):
|
|
while topics_set and len(self.topics) < WS_TOPICS_LIMIT:
|
|
topic = topics_set.pop()
|
|
self.topics[str(topic)] = topic
|
|
self._topics_changed.set()
|
|
self.set_status(refresh_topics=True)
|
|
|
|
def remove_topics(self, topics_set: Set[str]):
|
|
existing = topics_set.intersection(self.topics.keys())
|
|
if not existing:
|
|
# nothing to remove from here
|
|
return
|
|
topics_set.difference_update(existing)
|
|
for topic in existing:
|
|
del self.topics[topic]
|
|
self._topics_changed.set()
|
|
self.set_status(refresh_topics=True)
|
|
|
|
async def send(self, message: JsonType):
|
|
if self._ws is None:
|
|
return
|
|
if message["type"] != "PING":
|
|
message["nonce"] = create_nonce()
|
|
await self._ws.send(json.dumps(message, separators=(',', ':')))
|
|
ws_logger.debug(f"Websocket[{self._idx}] sent: {message}")
|
|
|
|
|
|
class WebsocketPool:
|
|
def __init__(self, twitch: Twitch):
|
|
self._twitch: Twitch = twitch
|
|
self._running = asyncio.Event()
|
|
self.websockets: List[Websocket] = []
|
|
|
|
@property
|
|
def running(self) -> bool:
|
|
return self._running.is_set()
|
|
|
|
def wait_until_connected(self):
|
|
return self._running.wait()
|
|
|
|
async def start(self):
|
|
await self._twitch.wait_until_login()
|
|
if self.running:
|
|
return
|
|
self._running.set()
|
|
await asyncio.gather(*(ws.start() for ws in self.websockets))
|
|
|
|
async def stop(self):
|
|
if not self.running:
|
|
return
|
|
self._running.clear()
|
|
await asyncio.gather(*(ws.stop() for ws in self.websockets))
|
|
|
|
def add_topics(self, topics: Iterable[WebsocketTopic]):
|
|
# ensure no topics end up duplicated
|
|
topics_set = set(topics)
|
|
if not topics_set:
|
|
# nothing to add
|
|
return
|
|
topics_set.difference_update(*(ws.topics.values() for ws in self.websockets))
|
|
if not topics_set:
|
|
# none left to add
|
|
return
|
|
for ws_idx in range(MAX_WEBSOCKETS):
|
|
if ws_idx < len(self.websockets):
|
|
# just read it back
|
|
ws = self.websockets[ws_idx]
|
|
else:
|
|
# create new
|
|
ws = Websocket(self, ws_idx)
|
|
if self.running:
|
|
ws.start_nowait()
|
|
self.websockets.append(ws)
|
|
# ask websocket to take any topics it can - this modifies the set in-place
|
|
ws.add_topics(topics_set)
|
|
# see if there's any leftover topics for the next websocket connection
|
|
if not topics_set:
|
|
return
|
|
# if we're here, there were leftover topics after filling up all websockets
|
|
raise MinerException("Maximum topics limit has been reached")
|
|
|
|
def remove_topics(self, topics: Iterable[str]):
|
|
topics_set = set(topics)
|
|
if not topics_set:
|
|
# nothing to remove
|
|
return
|
|
for ws in self.websockets:
|
|
ws.remove_topics(topics_set)
|
|
# count up all the topics - if we happen to have more websockets connected than needed,
|
|
# stop the last one and recycle topics from it - repeat until we have enough
|
|
recycled_topics: List[WebsocketTopic] = []
|
|
while True:
|
|
count = sum(len(ws.topics) for ws in self.websockets)
|
|
if count <= (len(self.websockets) - 1) * WS_TOPICS_LIMIT:
|
|
ws = self.websockets.pop()
|
|
recycled_topics.extend(ws.topics.values())
|
|
ws.remove()
|
|
else:
|
|
break
|
|
if recycled_topics:
|
|
self.add_topics(recycled_topics)
|