diff --git a/lib/messaging.py b/lib/messaging.py index 8945b6c..adc90cc 100644 --- a/lib/messaging.py +++ b/lib/messaging.py @@ -16,11 +16,6 @@ import collections from contextlib import closing -try: - import selectors -except ImportError: - import selectors2 as selectors - class Namespace: def __init__(self, **kwargs): @@ -56,6 +51,16 @@ def get_ip_address(): def get_ntp_time(ntp_host, ntp_port): + """ + Gets and returns time from specified host and port of NTP server. + + Args: + ntp_host (string): hostname or address of the NTP server. + ntp_port (int): port of the NTP server. + + Returns: + int: Current time recieved from the NTP server + """ NTP_DELTA = 2208988800 # 1970-01-01 00:00:00 NTP_QUERY = b'\x1b' + bytes(47) with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as ntp_socket: @@ -107,23 +112,35 @@ class BroadcastProtocol: def __init__(self, on_broadcast): self._on_broadcast = on_broadcast + loop = asyncio.get_event_loop() + self.closed: asyncio.Future = loop.create_future() + + def connection_made(self, transport): + self.transport = transport + logging.info("Broadcast connection established") + + def connection_lost(self, exc): + logging.info(f"Broadcast connection lost: {'closed' if exc is None else exc}") + if not self.closed.done(): + self.closed.set_result(True) + def datagram_received(self, data, addr): - message = MessageManager() - message.append_data(data) + logging.info(data) + + message = MessageManager(data) message.process_message() content = message.content is_ip_broadcast = (content is not None and message.jsonheader["action"] == "server_ip") if is_ip_broadcast: + logging.debug(f"Got broadcast message from {addr}: {content}") asyncio.get_event_loop().call_soon(self._on_broadcast, message) - # different_id = content["kwargs"]["id"] != str(self.id) - # self_younger = float(content["kwargs"]["start_time"]) <= self.time_started else: - logging.warning("Got wrong broadcast message from {}".format(addr)) + logging.warning(f"Got wrong broadcast message from {addr}") - # def error_received(self, exc): - # logging.error("Error received") + def error_received(self, exc): + logging.warning(f"Error on broadcast connection received: {exc}") class BroadcastSendProtocol: pass @@ -135,7 +152,7 @@ class MessageManager: Messages in protocol implemented by this class consists of 3 parts: * Fixed-length (2 bytes) protoheader - contains length of json header - * json header - contains information about message contents: length, encoding, byteorder, type of message and contents, etc. + * json header - contains information about message contents: length, encoding, type of message and contents, etc. * content - contains actual contents of message (json information, bytes, etc.) @@ -144,17 +161,22 @@ class MessageManager: jsonheader (dict): Headers dictionary with information about message encoding and purpose. Would be populated when receiving and processing of the json header will be completed. content (object): Would be populated when receiving and processing of the message will be completed. Defaults to None. """ - def __init__(self): + def __init__(self, data): """ ```python message = MessageManager() ``` """ - self._income_raw = b"" #todo bytearrray + self._income_raw = None + self._jsonheader_len = None self.jsonheader = None self.content = None + # self._processed = False + + self.set_buffer(data) + @staticmethod def _json_encode(obj, encoding="utf-8"): return json.dumps(obj, ensure_ascii=False).encode(encoding) @@ -181,11 +203,11 @@ class MessageManager: """ jsonheader = { - "byteorder": sys.byteorder, + "content-length": len(content_bytes), "content-type": content_type, "content-encoding": content_encoding, - "content-length": len(content_bytes), "message-type": message_type, + # "message-uuid": } if additional_headers: jsonheader.update(additional_headers) @@ -283,18 +305,18 @@ class MessageManager: def _process_jsonheader(self): header_len = self._jsonheader_len - if len(self._income_raw) >= header_len: - self.jsonheader = self._json_decode(self._income_raw[:header_len], "utf-8") - self._income_raw = self._income_raw[header_len:] - for reqhdr in ( - "byteorder", - "content-length", - "content-type", - "content-encoding", - "message-type", - ): - if reqhdr not in self.jsonheader: - raise ValueError('Missing required header {}'.format(reqhdr)) + if not len(self._income_raw) >= header_len: + return + self.jsonheader = self._json_decode(self._income_raw[:header_len], "utf-8") + self._income_raw = self._income_raw[header_len:] + for reqhdr in ( + "content-length", + "content-type", + "content-encoding", + "message-type", + ): + if reqhdr not in self.jsonheader: + raise ValueError('Missing required header {}'.format(reqhdr)) def _process_content(self): content_len = self.jsonheader["content-length"] @@ -308,10 +330,10 @@ class MessageManager: else: self.content = data - def append_data(self, data): - self._income_raw += data + def set_buffer(self, data): + self._income_raw = memoryview(data) - def get_leftovers(self): + def get_buffer(self): return self._income_raw def process_message(self): @@ -326,8 +348,11 @@ class MessageManager: self._process_jsonheader() if self.jsonheader: - if self.content is None: + if not self.processed: self._process_content() + @property + def processed(self): + return self.content is not None class CallbackManager: def __init__(self): @@ -352,32 +377,33 @@ class CallbackManager: def request_callback(self, key): return self._register_function(self.request_callbacks, key) -class PeerProtocol: - def __init__(self, parent, calabacks): - self._parent = parent - self._callbacks = calabacks +class PeerProtocol(asyncio.Protocol): + def __init__(self, parent, callbacks): + # self._parent = parent + self._callbacks = callbacks + self._recv_buffer = bytearray() + self._current_msg = None self._recv_msg_queue = asyncio.Queue def connection_made(self, transport): - peername = transport.get_extra_info('peername') - - self._resend_requests() - self.transport = transport + peername = transport.get_extra_info('peername') + print(peername) + # self._resend_requests() + def data_received(self, data): + self._recv_buffer += data + def _proceess_received(self): while self._recv_buffer: # add new message object if queue is empty or last message already processed - if self._recv_msg_queue.empty() or (self._received_queue[0].content is not None): - self._recv_msg_queue.put(MessageManager()) + if self._current_msg is None: + self._current_msg = MessageManager() - last_message = self._received_queue[0] - - last_message.income_raw += self._recv_buffer - self._recv_buffer = b'' + #self._recv_buffer. last_message.process_message() # if something left after processing message - put it back @@ -393,7 +419,7 @@ class PeerProtocol: # logger.info("Closing connection to {}".format(self.addr)) -class ConnectionManager(object): +class ConnectionManager: """ This class represents high-level protocol of TCP connection. diff --git a/server/modules/server_core.py b/server/modules/server_core.py index 16ce14a..03e76fd 100644 --- a/server/modules/server_core.py +++ b/server/modules/server_core.py @@ -43,22 +43,26 @@ class Server: self.callbacks = messaging.CallbackManager() self._clients = dict() - self.host = socket.gethostname() self.ip = messaging.get_ip_address() self.config = ConfigManager() self.config_path = config_path - self._broadcast_send_task = None - self._broadcast_listen_task = None + self._broadcast_send_task: asyncio.Task = None + self._broadcast_listen_task: asyncio.Task = None + + self._stopping: asyncio.Future = None + self._stopped: asyncio.Future = None def load_config(self): self.config.load_config_and_spec(self.config_path) - async def start(self): + async def run(self, wait_closed: bool=False): loop = asyncio.get_event_loop() - loop.set_debug(True) + # loop.set_debug(True) + self._stopping = loop.create_future() + self._stopped = loop.create_future() self.time_started = time.time() @@ -69,23 +73,33 @@ class Server: self._broadcast_send_task = loop.create_task(self._broadcast_send()) if self.config.broadcast_listen: - self._broadcast_send_task = loop.create_task(self._broadcast_listen()) + self._broadcast_listen_task = loop.create_task(self._broadcast_listen()) logging.info(f"Starting server with id: {self.id} on {self.host} ({self.ip})") - self._tcp_server = await loop.create_server(asyncio.Protocol(), + self._tcp_server = await loop.create_server(messaging.PeerProtocol, port=self.config.server_port, reuse_address=True, start_serving=False) for sock in self._tcp_server.sockets: # to handle multiple interfaces (IPv4\IPv6) - self._configure_sock(sock) + self._configure_server_sock(sock) sockname = sock.getsockname() logging.info(f"Running server socket on {sockname[0]} : {sockname[1]}") logging.info("Starting serving") await self._tcp_server.start_serving() + if wait_closed: + await self._stopped - async def stop(self): + def serve_forever(self): + asyncio.run(self.run(wait_closed=True)) + + async def stop(self, reason: str=''): + if self._stopping.done(): + logging.error("Server is already stopping") + return + + self._stopping.set_result(True) logging.info("Stopping server") self._tcp_server.close() to_await = [self._tcp_server.wait_closed()] @@ -100,12 +114,16 @@ class Server: self._broadcast_listen_task.cancel() to_await.append(self._broadcast_listen_task) - await asyncio.gather(*to_await, return_exceptions=True) - logging.info("Server stopped") + await asyncio.gather(*to_await, return_exceptions=True) # wait until everything shuts down - def terminate(self, reason="Terminated"): - self.stop() + if not self._stopped.done(): + self._stopped.set_result(True) + + logging.info(f"Server stopped: {reason}") + + async def terminate(self, reason:str ="Terminated"): logging.critical(reason) + await self.stop(reason) def time_now(self): if self.config.ntp_use: @@ -114,7 +132,7 @@ class Server: return time.time() @staticmethod - def _configure_sock(sock): + def _configure_server_sock(sock): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) messaging.set_keepalive(sock) @@ -133,47 +151,54 @@ class Server: client.connect(self.sel, conn, addr) async def _broadcast_send(self): - logging.info("Broadcast sender task started!") + logging.info("Broadcast sender task started") msg = messaging.MessageManager.create_action_message( "server_ip", kwargs={"host": self.ip, "port": self.config.server_port, "id": self.id, "start_time": self.time_started}) - await asyncio.sleep(0) logging.debug( f"Formed broadcast message to {self.config.broadcast_send_ip}:{self.config.broadcast_port}: {msg}") broadcast_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) broadcast_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) broadcast_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - while True: - try: - await asyncio.sleep(self.config.broadcast_delay) - broadcast_sock.sendto(msg, (self.config.broadcast_send_ip, self.config.broadcast_port)) - except OSError as e: - logging.error(f"Cannot send broadcast due error {e}") - except asyncio.CancelledError: - print("cans") - raise - else: - logging.debug("Broadcast sent") + try: + while True: + try: + await asyncio.sleep(self.config.broadcast_delay) + broadcast_sock.sendto(msg, (self.config.broadcast_send_ip, self.config.broadcast_port)) + except OSError as e: + logging.error(f"Cannot send broadcast due error {e}") + else: + logging.debug("Broadcast sent") + finally: + logging.info("Broadcast sender task stopped") async def _broadcast_listen(self): logging.info("Broadcast listener task started!") - broadcast_client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - broadcast_client.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + + def broadcast_callback(message: messaging.MessageManager): + content = message.content + different_id = content["kwargs"]["id"] != str(self.id) + self_younger = float(content["kwargs"]["start_time"]) <= self.time_started + if different_id and self_younger: + loop.run_until_complete( + self.terminate("Another server is running on this local network, shutting down!")) loop = asyncio.get_event_loop() - transport, protocol = await loop.create_datagram_endpoint(messaging.BroadcastProtocol, - local_addr=('', self.config.broadcast_port)) - # broadcast_client.settimeout(1) - # try: - # broadcast_client.bind(("", self.config.broadcast_port)) - # except OSError: - # self.terminate("Another server is running on this computer, shutting down!") - # return - # - # finally: - # broadcast_client.close() - # logging.info("Broadcast listener thread stopped, socked closed!") + try: + transport, protocol = await loop.create_datagram_endpoint(lambda: messaging.BroadcastProtocol(broadcast_callback), + local_addr=('', self.config.broadcast_port), + family=socket.AF_INET) + except OSError: + logging.info("Broadcast listener exited: port is busy") + loop.run_until_complete(self.terminate("Another server is likely running on this computer, shutting down!")) + return + + try: + await protocol.closed + finally: + transport.close() + logging.info("Broadcast listener task stopped") def send_starttime(self, copter, start_time): copter.send_message("start", kwargs={"time": str(start_time)}) @@ -196,6 +221,8 @@ def requires_any_connected(f): return wrapper +class RemoteClientProtocol(messaging.PeerProtocol): + class Client(messaging.ConnectionManager): clients = {} @@ -304,13 +331,16 @@ if __name__ == '__main__': print(loop) try: - loop.run_until_complete(server.start()) - loop.run_until_complete(asyncio.sleep(4)) - loop.run_until_complete(server.stop()) + loop.run_until_complete(server.run()) + #loop.run_until_complete(asyncio.sleep(4)) + #loop.run_until_complete(server.stop()) + loop.run_forever() print(3) finally: + loop.run_until_complete(server.stop("hey")) + print(1) print(4) - #loop.run_forever() + #