Finished up reconnection handling, processing of recieved data, improved API

This commit is contained in:
Artem30801
2020-11-07 01:09:35 +03:00
parent 8ea73f8ec9
commit 11c6953041
4 changed files with 170 additions and 66 deletions

View File

@@ -24,9 +24,7 @@ class ServerPeer(messaging.PeerProtocol):
def connection_lost(self, exc):
super().connection_lost(exc)
if exc is not None:
loop = asyncio.get_event_loop()
loop.call_soon(self._parent.reconnect)
self._parent._server_connection = None
self._parent.reconnect()
class Client:
"""
@@ -53,7 +51,7 @@ class Client:
self.callbacks = messaging.CallbackManager()
self._server_connection: messaging.PeerProtocol = None
self._server_connection: ServerPeer = None
self.client_id = None
@@ -68,7 +66,9 @@ class Client:
@property
def connected(self):
return self._server_connection is not None
if self._server_connection is None:
return False
return self._server_connection.is_connected
def load_config(self):
"""
@@ -102,7 +102,7 @@ class Client:
return timenow
def serve_forever(self):
asyncio.run(self.run(serve_forever=True))
asyncio.run(self.run(serve_forever=True), debug=True)
async def run(self, serve_forever=False):
"""
@@ -116,15 +116,22 @@ class Client:
self.load_config()
self.register_callbacks()
self._server_connection = ServerPeer(self, self.callbacks)
logger.info(f"Starting client with id: '{self.client_id}' on '{socket.gethostname()}'"
f" ({messaging.get_ip_address()})")
self.reconnect()
await self._server_connection.connected
self._server_connection.send_message("hi", args=(1, 2), kwargs={'k': 4})
if serve_forever:
await self._stopped
async def stop(self, reason: str=''):
if self._stopped.done():
logging.error("Client is already stopped")
return
if self._stopping.done():
logging.error("Client is already stopping")
return
@@ -151,10 +158,9 @@ class Client:
logger.warning("Reconnection task is already running")
logger.info("Starting reconnection task")
loop = asyncio.get_event_loop()
self._reconnect_task = loop.create_task(self._reconnect()) # todo args
self._reconnect_task = asyncio.create_task(self._reconnect()) # todo args
async def _reconnect(self, attempt_limit=3, timeout=20):
async def _reconnect(self, attempt_limit=3, timeout=10):
logger.info(f"Reconnection task started")
try:
@@ -166,29 +172,27 @@ class Client:
if self.connected:
return
else:
logger.info("Too many attempts. Trying to get new server IP")
logger.warning("Too many attempts. Trying to get new server IP")
await self._broadcast_listen(timeout)
except Exception as e:
print(e, 1, e.args, e, type(e))
finally:
logging.info("Reconnection task stopped")
self._reconnect_task = None
logger.info("Reconnection task stopped")
async def _connect(self):
loop = asyncio.get_event_loop()
try:
transport, protocol = await loop.create_connection(lambda: ServerPeer(self, self.callbacks),
transport, protocol = await loop.create_connection(lambda: self._server_connection,
host=self.config.server_host,
port=self.config.server_port,
)
except OSError as e:
logger.error(f"Cannot connect to server due error: {e}")
self._server_connection = None
else:
logger.info("Connection to server successful!")
messaging.set_keepalive(transport.get_extra_info('socket'))
self._server_connection = protocol
logger.info("Connection to server successful")
async def _broadcast_listen(self, listen_timeout=None):
logging.info("Broadcast listener started")
@@ -291,7 +295,7 @@ if __name__ == "__main__":
('git_version', '42aee96'), ('calibration_status', None), ('start_position', [0.2, 0.2, 0.0]),
('mode', 'MANUAL'), ('time_delta', 1581342970.889573), ('armed', False),
('config_version', 'Copter config V0.0'), ('last_task', 'No task')])
if active_client.connected:
if active_client.wait_connected:
active_client.server_connection.send_message("telemetry", kwargs={"value": t})

45
lib/asyncio_mod.py Normal file
View File

@@ -0,0 +1,45 @@
import asyncio
import asyncio.constants as constants
async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
protocol = None
transport = None
try:
protocol = protocol_factory(conn.getpeername())
waiter = self.create_future()
if sslcontext:
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
server=server)
try:
await waiter
except BaseException:
transport.close()
raise
# It's now up to the protocol to handle the connection.
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
if self._debug:
context = {
'message':
'Error on transport creation for incoming connection',
'exception': exc,
}
if protocol is not None:
context['protocol'] = protocol
if transport is not None:
context['transport'] = transport
self.call_exception_handler(context)
asyncio.SelectorEventLoop._accept_connection2 = _accept_connection2

View File

@@ -5,6 +5,7 @@ import io
import os
import sys
import json
import time
import socket
import struct
import random
@@ -14,6 +15,7 @@ import platform
import traceback
import collections
class Namespace:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@@ -30,10 +32,12 @@ class PendingRequest(Namespace): pass
logger = logging.getLogger(__name__)
def str_peername(peername):
return f"{peername[0]}:{peername[1]}"
def get_ip_address():
def get_ip_address(): # dodo async
"""
Returns the IP address of current computer or `localhost` if no network connection present.
@@ -83,9 +87,9 @@ def set_keepalive(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
Raises:
NotImplementedError: for unknown platform.
"""
"""
current_platform = platform.system() # could be empty
if current_platform == "Linux":
return _set_keepalive_linux(sock, after_idle_sec, interval_sec, max_fails)
if current_platform == "Windows":
@@ -95,14 +99,17 @@ def set_keepalive(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
raise NotImplementedError
def _set_keepalive_linux(sock, after_idle_sec, interval_sec, max_fails):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, after_idle_sec)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval_sec)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, max_fails)
def _set_keepalive_windows(sock, after_idle_sec, interval_sec):
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, after_idle_sec*1000, interval_sec*1000))
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, after_idle_sec * 1000, interval_sec * 1000))
def _set_keepalive_osx(sock, interval_sec):
TCP_KEEPALIVE = 0x10
@@ -114,8 +121,11 @@ class BroadcastProtocol:
def __init__(self, on_broadcast):
self._on_broadcast = on_broadcast
loop = asyncio.get_event_loop()
self.closed: asyncio.Future = loop.create_future()
self._closed = asyncio.Event()
@property
def closed(self):
return self._closed.wait()
def connection_made(self, transport):
self.transport = transport
@@ -123,8 +133,7 @@ class BroadcastProtocol:
def connection_lost(self, exc):
logging.info(f"Broadcast connection lost: {'closed' if exc is None else exc}")
if not self.closed.done():
self.closed.set_result(True)
self._closed.set()
def datagram_received(self, data, addr):
message = MessageManager(data)
@@ -143,9 +152,11 @@ class BroadcastProtocol:
def error_received(self, exc):
logging.warning(f"Error on broadcast connection received: {exc}")
class BroadcastSendProtocol:
pass
class MessageManager:
"""
MessageManager class represents single incoming by TCP stream message and contains methods to decode and extract data from incoming data. It also contains static class methods for encoding various types of messages.
@@ -162,6 +173,7 @@ class MessageManager:
jsonheader (dict): Headers dictionary with information about message encoding and purpose. Would be populated when receiving and processing of the json header will be completed.
content (object): Would be populated when receiving and processing of the message will be completed. Defaults to None.
"""
def __init__(self, data):
"""
```python
@@ -337,6 +349,9 @@ class MessageManager:
def get_buffer(self):
return self._income_raw
def reset_buffer(self):
self._income_raw = None
def process_message(self):
"""
Attempts processing the message. Chunks of `income_raw` would be consumed as different parts of the message will be processed. The result of processing (body of the message) will be available at `content` and `jsonheader`.
@@ -351,10 +366,12 @@ class MessageManager:
if self.jsonheader:
if not self.processed:
self._process_content()
@property
def processed(self):
return self.content is not None
class CallbackManager:
def __init__(self):
self.action_callbacks = dict()
@@ -370,6 +387,7 @@ class CallbackManager:
d[key] = f
logger.debug("Registered callback function {} for {}".format(f, key))
return f
return inner
def action_callback(self, key):
@@ -378,35 +396,61 @@ class CallbackManager:
def request_callback(self, key):
return self._register_function(self.request_callbacks, key)
class PeerProtocol(asyncio.Protocol):
def __init__(self, parent, callbacks):
self._parent = parent
self._callbacks = callbacks
self.transport = None
self._recv_buffer = bytearray()
self._current_msg = None
self._msg_queue = asyncio.Queue()
loop = asyncio.get_event_loop()
self.closed: asyncio.Future = loop.create_future()
self._connected = asyncio.Event()
self._closed = asyncio.Event()
self._closed.set()
@property
def peername(self):
return self.transport.get_extra_info('peername')
@property
def is_connected(self):
return self._connected.is_set()
@property
def connected(self):
return not self.closed.done()
return self._connected.wait()
@property
def closed(self):
return self._closed.wait()
def connection_made(self, transport):
self.transport = transport
logging.info(f"Connected to {str_peername(self.peername)}")
self._connected.set()
self._closed.clear()
# self._resend_requests()
def connection_lost(self, exc):
logger.info(f"Lost connection to {str_peername(self.peername)}: {'closed' if exc is None else exc}")
self._connected.clear()
self._closed.set()
def error_received(self, exc):
logging.warning(f"Error on broadcast connection received: {exc}")
def data_received(self, data):
self._recv_buffer += data
logger.debug("Received {} bytes from {}".format(len(data), self.peername))
logger.debug(f'{data}, {self._recv_buffer}')
asyncio.create_task(self._proceess_received())
async def _proceess_received(self):
while self._recv_buffer:
@@ -416,23 +460,37 @@ class PeerProtocol(asyncio.Protocol):
self._current_msg.set_buffer(self._recv_buffer)
self._current_msg.process_message()
self._recv_buffer = bytearray(self._current_msg.get_buffer())
# self._current_msg.reset_buffer()
if self._current_msg.processed:
#self._recv_buffer =
logging.info(self._current_msg.content)
await self._msg_queue.put(self._current_msg)
self._current_msg = None
else:
await asyncio.sleep(0)
# if last_message.content is not None and last_message.income_raw:
# self._recv_buffer = last_message.income_raw + self._recv_buffer
# last_message.income_raw = b''
#
# if self._received_queue and last_message.content is not None:
# self.process_received(self._received_queue.popleft())
# Sending api functions
def connection_lost(self, exc):
logger.info(f"Lost connection to {str_peername(self.peername)}: {'closed' if exc is None else exc}")
if not self.closed.done():
self.closed.set_result(True)
def _send(self, msg):
if not self.is_connected:
logger.error("Peer is disconnected, can't send")
return
self.transport.write(msg[:10])
time.sleep(0.5)
self.transport.write(msg[10:])
def send_message(self, action, args=(), kwargs=None):
"""
Sends to peer message with specified action, arguments and keyword arguments.
Args:
action (str): action(command) to perform upon receiving. Should correspond with `action_string` of function registered in `message_callback()` on the peer.
args (tuple, optional): Arguments for the command. Defaults to ().
kwargs (dict, optional): Keyword arguments for the command. Defaults to None.
"""
self._send(MessageManager.create_action_message(action, args, kwargs))
class ConnectionManager:
@@ -464,7 +522,6 @@ class ConnectionManager:
self.whoami = whoami
def _clear(self):
if not self.resume_queue: # maybe needs locks
self._recv_buffer = b''
@@ -472,10 +529,9 @@ class ConnectionManager:
self._received_queue.clear()
self._send_queue.clear()
def process_received(self, message):
message_type = message.jsonheader["message-type"]
content = message.content if message.jsonheader["content-type"] != "binary"\
content = message.content if message.jsonheader["content-type"] != "binary" \
else message.content[:256]
logger.debug(
"Received message! Header: {}, content: {}".format(message.jsonheader, content))
@@ -646,23 +702,12 @@ class ConnectionManager:
callback_args=callback_args, callback_kwargs=callback_kwargs)
def _resend_requests(self):
for request_id, request in self._request_queue.items(): # TODO filter
if request.resend:
self._send(MessageManager.create_request(
request.requested_value, request_id, request.request_kwargs.update(resend=request.resend))
)
request.resend = False
def send_message(self, action, args=(), kwargs=None):
"""
Sends to peer message with specified action, arguments and keyword arguments.
Args:
action (str): action(command) to perform upon receiving. Should correspond with `action_string` of function registered in `message_callback()` on the peer.
args (tuple, optional): Arguments for the command. Defaults to ().
kwargs (dict, optional): Keyword arguments for the command. Defaults to None.
"""
self._send(MessageManager.create_action_message(action, args, kwargs))
for request_id, request in self._request_queue.items(): # TODO filter
if request.resend:
self._send(MessageManager.create_request(
request.requested_value, request_id, request.request_kwargs.update(resend=request.resend))
)
request.resend = False
def _send_response(self, requested_value, request_id, value, filetransfer=False):
self._send(MessageManager.create_response(requested_value, request_id, value, filetransfer))
@@ -683,4 +728,5 @@ class ConnectionManager:
else:
logger.info("Sending file {} to {} (as: {})".format(filepath, self.addr, dest_filepath))
self._send(MessageManager.create_message(data, "binary", "message",
additional_headers={"action": "filetransfer", "filepath": dest_filepath}))
additional_headers={"action": "filetransfer",
"filepath": dest_filepath}))

View File

@@ -8,7 +8,7 @@ import random
import logging
import datetime
import collections
import traceback
# import traceback
# Add parent dir to PATH to import messaging_lib and config_lib
current_dir = os.path.dirname(os.path.realpath(__file__))
@@ -16,6 +16,7 @@ sys.path.insert(0, os.path.realpath(os.path.join(current_dir, os.pardir, os.pard
# Import modules from lib dir
import lib.messaging as messaging
import lib.asyncio_mod
from lib.config import ConfigManager
random.seed()
@@ -58,6 +59,11 @@ class Server:
def load_config(self):
self.config.load_config_and_spec(self.config_path)
def _protocol_factory(self, addr): # protocol call was modded in factory call (in asyncio-mod)
logging.debug(f"Got connection from {messaging.str_peername(addr)}")
instance = self._clients.get(addr[0], messaging.PeerProtocol(self, self.callbacks))
return instance
async def run(self, wait_closed: bool=False):
if self.time_started is not None:
logger.warning("Server is already running, restarting")
@@ -81,7 +87,7 @@ class Server:
logging.info(f"Starting server with id: '{self.id}' on '{self.host}' ({self.ip})")
self._tcp_server = await loop.create_server(lambda: messaging.PeerProtocol(self, self.callbacks),
self._tcp_server = await loop.create_server(self._protocol_factory,
port=self.config.server_port,
reuse_address=True,
start_serving=False)
@@ -96,9 +102,12 @@ class Server:
await self._stopped
def serve_forever(self):
asyncio.run(self.run(wait_closed=True))
asyncio.run(self.run(wait_closed=True), debug=True)
async def stop(self, reason: str=''):
if self._stopped.done():
logging.error("Server is already stopped")
return
if self._stopping.done():
logging.error("Server is already stopping")
return
@@ -197,7 +206,7 @@ class Server:
def requires_connect(f):
def wrapper(*args, **kwargs):
if args[0].connected:
if args[0].wait_connected:
return f(*args, **kwargs)
logging.warning("Function requires client to be connected!")
@@ -298,7 +307,7 @@ class Client(messaging.ConnectionManager):
@requires_any_connected
def broadcast(message, force_all=False):
for client in Client.clients.values():
if client.connected or force_all:
if client.wait_connected or force_all:
client._send(message)
@classmethod