Update proxy branch with the new websocket implementation

This commit is contained in:
DevilXD
2022-04-08 22:07:41 +02:00
7 changed files with 214 additions and 102 deletions

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
import re
import json
import asyncio
import logging
from base64 import b64encode
from functools import cached_property
from typing import Any, SupportsInt, TYPE_CHECKING
from utils import Game, invalidate_cache
from utils import invalidate_cache, json_minify, Game
from exceptions import MinerException, RequestException
from constants import BASE_URL, GQL_OPERATIONS, ONLINE_DELAY, DROPS_ENABLED_TAG, URLType
@@ -326,8 +325,7 @@ class Channel:
}
}
]
json_event = json.dumps(payload, separators=(",", ":"))
return {"data": (b64encode(json_event.encode("utf8"))).decode("utf8")}
return {"data": (b64encode(json_minify(payload).encode("utf8"))).decode("utf8")}
async def send_watch(self) -> bool:
"""

View File

@@ -33,6 +33,23 @@ class RequestException(MinerException):
super().__init__("Unknown error during request")
class WebsocketClosed(RequestException):
"""
Raised when the websocket connection has been closed.
Attributes:
-----------
received: bool
`True` if the closing was caused by our side receiving a close frame, `False` otherwise.
"""
def __init__(self, *args: object, received: bool = False):
if args:
super().__init__(*args)
else:
super().__init__("Websocket has been closed")
self.received: bool = received
class LoginException(RequestException):
"""
Raised when an exception occurs during login phase.

View File

@@ -17,10 +17,6 @@ try:
import aiohttp # noqa
except ModuleNotFoundError as exc:
raise ImportError("You have to run 'pip install aiohttp' first") from exc
try:
import websockets # noqa
except ModuleNotFoundError as exc:
raise ImportError("You have to run 'pip install websockets' first") from exc
try:
import pystray # noqa
except ModuleNotFoundError as exc:

View File

@@ -1,4 +1,3 @@
aiohttp>2.0,<4.0
Pillow
pystray
websockets>10.0

119
twitch.py
View File

@@ -72,7 +72,9 @@ class Twitch:
# Maintenance task
self._mnt_task: asyncio.Task[None] | None = None
def initialize(self) -> None:
async def get_session(self) -> aiohttp.ClientSession:
if self._session is not None:
return self._session
cookie_jar = aiohttp.CookieJar()
if COOKIES_PATH.exists():
cookie_jar.load(COOKIES_PATH)
@@ -81,6 +83,7 @@ class Twitch:
headers={"User-Agent": USER_AGENT},
timeout=aiohttp.ClientTimeout(connect=5, total=10),
)
return self._session
async def shutdown(self) -> None:
start_time = time()
@@ -135,7 +138,7 @@ class Twitch:
def prevent_close(self):
"""
Called when the application window has to be prevented from closing, even after the user
closes it with X. Usually used solely to display tracebacks drom the closing sequence.
closes it with X. Usually used solely to display tracebacks from the closing sequence.
"""
self.gui.prevent_close()
@@ -203,6 +206,7 @@ class Twitch:
self._state_change.clear()
elif self._state is State.INVENTORY_FETCH:
await self.fetch_inventory()
self.gui.set_games(set(campaign.game for campaign in self.inventory))
self.change_state(State.GAMES_UPDATE)
elif self._state is State.GAMES_UPDATE:
# Figure out which games to watch, and claim the drops we can
@@ -222,14 +226,15 @@ class Twitch:
game = campaign.game
if (
game not in self.games # isn't already there
and game.name not in exclude # isn't excluded
# isn't excluded by priority_only
and game.name not in exclude # and isn't excluded
# and isn't excluded by priority_only
and (not priority_only or game.name in priority)
and campaign.can_earn() # campaign can be progressed
and campaign.can_earn() # and can be progressed (active required)
):
# non-excluded games with no priority, are placed last, below priority ones
self.games[game] = priorities.get(game.name, 0)
self.gui.set_games(self.games.keys())
full_cleanup = True
self.restart_watching()
self.change_state(State.CHANNELS_CLEANUP)
elif self._state is State.CHANNELS_CLEANUP:
if not self.games or full_cleanup:
@@ -330,13 +335,14 @@ class Twitch:
])
# relink watching channel after cleanup,
# or stop watching it if it no longer qualifies
# NOTE: this replaces 'self.watching_channel's internal value with the new object
watching_channel = self.watching_channel.get_with_default(None)
if watching_channel is not None:
new_watching = channels.get(watching_channel.id)
if new_watching is not None and self.can_watch(new_watching):
self.watch(new_watching)
else:
# we're removing a channel we're watching
# we've removed a channel we were watching
self.stop_watching()
# pre-display the active drop with a substracted minute
for channel in channels.values():
@@ -349,36 +355,34 @@ class Twitch:
elif self._state is State.CHANNEL_SWITCH:
# Change into the selected channel, stay in the watching channel,
# or select a new channel that meets the required conditions
priority_channels: list[Channel] = []
selected_channel = self.gui.channels.get_selection()
if selected_channel is not None:
self.gui.channels.clear_selection()
priority_channels.append(selected_channel)
watching_channel = self.watching_channel.get_with_default(None)
if watching_channel is not None:
priority_channels.append(watching_channel)
# If there's no selected channel, change into a channel we can watch
new_watching = None
for channel in priority_channels:
if self.can_watch(channel):
new_watching = channel
break
if new_watching is None:
selected_channel = self.gui.channels.get_selection()
if selected_channel is not None and self.can_watch(selected_channel):
# selected channel is checked first, and set as long as we can watch it
new_watching = selected_channel
else:
# other channels additionally need to have a good reason
# for a switch (including the watching one)
# NOTE: we need to sort the channels every time because one channel
# can end up streaming any game, since channels aren't game-tied
# can end up streaming any game - channels aren't game-tied
for channel in sorted(channels.values(), key=self._game_key):
if self.can_watch(channel):
if self.can_watch(channel) and self.should_switch(channel):
new_watching = channel
break
if new_watching is not None:
self.watch(channel)
# break the state change chain by clearing the flag
self._state_change.clear()
else:
watching_channel = self.watching_channel.get_with_default(None)
if watching_channel is None and new_watching is None:
# not watching anything and there isn't anything to watch either
self.gui.print(
"No available channels to watch. Waiting for an ONLINE channel..."
)
self.change_state(State.IDLE)
else:
if new_watching is not None:
# if we have a better switch target - do so
# otherwise, continue watching what we had before
self.watch(new_watching)
# break the state change chain by clearing the flag
self._state_change.clear()
elif self._state is State.EXIT:
# we've been requested to exit the application
break
@@ -386,8 +390,7 @@ class Twitch:
# post-main-loop code goes here
async def _watch_sleep(self, delay: float) -> None:
# we use wait_for here to allow an asyncio.sleep that can be ended prematurely,
# without cancelling the containing task
# we use wait_for here to allow an asyncio.sleep-like that can be ended prematurely
self._watching_restart.clear()
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(self._watching_restart.wait(), timeout=delay)
@@ -432,6 +435,8 @@ class Twitch:
else:
drop.update_minutes(drop_data["currentMinutesWatched"])
drop.display()
else:
use_active = True
if use_active:
# Sometimes, even GQL fails to give us the correct drop.
# In that case, we can use the locally cached inventory to try
@@ -461,6 +466,9 @@ class Twitch:
await asyncio.sleep(60)
def can_watch(self, channel: Channel) -> bool:
"""
Determines if the given channel qualifies as a watching candidate.
"""
if not self.games:
return False
return (
@@ -472,6 +480,28 @@ class Twitch:
and any(campaign.can_earn(channel) for campaign in self.inventory)
)
def should_switch(self, channel: Channel) -> bool:
"""
Determines if the given channel qualifies as a switch candidate.
"""
watching_channel = self.watching_channel.get_with_default(None)
channel_order = self._game_key(channel)
if watching_channel is not None:
watching_order = self._game_key(watching_channel)
else:
# stub it with some high value, it doesn't really matter
# since 'is None' check returns earlier anyway
watching_order = 1
return (
watching_channel is None # there's no current watching channel
# or this channel's game is higher order than the watching one's
# NOTE: order is tied to the priority list position, so lower == higher
or channel_order < watching_order
or channel_order == watching_order # or the order is the same
# and this channel has priority over the watching channel
and channel.priority > watching_channel.priority
)
def watch(self, channel: Channel):
self.gui.channels.set_watching(channel)
self.watching_channel.set(channel)
@@ -516,23 +546,9 @@ class Twitch:
Called by a Channel when it goes online (after pending).
"""
logger.debug(f"{channel.name} goes ONLINE")
channel_order = self._game_key(channel)
watching_channel = self.watching_channel.get_with_default(None)
if watching_channel is not None:
watching_order = self._game_key(watching_channel)
else:
watching_order = 1
if (
(
self._state is State.IDLE # we're currently idle
or watching_channel is None # or aren't watching anything
# or this channel is higher order than the watching one
# NOTE: order is tied to the list position, so lower == higher
or channel_order < watching_order
or channel_order == watching_order # or the order is the same
and channel.priority # and this channel has priority
and not watching_channel.priority # and we're watching a non-priority channel
) and self.can_watch(channel)
self.can_watch(channel) # we can watch the channel that just got ONLINE
and self.should_switch(channel) # and we should!
):
self.gui.print(f"{channel.name} goes ONLINE, switching...")
self.watch(channel)
@@ -763,12 +779,10 @@ class Twitch:
logger.debug("Checking login")
login_form.update("Logging in...", None)
# NOTE: We need this here because of the jar being accessed
if self._session is None:
self.initialize()
assert self._session is not None
session = await self.get_session()
url = URL(BASE_URL)
assert url.host is not None
jar = cast(aiohttp.CookieJar, self._session.cookie_jar)
jar = cast(aiohttp.CookieJar, session.cookie_jar)
for attempt in range(2):
cookie = jar.filter_cookies(url)
if not cookie:
@@ -810,10 +824,7 @@ class Twitch:
async def request(
self, method: str, url: str, *, attempts: int = 5, **kwargs
) -> abc.AsyncIterator[aiohttp.ClientResponse]:
session = self._session
if session is None:
self.initialize()
assert session is not None
session = await self.get_session()
method = method.upper()
if self.settings.proxy and "proxy" not in kwargs:
kwargs["proxy"] = self.settings.proxy

View File

@@ -6,6 +6,7 @@ import random
import string
import asyncio
import logging
import traceback
from enum import Enum
from pathlib import Path
from functools import wraps
@@ -34,7 +35,21 @@ _D = TypeVar("_D") # default
_P = ParamSpec("_P") # params
_JSON_T = TypeVar("_JSON_T", bound=Mapping[Any, Any])
logger = logging.getLogger("TwitchDrops")
NONCE_CHARS = string.ascii_letters + string.digits
def format_traceback(exc: BaseException, **kwargs: Any) -> str:
"""
Like `traceback.print_exc` but returns a string. Uses the passed-in exception.
Any additional `**kwargs` are passed to the underlaying `traceback.format_exception`.
"""
return ''.join(traceback.format_exception(type(exc), exc, **kwargs))
def json_minify(data: JsonType | list[JsonType]) -> str:
"""
Returns minified JSON for payload usage.
"""
return json.dumps(data, separators=(',', ':'))
def resource_path(relative_path: Path | str) -> Path:
@@ -56,6 +71,9 @@ def timestamp(string: str) -> datetime:
return datetime.strptime(string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
NONCE_CHARS = string.ascii_letters + string.digits
def create_nonce(length: int = 30) -> str:
return ''.join(random.choices(NONCE_CHARS, k=length))
@@ -110,7 +128,7 @@ def _serialize(obj: Any) -> Any:
_MISSING = object()
serialize_env: dict[str, Callable[[Any], object]] = {
SERIALIZE_ENV: dict[str, Callable[[Any], object]] = {
"set": set,
"datetime": lambda d: datetime.fromtimestamp(d, timezone.utc),
"URL": yarl.URL,
@@ -130,8 +148,8 @@ def _remove_missing(obj: JsonType) -> JsonType:
def _deserialize(obj: JsonType) -> Any:
if "__type" in obj:
obj_type = obj["__type"]
if obj_type in serialize_env:
return serialize_env[obj_type](obj["data"])
if obj_type in SERIALIZE_ENV:
return SERIALIZE_ENV[obj_type](obj["data"])
else:
return _MISSING
return obj
@@ -150,6 +168,50 @@ def json_save(path: Path, contents: Mapping[Any, Any]) -> None:
json.dump(contents, file, default=_serialize, sort_keys=True, indent=4)
class ExponentialBackoff:
def __init__(
self,
*,
base: float = 2,
variance: float | tuple[float, float] = 0.1,
shift: float = -1,
maximum: float = 300,
):
if base <= 1:
raise ValueError("base has to be greater than 1")
self.exp: int = 0
self.base: float = float(base)
self.shift: float = float(shift)
self.maximum: float = float(maximum)
self.variance_min: float
self.variance_max: float
if isinstance(variance, tuple):
self.variance_min, self.variance_max = variance
else:
self.variance_min = 1 - variance
self.variance_max = 1 + variance
def __iter__(self) -> abc.Iterator[float]:
return self
def __next__(self) -> float:
value: float = (
pow(self.base, self.exp)
* random.uniform(self.variance_min, self.variance_max)
+ self.shift
)
if value > self.maximum:
return self.maximum
# NOTE: variance can cause the returned value to be lower than the previous one already,
# so this should be safe to move past the first return,
# to prevent the exponent from getting very big after reaching max and many iterations
self.exp += 1
return value
def reset(self) -> None:
self.exp = 0
class OrderedSet(MutableSet[_T]):
"""
Implementation of a set that preserves insertion order,

View File

@@ -7,32 +7,36 @@ from time import time
from contextlib import suppress
from typing import TYPE_CHECKING
from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
from websockets.client import WebSocketClientProtocol, connect as websocket_connect
import aiohttp
from exceptions import MinerException
from utils import task_wrapper, create_nonce, AwaitableValue
from exceptions import MinerException, WebsocketClosed
from constants import PING_INTERVAL, PING_TIMEOUT, MAX_WEBSOCKETS, WS_TOPICS_LIMIT
from utils import (
task_wrapper, create_nonce, json_minify, format_traceback, AwaitableValue, ExponentialBackoff
)
if TYPE_CHECKING:
from collections import abc
from twitch import Twitch
from gui import WebsocketStatus
from constants import JsonType, WebsocketTopic
WSMsgType = aiohttp.WSMsgType
logger = logging.getLogger("TwitchDrops")
ws_logger = logging.getLogger("TwitchDrops.websocket")
class Websocket:
def __init__(self, pool: WebsocketPool, index: int):
self._pool = pool
self._twitch = pool._twitch
self._pool: WebsocketPool = pool
self._twitch: Twitch = pool._twitch
self._ws_gui: WebsocketStatus = self._twitch.gui.websockets
# websocket index
self._idx: int = index
# current websocket connection
self._ws: AwaitableValue[WebSocketClientProtocol] = AwaitableValue()
self._ws: AwaitableValue[aiohttp.ClientWebSocketResponse] = AwaitableValue()
# set when the websocket needs to reconnect
self._reconnect_requested = asyncio.Event()
# set when the topics changed
@@ -61,7 +65,6 @@ class Websocket:
)
def request_reconnect(self):
ws_logger.warning(f"Websocket[{self._idx}] requested reconnect.")
# reset our ping interval, so we send a PING after reconnect right away
self._next_ping = time()
self._reconnect_requested.set()
@@ -104,20 +107,37 @@ class Websocket:
self._twitch.gui.websockets.remove(self._idx)
asyncio.create_task(remover())
async def _backoff_connect(
self, ws_url: str, **kwargs
) -> abc.AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
session = await self._twitch.get_session()
backoff = ExponentialBackoff(**kwargs)
for delay in backoff:
try:
async with session.ws_connect(ws_url, ssl=True) as websocket:
backoff.reset()
yield websocket
except aiohttp.ClientConnectionError:
ws_logger.info(
f"Websocket[{self._idx}] connection error (sleep: {delay:.3}s)", exc_info=True
)
await asyncio.sleep(delay)
@task_wrapper
async def _handle(self):
# ensure we're logged in before connecting
self.set_status("Initializing...")
await self._twitch.wait_until_login()
ws_logger.info(f"Websocket[{self._idx}] connecting...")
self.set_status("Connecting...")
# Connect/Reconnect loop
async for websocket in websocket_connect(
"wss://pubsub-edge.twitch.tv/v1", ssl=True, ping_interval=None
async for websocket in self._backoff_connect(
"wss://pubsub-edge.twitch.tv/v1", maximum=3*60 # 3 minutes maximum backoff time
):
# 3 minutes of max backoff
websocket.BACKOFF_MAX = 3 * 60 # type: ignore
self._ws.set(websocket)
self._reconnect_requested.clear()
# NOTE: _topics_changed doesn't start set,
# because there's no initial topics we can sub to right away
self.set_status("Connected")
try:
try:
@@ -128,26 +148,20 @@ class Websocket:
finally:
self._ws.clear()
self._submitted.clear()
self._topics_changed.set() # lets the next WS connection resub to the topics
# set _topics_changed to let the next WS connection resub to the topics
self._topics_changed.set()
# A reconnect was requested
except ConnectionClosed as exc:
if isinstance(exc, ConnectionClosedOK):
if exc.rcvd_then_sent:
# server closed the connection, not us - reconnect
ws_logger.warning(f"Websocket[{self._idx}] got disconnected.")
else:
# we closed it - exit
ws_logger.info(f"Websocket[{self._idx}] stopped.")
self.set_status("Disconnected")
return
except WebsocketClosed as exc:
if exc.received:
# server closed the connection, not us - reconnect
ws_logger.warning(
f"Websocket[{self._idx}] closed unexpectedly: {websocket.close_code}"
)
else:
if exc.rcvd is not None:
code = exc.rcvd.code
elif exc.sent is not None:
code = exc.sent.code
else:
code = -1
ws_logger.warning(f"Websocket[{self._idx}] closed unexpectedly: {code}")
# we closed it - exit
ws_logger.info(f"Websocket[{self._idx}] stopped.")
self.set_status("Disconnected")
return
except Exception:
ws_logger.exception(f"Exception in Websocket[{self._idx}]")
self.set_status("Reconnecting...")
@@ -161,6 +175,7 @@ class Websocket:
await self.send({"type": "PING"})
elif now >= self._max_pong:
# it's been more than 10s and there was no PONG
ws_logger.warning(f"Websocket[{self._idx}] didn't receive a PONG, reconnecting...")
self.request_reconnect()
async def _handle_topics(self):
@@ -208,10 +223,23 @@ class Websocket:
ws = self._ws.get_with_default(None)
assert ws is not None
while True:
raw_message = await ws.recv()
message = json.loads(raw_message)
ws_logger.debug(f"Websocket[{self._idx}] received: {message}")
messages.append(message)
raw_message: aiohttp.WSMessage = await ws.receive()
ws_logger.debug(f"Websocket[{self._idx}] received: {raw_message}")
if raw_message.type is WSMsgType.TEXT:
message: JsonType = json.loads(raw_message.data)
messages.append(message)
continue # shortcut to avoid checking all other elifs if not necessary
elif raw_message.type is WSMsgType.CLOSE:
raise WebsocketClosed(received=True)
elif raw_message.type is WSMsgType.CLOSED:
raise WebsocketClosed(received=False)
elif raw_message.type is WSMsgType.CLOSING:
pass # skip these
elif raw_message.type is WSMsgType.ERROR:
logger.error(f"Websocket[{self._idx}] error: {format_traceback(raw_message.data)}")
raise WebsocketClosed()
else:
logger.error(f"Websocket[{self._idx}] error: Unknown message: {raw_message}")
def _handle_message(self, message):
# request the assigned topic to process the response
@@ -241,6 +269,7 @@ class Websocket:
pass
elif msg_type == "RECONNECT":
# We've received a reconnect request
ws_logger.warning(f"Websocket[{self._idx}] requested reconnect.")
self.request_reconnect()
else:
ws_logger.warning(f"Websocket[{self._idx}] received unknown payload: {message}")
@@ -271,7 +300,7 @@ class Websocket:
assert ws is not None
if message["type"] != "PING":
message["nonce"] = create_nonce()
await ws.send(json.dumps(message, separators=(',', ':')))
await ws.send_json(message, dumps=json_minify)
ws_logger.debug(f"Websocket[{self._idx}] sent: {message}")