add tests for _read_from_socket

Check that the return value against the various combination of
parameters this function can take (tty, stream, and demux).

This commit also fixes a bug that the tests uncovered a bug in
consume_socket_output.

Signed-off-by: Corentin Henry <corentinhenry@gmail.com>
This commit is contained in:
Corentin Henry 2018-11-27 17:01:06 -08:00
parent 5f157bbaca
commit 6540900dae
2 changed files with 90 additions and 22 deletions

View File

@ -136,15 +136,17 @@ def consume_socket_output(frames, demux=False):
# we just need to concatenate.
return six.binary_type().join(frames)
# If the streams are demultiplexed, the generator returns tuples
# (stdin, stdout, stderr)
# If the streams are demultiplexed, the generator yields tuples
# (stdout, stderr)
out = [six.binary_type(), six.binary_type()]
for frame in frames:
for stream_id in [STDOUT, STDERR]:
# It is guaranteed that for each frame, one and only one stream
# is not None.
if frame[stream_id] is not None:
out[stream_id] += frame[stream_id]
# It is guaranteed that for each frame, one and only one stream
# is not None.
assert frame != (None, None)
if frame[0] is not None:
out[0] += frame[0]
else:
out[1] += frame[1]
return tuple(out)

View File

@ -15,6 +15,7 @@ from docker.api import APIClient
import requests
from requests.packages import urllib3
import six
import struct
from . import fake_api
@ -467,24 +468,25 @@ class UnixSocketStreamTest(unittest.TestCase):
class TCPSocketStreamTest(unittest.TestCase):
text_data = b'''
stdout_data = b'''
Now, those children out there, they're jumping through the
flames in the hope that the god of the fire will make them fruitful.
Really, you can't blame them. After all, what girl would not prefer the
child of a god to that of some acne-scarred artisan?
'''
stderr_data = b'''
And what of the true God? To whose glory churches and monasteries have been
built on these islands for generations past? Now shall what of Him?
'''
def setUp(self):
self.server = six.moves.socketserver.ThreadingTCPServer(
('', 0), self.get_handler_class()
)
('', 0), self.get_handler_class())
self.thread = threading.Thread(target=self.server.serve_forever)
self.thread.setDaemon(True)
self.thread.start()
self.address = 'http://{}:{}'.format(
socket.gethostname(), self.server.server_address[1]
)
socket.gethostname(), self.server.server_address[1])
def tearDown(self):
self.server.shutdown()
@ -492,31 +494,95 @@ class TCPSocketStreamTest(unittest.TestCase):
self.thread.join()
def get_handler_class(self):
text_data = self.text_data
stdout_data = self.stdout_data
stderr_data = self.stderr_data
class Handler(six.moves.BaseHTTPServer.BaseHTTPRequestHandler, object):
def do_POST(self):
resp_data = self.get_resp_data()
self.send_response(101)
self.send_header(
'Content-Type', 'application/vnd.docker.raw-stream'
)
'Content-Type', 'application/vnd.docker.raw-stream')
self.send_header('Connection', 'Upgrade')
self.send_header('Upgrade', 'tcp')
self.end_headers()
self.wfile.flush()
time.sleep(0.2)
self.wfile.write(text_data)
self.wfile.write(resp_data)
self.wfile.flush()
def get_resp_data(self):
path = self.path.split('/')[-1]
if path == 'tty':
return stdout_data + stderr_data
elif path == 'no-tty':
data = b''
data += self.frame_header(1, stdout_data)
data += stdout_data
data += self.frame_header(2, stderr_data)
data += stderr_data
return data
else:
raise Exception('Unknown path {0}'.format(path))
@staticmethod
def frame_header(stream, data):
return struct.pack('>BxxxL', stream, len(data))
return Handler
def test_read_from_socket(self):
def request(self, stream=None, tty=None, demux=None):
assert stream is not None and tty is not None and demux is not None
with APIClient(base_url=self.address) as client:
resp = client._post(client._url('/dummy'), stream=True)
data = client._read_from_socket(resp, stream=True, tty=True)
results = b''.join(data)
if tty:
url = client._url('/tty')
else:
url = client._url('/no-tty')
resp = client._post(url, stream=True)
return client._read_from_socket(
resp, stream=stream, tty=tty, demux=demux)
assert results == self.text_data
def test_read_from_socket_1(self):
res = self.request(stream=True, tty=True, demux=False)
assert next(res) == self.stdout_data + self.stderr_data
with self.assertRaises(StopIteration):
next(res)
def test_read_from_socket_2(self):
res = self.request(stream=True, tty=True, demux=True)
assert next(res) == (self.stdout_data + self.stderr_data, None)
with self.assertRaises(StopIteration):
next(res)
def test_read_from_socket_3(self):
res = self.request(stream=True, tty=False, demux=False)
assert next(res) == self.stdout_data
assert next(res) == self.stderr_data
with self.assertRaises(StopIteration):
next(res)
def test_read_from_socket_4(self):
res = self.request(stream=True, tty=False, demux=True)
assert (self.stdout_data, None) == next(res)
assert (None, self.stderr_data) == next(res)
with self.assertRaises(StopIteration):
next(res)
def test_read_from_socket_5(self):
res = self.request(stream=False, tty=True, demux=False)
assert res == self.stdout_data + self.stderr_data
def test_read_from_socket_6(self):
res = self.request(stream=False, tty=True, demux=True)
assert res == (self.stdout_data + self.stderr_data, b'')
def test_read_from_socket_7(self):
res = self.request(stream=False, tty=False, demux=False)
res == self.stdout_data + self.stderr_data
def test_read_from_socket_8(self):
res = self.request(stream=False, tty=False, demux=True)
assert res == (self.stdout_data, self.stderr_data)
class UserAgentTest(unittest.TestCase):