mirror of
https://github.com/CopterExpress/clever-show.git
synced 2026-05-26 07:07:58 +00:00
212 lines
6.7 KiB
Python
212 lines
6.7 KiB
Python
import logging
|
|
import asyncio
|
|
|
|
import messages
|
|
import exceptions
|
|
|
|
from network import str_peername
|
|
from utils import Callback
|
|
|
|
logger = logging.getLogger("protocols")
|
|
|
|
|
|
class BroadcastProtocol:
|
|
def __init__(self, on_broadcast):
|
|
self.transport = None
|
|
self._on_broadcast = on_broadcast
|
|
|
|
self._closed = asyncio.Event()
|
|
|
|
@property
|
|
def closed(self):
|
|
return self._closed.wait()
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
logger.info("Broadcast connection established")
|
|
|
|
def connection_lost(self, exc):
|
|
logger.info(f"Broadcast connection lost: {'closed' if exc is None else exc}")
|
|
self._closed.set()
|
|
|
|
def datagram_received(self, data, addr):
|
|
message = messages.MessageDecoder(data)
|
|
message.process_message()
|
|
content = message.content
|
|
|
|
is_ip_broadcast = (content is not None and message.header["action"] == "server_ip")
|
|
|
|
if is_ip_broadcast:
|
|
logger.debug(f"Got broadcast message from {addr}: {content}")
|
|
asyncio.get_event_loop().call_soon(self._on_broadcast, message)
|
|
self.transport.close()
|
|
else:
|
|
logger.warning(f"Got wrong broadcast message from {addr}")
|
|
|
|
def error_received(self, exc):
|
|
logger.warning(f"Error on broadcast connection received: {exc}")
|
|
|
|
class PeerProtocol(asyncio.Protocol):
|
|
def __init__(self, connected_callback=None, disconnected_callback=None, message_callback=None):
|
|
self.connected_callback = Callback(connected_callback)
|
|
self.disconnected_callback = Callback(disconnected_callback)
|
|
self.message_callback = Callback(message_callback)
|
|
|
|
self.transport: asyncio.Transport = None
|
|
|
|
self._connected = asyncio.Event()
|
|
self._closed = asyncio.Event()
|
|
self._closed.set()
|
|
|
|
self._recv_buffer = bytearray()
|
|
self._recv_queue = asyncio.Queue()
|
|
self._current_msg = None
|
|
self._recv_process_task = None
|
|
|
|
self._send_queue = asyncio.Queue() # holds messages to send
|
|
self._send_task = None
|
|
|
|
self._can_write = asyncio.Event()
|
|
|
|
self._recv_waiter = None
|
|
|
|
@property
|
|
def peername(self):
|
|
return self.transport.get_extra_info('peername')
|
|
|
|
@property
|
|
def is_connected(self):
|
|
return self._connected.is_set()
|
|
|
|
@property
|
|
def connected(self):
|
|
return self._connected.wait()
|
|
|
|
@property
|
|
def closed(self):
|
|
return self._closed.wait()
|
|
|
|
async def close(self):
|
|
self.transport.close()
|
|
await self.closed
|
|
|
|
# Drain control
|
|
|
|
def pause_writing(self) -> None:
|
|
self._can_write.clear()
|
|
|
|
def resume_writing(self) -> None:
|
|
self._can_write.set()
|
|
|
|
async def drain(self) -> None:
|
|
await self._can_write.wait()
|
|
|
|
async def send(self, msg: messages.AbstractPendingMessage):
|
|
if not self.is_connected: # and not send_disconnected
|
|
msg.sent.cancel("Peer is disconnected, can't send")
|
|
raise RuntimeError("Peer is disconnected, can't send")
|
|
|
|
await self._send_queue.put(msg)
|
|
if self._send_task is None:
|
|
self._send_task = asyncio.create_task(self._send())
|
|
|
|
logger.debug(f"Queued sending of message {msg} to {str_peername(self.peername)}")
|
|
await msg.sent
|
|
|
|
async def _send(self):
|
|
while self.is_connected:
|
|
msg = None
|
|
try:
|
|
msg: messages.AbstractPendingMessage = await self._send_queue.get()
|
|
self.transport.write(msg.encode())
|
|
await self.drain()
|
|
except asyncio.CancelledError:
|
|
if msg is not None:
|
|
msg.sent.set_exception(exceptions.ConnectionClosedError)
|
|
raise
|
|
else:
|
|
msg.sent.set_result(None)
|
|
logger.debug(f"Sent message {msg} to {str_peername(self.peername)}")
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
logger.info(f"Connected to {str_peername(self.peername)}")
|
|
|
|
self._connected.set()
|
|
self._closed.clear()
|
|
|
|
self._can_write.set()
|
|
|
|
self.connected_callback(self)
|
|
|
|
def connection_lost(self, exc):
|
|
logger.info(f"Lost connection to {str_peername(self.peername)}: {'closed' if exc is None else exc}")
|
|
self._connected.clear()
|
|
self._closed.set()
|
|
|
|
if self._recv_waiter is not None:
|
|
self._recv_waiter.cancel()
|
|
|
|
if self._recv_process_task is not None:
|
|
self._recv_process_task.cancel()
|
|
|
|
if self._send_task is not None:
|
|
self._send_task.cancel()
|
|
|
|
self._can_write.set()
|
|
|
|
self._recv_buffer = bytearray()
|
|
self._current_msg = None
|
|
|
|
self.disconnected_callback(self)
|
|
|
|
def error_received(self, exc):
|
|
logger.warning(f"Error on {str_peername(self.peername)} connection received: {exc}")
|
|
|
|
def data_received(self, data):
|
|
self._recv_buffer += data
|
|
logger.debug(f"Received {len(data)} bytes from {str_peername(self.peername)}")
|
|
|
|
if self._recv_process_task is None:
|
|
self._recv_process_task = asyncio.create_task(self._proceess_received())
|
|
|
|
async def _proceess_received(self):
|
|
try:
|
|
while self._recv_buffer:
|
|
if self._current_msg is None:
|
|
self._current_msg = messages.MessageDecoder(self._recv_buffer)
|
|
else:
|
|
self._current_msg.set_buffer(self._recv_buffer)
|
|
|
|
self._current_msg.process_message()
|
|
self._recv_buffer = self._current_msg.get_buffer()
|
|
# self._current_msg.reset_buffer()
|
|
if self._current_msg.processed:
|
|
logger.debug(f"Received message {self._current_msg.content} from {str_peername(self.peername)}")
|
|
|
|
put_msg = self.message_callback.gather_bool(self, self._current_msg)
|
|
if put_msg:
|
|
await self._recv_queue.put(self._current_msg)
|
|
self._current_msg = None
|
|
else:
|
|
await asyncio.sleep(0)
|
|
except Exception as e:
|
|
logger.error(f"Error during message processing: {e}")
|
|
else:
|
|
logger.debug("All data processed")
|
|
finally:
|
|
self._recv_process_task = None
|
|
|
|
async def receive_message(self):
|
|
if self._recv_waiter is not None:
|
|
raise RuntimeError("receive_message() is already being awaited")
|
|
|
|
self._recv_waiter = asyncio.create_task(self._recv_queue.get())
|
|
|
|
try:
|
|
return await self._recv_waiter
|
|
except asyncio.CancelledError:
|
|
raise exceptions.ConnectionClosedError
|
|
finally:
|
|
self._recv_waiter = None
|