Source code for pyicat_plus.tests.fixtures.tcp
import time
import json
import socket
from contextlib import contextmanager
import xmltodict
from .misc import eprint
from ...concurrency import Queue
from ...concurrency import spawn
from ...concurrency import GEVENT_PATCHED
[docs]
def get_open_port():
s = socket.socket()
try:
s.bind(("", 0))
return s.getsockname()[1]
finally:
s.close()
[docs]
def wait_tcp_online(host, port, timeout=5):
"""Wait for a TCP port with a timeout.
Raises a `gevent.Timeout` if the port was not found.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
try:
while True:
try:
sock.connect((host, port))
break
except ConnectionError:
pass
finally:
sock.close()
[docs]
@contextmanager
def tcp_message_server(data_parser=None, validate_all=True, timeout=5):
"""Start a TCP server and yield a queue of events.
Data packages are separated by newline characters.
Supported package encodings are UTF8 (default) and json.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
port = get_open_port()
sock.bind(("localhost", port))
sock.listen()
# Listen to this socket
messages = Queue()
stop = False
def listener():
nonlocal stop
buffer = b""
conn = None
try:
while True:
try:
conn, addr = sock.accept()
break
except socket.timeout:
time.sleep(0.1)
while True:
try:
buffer += conn.recv(16384)
except ConnectionResetError:
stop = True
except socket.timeout:
pass
if buffer:
out, sep, buffer = buffer.rpartition(b"\n")
if sep:
for bdata in out.split(b"\n"):
if data_parser == "json":
messages.put(json.loads(bdata))
elif b'xmlns:tns="http://www.esrf.fr/icat"' in bdata:
data = xmltodict.parse(
bdata.decode(),
process_namespaces=True,
namespaces={"http://www.esrf.fr/icat": None},
)
messages.put(data)
else:
messages.put(bdata.decode())
if stop:
return
time.sleep(0.1)
finally:
if conn is not None:
conn.close()
glistener = spawn(listener)
try:
yield port, messages
finally:
messages.put(StopIteration)
stop = True
if validate_all:
for msg in iter(messages.get, StopIteration):
eprint(f"Unvalidated message: {msg}")
if GEVENT_PATCHED:
glistener.kill()
else:
glistener.join()
sock.close()