started asyncio refactor

This commit is contained in:
Artem30801
2020-10-26 13:13:11 +03:00
parent ca8e960318
commit d4d3a4f45f
3 changed files with 187 additions and 359 deletions

View File

@@ -194,7 +194,7 @@ class Client(object):
attempt_count += 1
else:
message = messaging.MessageManager()
message.income_raw = data
message._income_raw = data
message.process_message()
if message.content and message.jsonheader["action"] == "server_ip":
logger.info("Received broadcast message {} from {}".format(message.content, addr))

View File

@@ -8,11 +8,11 @@ import json
import socket
import struct
import random
import asyncio
import logging
import threading
import collections
import platform
import traceback
import collections
from contextlib import closing
@@ -55,6 +55,15 @@ def get_ip_address():
return "localhost"
def get_ntp_time(ntp_host, ntp_port):
NTP_DELTA = 2208988800 # 1970-01-01 00:00:00
NTP_QUERY = b'\x1b' + bytes(47)
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as ntp_socket:
ntp_socket.sendto(NTP_QUERY, (ntp_host, ntp_port))
msg, _ = ntp_socket.recvfrom(1024)
return int.from_bytes(msg[-8:], 'big') / 2 ** 32 - NTP_DELTA
def set_keepalive(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
"""
Sets `keepalive` parameters of given socket.
@@ -94,25 +103,31 @@ def _set_keepalive_osx(sock, interval_sec):
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec)
class _Singleton(type):
"""
A metaclass that creates a Singleton base class when called.
"""
_instances = {}
class BroadcastProtocol:
def __init__(self, on_broadcast):
self._on_broadcast = on_broadcast
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(_Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
def datagram_received(self, data, addr):
message = MessageManager()
message.append_data(data)
message.process_message()
content = message.content
is_ip_broadcast = (content is not None and message.jsonheader["action"] == "server_ip")
class Singleton(_Singleton('SingletonMeta', (object,), {})):
"""
Singleton base class.
"""
if is_ip_broadcast:
asyncio.get_event_loop().call_soon(self._on_broadcast, message)
# different_id = content["kwargs"]["id"] != str(self.id)
# self_younger = float(content["kwargs"]["start_time"]) <= self.time_started
else:
logging.warning("Got wrong broadcast message from {}".format(addr))
# def error_received(self, exc):
# logging.error("Error received")
class BroadcastSendProtocol:
pass
class MessageManager:
"""
MessageManager class represents single incoming by TCP stream message and contains methods to decode and extract data from incoming data. It also contains static class methods for encoding various types of messages.
@@ -125,7 +140,7 @@ class MessageManager:
Attributes:
income_raw (bytes string): Holds incoming data bytes. Append incoming data to this attribute. It may not be empty after processing.
_income_raw (bytes string): Holds incoming data bytes. Append incoming data to this attribute. It may not be empty after processing.
jsonheader (dict): Headers dictionary with information about message encoding and purpose. Would be populated when receiving and processing of the json header will be completed.
content (object): Would be populated when receiving and processing of the message will be completed. Defaults to None.
"""
@@ -135,7 +150,7 @@ class MessageManager:
message = MessageManager()
```
"""
self.income_raw = b""
self._income_raw = b"" #todo bytearrray
self._jsonheader_len = None
self.jsonheader = None
self.content = None
@@ -262,15 +277,15 @@ class MessageManager:
def _process_protoheader(self):
header_len = 2
if len(self.income_raw) >= header_len:
self._jsonheader_len = struct.unpack(">H", self.income_raw[:header_len])[0]
self.income_raw = self.income_raw[header_len:]
if len(self._income_raw) >= header_len:
self._jsonheader_len = struct.unpack(">H", self._income_raw[:header_len])[0]
self._income_raw = self._income_raw[header_len:]
def _process_jsonheader(self):
header_len = self._jsonheader_len
if len(self.income_raw) >= header_len:
self.jsonheader = self._json_decode(self.income_raw[:header_len], "utf-8")
self.income_raw = self.income_raw[header_len:]
if len(self._income_raw) >= header_len:
self.jsonheader = self._json_decode(self._income_raw[:header_len], "utf-8")
self._income_raw = self._income_raw[header_len:]
for reqhdr in (
"byteorder",
"content-length",
@@ -283,16 +298,22 @@ class MessageManager:
def _process_content(self):
content_len = self.jsonheader["content-length"]
if not len(self.income_raw) >= content_len:
if not len(self._income_raw) >= content_len:
return
data = self.income_raw[:content_len]
self.income_raw = self.income_raw[content_len:]
data = self._income_raw[:content_len]
self._income_raw = self._income_raw[content_len:]
if self.jsonheader["content-type"] == "json":
encoding = self.jsonheader["content-encoding"]
self.content = self._json_decode(data, encoding)
else:
self.content = data
def append_data(self, data):
self._income_raw += data
def get_leftovers(self):
return self._income_raw
def process_message(self):
"""
Attempts processing the message. Chunks of `income_raw` would be consumed as different parts of the message will be processed. The result of processing (body of the message) will be available at `content` and `jsonheader`.
@@ -331,6 +352,46 @@ class CallbackManager:
def request_callback(self, key):
return self._register_function(self.request_callbacks, key)
class PeerProtocol:
def __init__(self, parent, calabacks):
self._parent = parent
self._callbacks = calabacks
self._recv_msg_queue = asyncio.Queue
def connection_made(self, transport):
peername = transport.get_extra_info('peername')
self._resend_requests()
self.transport = transport
def data_received(self, data):
while self._recv_buffer:
# add new message object if queue is empty or last message already processed
if self._recv_msg_queue.empty() or (self._received_queue[0].content is not None):
self._recv_msg_queue.put(MessageManager())
last_message = self._received_queue[0]
last_message.income_raw += self._recv_buffer
self._recv_buffer = b''
last_message.process_message()
# if something left after processing message - put it back
if last_message.content is not None and last_message.income_raw:
self._recv_buffer = last_message.income_raw + self._recv_buffer
last_message.income_raw = b''
if self._received_queue and last_message.content is not None:
self.process_received(self._received_queue.popleft())
def connection_lost(self, exc):
pass
# logger.info("Closing connection to {}".format(self.addr))
class ConnectionManager(object):
"""
@@ -370,50 +431,6 @@ class ConnectionManager(object):
self.whoami = whoami
self._send_queue = collections.deque()
self._received_queue = collections.deque()
self._request_queue = collections.OrderedDict()
self._send_lock = threading.Lock()
self._request_lock = threading.Lock()
self._close_lock = threading.Lock()
self.buffer_size = 1024
self.resume_queue = False
self.resend_requests = True
def _set_selector_events_mask(self, mode):
"""Set selector to listen for events: mode is 'r', 'w', 'rw'."""
if mode == "r":
events = selectors.EVENT_READ
elif mode == "w":
events = selectors.EVENT_WRITE
elif mode == "rw":
events = selectors.EVENT_READ | selectors.EVENT_WRITE
else:
raise ValueError("Invalid events mask mode {}.".format(mode))
key = self.selector.modify(self.socket, events, data=self)
logger.debug("Switched selector of {} to mode {}".format(self.addr, key.events))
return key
def connect(self, client_selector, client_socket, client_addr):
"""[summary]
Args:
client_selector (selector): Related selector object.
client_socket (socket): Socket object of the connection.
client_addr (str): Address of the peer.
"""
self.selector = client_selector
self.socket = client_socket
self.addr = client_addr
self._clear()
self._set_selector_events_mask('r')
if self.resend_requests:
self._resend_requests()
def _clear(self):
if not self.resume_queue: # maybe needs locks
@@ -422,82 +439,11 @@ class ConnectionManager(object):
self._received_queue.clear()
self._send_queue.clear()
def close(self):
"""
Closes connection with the peer.
"""
with self._close_lock:
self._should_close = True
self._set_selector_events_mask('w')
NotifierSock().notify()
def _close(self):
logger.info("Closing connection to {}".format(self.addr))
try:
logger.info("Unregistering selector of {}".format(self.addr))
self.selector.unregister(self.socket)
except AttributeError:
pass
except Exception as error:
logger.error("{}: Error during selector unregistering: {}".format(self.addr, error))
finally:
self.selector = None
try:
logger.info("Closing socket of of {}".format(self.addr))
self.socket.close()
except AttributeError:
pass
except OSError as error:
logger.error("{}: Error during socket closing: {}".format(self.addr, error))
finally:
self.socket = None
with self._close_lock:
self._should_close = False
self._clear()
logger.info("CLOSED connection to {}".format(self.addr))
def process_events(self, mask):
"""Processes read/write events with given mask.
Args:
mask (bytes): mask of the selector events.
"""
with self._close_lock:
close = self._should_close
if close:
self._close()
return
if mask & selectors.EVENT_READ:
self.read()
if mask & selectors.EVENT_WRITE:
self.write()
def read(self):
self._read()
while self._recv_buffer:
# add new message object if queue is empty or last message already processed
if not self._received_queue or (self._received_queue[0].content is not None):
self._received_queue.appendleft(MessageManager())
last_message = self._received_queue[0]
last_message.income_raw += self._recv_buffer
self._recv_buffer = b''
last_message.process_message()
# if something left after processing message - put it back
if last_message.content is not None and last_message.income_raw:
self._recv_buffer = last_message.income_raw + self._recv_buffer
last_message.income_raw = b''
if self._received_queue and last_message.content is not None:
self.process_received(self._received_queue.popleft())
def _read(self):
try:
@@ -757,57 +703,4 @@ class ConnectionManager(object):
else:
logger.info("Sending file {} to {} (as: {})".format(filepath, self.addr, dest_filepath))
self._send(MessageManager.create_message(data, "binary", "message",
additional_headers={"action": "filetransfer", "filepath": dest_filepath}))
class NotifierSock(Singleton):
def __init__(self):
self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self._server_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self._sending_sock = socket.socket()
self._send_lock = threading.Lock()
self._receiving_sock = None
def init(self, selector, port=26000):
port += random.randint(0, 100) # local testing fix
self._server_socket.bind(('', port))
self._server_socket.listen(1)
self._sending_sock.connect(('127.0.0.1', port))
self._receiving_sock, _ = self._server_socket.accept()
logger.info("Notify socket: connected")
selector.register(self._receiving_sock, selectors.EVENT_READ, data=self)
logger.info("Notify socket: selector registered")
def get_sock(self):
return self._receiving_sock
def notify(self):
with self._send_lock:
if self._receiving_sock is None:
return
self._sending_sock.sendall(bytes(1))
logger.debug("Notify socket: notified")
def process_events(self, mask):
if mask & selectors.EVENT_READ and self._receiving_sock is not None:
try:
self._receiving_sock.recv(1024)
logger.debug("Notify socket: received")
except io.BlockingIOError:
pass
except Exception as e:
logger.error(e)
def close(self):
try:
self._server_socket.close()
self._sending_sock.close()
self._receiving_sock.close()
except (OSError, KeyError) as error:
logger.error("Error during unregistring notifier socket: {}".format(error))
additional_headers={"action": "filetransfer", "filepath": dest_filepath}))

View File

@@ -1,12 +1,12 @@
import os
import sys
import time
import socket
import asyncio
import random
import logging
import datetime
import threading
import selectors
import collections
import traceback
@@ -33,144 +33,93 @@ if not os.path.exists(log_path):
logger = logging.getLogger(__name__)
class Server(messaging.Singleton):
class Server:
def __init__(self, config_path=os.path.join(current_dir, os.pardir, "config", "server.ini"), server_id=None):
self.id = server_id if server_id else str(random.randint(0, 9999)).zfill(4)
self.time_started = 0
# Init socket
self.sel = selectors.DefaultSelector()
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
messaging.set_keepalive(self.server_socket)
self.server_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self._tcp_server = None
self.callbacks = messaging.CallbackManager()
self._clients = dict()
self.host = socket.gethostname()
self.ip = messaging.get_ip_address()
# Init configs
self.config = ConfigManager()
self.config_path = config_path
self.callbacks = messaging.CallbackManager()
# Init threads
self.autoconnect_thread = threading.Thread(target=self._client_processor, daemon=True,
name='Client processor')
self.client_processor_thread_running = threading.Event() # Can be used for manual thread killing
self.broadcast_thread = threading.Thread(target=self._ip_broadcast, daemon=True,
name='IP broadcast sender')
self.broadcast_thread_running = threading.Event() # TODO replace by interrupt
self.broadcast_thread_interrupt = threading.Event()
self.listener_thread = threading.Thread(target=self._broadcast_listen, daemon=True,
name='IP broadcast listener')
self.listener_thread_running = threading.Event()
self._broadcast_send_task = None
self._broadcast_listen_task = None
def load_config(self):
self.config.load_config_and_spec(self.config_path)
def start(self):
# load config on startup
self.load_config()
async def start(self):
loop = asyncio.get_event_loop()
loop.set_debug(True)
self.time_started = time.time()
logging.info("Starting server with id: {} on {}:{} ({})!".format(self.id, self.ip, self.config.server_port,
socket.gethostname()))
logging.info("Binding server socket!")
self.server_socket.bind((self.ip, self.config.server_port))
logging.info("Starting client processor thread!")
self.client_processor_thread_running.set()
self.autoconnect_thread.start()
# load config on startup
self.load_config() # TODO async
if self.config.broadcast_send:
logging.info("Starting broadcast sender thread!")
self.broadcast_thread_running.set()
self.broadcast_thread.start()
self._broadcast_send_task = loop.create_task(self._broadcast_send())
if self.config.broadcast_listen:
logging.info("Starting broadcast listener thread!")
self.listener_thread_running.set()
self.listener_thread.start()
self._broadcast_send_task = loop.create_task(self._broadcast_listen())
def stop(self):
logging.info(f"Starting server with id: {self.id} on {self.host} ({self.ip})")
self._tcp_server = await loop.create_server(asyncio.Protocol(),
port=self.config.server_port,
reuse_address=True,
start_serving=False)
for sock in self._tcp_server.sockets: # to handle multiple interfaces (IPv4\IPv6)
self._configure_sock(sock)
sockname = sock.getsockname()
logging.info(f"Running server socket on {sockname[0]} : {sockname[1]}")
logging.info("Starting serving")
await self._tcp_server.start_serving()
async def stop(self):
logging.info("Stopping server")
self._tcp_server.close()
to_await = [self._tcp_server.wait_closed()]
self.client_processor_thread_running.clear()
if self._broadcast_send_task is not None:
logging.info("Cancelling broadcast sending")
self._broadcast_send_task.cancel()
to_await.append(self._broadcast_send_task)
self.broadcast_thread_interrupt.set()
self.broadcast_thread_running.clear()
self.listener_thread_running.clear()
messaging.NotifierSock().notify()
self.server_socket.close()
self.sel.close()
messaging.NotifierSock().close()
if self._broadcast_listen_task is not None:
logging.info("Cancelling broadcast listening")
self._broadcast_listen_task.cancel()
to_await.append(self._broadcast_listen_task)
await asyncio.gather(*to_await, return_exceptions=True)
logging.info("Server stopped")
def terminate(self, reason="Terminated"):
self.stop()
logging.critical(reason)
@staticmethod
def get_ntp_time(ntp_host, ntp_port):
NTP_DELTA = 2208988800 # 1970-01-01 00:00:00
NTP_QUERY = b'\x1b' + bytes(47)
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as ntp_socket:
ntp_socket.sendto(NTP_QUERY, (ntp_host, ntp_port))
msg, _ = ntp_socket.recvfrom(1024)
return int.from_bytes(msg[-8:], 'big') / 2 ** 32 - NTP_DELTA
def time_now(self):
if self.config.ntp_use:
return self.get_ntp_time(self.config.ntp_host, self.config.ntp_port)
return messaging.get_ntp_time(self.config.ntp_host, self.config.ntp_port)
return time.time()
# noinspection PyArgumentList
def _client_processor(self):
logging.info("Client processor (selector) thread started!")
messaging.NotifierSock().init(self.sel)
self.server_socket.listen()
self.server_socket.setblocking(False)
self.sel.register(self.server_socket, selectors.EVENT_READ, data=None)
while self.client_processor_thread_running.is_set():
events = self.sel.select(timeout=1)
for key, mask in events:
client = key.data
if client is None:
self._connect_client(key.fileobj)
elif isinstance(client, messaging.ConnectionManager):
try:
client.process_events(mask)
except Exception as error:
logging.error("Exception {} occurred for {}! Resetting connection!".format(error, client.addr))
traceback.print_exc()
client.close(True)
else: # Notifier
client.process_events(mask)
logging.info("Client autoconnect thread stopped!")
@staticmethod
def _configure_sock(sock):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
messaging.set_keepalive(sock)
def _connect_client(self, sock):
try:
conn, addr = sock.accept()
except OSError:
logging.error("Error while connecting socket!")
return
logging.info("Got connection from: {}".format(str(addr)))
conn.setblocking(False)
if not any(client_addr == addr[0] for client_addr in Client.clients.keys()):
client = Client(self.callbacks, addr[0])
@@ -183,74 +132,48 @@ class Server(messaging.Singleton):
self.sel.register(conn, selectors.EVENT_READ, data=client)
client.connect(self.sel, conn, addr)
def _ip_broadcast(self):
logging.info("Broadcast sender thread started!")
async def _broadcast_send(self):
logging.info("Broadcast sender task started!")
msg = messaging.MessageManager.create_action_message(
"server_ip", kwargs={"host": self.ip, "port": str(self.config.server_port), "id": self.id,
"start_time": str(self.time_started)})
logging.debug("Formed broadcast message to {}:{}: {}".format(self.config.broadcast_send_ip, self.config.broadcast_port, msg))
"server_ip", kwargs={"host": self.ip, "port": self.config.server_port,
"id": self.id, "start_time": self.time_started})
await asyncio.sleep(0)
logging.debug(
f"Formed broadcast message to {self.config.broadcast_send_ip}:{self.config.broadcast_port}: {msg}")
broadcast_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
broadcast_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
broadcast_sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
while True:
try:
await asyncio.sleep(self.config.broadcast_delay)
broadcast_sock.sendto(msg, (self.config.broadcast_send_ip, self.config.broadcast_port))
except OSError as e:
logging.error(f"Cannot send broadcast due error {e}")
except asyncio.CancelledError:
print("cans")
raise
else:
logging.debug("Broadcast sent")
try:
while self.broadcast_thread_running.is_set():
self.broadcast_thread_interrupt.wait(timeout=self.config.broadcast_delay)
try:
broadcast_sock.sendto(msg, (self.config.broadcast_send_ip, self.config.broadcast_port))
except OSError as e:
logging.error(f"Cannot send broadcast due error {e}")
else:
logging.debug("Broadcast sent")
except Exception as e:
logging.error(f"Unexpected error {e}!")
raise
def _broadcast_listen(self):
logging.info("Broadcast listener thread started!")
async def _broadcast_listen(self):
logging.info("Broadcast listener task started!")
broadcast_client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
broadcast_client.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint(messaging.BroadcastProtocol,
local_addr=('', self.config.broadcast_port))
# broadcast_client.settimeout(1)
try:
broadcast_client.bind(("", self.config.broadcast_port))
except OSError:
self.terminate("Another server is running on this computer, shutting down!")
return
try:
while self.listener_thread_running.is_set():
try:
data, addr = broadcast_client.recvfrom(1024) # TODO nonblock
except OSError:
logging.error(f"Cannot receive broadcast due error {e}")
continue
message = messaging.MessageManager()
message.income_raw = data
message.process_message()
content = message.content
right_command = (content and message.jsonheader["action"] == "server_ip")
if right_command:
different_id = content["kwargs"]["id"] != str(self.id)
self_younger = float(content["kwargs"]["start_time"]) <= self.time_started
if different_id and self_younger:
# younger server should shut down
self.terminate("Another server detected over the network, shutting down!")
else:
logging.warning("Got wrong broadcast message from {}".format(addr))
except Exception as e:
logging.error(f"Unexpected error {e}!")
raise
finally:
broadcast_client.close()
logging.info("Broadcast listener thread stopped, socked closed!")
# try:
# broadcast_client.bind(("", self.config.broadcast_port))
# except OSError:
# self.terminate("Another server is running on this computer, shutting down!")
# return
#
# finally:
# broadcast_client.close()
# logging.info("Broadcast listener thread stopped, socked closed!")
def send_starttime(self, copter, start_time):
copter.send_message("start", kwargs={"time": str(start_time)})
@@ -273,6 +196,7 @@ def requires_any_connected(f):
return wrapper
class Client(messaging.ConnectionManager):
clients = {}
@@ -303,7 +227,7 @@ class Client(messaging.ConnectionManager):
self.connected = True
#if self.copter_id is None:
# if self.copter_id is None:
self.get_response("id", self._got_id)
if self.on_connect:
@@ -375,7 +299,18 @@ if __name__ == '__main__':
])
server = Server()
server.start()
loop = asyncio.get_event_loop()
loop.set_debug(True)
print(loop)
try:
loop.run_until_complete(server.start())
loop.run_until_complete(asyncio.sleep(4))
loop.run_until_complete(server.stop())
print(3)
finally:
print(1)
print(4)
#loop.run_forever()
while True:
pass