diff --git a/gui.py b/gui.py index 9c0c3bd..71f171d 100644 --- a/gui.py +++ b/gui.py @@ -489,7 +489,6 @@ class ConsoleOutput: class Buttons(TypedDict): frame: ttk.Frame - cleanup: ttk.Button switch: ttk.Button load_points: ttk.Button @@ -506,11 +505,6 @@ class ChannelList: buttons_frame = ttk.Frame(frame) self._buttons: Buttons = { "frame": buttons_frame, - "cleanup": ttk.Button( - buttons_frame, - text="Cleanup", - command=manager._twitch.state_change(State.CHANNEL_CLEANUP), - ), "switch": ttk.Button( buttons_frame, text="Switch", @@ -522,9 +516,8 @@ class ChannelList: ), } buttons_frame.grid(column=0, row=0, columnspan=2) - self._buttons["cleanup"].grid(column=0, row=0) - self._buttons["switch"].grid(column=1, row=0) - self._buttons["load_points"].grid(column=2, row=0) + self._buttons["switch"].grid(column=0, row=0) + self._buttons["load_points"].grid(column=1, row=0) scroll = ttk.Scrollbar(frame, orient="vertical") self._table = table = ttk.Treeview( frame, @@ -601,9 +594,10 @@ class ChannelList: # causes the columns to shrink back after long values have been removed from it columns = self._table.cget("columns") iids = self._table.get_children() - for column in columns: - width = max(self._measure(self._table.set(i, column)) for i in iids) - self._table.column(column, minwidth=width, width=width) + if iids: # table needs to have at least one item + for column in columns: + width = max(self._measure(self._table.set(i, column)) for i in iids) + self._table.column(column, minwidth=width, width=width) self._redraw() def _set(self, iid: str, column: str, value: str): diff --git a/twitch.py b/twitch.py index 9568ae4..b4b7f11 100644 --- a/twitch.py +++ b/twitch.py @@ -8,7 +8,9 @@ from yarl import URL from time import time from itertools import chain from functools import partial -from typing import Any, Callable, Iterable, Optional, Union, List, Dict, cast, TYPE_CHECKING +from typing import ( + Callable, Iterable, Optional, Union, List, Dict, Generic, TypeVar, cast, TYPE_CHECKING +) try: import aiohttp @@ -43,6 +45,35 @@ logger = logging.getLogger("TwitchDrops") gql_logger = logging.getLogger("TwitchDrops.gql") +_V = TypeVar("_V") +_D = TypeVar("_D") + + +class _AwaitableValue(Generic[_V]): + def __init__(self): + self._value: _V + self._event = asyncio.Event() + + def has_value(self) -> bool: + return self._event.is_set() + + def get_with_default(self, default: _D) -> Union[_D, _V]: + if self._event.is_set(): + return self._value + return default + + async def get(self) -> _V: + await self._event.wait() + return self._value + + def set(self, value: _V) -> None: + self._value = value + self._event.set() + + def clear(self) -> None: + self._event.clear() + + class Twitch: def __init__(self, options: ParsedArgs): self._options = options @@ -66,9 +97,9 @@ class Twitch: self.inventory: List[DropsCampaign] = [] # inventory # Storing and watching channels self.channels: Dict[int, Channel] = {} - self._watching_channel: Optional[Channel] = None - self._watching_task: Optional[asyncio.Task[Any]] = None - self._last_watch = time() - 60 + self._watching_channel: _AwaitableValue[Channel] = _AwaitableValue() + self._watching_task: Optional[asyncio.Task[None]] = None + self._watching_restart = asyncio.Event() self._drop_update: Optional[asyncio.Future[bool]] = None # Websocket self.websocket = WebsocketPool(self) @@ -137,6 +168,9 @@ class Twitch: start_time = time() self.gui.print("Exiting...") self.stop_watching() + if self._watching_task is not None: + self._watching_task.cancel() + self._watching_task = None self._session.cookie_jar.save(COOKIES_PATH) # type: ignore await self._session.close() await self.websocket.stop() @@ -157,6 +191,7 @@ class Twitch: • Changing the stream that's being watched if necessary """ self.gui.start() + self._watching_task = asyncio.create_task(self._watch_loop()) await self.check_login() # Add default topics assert self._user_id is not None @@ -271,16 +306,13 @@ class Twitch: self.change_state(State.CHANNEL_CLEANUP) await self._state_change.wait() - async def _watch_loop(self, channel: Channel): - # last_watch is a timestamp of the last time we've sent a watch payload - # We need this because watch_loop can be cancelled and rescheduled multiple times - # in quick succession, and apparently Twitch doesn't like that very much + async def _watch_loop(self) -> None: interval = WATCH_INTERVAL.total_seconds() - await asyncio.sleep(self._last_watch + interval - time()) - i = 0 + i = 1 while True: + channel = await self._watching_channel.get() await channel.send_watch() - self._last_watch = time() + last_watch = time() self._drop_update = asyncio.Future() use_active = False try: @@ -317,38 +349,42 @@ class Twitch: drop.display() else: logger.error("Active drop search failed") - if i == 0: + if i % 30 == 1: # ensure every 30 minutes that we don't have unclaimed points bonus await channel.claim_bonus() - i = (i + 1) % 30 - await asyncio.sleep(self._last_watch + interval - time()) + if i % 60 == 0: + # refresh inventory every hour + self.change_state(State.INVENTORY_FETCH) + i = (i + 1) % 3600 + self._watching_restart.clear() + try: + await asyncio.wait_for( + self._watching_restart.wait(), timeout=last_watch + interval - time() + ) + except asyncio.TimeoutError: + pass def watch(self, channel: Channel): if self.is_watching(channel): # we're already watching the same channel, so there's no point switching return - if self._watching_task is not None: - self._watching_task.cancel() self.gui.channels.set_watching(channel) - self._watching_channel = channel - self._watching_task = asyncio.create_task(self._watch_loop(channel)) + self._watching_channel.set(channel) def stop_watching(self): self.gui.progress.stop_timer() self.gui.channels.clear_watching() - if self._watching_task is not None: - self._watching_task.cancel() - self._watching_task = None - self._watching_channel = None + self._watching_channel.clear() def restart_watching(self, channel: Optional[Channel] = None): # this forcibly re-sends the watching payload to the specified or currently watched channel if channel is None: - channel = self._watching_channel + channel = self._watching_channel.get_with_default(None) if channel is not None: - self.stop_watching() - self._last_watch = time() - 60 - self.watch(channel) + self.gui.progress.stop_timer() + self._watching_channel.set(channel) + self.gui.channels.set_watching(channel) + self._watching_restart.set() async def process_stream_state(self, channel_id: int, message: JsonType): msg_type = message["type"]