started asyncio refactor

This commit is contained in:
Artem30801
2020-10-26 13:13:11 +03:00
parent ca8e960318
commit d4d3a4f45f
3 changed files with 187 additions and 359 deletions

View File

@@ -8,11 +8,11 @@ import json
import socket
import struct
import random
import asyncio
import logging
import threading
import collections
import platform
import traceback
import collections
from contextlib import closing
@@ -55,6 +55,15 @@ def get_ip_address():
return "localhost"
def get_ntp_time(ntp_host, ntp_port):
NTP_DELTA = 2208988800 # 1970-01-01 00:00:00
NTP_QUERY = b'\x1b' + bytes(47)
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as ntp_socket:
ntp_socket.sendto(NTP_QUERY, (ntp_host, ntp_port))
msg, _ = ntp_socket.recvfrom(1024)
return int.from_bytes(msg[-8:], 'big') / 2 ** 32 - NTP_DELTA
def set_keepalive(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
"""
Sets `keepalive` parameters of given socket.
@@ -94,25 +103,31 @@ def _set_keepalive_osx(sock, interval_sec):
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec)
class _Singleton(type):
"""
A metaclass that creates a Singleton base class when called.
"""
_instances = {}
class BroadcastProtocol:
def __init__(self, on_broadcast):
self._on_broadcast = on_broadcast
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(_Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
def datagram_received(self, data, addr):
message = MessageManager()
message.append_data(data)
message.process_message()
content = message.content
is_ip_broadcast = (content is not None and message.jsonheader["action"] == "server_ip")
class Singleton(_Singleton('SingletonMeta', (object,), {})):
"""
Singleton base class.
"""
if is_ip_broadcast:
asyncio.get_event_loop().call_soon(self._on_broadcast, message)
# different_id = content["kwargs"]["id"] != str(self.id)
# self_younger = float(content["kwargs"]["start_time"]) <= self.time_started
else:
logging.warning("Got wrong broadcast message from {}".format(addr))
# def error_received(self, exc):
# logging.error("Error received")
class BroadcastSendProtocol:
pass
class MessageManager:
"""
MessageManager class represents single incoming by TCP stream message and contains methods to decode and extract data from incoming data. It also contains static class methods for encoding various types of messages.
@@ -125,7 +140,7 @@ class MessageManager:
Attributes:
income_raw (bytes string): Holds incoming data bytes. Append incoming data to this attribute. It may not be empty after processing.
_income_raw (bytes string): Holds incoming data bytes. Append incoming data to this attribute. It may not be empty after processing.
jsonheader (dict): Headers dictionary with information about message encoding and purpose. Would be populated when receiving and processing of the json header will be completed.
content (object): Would be populated when receiving and processing of the message will be completed. Defaults to None.
"""
@@ -135,7 +150,7 @@ class MessageManager:
message = MessageManager()
```
"""
self.income_raw = b""
self._income_raw = b"" #todo bytearrray
self._jsonheader_len = None
self.jsonheader = None
self.content = None
@@ -262,15 +277,15 @@ class MessageManager:
def _process_protoheader(self):
header_len = 2
if len(self.income_raw) >= header_len:
self._jsonheader_len = struct.unpack(">H", self.income_raw[:header_len])[0]
self.income_raw = self.income_raw[header_len:]
if len(self._income_raw) >= header_len:
self._jsonheader_len = struct.unpack(">H", self._income_raw[:header_len])[0]
self._income_raw = self._income_raw[header_len:]
def _process_jsonheader(self):
header_len = self._jsonheader_len
if len(self.income_raw) >= header_len:
self.jsonheader = self._json_decode(self.income_raw[:header_len], "utf-8")
self.income_raw = self.income_raw[header_len:]
if len(self._income_raw) >= header_len:
self.jsonheader = self._json_decode(self._income_raw[:header_len], "utf-8")
self._income_raw = self._income_raw[header_len:]
for reqhdr in (
"byteorder",
"content-length",
@@ -283,16 +298,22 @@ class MessageManager:
def _process_content(self):
content_len = self.jsonheader["content-length"]
if not len(self.income_raw) >= content_len:
if not len(self._income_raw) >= content_len:
return
data = self.income_raw[:content_len]
self.income_raw = self.income_raw[content_len:]
data = self._income_raw[:content_len]
self._income_raw = self._income_raw[content_len:]
if self.jsonheader["content-type"] == "json":
encoding = self.jsonheader["content-encoding"]
self.content = self._json_decode(data, encoding)
else:
self.content = data
def append_data(self, data):
self._income_raw += data
def get_leftovers(self):
return self._income_raw
def process_message(self):
"""
Attempts processing the message. Chunks of `income_raw` would be consumed as different parts of the message will be processed. The result of processing (body of the message) will be available at `content` and `jsonheader`.
@@ -331,6 +352,46 @@ class CallbackManager:
def request_callback(self, key):
return self._register_function(self.request_callbacks, key)
class PeerProtocol:
def __init__(self, parent, calabacks):
self._parent = parent
self._callbacks = calabacks
self._recv_msg_queue = asyncio.Queue
def connection_made(self, transport):
peername = transport.get_extra_info('peername')
self._resend_requests()
self.transport = transport
def data_received(self, data):
while self._recv_buffer:
# add new message object if queue is empty or last message already processed
if self._recv_msg_queue.empty() or (self._received_queue[0].content is not None):
self._recv_msg_queue.put(MessageManager())
last_message = self._received_queue[0]
last_message.income_raw += self._recv_buffer
self._recv_buffer = b''
last_message.process_message()
# if something left after processing message - put it back
if last_message.content is not None and last_message.income_raw:
self._recv_buffer = last_message.income_raw + self._recv_buffer
last_message.income_raw = b''
if self._received_queue and last_message.content is not None:
self.process_received(self._received_queue.popleft())
def connection_lost(self, exc):
pass
# logger.info("Closing connection to {}".format(self.addr))
class ConnectionManager(object):
"""
@@ -370,50 +431,6 @@ class ConnectionManager(object):
self.whoami = whoami
self._send_queue = collections.deque()
self._received_queue = collections.deque()
self._request_queue = collections.OrderedDict()
self._send_lock = threading.Lock()
self._request_lock = threading.Lock()
self._close_lock = threading.Lock()
self.buffer_size = 1024
self.resume_queue = False
self.resend_requests = True
def _set_selector_events_mask(self, mode):
"""Set selector to listen for events: mode is 'r', 'w', 'rw'."""
if mode == "r":
events = selectors.EVENT_READ
elif mode == "w":
events = selectors.EVENT_WRITE
elif mode == "rw":
events = selectors.EVENT_READ | selectors.EVENT_WRITE
else:
raise ValueError("Invalid events mask mode {}.".format(mode))
key = self.selector.modify(self.socket, events, data=self)
logger.debug("Switched selector of {} to mode {}".format(self.addr, key.events))
return key
def connect(self, client_selector, client_socket, client_addr):
"""[summary]
Args:
client_selector (selector): Related selector object.
client_socket (socket): Socket object of the connection.
client_addr (str): Address of the peer.
"""
self.selector = client_selector
self.socket = client_socket
self.addr = client_addr
self._clear()
self._set_selector_events_mask('r')
if self.resend_requests:
self._resend_requests()
def _clear(self):
if not self.resume_queue: # maybe needs locks
@@ -422,82 +439,11 @@ class ConnectionManager(object):
self._received_queue.clear()
self._send_queue.clear()
def close(self):
"""
Closes connection with the peer.
"""
with self._close_lock:
self._should_close = True
self._set_selector_events_mask('w')
NotifierSock().notify()
def _close(self):
logger.info("Closing connection to {}".format(self.addr))
try:
logger.info("Unregistering selector of {}".format(self.addr))
self.selector.unregister(self.socket)
except AttributeError:
pass
except Exception as error:
logger.error("{}: Error during selector unregistering: {}".format(self.addr, error))
finally:
self.selector = None
try:
logger.info("Closing socket of of {}".format(self.addr))
self.socket.close()
except AttributeError:
pass
except OSError as error:
logger.error("{}: Error during socket closing: {}".format(self.addr, error))
finally:
self.socket = None
with self._close_lock:
self._should_close = False
self._clear()
logger.info("CLOSED connection to {}".format(self.addr))
def process_events(self, mask):
"""Processes read/write events with given mask.
Args:
mask (bytes): mask of the selector events.
"""
with self._close_lock:
close = self._should_close
if close:
self._close()
return
if mask & selectors.EVENT_READ:
self.read()
if mask & selectors.EVENT_WRITE:
self.write()
def read(self):
self._read()
while self._recv_buffer:
# add new message object if queue is empty or last message already processed
if not self._received_queue or (self._received_queue[0].content is not None):
self._received_queue.appendleft(MessageManager())
last_message = self._received_queue[0]
last_message.income_raw += self._recv_buffer
self._recv_buffer = b''
last_message.process_message()
# if something left after processing message - put it back
if last_message.content is not None and last_message.income_raw:
self._recv_buffer = last_message.income_raw + self._recv_buffer
last_message.income_raw = b''
if self._received_queue and last_message.content is not None:
self.process_received(self._received_queue.popleft())
def _read(self):
try:
@@ -757,57 +703,4 @@ class ConnectionManager(object):
else:
logger.info("Sending file {} to {} (as: {})".format(filepath, self.addr, dest_filepath))
self._send(MessageManager.create_message(data, "binary", "message",
additional_headers={"action": "filetransfer", "filepath": dest_filepath}))
class NotifierSock(Singleton):
def __init__(self):
self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self._server_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self._sending_sock = socket.socket()
self._send_lock = threading.Lock()
self._receiving_sock = None
def init(self, selector, port=26000):
port += random.randint(0, 100) # local testing fix
self._server_socket.bind(('', port))
self._server_socket.listen(1)
self._sending_sock.connect(('127.0.0.1', port))
self._receiving_sock, _ = self._server_socket.accept()
logger.info("Notify socket: connected")
selector.register(self._receiving_sock, selectors.EVENT_READ, data=self)
logger.info("Notify socket: selector registered")
def get_sock(self):
return self._receiving_sock
def notify(self):
with self._send_lock:
if self._receiving_sock is None:
return
self._sending_sock.sendall(bytes(1))
logger.debug("Notify socket: notified")
def process_events(self, mask):
if mask & selectors.EVENT_READ and self._receiving_sock is not None:
try:
self._receiving_sock.recv(1024)
logger.debug("Notify socket: received")
except io.BlockingIOError:
pass
except Exception as e:
logger.error(e)
def close(self):
try:
self._server_socket.close()
self._sending_sock.close()
self._receiving_sock.close()
except (OSError, KeyError) as error:
logger.error("Error during unregistring notifier socket: {}".format(error))
additional_headers={"action": "filetransfer", "filepath": dest_filepath}))