mirror of
https://github.com/CopterExpress/clever-show.git
synced 2026-05-26 07:07:58 +00:00
Request\response processing
This commit is contained in:
@@ -7,7 +7,7 @@ import logging
|
||||
|
||||
logging.basicConfig( # TODO all prints as logs
|
||||
level=logging.DEBUG, # INFO
|
||||
stream=sys.stdout,
|
||||
# stream=sys.stdout,
|
||||
format="%(asctime)s [%(name)-7.7s] [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
|
||||
44
examples/code/connection.py
Normal file
44
examples/code/connection.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import lib.connections as c
|
||||
import lib.messages as messages
|
||||
import lib.utils as utils
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format=utils.logger_format,
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
])
|
||||
|
||||
async def server():
|
||||
server, new_connections = await c.create_server(8181)
|
||||
connection1: c.Connection = await new_connections
|
||||
#await connection1.send(messages.Request("greetings"))
|
||||
await asyncio.sleep(2)
|
||||
t1 = asyncio.create_task(asyncio.wait_for(connection1.send(messages.Request("greetings")), 5))
|
||||
t2 = asyncio.create_task(connection1.send(messages.Request("hello", args=(1, 3))))
|
||||
|
||||
print(await asyncio.gather(t1, t2, return_exceptions=True))
|
||||
await connection1.close()
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
async def client():
|
||||
connection = await c.connect(8181, "")
|
||||
request = await connection.receive("hello")
|
||||
print(request.message.content)
|
||||
request2 = await connection.receive()
|
||||
print(request2.message.content)
|
||||
|
||||
await request.reply("hi")
|
||||
|
||||
# await connection.close()
|
||||
|
||||
async def main():
|
||||
s = asyncio.create_task(server())
|
||||
c = asyncio.create_task(client())
|
||||
await asyncio.gather(s, c)
|
||||
|
||||
asyncio.run(main(), debug=False)
|
||||
@@ -13,9 +13,25 @@ logging.basicConfig(
|
||||
|
||||
async def server():
|
||||
loop = asyncio.get_event_loop()
|
||||
clients = asyncio.Queue()
|
||||
|
||||
server = await loop.create_server(lambda: p.PeerProtocol(None, None), host='', port=8181)
|
||||
def factory(*args):
|
||||
protocol = p.PeerProtocol()
|
||||
clients.put_nowait(protocol)
|
||||
return protocol
|
||||
|
||||
server = await loop.create_server(factory, host='', port=8181)
|
||||
|
||||
client1 = await clients.get()
|
||||
await client1.connected
|
||||
msg = await client1.receive_message()
|
||||
print(msg.header, msg.content)
|
||||
|
||||
msg = messages.PendingMessage(msg.content + " from server")
|
||||
await msg.send_to(client1)
|
||||
|
||||
await client1.closed
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
async def client():
|
||||
@@ -23,9 +39,11 @@ async def client():
|
||||
|
||||
transport, protocol = await loop.create_connection(lambda: p.PeerProtocol(None, None),
|
||||
host='', port=8181)
|
||||
await protocol.send(messages.Request("greetings"))
|
||||
|
||||
await protocol.closed
|
||||
await protocol.send(messages.PendingMessage("greetings"))
|
||||
print((await protocol.receive_message()).content)
|
||||
|
||||
await protocol.close()
|
||||
|
||||
async def main():
|
||||
s = asyncio.create_task(server())
|
||||
|
||||
@@ -14,10 +14,7 @@ def modify_filename(path, pattern): # TODO move to core
|
||||
|
||||
|
||||
def parent_path(path, levels=1):
|
||||
for i in range(levels):
|
||||
path = os.path.abspath(os.path.join(path, os.pardir))
|
||||
return path
|
||||
|
||||
return os.path.realpath(os.path.join(path, *[os.pardir for _ in range(levels)]))
|
||||
|
||||
def parent_dir(path):
|
||||
return os.path.basename(os.path.normpath(path))
|
||||
|
||||
@@ -10,10 +10,11 @@ import traceback
|
||||
|
||||
import messages
|
||||
import exceptions
|
||||
from messages import MessageDecoder, Request
|
||||
from messages import MessageDecoder, Request, ReceivedRequest
|
||||
from protocols import PeerProtocol
|
||||
from utils import Callback, KeyQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("connections")
|
||||
|
||||
class CallbackManager:
|
||||
def __init__(self):
|
||||
@@ -113,24 +114,31 @@ class temp:
|
||||
# Sending api functions
|
||||
|
||||
class Connection:
|
||||
def __init__(self, callbacks):
|
||||
self.callbacks = callbacks
|
||||
def __init__(self, request_callback=None):
|
||||
self.request_callback = Callback(request_callback)
|
||||
|
||||
self.protocol: PeerProtocol = None
|
||||
self.transport = None
|
||||
|
||||
self._sent_requests = dict() # holds requests awaiting reply
|
||||
self._recv_requests = asyncio.Queue()
|
||||
self._recv_requests = KeyQueue()
|
||||
|
||||
def connect(self, protocol: PeerProtocol, transport):
|
||||
self.protocol = protocol
|
||||
def connect(self, transport, protocol: PeerProtocol):
|
||||
self.transport = transport
|
||||
self.protocol = protocol
|
||||
|
||||
self.protocol.message_callback.add(self._receive)
|
||||
self.protocol.disconnected_callback.add(self._on_disconnect)
|
||||
|
||||
def _on_disconnect(self, protocol):
|
||||
self._recv_requests.cancel()
|
||||
|
||||
async def close(self):
|
||||
self.transport.close()
|
||||
await self.protocol.closed
|
||||
|
||||
async def send(self, request: Request):
|
||||
async def send(self, request: Request) -> messages.MessageDecoder:
|
||||
request.set_chain_id()
|
||||
self._sent_requests[request.chain_id] = request
|
||||
|
||||
try:
|
||||
@@ -144,40 +152,94 @@ class Connection:
|
||||
else:
|
||||
return response
|
||||
|
||||
async def send_reply(self, response: messages.Response):
|
||||
await self.protocol.send(response)
|
||||
async def send_request(self, name, args=(), kwargs=None, callback=None):
|
||||
request = Request(name, args, kwargs, callback)
|
||||
return await self.send(request)
|
||||
|
||||
async def receive(self):
|
||||
if self._recv_waiter is not None:
|
||||
raise RuntimeError("receive() is already being awaited")
|
||||
|
||||
self._recv_waiter = asyncio.create_task(self._recv_requests.get())
|
||||
async def send_response(self, request, data=None, progress=None, err=None):
|
||||
pass
|
||||
|
||||
async def receive(self, name: str =None) -> messages.ReceivedRequest:
|
||||
try:
|
||||
return await self._recv_waiter
|
||||
return await self._recv_requests.get(key=name)
|
||||
except asyncio.CancelledError:
|
||||
raise exceptions.ConnectionClosedError
|
||||
finally:
|
||||
self._recv_waiter = None
|
||||
|
||||
async def _receive(self):
|
||||
msg = await self.protocol.receive_message()
|
||||
def _receive(self, protocol, msg):
|
||||
if msg.message_type == messages.MessageTypes.RESPONSE:
|
||||
await self._process_response(msg)
|
||||
self._process_response(msg)
|
||||
elif msg.message_type == messages.MessageTypes.REQUEST:
|
||||
pass
|
||||
self._process_request(msg)
|
||||
else:
|
||||
logger.warning(f"Got unknown message type {msg.message_type} for message {msg}")
|
||||
return True # put msg in protocol rcv queue
|
||||
return False # don't put msg in protocol rcv queue
|
||||
|
||||
def _process_response(self, response: MessageDecoder):
|
||||
logger.debug(f"Received response {response} with id {response.chain_id}")
|
||||
|
||||
async def _process_response(self, msg: MessageDecoder):
|
||||
try:
|
||||
request = self._sent_requests[msg.chain_id]
|
||||
request = self._sent_requests.pop(response.chain_id)
|
||||
except KeyError:
|
||||
logger.error(f"Unknown response {msg} with id {msg.chain_id}")
|
||||
logger.error(f"Unexpected response {response} with id {response.chain_id}")
|
||||
return
|
||||
|
||||
request.response.set_result(msg)
|
||||
try:
|
||||
response_type = response.header["response-type"]
|
||||
except KeyError:
|
||||
logger.error("Missing 'response-type' header in response")
|
||||
return
|
||||
|
||||
if response_type == messages.ResponseTypes.RESULT:
|
||||
result = response.content
|
||||
request.response.set_result(result)
|
||||
elif response_type == messages.ResponseTypes.ERROR: # TODOOOO
|
||||
error = None
|
||||
request.response.set_result(error)
|
||||
elif response_type == messages.ResponseTypes.STATUS:
|
||||
pass
|
||||
else:
|
||||
logger.error(f"Unknown response type '{response_type}' in response {response}")
|
||||
|
||||
def _process_request(self, request_msg):
|
||||
request = ReceivedRequest(self, request_msg)
|
||||
logger.debug(f"Received request {request_msg} with id {request_msg.chain_id} for {request.name}")
|
||||
|
||||
put_msg = self.request_callback.gather_bool(self, request)
|
||||
if put_msg:
|
||||
self._recv_requests.put_nowait(request, key=request.name)
|
||||
|
||||
async def connect(port, host, connection_factory=Connection, protocol_factory=PeerProtocol, **kwargs):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
transport, protocol = await loop.create_connection(protocol_factory, host=host, port=port, **kwargs)
|
||||
connection = connection_factory()
|
||||
connection.connect(transport, protocol)
|
||||
return connection
|
||||
|
||||
async def create_server(port, host="", connection_factory=Connection, protocol_factory=PeerProtocol,
|
||||
reuse_connections=True, **kwargs):
|
||||
loop = asyncio.get_event_loop()
|
||||
new_connections = asyncio.Queue()
|
||||
|
||||
def factory(addr=None):
|
||||
nonlocal new_connections
|
||||
protocol = protocol_factory()
|
||||
|
||||
def callback(p):
|
||||
nonlocal new_connections
|
||||
connection = connection_factory()
|
||||
connection.connect(p.transport, p)
|
||||
new_connections.put_nowait(connection)
|
||||
|
||||
protocol.connected_callback.add(callback)
|
||||
|
||||
return protocol
|
||||
|
||||
server = await loop.create_server(factory, host=host, port=port, **kwargs)
|
||||
|
||||
return server, new_connections.get()
|
||||
|
||||
async def _process_request(self):
|
||||
pass
|
||||
|
||||
class LegacyConnectionManager:
|
||||
"""
|
||||
@@ -206,14 +268,6 @@ class LegacyConnectionManager:
|
||||
|
||||
|
||||
def _process_response(self, message):
|
||||
request_id, requested_value = message.jsonheader["request_id"], message.jsonheader["requested_value"]
|
||||
|
||||
with self._request_lock:
|
||||
request = self._request_queue.pop(request_id, None)
|
||||
if (request is None) or (request.requested_value != requested_value):
|
||||
logger.warning("Unexpected response!")
|
||||
return
|
||||
|
||||
if requested_value == "filetransfer":
|
||||
value = True
|
||||
self._process_filetransfer(message.content, request.callback_kwargs["filepath"])
|
||||
@@ -283,17 +337,6 @@ class LegacyConnectionManager:
|
||||
self.get_response("filetransfer", callback, request_kwargs=request_kwargs,
|
||||
callback_args=callback_args, callback_kwargs=callback_kwargs)
|
||||
|
||||
def _resend_requests(self):
|
||||
for request_id, request in self._request_queue.items(): # TODO filter
|
||||
if request.resend:
|
||||
self._send(Message.create_request(
|
||||
request.requested_value, request_id, request.request_kwargs.update(resend=request.resend))
|
||||
)
|
||||
request.resend = False
|
||||
|
||||
def _send_response(self, requested_value, request_id, value, filetransfer=False):
|
||||
self._send(Message.create_response(requested_value, request_id, value, filetransfer))
|
||||
|
||||
def send_file(self, filepath, dest_filepath): # clever_restart=False
|
||||
"""
|
||||
Sends to peer a file from `filepath` to write it on `dest_filepath`.
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
|
||||
def b_partial(func, *args, **kwargs): # call argument blocker partial
|
||||
return lambda *a: func(*args, **kwargs)
|
||||
124
lib/messages.py
124
lib/messages.py
@@ -45,6 +45,7 @@ class ContentTypes:
|
||||
ENCODED = "encoded"
|
||||
|
||||
class MessageTypes:
|
||||
MESSAGE = "message"
|
||||
REQUEST = "request"
|
||||
RESPONSE = "response"
|
||||
|
||||
@@ -159,11 +160,14 @@ class MessageDecoder:
|
||||
class MessageEncoder:
|
||||
def __init__(self, codec=default_codec):
|
||||
self.chain_id = None
|
||||
|
||||
self.codec = codec
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
return self.encode_raw_message(*args, **kwargs)
|
||||
def set_chain_id(self, chain_id=None):
|
||||
if chain_id is None:
|
||||
if self.chain_id is None:
|
||||
self.chain_id = uuid.uuid4().hex
|
||||
else:
|
||||
self.chain_id = chain_id
|
||||
|
||||
def encode_raw_message(self, content: bytes, content_type, message_type, chain_id=None, additional_headers=None):
|
||||
"""Returns encoded message in bytes. It is recommended use other encoding functions for general purposes.
|
||||
@@ -177,19 +181,21 @@ class MessageEncoder:
|
||||
Returns:
|
||||
bytes: encoded message
|
||||
"""
|
||||
# if chain_id is None:
|
||||
# if self.chain_id is not None:
|
||||
# chain_id = self.chain_id
|
||||
# else:
|
||||
# chain_id = uuid.uuid4().hex
|
||||
# else:
|
||||
# self.chain_id = chain_id
|
||||
self.set_chain_id(chain_id)
|
||||
|
||||
if chain_id is None:
|
||||
chain_id = uuid.uuid4().int
|
||||
elif self.chain_id is not None:
|
||||
chain_id = self.chain_id
|
||||
else:
|
||||
self.chain_id = chain_id
|
||||
|
||||
header = {
|
||||
"content-length": len(content),
|
||||
"content-type": content_type,
|
||||
"message-type": message_type,
|
||||
"chain-id": chain_id
|
||||
"chain-id": self.chain_id
|
||||
}
|
||||
if additional_headers:
|
||||
header.update(additional_headers)
|
||||
@@ -256,7 +262,7 @@ class MessageEncoder:
|
||||
# "response", additional_headers=headers)
|
||||
# return message
|
||||
|
||||
class PendingMessage(MessageEncoder):
|
||||
class AbstractPendingMessage(MessageEncoder):
|
||||
def __init__(self, codec=default_codec):
|
||||
super().__init__(codec)
|
||||
self._sent = asyncio.Future()
|
||||
@@ -265,24 +271,44 @@ class PendingMessage(MessageEncoder):
|
||||
def sent(self):
|
||||
return self._sent
|
||||
|
||||
async def send(self, connection):
|
||||
if self._sent:
|
||||
async def send_to(self, connection):
|
||||
if self._sent.done():
|
||||
raise RuntimeError("This message was already sent, create another one")
|
||||
return connection.send(self)
|
||||
|
||||
class Response(PendingMessage):
|
||||
def __init__(self, chain_id, result, type, codec=default_codec):
|
||||
super().__init__(codec)
|
||||
|
||||
self._chain_id = chain_id
|
||||
self._result = result
|
||||
return await connection.send(self)
|
||||
|
||||
def encode(self):
|
||||
contents = {"value": self._result}
|
||||
return self.encode_message(contents, MessageTypes.RESPONSE, chain_id=self._chain_id)
|
||||
raise NotImplementedError
|
||||
|
||||
class PendingMessage(AbstractPendingMessage):
|
||||
def __init__(self, content, encode_content=True, codec=default_codec):
|
||||
super().__init__(codec)
|
||||
self._content = content
|
||||
self.encode_content = encode_content
|
||||
|
||||
def encode(self):
|
||||
if self.encode_content:
|
||||
content = self.codec.encode(self._content)
|
||||
content_type = ContentTypes.ENCODED
|
||||
else:
|
||||
content = self._content
|
||||
content_type = ContentTypes.BINARY
|
||||
return self.encode_raw_message(content, content_type, MessageTypes.MESSAGE)
|
||||
|
||||
|
||||
class Request(PendingMessage):
|
||||
class Response(AbstractPendingMessage):
|
||||
def __init__(self, chain_id, result, response_type, codec=default_codec):
|
||||
super().__init__(codec)
|
||||
|
||||
self.chain_id = chain_id
|
||||
self._result = result
|
||||
self._type = response_type
|
||||
|
||||
def encode(self):
|
||||
return self.encode_message(self._result, MessageTypes.RESPONSE, # chain_id=self.chain_id,
|
||||
additional_headers={"response-type": self._type})
|
||||
|
||||
|
||||
class Request(AbstractPendingMessage):
|
||||
def __init__(self, name, args=(), kwargs=None, callback=None, codec=default_codec):
|
||||
super().__init__(codec)
|
||||
|
||||
@@ -295,11 +321,24 @@ class Request(PendingMessage):
|
||||
|
||||
self.callback = callback
|
||||
self._response = asyncio.Future()
|
||||
#self.responses ?
|
||||
|
||||
self._progress = float('nan')
|
||||
self._got_progress = asyncio.Future()
|
||||
|
||||
@property
|
||||
def response(self):
|
||||
return self._response
|
||||
|
||||
@property
|
||||
async def progress(self):
|
||||
await self._got_progress
|
||||
return self._progress
|
||||
|
||||
@progress.setter
|
||||
def progress(self, value):
|
||||
pass
|
||||
|
||||
def encode(self):
|
||||
contents = {"name": self._name,
|
||||
"args": self._args,
|
||||
@@ -325,7 +364,7 @@ class RequestBatch:
|
||||
|
||||
async def send(self):
|
||||
for connection, request in self._request_dict.items():
|
||||
connection.send(request)
|
||||
connection.send_to(request)
|
||||
|
||||
@property
|
||||
def request_dict(self):
|
||||
@@ -345,11 +384,40 @@ class ReceivedRequest:
|
||||
self.connection = connection
|
||||
self.message: MessageDecoder = message
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.message.content["name"]
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self.message.content.get("args", list())
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
return self.message.content.get("kwargs", dict())
|
||||
|
||||
async def send(self, reply: Response):
|
||||
await self.connection.protocol.send(reply)
|
||||
await reply.sent
|
||||
|
||||
async def reply(self, data):
|
||||
reply = Response(self.message.chain_id)
|
||||
reply = Response(self.message.chain_id, data, ResponseTypes.RESULT)
|
||||
|
||||
await self.send(reply)
|
||||
return reply
|
||||
|
||||
async def reply_processing(self, progress: float=0):
|
||||
progress = max(0, min(1, progress))
|
||||
progress = max(0.0, min(1.0, progress))
|
||||
#contents = {"progress": progress}
|
||||
reply = Response(self.message.chain_id, progress, ResponseTypes.RESULT)
|
||||
|
||||
await self.send(reply)
|
||||
return reply
|
||||
|
||||
async def reply_error(self, err: Exception):
|
||||
pass
|
||||
#contents = {"error": err}
|
||||
reply = Response(self.message.chain_id, err, ResponseTypes.ERROR)
|
||||
|
||||
await self.send(reply)
|
||||
return reply
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ import messages
|
||||
import exceptions
|
||||
|
||||
from network import str_peername
|
||||
from utils import Callback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("protocols")
|
||||
|
||||
|
||||
class BroadcastProtocol:
|
||||
@@ -29,7 +30,7 @@ class BroadcastProtocol:
|
||||
self._closed.set()
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
message = message.MessageDecoder(data)
|
||||
message = messages.MessageDecoder(data)
|
||||
message.process_message()
|
||||
content = message.content
|
||||
|
||||
@@ -46,8 +47,10 @@ class BroadcastProtocol:
|
||||
logger.warning(f"Error on broadcast connection received: {exc}")
|
||||
|
||||
class PeerProtocol(asyncio.Protocol):
|
||||
def __init__(self, connected_callback, callbacks):
|
||||
self._callbacks = callbacks
|
||||
def __init__(self, connected_callback=None, disconnected_callback=None, message_callback=None):
|
||||
self.connected_callback = Callback(connected_callback)
|
||||
self.disconnected_callback = Callback(disconnected_callback)
|
||||
self.message_callback = Callback(message_callback)
|
||||
|
||||
self.transport: asyncio.Transport = None
|
||||
|
||||
@@ -83,6 +86,10 @@ class PeerProtocol(asyncio.Protocol):
|
||||
def closed(self):
|
||||
return self._closed.wait()
|
||||
|
||||
async def close(self):
|
||||
self.transport.close()
|
||||
await self.closed
|
||||
|
||||
# Drain control
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
@@ -94,7 +101,7 @@ class PeerProtocol(asyncio.Protocol):
|
||||
async def drain(self) -> None:
|
||||
await self._can_write.wait()
|
||||
|
||||
async def send(self, msg: messages.PendingMessage):
|
||||
async def send(self, msg: messages.AbstractPendingMessage):
|
||||
if not self.is_connected: # and not send_disconnected
|
||||
msg.sent.cancel("Peer is disconnected, can't send")
|
||||
raise RuntimeError("Peer is disconnected, can't send")
|
||||
@@ -110,7 +117,7 @@ class PeerProtocol(asyncio.Protocol):
|
||||
while self.is_connected:
|
||||
msg = None
|
||||
try:
|
||||
msg: messages.PendingMessage = await self._send_queue.get()
|
||||
msg: messages.AbstractPendingMessage = await self._send_queue.get()
|
||||
self.transport.write(msg.encode())
|
||||
await self.drain()
|
||||
except asyncio.CancelledError:
|
||||
@@ -130,6 +137,8 @@ class PeerProtocol(asyncio.Protocol):
|
||||
|
||||
self._can_write.set()
|
||||
|
||||
self.connected_callback(self)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
logger.info(f"Lost connection to {str_peername(self.peername)}: {'closed' if exc is None else exc}")
|
||||
self._connected.clear()
|
||||
@@ -149,6 +158,8 @@ class PeerProtocol(asyncio.Protocol):
|
||||
self._recv_buffer = bytearray()
|
||||
self._current_msg = None
|
||||
|
||||
self.disconnected_callback(self)
|
||||
|
||||
def error_received(self, exc):
|
||||
logger.warning(f"Error on {str_peername(self.peername)} connection received: {exc}")
|
||||
|
||||
@@ -171,8 +182,11 @@ class PeerProtocol(asyncio.Protocol):
|
||||
self._recv_buffer = self._current_msg.get_buffer()
|
||||
# self._current_msg.reset_buffer()
|
||||
if self._current_msg.processed:
|
||||
logger.debug(f"Recieved message {self._current_msg.content} from {str_peername(self.peername)}")
|
||||
await self._recv_queue.put(self._current_msg)
|
||||
logger.debug(f"Received message {self._current_msg.content} from {str_peername(self.peername)}")
|
||||
|
||||
put_msg = self.message_callback.gather_bool(self, self._current_msg)
|
||||
if put_msg:
|
||||
await self._recv_queue.put(self._current_msg)
|
||||
self._current_msg = None
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
277
lib/utils.py
Normal file
277
lib/utils.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import collections
|
||||
import asyncio
|
||||
|
||||
logger_format = "%(asctime)s [%(name)-11.11s] [%(levelname)-7.7s] %(message)s"
|
||||
|
||||
def b_partial(func, *args, **kwargs): # call argument blocker partial
|
||||
return lambda *a: func(*args, **kwargs)
|
||||
|
||||
|
||||
class Callback:
|
||||
def __init__(self, *callbacks):
|
||||
self._callbacks = set()
|
||||
self._add_callbacks(callbacks)
|
||||
|
||||
@property
|
||||
def all(self):
|
||||
return self._callbacks.copy()
|
||||
|
||||
def add(self, callback):
|
||||
self._add_callback(callback)
|
||||
|
||||
def remove(self, callback):
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
def _add_callbacks(self, callbacks):
|
||||
for callback in callbacks:
|
||||
self._add_callback(callback)
|
||||
|
||||
def _add_callback(self, callback):
|
||||
if callback is None:
|
||||
return
|
||||
elif isinstance(callback, Callback):
|
||||
self._callbacks.update(callback.all)
|
||||
elif isinstance(callback, (list, set)):
|
||||
self._add_callbacks(callback)
|
||||
elif callable(callback):
|
||||
self._callbacks.add(callback)
|
||||
else:
|
||||
raise ValueError(f"Callback {callback} is not callable object!")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return list(self.iter(*args, **kwargs))
|
||||
|
||||
def iter(self, *args, **kwargs):
|
||||
for callback in self._callbacks:
|
||||
yield callback(*args, **kwargs)
|
||||
|
||||
def gather_bool(self, *args, **kwargs):
|
||||
if self.all:
|
||||
return all(self(*args, **kwargs))
|
||||
return True
|
||||
|
||||
class KeyQueue:
|
||||
"""A queue, useful for coordinating producer and consumer coroutines.
|
||||
If maxsize is less than or equal to zero, the queue size is infinite. If it
|
||||
is an integer greater than 0, then "await put()" will block when the
|
||||
queue reaches maxsize, until an item is removed by get().
|
||||
Unlike the standard library Queue, you can reliably know this Queue's size
|
||||
with qsize(), since your single-threaded asyncio application won't be
|
||||
interrupted between calling qsize() and doing an operation on the Queue.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize=0):
|
||||
self._loop = asyncio.events.get_event_loop()
|
||||
|
||||
self._maxsize = maxsize
|
||||
|
||||
# Futures.
|
||||
self._getters = collections.deque()
|
||||
# Futures.
|
||||
self._putters = collections.deque()
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = asyncio.Event(loop=self._loop)
|
||||
self._finished.set()
|
||||
self._init(maxsize)
|
||||
|
||||
# These three are overridable in subclasses.
|
||||
|
||||
def _init(self, maxsize):
|
||||
self._queue = collections.deque()
|
||||
|
||||
def _get(self, key=None):
|
||||
if key is None:
|
||||
return self._queue.popleft()[1]
|
||||
|
||||
for pair in self._queue:
|
||||
k, item = pair
|
||||
if k == key:
|
||||
self._queue.remove(pair)
|
||||
return item
|
||||
else:
|
||||
raise KeyError(f"No item with key {key} in queue")
|
||||
|
||||
def _put(self, item, key=None):
|
||||
self._queue.append((key, item))
|
||||
|
||||
# End of the overridable methods.
|
||||
|
||||
def _wakeup_next(self, waiters):
|
||||
# Wake up the next waiter (if any) that isn't cancelled.
|
||||
while waiters:
|
||||
waiter = waiters.popleft()
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
break
|
||||
|
||||
def _wakeup_get(self, key=None):
|
||||
if key is None:
|
||||
self._wakeup_next(self._getters)
|
||||
|
||||
for pair in self._getters:
|
||||
k, waiter = pair
|
||||
if k == key and not waiter.done():
|
||||
waiter.set_result(None)
|
||||
self._getters.remove(pair)
|
||||
break
|
||||
|
||||
def __repr__(self):
|
||||
return f'<{type(self).__name__} at {id(self):#x} {self._format()}>'
|
||||
|
||||
def __str__(self):
|
||||
return f'<{type(self).__name__} {self._format()}>'
|
||||
|
||||
def __class_getitem__(cls, type):
|
||||
return cls
|
||||
|
||||
def _format(self):
|
||||
result = f'maxsize={self._maxsize!r}'
|
||||
if getattr(self, '_queue', None):
|
||||
result += f' _queue={list(self._queue)!r}'
|
||||
if self._getters:
|
||||
result += f' _getters[{len(self._getters)}]'
|
||||
if self._putters:
|
||||
result += f' _putters[{len(self._putters)}]'
|
||||
if self._unfinished_tasks:
|
||||
result += f' tasks={self._unfinished_tasks}'
|
||||
return result
|
||||
|
||||
def qsize(self):
|
||||
"""Number of items in the queue."""
|
||||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def maxsize(self):
|
||||
"""Number of items allowed in the queue."""
|
||||
return self._maxsize
|
||||
|
||||
def empty(self):
|
||||
"""Return True if the queue is empty, False otherwise."""
|
||||
return not self._queue
|
||||
|
||||
def full(self):
|
||||
"""Return True if there are maxsize items in the queue.
|
||||
|
||||
Note: if the Queue was initialized with maxsize=0 (the default),
|
||||
then full() is never True.
|
||||
"""
|
||||
if self._maxsize <= 0:
|
||||
return False
|
||||
else:
|
||||
return self.qsize() >= self._maxsize
|
||||
|
||||
async def put(self, item, key=None):
|
||||
"""Put an item into the queue.
|
||||
|
||||
Put an item into the queue. If the queue is full, wait until a free
|
||||
slot is available before adding item.
|
||||
"""
|
||||
while self.full():
|
||||
putter = self._loop.create_future()
|
||||
self._putters.append(putter)
|
||||
try:
|
||||
await putter
|
||||
except:
|
||||
putter.cancel() # Just in case putter is not done yet.
|
||||
try:
|
||||
# Clean self._putters from canceled putters.
|
||||
self._putters.remove(putter)
|
||||
except ValueError:
|
||||
# The putter could be removed from self._putters by a
|
||||
# previous get_nowait call.
|
||||
pass
|
||||
if not self.full() and not putter.cancelled():
|
||||
# We were woken up by get_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._putters)
|
||||
raise
|
||||
return self.put_nowait(item, key)
|
||||
|
||||
def put_nowait(self, item, key=None):
|
||||
"""Put an item into the queue without blocking.
|
||||
|
||||
If no free slot is immediately available, raise QueueFull.
|
||||
"""
|
||||
if self.full():
|
||||
raise asyncio.QueueFull
|
||||
self._put(item, key)
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._wakeup_get(key)
|
||||
|
||||
async def get(self, key=None):
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
If queue is empty, wait until an item is available.
|
||||
"""
|
||||
while self.empty():
|
||||
getter = self._loop.create_future()
|
||||
key_getter = key, getter
|
||||
self._getters.append(key_getter)
|
||||
try:
|
||||
await getter
|
||||
except:
|
||||
getter.cancel() # Just in case getter is not done yet.
|
||||
try:
|
||||
# Clean self._getters from canceled getters.
|
||||
self._getters.remove(key_getter)
|
||||
except ValueError:
|
||||
# The getter could be removed from self._getters by a
|
||||
# previous put_nowait call.
|
||||
pass
|
||||
if not self.empty() and not getter.cancelled():
|
||||
# We were woken up by put_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_get(key)
|
||||
raise
|
||||
return self.get_nowait(key)
|
||||
|
||||
def get_nowait(self, key=None):
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
Return an item if one is immediately available, else raise QueueEmpty.
|
||||
"""
|
||||
if self.empty():
|
||||
raise asyncio.QueueEmpty
|
||||
item = self._get(key)
|
||||
self._wakeup_next(self._putters)
|
||||
return item
|
||||
|
||||
def task_done(self):
|
||||
"""Indicate that a formerly enqueued task is complete.
|
||||
|
||||
Used by queue consumers. For each get() used to fetch a task,
|
||||
a subsequent call to task_done() tells the queue that the processing
|
||||
on the task is complete.
|
||||
|
||||
If a join() is currently blocking, it will resume when all items have
|
||||
been processed (meaning that a task_done() call was received for every
|
||||
item that had been put() into the queue).
|
||||
|
||||
Raises ValueError if called more times than there were items placed in
|
||||
the queue.
|
||||
"""
|
||||
if self._unfinished_tasks <= 0:
|
||||
raise ValueError('task_done() called too many times')
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
async def join(self):
|
||||
"""Block until all items in the queue have been gotten and processed.
|
||||
|
||||
The count of unfinished tasks goes up whenever an item is added to the
|
||||
queue. The count goes down whenever a consumer calls task_done() to
|
||||
indicate that the item was retrieved and all work on it is complete.
|
||||
When the count of unfinished tasks drops to zero, join() unblocks.
|
||||
"""
|
||||
if self._unfinished_tasks > 0:
|
||||
await self._finished.wait()
|
||||
|
||||
def cancel(self):
|
||||
# self._finished.set()
|
||||
# self._unfinished_tasks = 0
|
||||
while self._getters:
|
||||
pair = self._getters.popleft()
|
||||
key, getter = pair
|
||||
getter.cancel()
|
||||
@@ -3,13 +3,13 @@ import sys
|
||||
import pytest
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.insert(0, os.path.realpath(os.path.join(current_dir, os.pardir)))
|
||||
|
||||
from lib.messaging import CallbackManager
|
||||
|
||||
def test_callback_registration():
|
||||
def f(arg):
|
||||
return arg
|
||||
c = CallbackManager()
|
||||
c.action_callback("123")(f)
|
||||
t = "test"
|
||||
assert t == c.action_callbacks["123"](t)
|
||||
#
|
||||
# from lib.mes import CallbackManager
|
||||
#
|
||||
# def test_callback_registration():
|
||||
# def f(arg):
|
||||
# return arg
|
||||
# c = CallbackManager()
|
||||
# c.action_callback("123")(f)
|
||||
# t = "test"
|
||||
# assert t == c.action_callbacks["123"](t)
|
||||
|
||||
28
tests/utils_test.py
Normal file
28
tests/utils_test.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from lib.utils import KeyQueue
|
||||
|
||||
def test_queue_sync():
|
||||
q = KeyQueue()
|
||||
val1 = "test1"
|
||||
val2 = "test2"
|
||||
key1 = "key1"
|
||||
q.put_nowait(val1)
|
||||
q.put_nowait(val2, key1)
|
||||
|
||||
assert q.get_nowait() == val1
|
||||
assert q.get_nowait(key1) == val2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_putting_async():
|
||||
q = KeyQueue()
|
||||
val1 = "test1"
|
||||
val2 = "test2"
|
||||
key1 = "key1"
|
||||
await q.put(val1)
|
||||
await q.put(val2, key1)
|
||||
|
||||
assert await q.get(key1) == val2
|
||||
assert await q.get() == val1
|
||||
Reference in New Issue
Block a user