Simplify the json_stream() utility

The previous implementation was generic over the kind of decoder
and separator. However, the only use was with JSON decoder and
newline-based splitting.

Signed-off-by: Francesco Zardi <frazar0@hotmail.it>
This commit is contained in:
Francesco Zardi 2024-09-14 17:16:05 +02:00
parent c7d4ec1421
commit bcf3e11daa
1 changed files with 9 additions and 21 deletions

View File

@ -1,4 +1,3 @@
import json
import json.decoder import json.decoder
from ..errors import StreamParseError from ..errors import StreamParseError
@ -37,30 +36,12 @@ def json_stream(stream):
This handles streams which are inconsistently buffered (some entries may This handles streams which are inconsistently buffered (some entries may
be newline delimited, and others are not). be newline delimited, and others are not).
""" """
return split_buffer(stream, json_splitter, json_decoder.decode)
def line_splitter(buffer, separator='\n'):
index = buffer.find(str(separator))
if index == -1:
return None
return buffer[:index + 1], buffer[index + 1:]
def split_buffer(stream, splitter=None, decoder=lambda a: a):
"""Given a generator which yields strings and a splitter function,
joins all input, splits on the separator and yields each chunk.
Unlike string.split(), each chunk includes the trailing
separator, except for the last one if none was found on the end
of the input.
"""
splitter = splitter or line_splitter
buffered = '' buffered = ''
for data in stream_as_text(stream): for data in stream_as_text(stream):
buffered += data buffered += data
while True: while True:
buffer_split = splitter(buffered) buffer_split = json_splitter(buffered)
if buffer_split is None: if buffer_split is None:
break break
@ -69,6 +50,13 @@ def split_buffer(stream, splitter=None, decoder=lambda a: a):
if buffered: if buffered:
try: try:
yield decoder(buffered) yield json_decoder.decode(buffered)
except Exception as e: except Exception as e:
raise StreamParseError(e) from e raise StreamParseError(e) from e
def line_splitter(buffer: str, separator='\n'):
index = buffer.find(str(separator))
if index == -1:
return None
return buffer[:index + 1], buffer[index + 1:]