From 3d3eb3b00517bbbd58333c3a5608ac939d3d01b0 Mon Sep 17 00:00:00 2001 From: DevilXD Date: Sun, 23 Jan 2022 12:40:16 +0100 Subject: [PATCH] Defer client session init --- channel.py | 8 +++-- gui.py | 2 +- main.py | 2 +- twitch.py | 85 +++++++++++++++++++++++++++++------------------------- utils.py | 4 +-- 5 files changed, 55 insertions(+), 46 deletions(-) diff --git a/channel.py b/channel.py index 4390576..8a7eb08 100644 --- a/channel.py +++ b/channel.py @@ -180,6 +180,7 @@ class Channel: To get this monstrous thing, you have to walk a chain of requests. Streamer page (HTML) --parse-> Streamer Settings (JavaScript) --parse-> Spade URL """ + assert self._twitch._session is not None async with self._twitch._session.get(self.url) as response: streamer_html = await response.text(encoding="utf8") match = re.search( @@ -303,14 +304,15 @@ class Channel: """ if not self.online: return False + session = self._twitch._session + if session is None: + return False if self._spade_url is None: self._spade_url = await self.get_spade_url() logger.debug(f"Sending minute-watched to {self.name}") for attempt in range(5): try: - async with self._twitch._session.post( - self._spade_url, data=self._payload - ) as response: + async with session.post(self._spade_url, data=self._payload) as response: return response.status == 204 except (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError): continue diff --git a/gui.py b/gui.py index 4a05621..048cbcb 100644 --- a/gui.py +++ b/gui.py @@ -712,7 +712,7 @@ class TrayIcon: def start(self): if self.icon is None: - loop = self._manager._twitch._loop + loop = asyncio.get_running_loop() drop = self._manager.progress._drop # we need this because tray icon lives in a separate thread diff --git a/main.py b/main.py index 7611e61..361e6a6 100644 --- a/main.py +++ b/main.py @@ -94,7 +94,7 @@ logging.getLogger("TwitchDrops.websocket").setLevel(options.debug_ws) # client run loop = asyncio.get_event_loop() -client = Twitch(loop, options) +client = Twitch(options) signal.signal(signal.SIGINT, lambda *_: client.close()) signal.signal(signal.SIGTERM, lambda *_: client.close()) try: diff --git a/twitch.py b/twitch.py index c09b085..34f24e6 100644 --- a/twitch.py +++ b/twitch.py @@ -20,9 +20,9 @@ except ModuleNotFoundError as exc: from channel import Channel from websocket import WebsocketPool from gui import GUIManager, LoginData -from inventory import DropsCampaign, Game, TimedDrop +from inventory import DropsCampaign, TimedDrop from exceptions import LoginException, CaptchaRequired -from utils import AwaitableValue, OrderedSet, task_wrapper +from utils import Game, AwaitableValue, OrderedSet, task_wrapper from constants import ( State, JsonType, @@ -47,29 +47,20 @@ gql_logger = logging.getLogger("TwitchDrops.gql") class Twitch: - def __init__(self, loop: asyncio.AbstractEventLoop, options: ParsedArgs): - self._loop = loop + def __init__(self, options: ParsedArgs): self.options = options - # GUI - self.gui = GUIManager(self) - # Cookies, session and auth - cookie_jar = aiohttp.CookieJar() - if os.path.isfile(COOKIES_PATH): - cookie_jar.load(COOKIES_PATH) - self._session = aiohttp.ClientSession( - loop=loop, - cookie_jar=cookie_jar, - headers={"User-Agent": USER_AGENT}, - timeout=aiohttp.ClientTimeout(connect=5, total=10), - ) - self._access_token: Optional[str] = None - self._user_id: Optional[int] = None - self._is_logged_in = asyncio.Event() # State management self._state: State = State.INVENTORY_FETCH self._state_change = asyncio.Event() - self.inventory: Dict[Game, List[DropsCampaign]] = {} self.game: Optional[Game] = None + self.inventory: Dict[Game, List[DropsCampaign]] = {} + # GUI + self.gui = GUIManager(self) + # Cookies, session and auth + self._session: Optional[aiohttp.ClientSession] = None + self._access_token: Optional[str] = None + self._user_id: Optional[int] = None + self._is_logged_in = asyncio.Event() # Storing and watching channels self.channels: OrderedDict[int, Channel] = OrderedDict() self.watching_channel: AwaitableValue[Channel] = AwaitableValue() @@ -79,6 +70,33 @@ class Twitch: # Websocket self.websocket = WebsocketPool(self) + async def initialize(self) -> None: + cookie_jar = aiohttp.CookieJar() + if os.path.isfile(COOKIES_PATH): + cookie_jar.load(COOKIES_PATH) + self._session = aiohttp.ClientSession( + cookie_jar=cookie_jar, + headers={"User-Agent": USER_AGENT}, + timeout=aiohttp.ClientTimeout(connect=5, total=10), + ) + + async def shutdown(self) -> None: + start_time = time() + self.gui.print("Exiting...") + self.stop_watching() + if self._watching_task is not None: + self._watching_task.cancel() + self._watching_task = None + # close session and stop websocket + if self._session is not None: + self._session.cookie_jar.save(COOKIES_PATH) # type: ignore + await self._session.close() + self._session = None + await self.websocket.stop() + # wait at least one full second + whatever it takes to complete the closing + # this allows aiohttp to safely close the session + await asyncio.sleep(start_time + 1 - time()) + def wait_until_login(self): return self._is_logged_in.wait() @@ -120,25 +138,6 @@ class Twitch: """ self.gui.print(*args, **kwargs) - def stop(self): - self.stop_watching() - if self._watching_task is not None: - self._watching_task.cancel() - self._watching_task = None - - async def shutdown(self): - start_time = time() - self.gui.print("Exiting...") - self.stop() - # save our cookies - self._session.cookie_jar.save(COOKIES_PATH) # type: ignore - # stop websocket and close session - await self.websocket.stop() - await self._session.close() - # wait at least one full second + whatever it takes to complete the closing - # this allows aiohttp to safely close the session - await asyncio.sleep(start_time + 1 - time()) - def is_watching(self, channel: Channel) -> bool: watching_channel = self.watching_channel.get_with_default(None) return watching_channel is not None and watching_channel == channel @@ -152,7 +151,11 @@ class Twitch: • Selecting a stream to watch, and watching it • Changing the stream that's being watched if necessary """ + if self._session is None: + await self.initialize() self.gui.start() + if self._watching_task is not None: + self._watching_task.cancel() self._watching_task = asyncio.create_task(self._watch_loop()) await self.check_login() # Add default topics @@ -575,6 +578,7 @@ class Twitch: """ if not 8 <= len(password) <= 71: return False + assert self._session is not None payload = {"password": password} async with self._session.post( f"{AUTH_URL}/api/v1/password_strength", json=payload @@ -590,6 +594,7 @@ class Twitch: async def _login(self) -> str: logger.debug("Login flow started") + assert self._session is not None payload: JsonType = { "client_id": CLIENT_ID, @@ -670,6 +675,7 @@ class Twitch: if self._access_token is not None and self._user_id is not None: # we're all good return + assert self._session is not None # looks like we're missing something logger.debug("Checking login") self.gui.login.update("Logging in...", None) @@ -712,6 +718,7 @@ class Twitch: async def gql_request(self, op: GQLOperation) -> JsonType: await self.check_login() + assert self._session is not None headers = { "Authorization": f"OAuth {self._access_token}", "Client-Id": CLIENT_ID, diff --git a/utils.py b/utils.py index 785df4f..3eb9056 100644 --- a/utils.py +++ b/utils.py @@ -13,8 +13,8 @@ from typing import Union, List, MutableSet, Iterable, Iterator, Generic, TypeVar from constants import JsonType -_V = TypeVar("_V") -_D = TypeVar("_D") +_V = TypeVar("_V") # value +_D = TypeVar("_D") # default logger = logging.getLogger("TwitchDrops") NONCE_CHARS = string.ascii_letters + string.digits