Fix reading from socket extracted from HTTPResponse

HTTPResponse access the underlying socket via an io.BufferedReader
object, that may still have "unread" bytes in its internal buffer when
the socket is extracted from HTTPResponse.

Preserve those bytes in an attribute of the socket object at the time
of the extraction so that they are not lost.

Signed-off-by: Sergei Trofimov <sergei.trofimov@arm.com>
This commit is contained in:
Sergei Trofimov 2019-11-21 06:52:46 +00:00
parent a0b9c3d0b3
commit b22baa0c19
2 changed files with 30 additions and 0 deletions

View File

@ -329,6 +329,12 @@ class APIClient(
# fine because we won't be doing TLS over them # fine because we won't be doing TLS over them
pass pass
if six.PY3:
# Preserve the io.BufferedReader buffer as part of the socket so
# that it may be read ahead of attempting futher reads from the
# socket -- see ..utils.socket.read() implementation.
setattr(sock, '_buffer', memoryview(response.raw._fp.fp.peek()))
return sock return sock
def _stream_helper(self, response, decode=False): def _stream_helper(self, response, decode=False):

View File

@ -25,6 +25,30 @@ def read(socket, n=4096):
Reads at most n bytes from socket Reads at most n bytes from socket
""" """
# This socket may have been extracted from an HTTPResponse. That would have
# wrapped the socked with an io.BufferedReader, which may still have some
# "unread" bytes in its internal buffer. If that were the case, those would
# be preserved inside _buffer attribute of the socket (see
# ..api.client.APIClient._get_raw_response_socket() implementation). These
# must be returned firsts, before attemting any further reads from the
# socket.
if not hasattr(socket, '_buffer') or not socket._buffer:
return read_from_socket(socket, n)
ret, socket._buffer = socket._buffer[:n], socket._buffer[n:]
still_to_read = n - len(ret)
if still_to_read:
return ret.tobytes() + read_from_socket(socket, still_to_read)
else:
return ret.tobytes()
def read_from_socket(socket, n=4096):
"""
Reads at most n bytes from socket
"""
recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK) recoverable_errors = (errno.EINTR, errno.EDEADLK, errno.EWOULDBLOCK)
if six.PY3 and not isinstance(socket, NpipeSocket): if six.PY3 and not isinstance(socket, NpipeSocket):