From 70ec36b99291e1f33ae6d83b1b4625d98fbab5e7 Mon Sep 17 00:00:00 2001 From: Artem30801 Date: Wed, 2 Dec 2020 09:49:44 +0300 Subject: [PATCH] Request\response processing --- drone/tests/animation_test.py | 2 +- examples/code/connection.py | 44 ++++++ examples/code/protocol.py | 24 ++- lib/config.py | 5 +- lib/connections.py | 137 +++++++++++------ lib/lib.py | 3 - lib/messages.py | 124 +++++++++++---- lib/protocols.py | 30 +++- lib/utils.py | 277 ++++++++++++++++++++++++++++++++++ tests/messaging_test.py | 20 +-- tests/utils_test.py | 28 ++++ 11 files changed, 590 insertions(+), 104 deletions(-) create mode 100644 examples/code/connection.py delete mode 100644 lib/lib.py create mode 100644 lib/utils.py create mode 100644 tests/utils_test.py diff --git a/drone/tests/animation_test.py b/drone/tests/animation_test.py index 534509f..8e65b51 100644 --- a/drone/tests/animation_test.py +++ b/drone/tests/animation_test.py @@ -7,7 +7,7 @@ import logging logging.basicConfig( # TODO all prints as logs level=logging.DEBUG, # INFO - stream=sys.stdout, +# stream=sys.stdout, format="%(asctime)s [%(name)-7.7s] [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s", handlers=[ logging.StreamHandler(sys.stdout), diff --git a/examples/code/connection.py b/examples/code/connection.py new file mode 100644 index 0000000..7a88c9b --- /dev/null +++ b/examples/code/connection.py @@ -0,0 +1,44 @@ +import asyncio +import logging + +import lib.connections as c +import lib.messages as messages +import lib.utils as utils + +logging.basicConfig( + level=logging.DEBUG, + format=utils.logger_format, + handlers=[ + logging.StreamHandler() + ]) + +async def server(): + server, new_connections = await c.create_server(8181) + connection1: c.Connection = await new_connections + #await connection1.send(messages.Request("greetings")) + await asyncio.sleep(2) + t1 = asyncio.create_task(asyncio.wait_for(connection1.send(messages.Request("greetings")), 5)) + t2 = asyncio.create_task(connection1.send(messages.Request("hello", args=(1, 3)))) + + print(await asyncio.gather(t1, t2, return_exceptions=True)) + await connection1.close() + server.close() + await server.wait_closed() + +async def client(): + connection = await c.connect(8181, "") + request = await connection.receive("hello") + print(request.message.content) + request2 = await connection.receive() + print(request2.message.content) + + await request.reply("hi") + + # await connection.close() + +async def main(): + s = asyncio.create_task(server()) + c = asyncio.create_task(client()) + await asyncio.gather(s, c) + +asyncio.run(main(), debug=False) diff --git a/examples/code/protocol.py b/examples/code/protocol.py index be574cf..0e31ea4 100644 --- a/examples/code/protocol.py +++ b/examples/code/protocol.py @@ -13,9 +13,25 @@ logging.basicConfig( async def server(): loop = asyncio.get_event_loop() + clients = asyncio.Queue() - server = await loop.create_server(lambda: p.PeerProtocol(None, None), host='', port=8181) + def factory(*args): + protocol = p.PeerProtocol() + clients.put_nowait(protocol) + return protocol + server = await loop.create_server(factory, host='', port=8181) + + client1 = await clients.get() + await client1.connected + msg = await client1.receive_message() + print(msg.header, msg.content) + + msg = messages.PendingMessage(msg.content + " from server") + await msg.send_to(client1) + + await client1.closed + server.close() await server.wait_closed() async def client(): @@ -23,9 +39,11 @@ async def client(): transport, protocol = await loop.create_connection(lambda: p.PeerProtocol(None, None), host='', port=8181) - await protocol.send(messages.Request("greetings")) - await protocol.closed + await protocol.send(messages.PendingMessage("greetings")) + print((await protocol.receive_message()).content) + + await protocol.close() async def main(): s = asyncio.create_task(server()) diff --git a/lib/config.py b/lib/config.py index f898fb1..55676bb 100644 --- a/lib/config.py +++ b/lib/config.py @@ -14,10 +14,7 @@ def modify_filename(path, pattern): # TODO move to core def parent_path(path, levels=1): - for i in range(levels): - path = os.path.abspath(os.path.join(path, os.pardir)) - return path - + return os.path.realpath(os.path.join(path, *[os.pardir for _ in range(levels)])) def parent_dir(path): return os.path.basename(os.path.normpath(path)) diff --git a/lib/connections.py b/lib/connections.py index 3429cdb..031a83e 100644 --- a/lib/connections.py +++ b/lib/connections.py @@ -10,10 +10,11 @@ import traceback import messages import exceptions -from messages import MessageDecoder, Request +from messages import MessageDecoder, Request, ReceivedRequest from protocols import PeerProtocol +from utils import Callback, KeyQueue -logger = logging.getLogger(__name__) +logger = logging.getLogger("connections") class CallbackManager: def __init__(self): @@ -113,24 +114,31 @@ class temp: # Sending api functions class Connection: - def __init__(self, callbacks): - self.callbacks = callbacks + def __init__(self, request_callback=None): + self.request_callback = Callback(request_callback) self.protocol: PeerProtocol = None self.transport = None self._sent_requests = dict() # holds requests awaiting reply - self._recv_requests = asyncio.Queue() + self._recv_requests = KeyQueue() - def connect(self, protocol: PeerProtocol, transport): - self.protocol = protocol + def connect(self, transport, protocol: PeerProtocol): self.transport = transport + self.protocol = protocol + + self.protocol.message_callback.add(self._receive) + self.protocol.disconnected_callback.add(self._on_disconnect) + + def _on_disconnect(self, protocol): + self._recv_requests.cancel() async def close(self): self.transport.close() await self.protocol.closed - async def send(self, request: Request): + async def send(self, request: Request) -> messages.MessageDecoder: + request.set_chain_id() self._sent_requests[request.chain_id] = request try: @@ -144,40 +152,94 @@ class Connection: else: return response - async def send_reply(self, response: messages.Response): - await self.protocol.send(response) + async def send_request(self, name, args=(), kwargs=None, callback=None): + request = Request(name, args, kwargs, callback) + return await self.send(request) - async def receive(self): - if self._recv_waiter is not None: - raise RuntimeError("receive() is already being awaited") - - self._recv_waiter = asyncio.create_task(self._recv_requests.get()) + async def send_response(self, request, data=None, progress=None, err=None): + pass + async def receive(self, name: str =None) -> messages.ReceivedRequest: try: - return await self._recv_waiter + return await self._recv_requests.get(key=name) except asyncio.CancelledError: raise exceptions.ConnectionClosedError - finally: - self._recv_waiter = None - async def _receive(self): - msg = await self.protocol.receive_message() + def _receive(self, protocol, msg): if msg.message_type == messages.MessageTypes.RESPONSE: - await self._process_response(msg) + self._process_response(msg) elif msg.message_type == messages.MessageTypes.REQUEST: - pass + self._process_request(msg) + else: + logger.warning(f"Got unknown message type {msg.message_type} for message {msg}") + return True # put msg in protocol rcv queue + return False # don't put msg in protocol rcv queue + + def _process_response(self, response: MessageDecoder): + logger.debug(f"Received response {response} with id {response.chain_id}") - async def _process_response(self, msg: MessageDecoder): try: - request = self._sent_requests[msg.chain_id] + request = self._sent_requests.pop(response.chain_id) except KeyError: - logger.error(f"Unknown response {msg} with id {msg.chain_id}") + logger.error(f"Unexpected response {response} with id {response.chain_id}") return - request.response.set_result(msg) + try: + response_type = response.header["response-type"] + except KeyError: + logger.error("Missing 'response-type' header in response") + return + + if response_type == messages.ResponseTypes.RESULT: + result = response.content + request.response.set_result(result) + elif response_type == messages.ResponseTypes.ERROR: # TODOOOO + error = None + request.response.set_result(error) + elif response_type == messages.ResponseTypes.STATUS: + pass + else: + logger.error(f"Unknown response type '{response_type}' in response {response}") + + def _process_request(self, request_msg): + request = ReceivedRequest(self, request_msg) + logger.debug(f"Received request {request_msg} with id {request_msg.chain_id} for {request.name}") + + put_msg = self.request_callback.gather_bool(self, request) + if put_msg: + self._recv_requests.put_nowait(request, key=request.name) + +async def connect(port, host, connection_factory=Connection, protocol_factory=PeerProtocol, **kwargs): + loop = asyncio.get_event_loop() + + transport, protocol = await loop.create_connection(protocol_factory, host=host, port=port, **kwargs) + connection = connection_factory() + connection.connect(transport, protocol) + return connection + +async def create_server(port, host="", connection_factory=Connection, protocol_factory=PeerProtocol, + reuse_connections=True, **kwargs): + loop = asyncio.get_event_loop() + new_connections = asyncio.Queue() + + def factory(addr=None): + nonlocal new_connections + protocol = protocol_factory() + + def callback(p): + nonlocal new_connections + connection = connection_factory() + connection.connect(p.transport, p) + new_connections.put_nowait(connection) + + protocol.connected_callback.add(callback) + + return protocol + + server = await loop.create_server(factory, host=host, port=port, **kwargs) + + return server, new_connections.get() - async def _process_request(self): - pass class LegacyConnectionManager: """ @@ -206,14 +268,6 @@ class LegacyConnectionManager: def _process_response(self, message): - request_id, requested_value = message.jsonheader["request_id"], message.jsonheader["requested_value"] - - with self._request_lock: - request = self._request_queue.pop(request_id, None) - if (request is None) or (request.requested_value != requested_value): - logger.warning("Unexpected response!") - return - if requested_value == "filetransfer": value = True self._process_filetransfer(message.content, request.callback_kwargs["filepath"]) @@ -283,17 +337,6 @@ class LegacyConnectionManager: self.get_response("filetransfer", callback, request_kwargs=request_kwargs, callback_args=callback_args, callback_kwargs=callback_kwargs) - def _resend_requests(self): - for request_id, request in self._request_queue.items(): # TODO filter - if request.resend: - self._send(Message.create_request( - request.requested_value, request_id, request.request_kwargs.update(resend=request.resend)) - ) - request.resend = False - - def _send_response(self, requested_value, request_id, value, filetransfer=False): - self._send(Message.create_response(requested_value, request_id, value, filetransfer)) - def send_file(self, filepath, dest_filepath): # clever_restart=False """ Sends to peer a file from `filepath` to write it on `dest_filepath`. diff --git a/lib/lib.py b/lib/lib.py deleted file mode 100644 index d3bfc11..0000000 --- a/lib/lib.py +++ /dev/null @@ -1,3 +0,0 @@ - -def b_partial(func, *args, **kwargs): # call argument blocker partial - return lambda *a: func(*args, **kwargs) diff --git a/lib/messages.py b/lib/messages.py index d8c3acd..0de5858 100644 --- a/lib/messages.py +++ b/lib/messages.py @@ -45,6 +45,7 @@ class ContentTypes: ENCODED = "encoded" class MessageTypes: + MESSAGE = "message" REQUEST = "request" RESPONSE = "response" @@ -159,11 +160,14 @@ class MessageDecoder: class MessageEncoder: def __init__(self, codec=default_codec): self.chain_id = None - self.codec = codec - def encode(self, *args, **kwargs): - return self.encode_raw_message(*args, **kwargs) + def set_chain_id(self, chain_id=None): + if chain_id is None: + if self.chain_id is None: + self.chain_id = uuid.uuid4().hex + else: + self.chain_id = chain_id def encode_raw_message(self, content: bytes, content_type, message_type, chain_id=None, additional_headers=None): """Returns encoded message in bytes. It is recommended use other encoding functions for general purposes. @@ -177,19 +181,21 @@ class MessageEncoder: Returns: bytes: encoded message """ + # if chain_id is None: + # if self.chain_id is not None: + # chain_id = self.chain_id + # else: + # chain_id = uuid.uuid4().hex + # else: + # self.chain_id = chain_id + self.set_chain_id(chain_id) - if chain_id is None: - chain_id = uuid.uuid4().int - elif self.chain_id is not None: - chain_id = self.chain_id - else: - self.chain_id = chain_id header = { "content-length": len(content), "content-type": content_type, "message-type": message_type, - "chain-id": chain_id + "chain-id": self.chain_id } if additional_headers: header.update(additional_headers) @@ -256,7 +262,7 @@ class MessageEncoder: # "response", additional_headers=headers) # return message -class PendingMessage(MessageEncoder): +class AbstractPendingMessage(MessageEncoder): def __init__(self, codec=default_codec): super().__init__(codec) self._sent = asyncio.Future() @@ -265,24 +271,44 @@ class PendingMessage(MessageEncoder): def sent(self): return self._sent - async def send(self, connection): - if self._sent: + async def send_to(self, connection): + if self._sent.done(): raise RuntimeError("This message was already sent, create another one") - return connection.send(self) - -class Response(PendingMessage): - def __init__(self, chain_id, result, type, codec=default_codec): - super().__init__(codec) - - self._chain_id = chain_id - self._result = result + return await connection.send(self) def encode(self): - contents = {"value": self._result} - return self.encode_message(contents, MessageTypes.RESPONSE, chain_id=self._chain_id) + raise NotImplementedError + +class PendingMessage(AbstractPendingMessage): + def __init__(self, content, encode_content=True, codec=default_codec): + super().__init__(codec) + self._content = content + self.encode_content = encode_content + + def encode(self): + if self.encode_content: + content = self.codec.encode(self._content) + content_type = ContentTypes.ENCODED + else: + content = self._content + content_type = ContentTypes.BINARY + return self.encode_raw_message(content, content_type, MessageTypes.MESSAGE) -class Request(PendingMessage): +class Response(AbstractPendingMessage): + def __init__(self, chain_id, result, response_type, codec=default_codec): + super().__init__(codec) + + self.chain_id = chain_id + self._result = result + self._type = response_type + + def encode(self): + return self.encode_message(self._result, MessageTypes.RESPONSE, # chain_id=self.chain_id, + additional_headers={"response-type": self._type}) + + +class Request(AbstractPendingMessage): def __init__(self, name, args=(), kwargs=None, callback=None, codec=default_codec): super().__init__(codec) @@ -295,11 +321,24 @@ class Request(PendingMessage): self.callback = callback self._response = asyncio.Future() + #self.responses ? + + self._progress = float('nan') + self._got_progress = asyncio.Future() @property def response(self): return self._response + @property + async def progress(self): + await self._got_progress + return self._progress + + @progress.setter + def progress(self, value): + pass + def encode(self): contents = {"name": self._name, "args": self._args, @@ -325,7 +364,7 @@ class RequestBatch: async def send(self): for connection, request in self._request_dict.items(): - connection.send(request) + connection.send_to(request) @property def request_dict(self): @@ -345,11 +384,40 @@ class ReceivedRequest: self.connection = connection self.message: MessageDecoder = message + @property + def name(self): + return self.message.content["name"] + + @property + def args(self): + return self.message.content.get("args", list()) + + @property + def kwargs(self): + return self.message.content.get("kwargs", dict()) + + async def send(self, reply: Response): + await self.connection.protocol.send(reply) + await reply.sent + async def reply(self, data): - reply = Response(self.message.chain_id) + reply = Response(self.message.chain_id, data, ResponseTypes.RESULT) + + await self.send(reply) + return reply async def reply_processing(self, progress: float=0): - progress = max(0, min(1, progress)) + progress = max(0.0, min(1.0, progress)) + #contents = {"progress": progress} + reply = Response(self.message.chain_id, progress, ResponseTypes.RESULT) + + await self.send(reply) + return reply async def reply_error(self, err: Exception): - pass + #contents = {"error": err} + reply = Response(self.message.chain_id, err, ResponseTypes.ERROR) + + await self.send(reply) + return reply + diff --git a/lib/protocols.py b/lib/protocols.py index 63c3469..a2fdc09 100644 --- a/lib/protocols.py +++ b/lib/protocols.py @@ -5,8 +5,9 @@ import messages import exceptions from network import str_peername +from utils import Callback -logger = logging.getLogger(__name__) +logger = logging.getLogger("protocols") class BroadcastProtocol: @@ -29,7 +30,7 @@ class BroadcastProtocol: self._closed.set() def datagram_received(self, data, addr): - message = message.MessageDecoder(data) + message = messages.MessageDecoder(data) message.process_message() content = message.content @@ -46,8 +47,10 @@ class BroadcastProtocol: logger.warning(f"Error on broadcast connection received: {exc}") class PeerProtocol(asyncio.Protocol): - def __init__(self, connected_callback, callbacks): - self._callbacks = callbacks + def __init__(self, connected_callback=None, disconnected_callback=None, message_callback=None): + self.connected_callback = Callback(connected_callback) + self.disconnected_callback = Callback(disconnected_callback) + self.message_callback = Callback(message_callback) self.transport: asyncio.Transport = None @@ -83,6 +86,10 @@ class PeerProtocol(asyncio.Protocol): def closed(self): return self._closed.wait() + async def close(self): + self.transport.close() + await self.closed + # Drain control def pause_writing(self) -> None: @@ -94,7 +101,7 @@ class PeerProtocol(asyncio.Protocol): async def drain(self) -> None: await self._can_write.wait() - async def send(self, msg: messages.PendingMessage): + async def send(self, msg: messages.AbstractPendingMessage): if not self.is_connected: # and not send_disconnected msg.sent.cancel("Peer is disconnected, can't send") raise RuntimeError("Peer is disconnected, can't send") @@ -110,7 +117,7 @@ class PeerProtocol(asyncio.Protocol): while self.is_connected: msg = None try: - msg: messages.PendingMessage = await self._send_queue.get() + msg: messages.AbstractPendingMessage = await self._send_queue.get() self.transport.write(msg.encode()) await self.drain() except asyncio.CancelledError: @@ -130,6 +137,8 @@ class PeerProtocol(asyncio.Protocol): self._can_write.set() + self.connected_callback(self) + def connection_lost(self, exc): logger.info(f"Lost connection to {str_peername(self.peername)}: {'closed' if exc is None else exc}") self._connected.clear() @@ -149,6 +158,8 @@ class PeerProtocol(asyncio.Protocol): self._recv_buffer = bytearray() self._current_msg = None + self.disconnected_callback(self) + def error_received(self, exc): logger.warning(f"Error on {str_peername(self.peername)} connection received: {exc}") @@ -171,8 +182,11 @@ class PeerProtocol(asyncio.Protocol): self._recv_buffer = self._current_msg.get_buffer() # self._current_msg.reset_buffer() if self._current_msg.processed: - logger.debug(f"Recieved message {self._current_msg.content} from {str_peername(self.peername)}") - await self._recv_queue.put(self._current_msg) + logger.debug(f"Received message {self._current_msg.content} from {str_peername(self.peername)}") + + put_msg = self.message_callback.gather_bool(self, self._current_msg) + if put_msg: + await self._recv_queue.put(self._current_msg) self._current_msg = None else: await asyncio.sleep(0) diff --git a/lib/utils.py b/lib/utils.py new file mode 100644 index 0000000..03b038c --- /dev/null +++ b/lib/utils.py @@ -0,0 +1,277 @@ +import collections +import asyncio + +logger_format = "%(asctime)s [%(name)-11.11s] [%(levelname)-7.7s] %(message)s" + +def b_partial(func, *args, **kwargs): # call argument blocker partial + return lambda *a: func(*args, **kwargs) + + +class Callback: + def __init__(self, *callbacks): + self._callbacks = set() + self._add_callbacks(callbacks) + + @property + def all(self): + return self._callbacks.copy() + + def add(self, callback): + self._add_callback(callback) + + def remove(self, callback): + self._callbacks.remove(callback) + + def _add_callbacks(self, callbacks): + for callback in callbacks: + self._add_callback(callback) + + def _add_callback(self, callback): + if callback is None: + return + elif isinstance(callback, Callback): + self._callbacks.update(callback.all) + elif isinstance(callback, (list, set)): + self._add_callbacks(callback) + elif callable(callback): + self._callbacks.add(callback) + else: + raise ValueError(f"Callback {callback} is not callable object!") + + def __call__(self, *args, **kwargs): + return list(self.iter(*args, **kwargs)) + + def iter(self, *args, **kwargs): + for callback in self._callbacks: + yield callback(*args, **kwargs) + + def gather_bool(self, *args, **kwargs): + if self.all: + return all(self(*args, **kwargs)) + return True + +class KeyQueue: + """A queue, useful for coordinating producer and consumer coroutines. + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "await put()" will block when the + queue reaches maxsize, until an item is removed by get(). + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded asyncio application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0): + self._loop = asyncio.events.get_event_loop() + + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Futures. + self._putters = collections.deque() + self._unfinished_tasks = 0 + self._finished = asyncio.Event(loop=self._loop) + self._finished.set() + self._init(maxsize) + + # These three are overridable in subclasses. + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self, key=None): + if key is None: + return self._queue.popleft()[1] + + for pair in self._queue: + k, item = pair + if k == key: + self._queue.remove(pair) + return item + else: + raise KeyError(f"No item with key {key} in queue") + + def _put(self, item, key=None): + self._queue.append((key, item)) + + # End of the overridable methods. + + def _wakeup_next(self, waiters): + # Wake up the next waiter (if any) that isn't cancelled. + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + break + + def _wakeup_get(self, key=None): + if key is None: + self._wakeup_next(self._getters) + + for pair in self._getters: + k, waiter = pair + if k == key and not waiter.done(): + waiter.set_result(None) + self._getters.remove(pair) + break + + def __repr__(self): + return f'<{type(self).__name__} at {id(self):#x} {self._format()}>' + + def __str__(self): + return f'<{type(self).__name__} {self._format()}>' + + def __class_getitem__(cls, type): + return cls + + def _format(self): + result = f'maxsize={self._maxsize!r}' + if getattr(self, '_queue', None): + result += f' _queue={list(self._queue)!r}' + if self._getters: + result += f' _getters[{len(self._getters)}]' + if self._putters: + result += f' _putters[{len(self._putters)}]' + if self._unfinished_tasks: + result += f' tasks={self._unfinished_tasks}' + return result + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() >= self._maxsize + + async def put(self, item, key=None): + """Put an item into the queue. + + Put an item into the queue. If the queue is full, wait until a free + slot is available before adding item. + """ + while self.full(): + putter = self._loop.create_future() + self._putters.append(putter) + try: + await putter + except: + putter.cancel() # Just in case putter is not done yet. + try: + # Clean self._putters from canceled putters. + self._putters.remove(putter) + except ValueError: + # The putter could be removed from self._putters by a + # previous get_nowait call. + pass + if not self.full() and not putter.cancelled(): + # We were woken up by get_nowait(), but can't take + # the call. Wake up the next in line. + self._wakeup_next(self._putters) + raise + return self.put_nowait(item, key) + + def put_nowait(self, item, key=None): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise QueueFull. + """ + if self.full(): + raise asyncio.QueueFull + self._put(item, key) + self._unfinished_tasks += 1 + self._finished.clear() + self._wakeup_get(key) + + async def get(self, key=None): + """Remove and return an item from the queue. + + If queue is empty, wait until an item is available. + """ + while self.empty(): + getter = self._loop.create_future() + key_getter = key, getter + self._getters.append(key_getter) + try: + await getter + except: + getter.cancel() # Just in case getter is not done yet. + try: + # Clean self._getters from canceled getters. + self._getters.remove(key_getter) + except ValueError: + # The getter could be removed from self._getters by a + # previous put_nowait call. + pass + if not self.empty() and not getter.cancelled(): + # We were woken up by put_nowait(), but can't take + # the call. Wake up the next in line. + self._wakeup_get(key) + raise + return self.get_nowait(key) + + def get_nowait(self, key=None): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise QueueEmpty. + """ + if self.empty(): + raise asyncio.QueueEmpty + item = self._get(key) + self._wakeup_next(self._putters) + return item + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + async def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer calls task_done() to + indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + await self._finished.wait() + + def cancel(self): + # self._finished.set() + # self._unfinished_tasks = 0 + while self._getters: + pair = self._getters.popleft() + key, getter = pair + getter.cancel() diff --git a/tests/messaging_test.py b/tests/messaging_test.py index 5a9edaf..485a897 100644 --- a/tests/messaging_test.py +++ b/tests/messaging_test.py @@ -3,13 +3,13 @@ import sys import pytest current_dir = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, os.path.realpath(os.path.join(current_dir, os.pardir))) - -from lib.messaging import CallbackManager - -def test_callback_registration(): - def f(arg): - return arg - c = CallbackManager() - c.action_callback("123")(f) - t = "test" - assert t == c.action_callbacks["123"](t) +# +# from lib.mes import CallbackManager +# +# def test_callback_registration(): +# def f(arg): +# return arg +# c = CallbackManager() +# c.action_callback("123")(f) +# t = "test" +# assert t == c.action_callbacks["123"](t) diff --git a/tests/utils_test.py b/tests/utils_test.py new file mode 100644 index 0000000..968c7e8 --- /dev/null +++ b/tests/utils_test.py @@ -0,0 +1,28 @@ +import asyncio + +import pytest + +from lib.utils import KeyQueue + +def test_queue_sync(): + q = KeyQueue() + val1 = "test1" + val2 = "test2" + key1 = "key1" + q.put_nowait(val1) + q.put_nowait(val2, key1) + + assert q.get_nowait() == val1 + assert q.get_nowait(key1) == val2 + +@pytest.mark.asyncio +async def test_queue_putting_async(): + q = KeyQueue() + val1 = "test1" + val2 = "test2" + key1 = "key1" + await q.put(val1) + await q.put(val2, key1) + + assert await q.get(key1) == val2 + assert await q.get() == val1