From 8ea73f8ec9f2a4e9f0ca059310c63699ed37fd5c Mon Sep 17 00:00:00 2001 From: Artem30801 Date: Tue, 3 Nov 2020 01:31:04 +0300 Subject: [PATCH] Messaging and server_core improvements. Refactored client_core to use asyncio. Autoconnection is fully working --- drone/modules/client_core.py | 317 ++++++++++++++++------------------ lib/messaging.py | 140 +++++---------- server/modules/server_core.py | 101 +++++------ 3 files changed, 235 insertions(+), 323 deletions(-) diff --git a/drone/modules/client_core.py b/drone/modules/client_core.py index 769c290..3df3c52 100644 --- a/drone/modules/client_core.py +++ b/drone/modules/client_core.py @@ -1,5 +1,5 @@ """ -Is a client-side module (meant to be run on Python 2.7) containing base Client class, utility functions and basic callbacks declarations. Main focus of the module is client-specific communication without reliance on `clover` Raspberry Pi environment. +Is a client-side module containing base Client class, utility functions and basic callbacks declarations. Main focus of the module is client-specific communication without reliance on `clover` Raspberry Pi environment. """ import os @@ -8,11 +8,8 @@ import time import errno import random import socket -import struct +import asyncio import logging -import selectors2 as selectors - -from contextlib import closing # Add parent dir to PATH to import messaging_lib and config_lib current_dir = os.path.dirname(os.path.realpath(__file__)) @@ -23,10 +20,15 @@ logger = logging.getLogger(__name__) import lib.messaging as messaging from lib.config import ConfigManager -active_client = None # needs to be refactored: Singleton \ factory callbacks +class ServerPeer(messaging.PeerProtocol): + def connection_lost(self, exc): + super().connection_lost(exc) + if exc is not None: + loop = asyncio.get_event_loop() + loop.call_soon(self._parent.reconnect) + self._parent._server_connection = None - -class Client(object): +class Client: """ Client base class provides config loading, communication with server (including automatic reconnection, broadcast listening and binding). You can inherit this class in order to extend functionality for practical applications. @@ -48,22 +50,25 @@ class Client(object): Args: config_path (string, optional): Path to the file with configuration. There also should be config specification file at `\config\configspec_client.ini`. Defaults to `\os.pardir\config\client.ini`. """ - self.selector = selectors.DefaultSelector() - self.client_socket = None self.callbacks = messaging.CallbackManager() - self.server_connection = messaging.ConnectionManager(self.callbacks) - self.connected = False + self._server_connection: messaging.PeerProtocol = None + self.client_id = None # Init configs self.config = ConfigManager() self.config_path = config_path + self._reconnect_task = None - global active_client - active_client = self + self._stopping: asyncio.Future = None + self._stopped: asyncio.Future = None + + @property + def connected(self): + return self._server_connection is not None def load_config(self): """ @@ -74,7 +79,7 @@ class Client(object): config_id = self.config.id.lower() if config_id == '/default': self.client_id = 'copter' + str(random.randrange(9999)).zfill(4) - self.config.set('', 'id', self.client_id, write=True) # set and write + self.config.set('', 'id', self.client_id, write=True) # set and write elif config_id == '/hostname': self.client_id = socket.gethostname() elif config_id == '/ip': @@ -84,27 +89,6 @@ class Client(object): logger.info("Config loaded") - @staticmethod - def get_ntp_time(ntp_host, ntp_port): - """Gets and returns time from specified host and port of NTP server. - - Args: - ntp_host (string): hostname or address of the NTP server. - ntp_port (int): port of the NTP server. - - Returns: - int: Current time recieved from the NTP server - """ - NTP_PACKET_FORMAT = "!12I" - NTP_DELTA = 2208988800 # 1970-01-01 00:00:00 - NTP_QUERY = '\x1b' + 47 * '\0' - - with closing(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) as s: - s.sendto(bytes(NTP_QUERY), (ntp_host, ntp_port)) - msg, address = s.recvfrom(1024) - unpacked = struct.unpack(NTP_PACKET_FORMAT, msg[0:struct.calcsize(NTP_PACKET_FORMAT)]) - return unpacked[10] + float(unpacked[11]) / 2 ** 32 - NTP_DELTA - def time_now(self): """gets and returns system time or NTP time depending on the config. @@ -112,141 +96,132 @@ class Client(object): int: Current time. """ if self.config.ntp_use: - timenow = self.get_ntp_time(self.config.ntp_host, self.config.ntp_port) + timenow = messaging.get_ntp_time(self.config.ntp_host, self.config.ntp_port) else: timenow = time.time() return timenow - def start(self): + def serve_forever(self): + asyncio.run(self.run(serve_forever=True)) + + async def run(self, serve_forever=False): """ Reloads config and starts infinite loop of connecting to the server and processing said connection. Calling of this method will indefinitely halt execution of any subsequent code. """ + loop = asyncio.get_event_loop() + + self._stopping = loop.create_future() + self._stopped = loop.create_future() + self.load_config() self.register_callbacks() - logger.info("Starting client") - messaging.NotifierSock().init(self.selector) + logger.info(f"Starting client with id: '{self.client_id}' on '{socket.gethostname()}'" + f" ({messaging.get_ip_address()})") - try: - while True: - self._reconnect() - self._process_connections() + self.reconnect() - except (KeyboardInterrupt, ): - logger.critical("Caught interrupt, exiting!") - self.selector.close() + if serve_forever: + await self._stopped - def _reconnect(self, timeout=2.0, attempt_limit=3): # TODO reconnecting broadcast listener in another thread - logger.info("Trying to connect to {}:{} ...".format(self.config.server_host, self.config.server_port)) - attempt_count = 0 - while not self.connected: - logger.info("Waiting for connection, attempt {}".format(attempt_count)) - try: - self.client_socket = socket.socket() - self.client_socket.settimeout(timeout) - messaging.set_keepalive(self.client_socket) - self.client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.client_socket.connect((self.config.server_host, self.config.server_port)) - except socket.error as error: - if isinstance(error, OSError): - if error.errno == errno.EINTR: - logger.critical("Shutting down on keyboard interrupt") - raise KeyboardInterrupt - - logger.warning("Can not connect due error: {}".format(error)) - attempt_count += 1 - time.sleep(timeout) - - else: - logger.info("Connection to server successful!") - self._connect() - break - - if attempt_count >= attempt_limit: - logger.info("Too many attempts. Trying to get new server IP") - self.broadcast_bind(timeout*2, attempt_limit) - attempt_count = 0 - - def _connect(self): - self.connected = True - self.client_socket.setblocking(False) - self.selector.register(self.client_socket, selectors.EVENT_READ, data=self.server_connection) - self.server_connection.connect(self.selector, self.client_socket, - (self.config.server_host, self.config.server_port)) - - def broadcast_bind(self, timeout=2.0, attempt_limit=3): - broadcast_client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - broadcast_client.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - broadcast_client.settimeout(timeout) - try: - broadcast_client.bind(("", self.config.broadcast_port)) - except socket.error as error: - logger.error("Error during broadcast listening binding: {}".format(error)) + async def stop(self, reason: str=''): + if self._stopping.done(): + logging.error("Client is already stopping") return - attempt_count = 0 + self._stopping.set_result(True) + logging.info("Stopping client") + self._server_connection.transport.close() + to_await = [self._server_connection.closed] + + if self._reconnect_task is not None: + logging.info("Cancelling reconnection") + self._reconnect_task.cancel() + to_await.append(self._reconnect_task) + + await asyncio.gather(*to_await, return_exceptions=True) # wait until everything shuts down + + if not self._stopped.done(): + self._stopped.set_result(True) + + logging.info(f"Client stopped: {reason}") + + def reconnect(self): + if self._reconnect_task is not None: + logger.warning("Reconnection task is already running") + + logger.info("Starting reconnection task") + loop = asyncio.get_event_loop() + self._reconnect_task = loop.create_task(self._reconnect()) # todo args + + async def _reconnect(self, attempt_limit=3, timeout=20): + logger.info(f"Reconnection task started") + try: - while attempt_count <= attempt_limit: - try: - data, addr = broadcast_client.recvfrom(self.config.server_buffer_size) - except socket.error as error: - logger.warning("Could not receive broadcast due error: {}".format(error)) - attempt_count += 1 + while not self.connected: + logger.info(f"Trying to connect to {self.config.server_host}:{self.config.server_port} ...") + for attempt_count in range(1, attempt_limit+1): + logger.info(f"Waiting for connection, attempt {attempt_count}/{attempt_limit}") + await self._connect() + if self.connected: + return else: - message = messaging.MessageManager() - message._income_raw = data - message.process_message() - if message.content and message.jsonheader["action"] == "server_ip": - logger.info("Received broadcast message {} from {}".format(message.content, addr)) - - kwargs = message.content["kwargs"] - self.config.set("SERVER", "port", int(kwargs["port"])) - self.config.set("SERVER", "host", kwargs["host"]) - self.config.write() - - logger.info("Binding to new IP: {}:{}".format( - self.config.server_host, self.config.server_port)) - self.on_broadcast_bind() - break + logger.info("Too many attempts. Trying to get new server IP") + await self._broadcast_listen(timeout) finally: - broadcast_client.close() + logging.info("Reconnection task stopped") + self._reconnect_task = None - def on_broadcast_bind(self): # TODO move ALL binding code here + async def _connect(self): + loop = asyncio.get_event_loop() + + try: + transport, protocol = await loop.create_connection(lambda: ServerPeer(self, self.callbacks), + host=self.config.server_host, + port=self.config.server_port, + ) + + except OSError as e: + logger.error(f"Cannot connect to server due error: {e}") + self._server_connection = None + else: + logger.info("Connection to server successful!") + + messaging.set_keepalive(transport.get_extra_info('socket')) + self._server_connection = protocol + + async def _broadcast_listen(self, listen_timeout=None): + logging.info("Broadcast listener started") + + loop = asyncio.get_event_loop() + try: + transport, protocol = await loop.create_datagram_endpoint( + lambda: messaging.BroadcastProtocol(self._on_broadcast_bind), + local_addr=('', self.config.broadcast_port), + family=socket.AF_INET) + except OSError as e: + logging.info(f"Broadcast listener exited: port is busy: {e}") + return + + try: + await asyncio.wait_for(protocol.closed, timeout=listen_timeout) + except asyncio.TimeoutError: + logging.warning("Broadcast listener timed out") + finally: + transport.close() + await protocol.closed + logging.info("Broadcast listener stopped") + + def _on_broadcast_bind(self, message: messaging.MessageManager): """ Method called on binding to the server by broadcast. Override that method in order to add functionality. """ - pass + kwargs = message.content["kwargs"] + self.config.set("SERVER", "port", int(kwargs["port"])) + self.config.set("SERVER", "host", kwargs["host"]) + self.config.write() - def _process_connections(self): - while True: - events = self.selector.select(timeout=1) - for key, mask in events: - connection = key.data - if connection is not None: - try: - connection.process_events(mask) - - except Exception as error: - logger.error( - "Exception {} occurred for {}! Resetting connection!".format(error, connection.addr) - ) - self.server_connection._close() - self.connected = False - - if isinstance(error, OSError): - if error.errno == errno.EINTR: - raise KeyboardInterrupt - try: - mapping_fds = self.selector.get_map().keys() # file descriptors - notifier_fd = messaging.NotifierSock().get_sock().fileno() - except (KeyError, RuntimeError) as e: - logger.error("Exception {} occurred when getting connections map!".format(e)) - logger.error("Connections changed during getting connections map, passing") - else: - notify_only= len(mapping_fds) == 1 and notifier_fd in mapping_fds - if notify_only or not mapping_fds: - logger.warning("No active connections left!") - return + logger.info(f"Got new server IP: {self.config.server_host}:{self.config.server_port}") def register_callbacks(self): @self.callbacks.action_callback("config") @@ -254,50 +229,48 @@ class Client(object): mode = kwargs.get("mode", "modify") # exceptions would be risen in case of incorrect config if mode == "rewrite": - active_client.config.load_from_dict(kwargs["config"], configspec=active_client.config_path) # with validation + self.config.load_from_dict(kwargs["config"], + configspec=self.config_path) # with validation elif mode == "modify": new_config = ConfigManager() new_config.load_from_dict(kwargs["config"]) - active_client.config.merge(new_config, validate=True) + self.config.merge(new_config, validate=True) - active_client.config.write() + self.config.write() logger.info("Config successfully updated from command") - active_client.load_config() + self.load_config() @self.callbacks.request_callback("config") def _response_config(*args, **kwargs): send_configspec = kwargs.get("send_configspec", False) - response = {"config": active_client.config.full_dict()} + response = {"config": self.config.full_dict()} if send_configspec: - response.update({"configspec": dict(active_client.config.config.configspec)}) + response.update({"configspec": dict(self.config.config.configspec)}) return response @self.callbacks.request_callback("clover_dir") def _response_clover_dir(*args, **kwargs): - return active_client.config.clover_dir + return self.config.clover_dir @self.callbacks.request_callback("id") def _response_id(*args, **kwargs): new_id = kwargs.get("new_id", None) if new_id is not None: - active_client.config.set("PRIVATE", "id", new_id, True) - active_client.load_config() + self.config.set("PRIVATE", "id", new_id, True) + self.load_config() # TODO renaming here - return active_client.client_id + return self.client_id @self.callbacks.request_callback("time") def _response_time(*args, **kwargs): - return active_client.time_now() + return self.time_now() if __name__ == "__main__": startup_cwd = os.getcwd() - import threading - - print(Client.get_ntp_time("ntp1.stratum2.ru", 123)) - + # print(Client.get_ntp_time("ntp1.stratum2.ru", 123)) def restart(): # move to core args = sys.argv[:] @@ -308,17 +281,23 @@ if __name__ == "__main__": os.chdir(startup_cwd) os.execv(sys.executable, args) + def mock_telem(): while True: time.sleep(5) - #t = dict([('fcu_status', None), ('current_position', [-2.89, 2.12, 3.64, 15.22, 'aruco_map']), ('animation_id', 'two_drones_test'), ('selfcheck', 'OK'), ('battery', None), ('git_version', '01bf95e'), ('calibration_status', None), ('start_position', [0.2, 0.2, 0.0]), ('mode', 'MANUAL'), ('time_delta', 1581338473.438682), ('armed', False), ('config_version', None), ('last_task', 'No task')]) - t = dict([('fcu_status', 'STANDBY'), ('current_position', [-1.17, 2.04, 3.45, 0, "11"]), ('animation_id', 'two_drones_test'), ('selfcheck', 'OK'), ('battery', [12.2, 1.0]), ('git_version', '42aee96'), ('calibration_status', None), ('start_position', [0.2, 0.2, 0.0]), ('mode', 'MANUAL'), ('time_delta', 1581342970.889573), ('armed', False), ('config_version', 'Copter config V0.0'), ('last_task', 'No task')]) + # t = dict([('fcu_status', None), ('current_position', [-2.89, 2.12, 3.64, 15.22, 'aruco_map']), ('animation_id', 'two_drones_test'), ('selfcheck', 'OK'), ('battery', None), ('git_version', '01bf95e'), ('calibration_status', None), ('start_position', [0.2, 0.2, 0.0]), ('mode', 'MANUAL'), ('time_delta', 1581338473.438682), ('armed', False), ('config_version', None), ('last_task', 'No task')]) + t = dict([('fcu_status', 'STANDBY'), ('current_position', [-1.17, 2.04, 3.45, 0, "11"]), + ('animation_id', 'two_drones_test'), ('selfcheck', 'OK'), ('battery', [12.2, 1.0]), + ('git_version', '42aee96'), ('calibration_status', None), ('start_position', [0.2, 0.2, 0.0]), + ('mode', 'MANUAL'), ('time_delta', 1581342970.889573), ('armed', False), + ('config_version', 'Copter config V0.0'), ('last_task', 'No task')]) if active_client.connected: active_client.server_connection.send_message("telemetry", kwargs={"value": t}) + logging.basicConfig(level=logging.DEBUG) client = Client() - tr = threading.Thread(target=mock_telem) - tr.start() - client.start() - + # tr = threading.Thread(target=mock_telem) + # tr.start() + client.serve_forever() + #asyncio.run(client.run()) diff --git a/lib/messaging.py b/lib/messaging.py index adc90cc..a493fae 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -14,9 +14,6 @@ import platform import traceback import collections -from contextlib import closing - - class Namespace: def __init__(self, **kwargs): self.__dict__.update(kwargs) @@ -33,6 +30,8 @@ class PendingRequest(Namespace): pass logger = logging.getLogger(__name__) +def str_peername(peername): + return f"{peername[0]}:{peername[1]}" def get_ip_address(): """ @@ -40,14 +39,17 @@ def get_ip_address(): Returns: string: IP address of current computer or `localhost` if no network connection present - """ + """ + ip_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # TODO IPv6 try: - with closing(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) as ip_socket: - ip_socket.connect(("8.8.8.8", 80)) - return ip_socket.getsockname()[0] - except OSError: - logger.warning("No network connection detected, using localhost") - return "localhost" + ip_socket.connect(("8.8.8.8", 80)) + ip = ip_socket.getsockname()[0] + except OSError as e: + logging.warning(f"No network connection detected, using localhost: {e}") + ip = "localhost" + finally: + ip_socket.close() + return ip def get_ntp_time(ntp_host, ntp_port): @@ -125,8 +127,6 @@ class BroadcastProtocol: self.closed.set_result(True) def datagram_received(self, data, addr): - logging.info(data) - message = MessageManager(data) message.process_message() content = message.content @@ -136,6 +136,7 @@ class BroadcastProtocol: if is_ip_broadcast: logging.debug(f"Got broadcast message from {addr}: {content}") asyncio.get_event_loop().call_soon(self._on_broadcast, message) + self.transport.close() else: logging.warning(f"Got wrong broadcast message from {addr}") @@ -379,44 +380,59 @@ class CallbackManager: class PeerProtocol(asyncio.Protocol): def __init__(self, parent, callbacks): - # self._parent = parent + self._parent = parent self._callbacks = callbacks self._recv_buffer = bytearray() self._current_msg = None - self._recv_msg_queue = asyncio.Queue + self._msg_queue = asyncio.Queue() + + loop = asyncio.get_event_loop() + self.closed: asyncio.Future = loop.create_future() + + @property + def peername(self): + return self.transport.get_extra_info('peername') + + @property + def connected(self): + return not self.closed.done() def connection_made(self, transport): self.transport = transport + logging.info(f"Connected to {str_peername(self.peername)}") - peername = transport.get_extra_info('peername') - print(peername) # self._resend_requests() - def data_received(self, data): self._recv_buffer += data + logger.debug("Received {} bytes from {}".format(len(data), self.peername)) - def _proceess_received(self): + async def _proceess_received(self): while self._recv_buffer: - # add new message object if queue is empty or last message already processed if self._current_msg is None: - self._current_msg = MessageManager() + self._current_msg = MessageManager(self._recv_buffer) + else: + self._current_msg.set_buffer(self._recv_buffer) - #self._recv_buffer. - last_message.process_message() + self._current_msg.process_message() - # if something left after processing message - put it back - if last_message.content is not None and last_message.income_raw: - self._recv_buffer = last_message.income_raw + self._recv_buffer - last_message.income_raw = b'' + if self._current_msg.processed: + #self._recv_buffer = + await self._msg_queue.put(self._current_msg) + self._current_msg = None - if self._received_queue and last_message.content is not None: - self.process_received(self._received_queue.popleft()) + # if last_message.content is not None and last_message.income_raw: + # self._recv_buffer = last_message.income_raw + self._recv_buffer + # last_message.income_raw = b'' + # + # if self._received_queue and last_message.content is not None: + # self.process_received(self._received_queue.popleft()) def connection_lost(self, exc): - pass - # logger.info("Closing connection to {}".format(self.addr)) + logger.info(f"Lost connection to {str_peername(self.peername)}: {'closed' if exc is None else exc}") + if not self.closed.done(): + self.closed.set_result(True) class ConnectionManager: @@ -446,15 +462,6 @@ class ConnectionManager: """ self.callbacks = callbacks - self.selector = None - self.socket = None - self.addr = None - - self._should_close = False - - self._recv_buffer = b"" - self._send_buffer = b"" - self.whoami = whoami @@ -466,26 +473,6 @@ class ConnectionManager: self._send_queue.clear() - - def read(self): - self._read() - - - def _read(self): - try: - data = self.socket.recv(self.buffer_size) - except io.BlockingIOError: - # Resource temporarily unavailable (errno EWOULDBLOCK) - pass - else: - if data: - self._recv_buffer += data - logger.debug("Received {} bytes from {}".format(len(data), self.addr)) - else: - logger.warning("Connection to {} lost!".format(self.addr)) - - raise RuntimeError("Peer closed.") - def process_received(self, message): message_type = message.jsonheader["message-type"] content = message.content if message.jsonheader["content-type"] != "binary"\ @@ -587,40 +574,8 @@ class ConnectionManager: logger.info("Return rights to pi:pi after file transfer") os.system("chown pi:pi {}".format(filepath)) - def write(self): - with self._send_lock: - if (not self._send_buffer) and self._send_queue: - message = self._send_queue.popleft() - self._send_buffer += message - if self._send_buffer: - self._write() - else: self._set_selector_events_mask('r') # we're done writing - def _write(self): - try: - sent = self.socket.send(self._send_buffer[:self.buffer_size]) - except io.BlockingIOError: - # Resource temporarily unavailable (errno EWOULDBLOCK) - pass - except Exception as error: - logger.warning( - "Attempt to send message {} to {} failed due error: {}".format(self._send_buffer, self.addr, error)) - - raise error - else: - self._send_buffer = self._send_buffer[sent:] - left = len(self._send_buffer) - logger.debug("Sent message to {}: sent {} bytes, {} bytes left.".format(self.addr, sent, left)) - - def _send(self, data): - with self._send_lock: - self._send_queue.append(data) - - if self.selector.get_key(self.socket).events != selectors.EVENT_WRITE: - self._set_selector_events_mask('rw') - NotifierSock().notify() - def get_response(self, requested_value, callback, # timeout=30, request_args=(), request_kwargs=None, callback_args=(), callback_kwargs=None, ): @@ -691,7 +646,6 @@ class ConnectionManager: callback_args=callback_args, callback_kwargs=callback_kwargs) def _resend_requests(self): - with self._request_lock: for request_id, request in self._request_queue.items(): # TODO filter if request.resend: self._send(MessageManager.create_request( @@ -729,4 +683,4 @@ class ConnectionManager: else: logger.info("Sending file {} to {} (as: {})".format(filepath, self.addr, dest_filepath)) self._send(MessageManager.create_message(data, "binary", "message", - additional_headers={"action": "filetransfer", "filepath": dest_filepath})) \ No newline at end of file + additional_headers={"action": "filetransfer", "filepath": dest_filepath})) diff --git a/server/modules/server_core.py b/server/modules/server_core.py index 03e76fd..fdcb811 100644 --- a/server/modules/server_core.py +++ b/server/modules/server_core.py @@ -36,15 +36,15 @@ logger = logging.getLogger(__name__) class Server: def __init__(self, config_path=os.path.join(current_dir, os.pardir, "config", "server.ini"), server_id=None): - self.id = server_id if server_id else str(random.randint(0, 9999)).zfill(4) - self.time_started = 0 + self.id = server_id if server_id is not None else str(random.randint(0, 9999)).zfill(4) + self.time_started = None self._tcp_server = None self.callbacks = messaging.CallbackManager() self._clients = dict() self.host = socket.gethostname() - self.ip = messaging.get_ip_address() + self.ip = messaging.get_ip_address() # TODO get all adresses self.config = ConfigManager() self.config_path = config_path @@ -59,13 +59,17 @@ class Server: self.config.load_config_and_spec(self.config_path) async def run(self, wait_closed: bool=False): + if self.time_started is not None: + logger.warning("Server is already running, restarting") + await self.stop("Restart") + + self.time_started = time.time() + loop = asyncio.get_event_loop() # loop.set_debug(True) self._stopping = loop.create_future() self._stopped = loop.create_future() - self.time_started = time.time() - # load config on startup self.load_config() # TODO async @@ -75,19 +79,19 @@ class Server: if self.config.broadcast_listen: self._broadcast_listen_task = loop.create_task(self._broadcast_listen()) - logging.info(f"Starting server with id: {self.id} on {self.host} ({self.ip})") + logging.info(f"Starting server with id: '{self.id}' on '{self.host}' ({self.ip})") - self._tcp_server = await loop.create_server(messaging.PeerProtocol, + self._tcp_server = await loop.create_server(lambda: messaging.PeerProtocol(self, self.callbacks), port=self.config.server_port, reuse_address=True, start_serving=False) for sock in self._tcp_server.sockets: # to handle multiple interfaces (IPv4\IPv6) self._configure_server_sock(sock) - sockname = sock.getsockname() - logging.info(f"Running server socket on {sockname[0]} : {sockname[1]}") + logging.info(f"Running server socket on {messaging.str_peername(sock.getsockname())}") logging.info("Starting serving") await self._tcp_server.start_serving() + if wait_closed: await self._stopped @@ -99,6 +103,7 @@ class Server: logging.error("Server is already stopping") return + self.time_started = None self._stopping.set_result(True) logging.info("Stopping server") self._tcp_server.close() @@ -136,21 +141,7 @@ class Server: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) messaging.set_keepalive(sock) - def _connect_client(self, sock): - - - if not any(client_addr == addr[0] for client_addr in Client.clients.keys()): - client = Client(self.callbacks, addr[0]) - client.buffer_size = self.config.server_buffer_size - logging.info("New client") - else: - client = Client.clients[addr[0]] - client.close(True) # to ensure in unregistering - logging.info("Reconnected client") - self.sel.register(conn, selectors.EVENT_READ, data=client) - client.connect(self.sel, conn, addr) - - async def _broadcast_send(self): + async def _broadcast_send(self): # TODO REDO logging.info("Broadcast sender task started") msg = messaging.MessageManager.create_action_message( "server_ip", kwargs={"host": self.ip, "port": self.config.server_port, @@ -159,7 +150,6 @@ class Server: f"Formed broadcast message to {self.config.broadcast_send_ip}:{self.config.broadcast_port}: {msg}") broadcast_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) - broadcast_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) broadcast_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) try: while True: @@ -177,9 +167,9 @@ class Server: logging.info("Broadcast listener task started!") def broadcast_callback(message: messaging.MessageManager): - content = message.content - different_id = content["kwargs"]["id"] != str(self.id) - self_younger = float(content["kwargs"]["start_time"]) <= self.time_started + kwargs = message.content["kwargs"] + different_id = kwargs["id"] != self.id + self_younger = float(kwargs["start_time"]) <= self.time_started if different_id and self_younger: loop.run_until_complete( self.terminate("Another server is running on this local network, shutting down!")) @@ -189,8 +179,8 @@ class Server: transport, protocol = await loop.create_datagram_endpoint(lambda: messaging.BroadcastProtocol(broadcast_callback), local_addr=('', self.config.broadcast_port), family=socket.AF_INET) - except OSError: - logging.info("Broadcast listener exited: port is busy") + except OSError as e: + logging.error(f"Broadcast listener exited: port is busy : {e}") loop.run_until_complete(self.terminate("Another server is likely running on this computer, shutting down!")) return @@ -198,6 +188,7 @@ class Server: await protocol.closed finally: transport.close() + await protocol.closed logging.info("Broadcast listener task stopped") def send_starttime(self, copter, start_time): @@ -221,12 +212,23 @@ def requires_any_connected(f): return wrapper -class RemoteClientProtocol(messaging.PeerProtocol): +class ClientPeer(messaging.PeerProtocol): + def connection_made(self, transport): + super().connection_made(transport) + self._parent.clients[self.peername[0]] = self + # if not any(client_addr == addr[0] for client_addr in Client.clients.keys()): + # client = Client(self.callbacks, addr[0]) + # client.buffer_size = self.config.server_buffer_size + # logging.info("New client") + # else: + # client = Client.clients[addr[0]] + # client.close(True) # to ensure in unregistering + # logging.info("Reconnected client") + # self.sel.register(conn, selectors.EVENT_READ, data=client) + # client.connect(self.sel, conn, addr) class Client(messaging.ConnectionManager): - clients = {} - on_connect = None # Use as callback functions on_first_connect = None on_disconnect = None @@ -237,7 +239,6 @@ class Client(messaging.ConnectionManager): self.clover_dir = None self.connected = False - self.clients[ip] = self @staticmethod def get_by_id(copter_id): @@ -280,13 +281,6 @@ class Client(messaging.ConnectionManager): if self.on_disconnect: self.on_disconnect(self) - if inner: - super()._close() - else: - super().close() - - logging.info("Connection to {} closed!".format(self.copter_id)) - def remove(self): if self.connected: self.close() @@ -298,9 +292,6 @@ class Client(messaging.ConnectionManager): logging.info("Client {} successfully removed!".format(self.copter_id)) - @requires_connect - def _send(self, data): - super()._send(data) logging.debug("Queued data to send (first 256 bytes): {}".format(data[:256])) @staticmethod @@ -319,28 +310,16 @@ class Client(messaging.ConnectionManager): if __name__ == '__main__': logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s [%(name)-7.7s] [%(threadName)-19.19s] [%(levelname)-7.7s] %(message)s", + format="%(asctime)s [%(name)-7.7s] [%(module)-19.19s] [%(levelname)-7.7s] %(message)s", handlers=[ logging.FileHandler(os.path.join(log_path, "{}.log".format(now))), logging.StreamHandler() ]) + # print(messaging.get_ip_address()) + server = Server() - loop = asyncio.get_event_loop() - loop.set_debug(True) - print(loop) - - try: - loop.run_until_complete(server.run()) - #loop.run_until_complete(asyncio.sleep(4)) - #loop.run_until_complete(server.stop()) - loop.run_forever() - print(3) - finally: - loop.run_until_complete(server.stop("hey")) - - print(1) - print(4) - # + server.serve_forever() +