Request\response processing

This commit is contained in:
Artem30801
2020-12-02 09:49:44 +03:00
parent 1cef868938
commit 70ec36b992
11 changed files with 590 additions and 104 deletions

View File

@@ -7,7 +7,7 @@ import logging
logging.basicConfig( # TODO all prints as logs
level=logging.DEBUG, # INFO
stream=sys.stdout,
# stream=sys.stdout,
format="%(asctime)s [%(name)-7.7s] [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),

View File

@@ -0,0 +1,44 @@
import asyncio
import logging
import lib.connections as c
import lib.messages as messages
import lib.utils as utils
logging.basicConfig(
level=logging.DEBUG,
format=utils.logger_format,
handlers=[
logging.StreamHandler()
])
async def server():
server, new_connections = await c.create_server(8181)
connection1: c.Connection = await new_connections
#await connection1.send(messages.Request("greetings"))
await asyncio.sleep(2)
t1 = asyncio.create_task(asyncio.wait_for(connection1.send(messages.Request("greetings")), 5))
t2 = asyncio.create_task(connection1.send(messages.Request("hello", args=(1, 3))))
print(await asyncio.gather(t1, t2, return_exceptions=True))
await connection1.close()
server.close()
await server.wait_closed()
async def client():
connection = await c.connect(8181, "")
request = await connection.receive("hello")
print(request.message.content)
request2 = await connection.receive()
print(request2.message.content)
await request.reply("hi")
# await connection.close()
async def main():
s = asyncio.create_task(server())
c = asyncio.create_task(client())
await asyncio.gather(s, c)
asyncio.run(main(), debug=False)

View File

@@ -13,9 +13,25 @@ logging.basicConfig(
async def server():
loop = asyncio.get_event_loop()
clients = asyncio.Queue()
server = await loop.create_server(lambda: p.PeerProtocol(None, None), host='', port=8181)
def factory(*args):
protocol = p.PeerProtocol()
clients.put_nowait(protocol)
return protocol
server = await loop.create_server(factory, host='', port=8181)
client1 = await clients.get()
await client1.connected
msg = await client1.receive_message()
print(msg.header, msg.content)
msg = messages.PendingMessage(msg.content + " from server")
await msg.send_to(client1)
await client1.closed
server.close()
await server.wait_closed()
async def client():
@@ -23,9 +39,11 @@ async def client():
transport, protocol = await loop.create_connection(lambda: p.PeerProtocol(None, None),
host='', port=8181)
await protocol.send(messages.Request("greetings"))
await protocol.closed
await protocol.send(messages.PendingMessage("greetings"))
print((await protocol.receive_message()).content)
await protocol.close()
async def main():
s = asyncio.create_task(server())

View File

@@ -14,10 +14,7 @@ def modify_filename(path, pattern): # TODO move to core
def parent_path(path, levels=1):
for i in range(levels):
path = os.path.abspath(os.path.join(path, os.pardir))
return path
return os.path.realpath(os.path.join(path, *[os.pardir for _ in range(levels)]))
def parent_dir(path):
return os.path.basename(os.path.normpath(path))

View File

@@ -10,10 +10,11 @@ import traceback
import messages
import exceptions
from messages import MessageDecoder, Request
from messages import MessageDecoder, Request, ReceivedRequest
from protocols import PeerProtocol
from utils import Callback, KeyQueue
logger = logging.getLogger(__name__)
logger = logging.getLogger("connections")
class CallbackManager:
def __init__(self):
@@ -113,24 +114,31 @@ class temp:
# Sending api functions
class Connection:
def __init__(self, callbacks):
self.callbacks = callbacks
def __init__(self, request_callback=None):
self.request_callback = Callback(request_callback)
self.protocol: PeerProtocol = None
self.transport = None
self._sent_requests = dict() # holds requests awaiting reply
self._recv_requests = asyncio.Queue()
self._recv_requests = KeyQueue()
def connect(self, protocol: PeerProtocol, transport):
self.protocol = protocol
def connect(self, transport, protocol: PeerProtocol):
self.transport = transport
self.protocol = protocol
self.protocol.message_callback.add(self._receive)
self.protocol.disconnected_callback.add(self._on_disconnect)
def _on_disconnect(self, protocol):
self._recv_requests.cancel()
async def close(self):
self.transport.close()
await self.protocol.closed
async def send(self, request: Request):
async def send(self, request: Request) -> messages.MessageDecoder:
request.set_chain_id()
self._sent_requests[request.chain_id] = request
try:
@@ -144,40 +152,94 @@ class Connection:
else:
return response
async def send_reply(self, response: messages.Response):
await self.protocol.send(response)
async def send_request(self, name, args=(), kwargs=None, callback=None):
request = Request(name, args, kwargs, callback)
return await self.send(request)
async def receive(self):
if self._recv_waiter is not None:
raise RuntimeError("receive() is already being awaited")
self._recv_waiter = asyncio.create_task(self._recv_requests.get())
async def send_response(self, request, data=None, progress=None, err=None):
pass
async def receive(self, name: str =None) -> messages.ReceivedRequest:
try:
return await self._recv_waiter
return await self._recv_requests.get(key=name)
except asyncio.CancelledError:
raise exceptions.ConnectionClosedError
finally:
self._recv_waiter = None
async def _receive(self):
msg = await self.protocol.receive_message()
def _receive(self, protocol, msg):
if msg.message_type == messages.MessageTypes.RESPONSE:
await self._process_response(msg)
self._process_response(msg)
elif msg.message_type == messages.MessageTypes.REQUEST:
pass
self._process_request(msg)
else:
logger.warning(f"Got unknown message type {msg.message_type} for message {msg}")
return True # put msg in protocol rcv queue
return False # don't put msg in protocol rcv queue
def _process_response(self, response: MessageDecoder):
logger.debug(f"Received response {response} with id {response.chain_id}")
async def _process_response(self, msg: MessageDecoder):
try:
request = self._sent_requests[msg.chain_id]
request = self._sent_requests.pop(response.chain_id)
except KeyError:
logger.error(f"Unknown response {msg} with id {msg.chain_id}")
logger.error(f"Unexpected response {response} with id {response.chain_id}")
return
request.response.set_result(msg)
try:
response_type = response.header["response-type"]
except KeyError:
logger.error("Missing 'response-type' header in response")
return
async def _process_request(self):
if response_type == messages.ResponseTypes.RESULT:
result = response.content
request.response.set_result(result)
elif response_type == messages.ResponseTypes.ERROR: # TODOOOO
error = None
request.response.set_result(error)
elif response_type == messages.ResponseTypes.STATUS:
pass
else:
logger.error(f"Unknown response type '{response_type}' in response {response}")
def _process_request(self, request_msg):
request = ReceivedRequest(self, request_msg)
logger.debug(f"Received request {request_msg} with id {request_msg.chain_id} for {request.name}")
put_msg = self.request_callback.gather_bool(self, request)
if put_msg:
self._recv_requests.put_nowait(request, key=request.name)
async def connect(port, host, connection_factory=Connection, protocol_factory=PeerProtocol, **kwargs):
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_connection(protocol_factory, host=host, port=port, **kwargs)
connection = connection_factory()
connection.connect(transport, protocol)
return connection
async def create_server(port, host="", connection_factory=Connection, protocol_factory=PeerProtocol,
reuse_connections=True, **kwargs):
loop = asyncio.get_event_loop()
new_connections = asyncio.Queue()
def factory(addr=None):
nonlocal new_connections
protocol = protocol_factory()
def callback(p):
nonlocal new_connections
connection = connection_factory()
connection.connect(p.transport, p)
new_connections.put_nowait(connection)
protocol.connected_callback.add(callback)
return protocol
server = await loop.create_server(factory, host=host, port=port, **kwargs)
return server, new_connections.get()
class LegacyConnectionManager:
"""
@@ -206,14 +268,6 @@ class LegacyConnectionManager:
def _process_response(self, message):
request_id, requested_value = message.jsonheader["request_id"], message.jsonheader["requested_value"]
with self._request_lock:
request = self._request_queue.pop(request_id, None)
if (request is None) or (request.requested_value != requested_value):
logger.warning("Unexpected response!")
return
if requested_value == "filetransfer":
value = True
self._process_filetransfer(message.content, request.callback_kwargs["filepath"])
@@ -283,17 +337,6 @@ class LegacyConnectionManager:
self.get_response("filetransfer", callback, request_kwargs=request_kwargs,
callback_args=callback_args, callback_kwargs=callback_kwargs)
def _resend_requests(self):
for request_id, request in self._request_queue.items(): # TODO filter
if request.resend:
self._send(Message.create_request(
request.requested_value, request_id, request.request_kwargs.update(resend=request.resend))
)
request.resend = False
def _send_response(self, requested_value, request_id, value, filetransfer=False):
self._send(Message.create_response(requested_value, request_id, value, filetransfer))
def send_file(self, filepath, dest_filepath): # clever_restart=False
"""
Sends to peer a file from `filepath` to write it on `dest_filepath`.

View File

@@ -1,3 +0,0 @@
def b_partial(func, *args, **kwargs): # call argument blocker partial
return lambda *a: func(*args, **kwargs)

View File

@@ -45,6 +45,7 @@ class ContentTypes:
ENCODED = "encoded"
class MessageTypes:
MESSAGE = "message"
REQUEST = "request"
RESPONSE = "response"
@@ -159,11 +160,14 @@ class MessageDecoder:
class MessageEncoder:
def __init__(self, codec=default_codec):
self.chain_id = None
self.codec = codec
def encode(self, *args, **kwargs):
return self.encode_raw_message(*args, **kwargs)
def set_chain_id(self, chain_id=None):
if chain_id is None:
if self.chain_id is None:
self.chain_id = uuid.uuid4().hex
else:
self.chain_id = chain_id
def encode_raw_message(self, content: bytes, content_type, message_type, chain_id=None, additional_headers=None):
"""Returns encoded message in bytes. It is recommended use other encoding functions for general purposes.
@@ -177,19 +181,21 @@ class MessageEncoder:
Returns:
bytes: encoded message
"""
# if chain_id is None:
# if self.chain_id is not None:
# chain_id = self.chain_id
# else:
# chain_id = uuid.uuid4().hex
# else:
# self.chain_id = chain_id
self.set_chain_id(chain_id)
if chain_id is None:
chain_id = uuid.uuid4().int
elif self.chain_id is not None:
chain_id = self.chain_id
else:
self.chain_id = chain_id
header = {
"content-length": len(content),
"content-type": content_type,
"message-type": message_type,
"chain-id": chain_id
"chain-id": self.chain_id
}
if additional_headers:
header.update(additional_headers)
@@ -256,7 +262,7 @@ class MessageEncoder:
# "response", additional_headers=headers)
# return message
class PendingMessage(MessageEncoder):
class AbstractPendingMessage(MessageEncoder):
def __init__(self, codec=default_codec):
super().__init__(codec)
self._sent = asyncio.Future()
@@ -265,24 +271,44 @@ class PendingMessage(MessageEncoder):
def sent(self):
return self._sent
async def send(self, connection):
if self._sent:
async def send_to(self, connection):
if self._sent.done():
raise RuntimeError("This message was already sent, create another one")
return connection.send(self)
class Response(PendingMessage):
def __init__(self, chain_id, result, type, codec=default_codec):
super().__init__(codec)
self._chain_id = chain_id
self._result = result
return await connection.send(self)
def encode(self):
contents = {"value": self._result}
return self.encode_message(contents, MessageTypes.RESPONSE, chain_id=self._chain_id)
raise NotImplementedError
class PendingMessage(AbstractPendingMessage):
def __init__(self, content, encode_content=True, codec=default_codec):
super().__init__(codec)
self._content = content
self.encode_content = encode_content
def encode(self):
if self.encode_content:
content = self.codec.encode(self._content)
content_type = ContentTypes.ENCODED
else:
content = self._content
content_type = ContentTypes.BINARY
return self.encode_raw_message(content, content_type, MessageTypes.MESSAGE)
class Request(PendingMessage):
class Response(AbstractPendingMessage):
def __init__(self, chain_id, result, response_type, codec=default_codec):
super().__init__(codec)
self.chain_id = chain_id
self._result = result
self._type = response_type
def encode(self):
return self.encode_message(self._result, MessageTypes.RESPONSE, # chain_id=self.chain_id,
additional_headers={"response-type": self._type})
class Request(AbstractPendingMessage):
def __init__(self, name, args=(), kwargs=None, callback=None, codec=default_codec):
super().__init__(codec)
@@ -295,11 +321,24 @@ class Request(PendingMessage):
self.callback = callback
self._response = asyncio.Future()
#self.responses ?
self._progress = float('nan')
self._got_progress = asyncio.Future()
@property
def response(self):
return self._response
@property
async def progress(self):
await self._got_progress
return self._progress
@progress.setter
def progress(self, value):
pass
def encode(self):
contents = {"name": self._name,
"args": self._args,
@@ -325,7 +364,7 @@ class RequestBatch:
async def send(self):
for connection, request in self._request_dict.items():
connection.send(request)
connection.send_to(request)
@property
def request_dict(self):
@@ -345,11 +384,40 @@ class ReceivedRequest:
self.connection = connection
self.message: MessageDecoder = message
@property
def name(self):
return self.message.content["name"]
@property
def args(self):
return self.message.content.get("args", list())
@property
def kwargs(self):
return self.message.content.get("kwargs", dict())
async def send(self, reply: Response):
await self.connection.protocol.send(reply)
await reply.sent
async def reply(self, data):
reply = Response(self.message.chain_id)
reply = Response(self.message.chain_id, data, ResponseTypes.RESULT)
await self.send(reply)
return reply
async def reply_processing(self, progress: float=0):
progress = max(0, min(1, progress))
progress = max(0.0, min(1.0, progress))
#contents = {"progress": progress}
reply = Response(self.message.chain_id, progress, ResponseTypes.RESULT)
await self.send(reply)
return reply
async def reply_error(self, err: Exception):
pass
#contents = {"error": err}
reply = Response(self.message.chain_id, err, ResponseTypes.ERROR)
await self.send(reply)
return reply

View File

@@ -5,8 +5,9 @@ import messages
import exceptions
from network import str_peername
from utils import Callback
logger = logging.getLogger(__name__)
logger = logging.getLogger("protocols")
class BroadcastProtocol:
@@ -29,7 +30,7 @@ class BroadcastProtocol:
self._closed.set()
def datagram_received(self, data, addr):
message = message.MessageDecoder(data)
message = messages.MessageDecoder(data)
message.process_message()
content = message.content
@@ -46,8 +47,10 @@ class BroadcastProtocol:
logger.warning(f"Error on broadcast connection received: {exc}")
class PeerProtocol(asyncio.Protocol):
def __init__(self, connected_callback, callbacks):
self._callbacks = callbacks
def __init__(self, connected_callback=None, disconnected_callback=None, message_callback=None):
self.connected_callback = Callback(connected_callback)
self.disconnected_callback = Callback(disconnected_callback)
self.message_callback = Callback(message_callback)
self.transport: asyncio.Transport = None
@@ -83,6 +86,10 @@ class PeerProtocol(asyncio.Protocol):
def closed(self):
return self._closed.wait()
async def close(self):
self.transport.close()
await self.closed
# Drain control
def pause_writing(self) -> None:
@@ -94,7 +101,7 @@ class PeerProtocol(asyncio.Protocol):
async def drain(self) -> None:
await self._can_write.wait()
async def send(self, msg: messages.PendingMessage):
async def send(self, msg: messages.AbstractPendingMessage):
if not self.is_connected: # and not send_disconnected
msg.sent.cancel("Peer is disconnected, can't send")
raise RuntimeError("Peer is disconnected, can't send")
@@ -110,7 +117,7 @@ class PeerProtocol(asyncio.Protocol):
while self.is_connected:
msg = None
try:
msg: messages.PendingMessage = await self._send_queue.get()
msg: messages.AbstractPendingMessage = await self._send_queue.get()
self.transport.write(msg.encode())
await self.drain()
except asyncio.CancelledError:
@@ -130,6 +137,8 @@ class PeerProtocol(asyncio.Protocol):
self._can_write.set()
self.connected_callback(self)
def connection_lost(self, exc):
logger.info(f"Lost connection to {str_peername(self.peername)}: {'closed' if exc is None else exc}")
self._connected.clear()
@@ -149,6 +158,8 @@ class PeerProtocol(asyncio.Protocol):
self._recv_buffer = bytearray()
self._current_msg = None
self.disconnected_callback(self)
def error_received(self, exc):
logger.warning(f"Error on {str_peername(self.peername)} connection received: {exc}")
@@ -171,7 +182,10 @@ class PeerProtocol(asyncio.Protocol):
self._recv_buffer = self._current_msg.get_buffer()
# self._current_msg.reset_buffer()
if self._current_msg.processed:
logger.debug(f"Recieved message {self._current_msg.content} from {str_peername(self.peername)}")
logger.debug(f"Received message {self._current_msg.content} from {str_peername(self.peername)}")
put_msg = self.message_callback.gather_bool(self, self._current_msg)
if put_msg:
await self._recv_queue.put(self._current_msg)
self._current_msg = None
else:

277
lib/utils.py Normal file
View File

@@ -0,0 +1,277 @@
import collections
import asyncio
logger_format = "%(asctime)s [%(name)-11.11s] [%(levelname)-7.7s] %(message)s"
def b_partial(func, *args, **kwargs): # call argument blocker partial
return lambda *a: func(*args, **kwargs)
class Callback:
def __init__(self, *callbacks):
self._callbacks = set()
self._add_callbacks(callbacks)
@property
def all(self):
return self._callbacks.copy()
def add(self, callback):
self._add_callback(callback)
def remove(self, callback):
self._callbacks.remove(callback)
def _add_callbacks(self, callbacks):
for callback in callbacks:
self._add_callback(callback)
def _add_callback(self, callback):
if callback is None:
return
elif isinstance(callback, Callback):
self._callbacks.update(callback.all)
elif isinstance(callback, (list, set)):
self._add_callbacks(callback)
elif callable(callback):
self._callbacks.add(callback)
else:
raise ValueError(f"Callback {callback} is not callable object!")
def __call__(self, *args, **kwargs):
return list(self.iter(*args, **kwargs))
def iter(self, *args, **kwargs):
for callback in self._callbacks:
yield callback(*args, **kwargs)
def gather_bool(self, *args, **kwargs):
if self.all:
return all(self(*args, **kwargs))
return True
class KeyQueue:
"""A queue, useful for coordinating producer and consumer coroutines.
If maxsize is less than or equal to zero, the queue size is infinite. If it
is an integer greater than 0, then "await put()" will block when the
queue reaches maxsize, until an item is removed by get().
Unlike the standard library Queue, you can reliably know this Queue's size
with qsize(), since your single-threaded asyncio application won't be
interrupted between calling qsize() and doing an operation on the Queue.
"""
def __init__(self, maxsize=0):
self._loop = asyncio.events.get_event_loop()
self._maxsize = maxsize
# Futures.
self._getters = collections.deque()
# Futures.
self._putters = collections.deque()
self._unfinished_tasks = 0
self._finished = asyncio.Event(loop=self._loop)
self._finished.set()
self._init(maxsize)
# These three are overridable in subclasses.
def _init(self, maxsize):
self._queue = collections.deque()
def _get(self, key=None):
if key is None:
return self._queue.popleft()[1]
for pair in self._queue:
k, item = pair
if k == key:
self._queue.remove(pair)
return item
else:
raise KeyError(f"No item with key {key} in queue")
def _put(self, item, key=None):
self._queue.append((key, item))
# End of the overridable methods.
def _wakeup_next(self, waiters):
# Wake up the next waiter (if any) that isn't cancelled.
while waiters:
waiter = waiters.popleft()
if not waiter.done():
waiter.set_result(None)
break
def _wakeup_get(self, key=None):
if key is None:
self._wakeup_next(self._getters)
for pair in self._getters:
k, waiter = pair
if k == key and not waiter.done():
waiter.set_result(None)
self._getters.remove(pair)
break
def __repr__(self):
return f'<{type(self).__name__} at {id(self):#x} {self._format()}>'
def __str__(self):
return f'<{type(self).__name__} {self._format()}>'
def __class_getitem__(cls, type):
return cls
def _format(self):
result = f'maxsize={self._maxsize!r}'
if getattr(self, '_queue', None):
result += f' _queue={list(self._queue)!r}'
if self._getters:
result += f' _getters[{len(self._getters)}]'
if self._putters:
result += f' _putters[{len(self._putters)}]'
if self._unfinished_tasks:
result += f' tasks={self._unfinished_tasks}'
return result
def qsize(self):
"""Number of items in the queue."""
return len(self._queue)
@property
def maxsize(self):
"""Number of items allowed in the queue."""
return self._maxsize
def empty(self):
"""Return True if the queue is empty, False otherwise."""
return not self._queue
def full(self):
"""Return True if there are maxsize items in the queue.
Note: if the Queue was initialized with maxsize=0 (the default),
then full() is never True.
"""
if self._maxsize <= 0:
return False
else:
return self.qsize() >= self._maxsize
async def put(self, item, key=None):
"""Put an item into the queue.
Put an item into the queue. If the queue is full, wait until a free
slot is available before adding item.
"""
while self.full():
putter = self._loop.create_future()
self._putters.append(putter)
try:
await putter
except:
putter.cancel() # Just in case putter is not done yet.
try:
# Clean self._putters from canceled putters.
self._putters.remove(putter)
except ValueError:
# The putter could be removed from self._putters by a
# previous get_nowait call.
pass
if not self.full() and not putter.cancelled():
# We were woken up by get_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(self._putters)
raise
return self.put_nowait(item, key)
def put_nowait(self, item, key=None):
"""Put an item into the queue without blocking.
If no free slot is immediately available, raise QueueFull.
"""
if self.full():
raise asyncio.QueueFull
self._put(item, key)
self._unfinished_tasks += 1
self._finished.clear()
self._wakeup_get(key)
async def get(self, key=None):
"""Remove and return an item from the queue.
If queue is empty, wait until an item is available.
"""
while self.empty():
getter = self._loop.create_future()
key_getter = key, getter
self._getters.append(key_getter)
try:
await getter
except:
getter.cancel() # Just in case getter is not done yet.
try:
# Clean self._getters from canceled getters.
self._getters.remove(key_getter)
except ValueError:
# The getter could be removed from self._getters by a
# previous put_nowait call.
pass
if not self.empty() and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_get(key)
raise
return self.get_nowait(key)
def get_nowait(self, key=None):
"""Remove and return an item from the queue.
Return an item if one is immediately available, else raise QueueEmpty.
"""
if self.empty():
raise asyncio.QueueEmpty
item = self._get(key)
self._wakeup_next(self._putters)
return item
def task_done(self):
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each get() used to fetch a task,
a subsequent call to task_done() tells the queue that the processing
on the task is complete.
If a join() is currently blocking, it will resume when all items have
been processed (meaning that a task_done() call was received for every
item that had been put() into the queue).
Raises ValueError if called more times than there were items placed in
the queue.
"""
if self._unfinished_tasks <= 0:
raise ValueError('task_done() called too many times')
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
async def join(self):
"""Block until all items in the queue have been gotten and processed.
The count of unfinished tasks goes up whenever an item is added to the
queue. The count goes down whenever a consumer calls task_done() to
indicate that the item was retrieved and all work on it is complete.
When the count of unfinished tasks drops to zero, join() unblocks.
"""
if self._unfinished_tasks > 0:
await self._finished.wait()
def cancel(self):
# self._finished.set()
# self._unfinished_tasks = 0
while self._getters:
pair = self._getters.popleft()
key, getter = pair
getter.cancel()

View File

@@ -3,13 +3,13 @@ import sys
import pytest
current_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.realpath(os.path.join(current_dir, os.pardir)))
from lib.messaging import CallbackManager
def test_callback_registration():
def f(arg):
return arg
c = CallbackManager()
c.action_callback("123")(f)
t = "test"
assert t == c.action_callbacks["123"](t)
#
# from lib.mes import CallbackManager
#
# def test_callback_registration():
# def f(arg):
# return arg
# c = CallbackManager()
# c.action_callback("123")(f)
# t = "test"
# assert t == c.action_callbacks["123"](t)

28
tests/utils_test.py Normal file
View File

@@ -0,0 +1,28 @@
import asyncio
import pytest
from lib.utils import KeyQueue
def test_queue_sync():
q = KeyQueue()
val1 = "test1"
val2 = "test2"
key1 = "key1"
q.put_nowait(val1)
q.put_nowait(val2, key1)
assert q.get_nowait() == val1
assert q.get_nowait(key1) == val2
@pytest.mark.asyncio
async def test_queue_putting_async():
q = KeyQueue()
val1 = "test1"
val2 = "test2"
key1 = "key1"
await q.put(val1)
await q.put(val2, key1)
assert await q.get(key1) == val2
assert await q.get() == val1