diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index d0c2972d..2912ccb3 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -136,14 +136,12 @@ module DiscourseAi def extract_completion_from(response_raw) parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here return if !parsed response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) @has_function_call ||= response_h.dig(:tool_calls).present? - @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) end @@ -172,8 +170,11 @@ module DiscourseAi function_buffer.at("tool_name").content = f_name if f_name function_buffer.at("tool_id").content = partial[:id] if partial[:id] - if partial.dig(:function, :arguments).present? - @args_buffer << partial.dig(:function, :arguments) + args = partial.dig(:function, :arguments) + + # allow for SPACE within arguments + if args && args != "" + @args_buffer << args begin json_args = JSON.parse(@args_buffer, symbolize_names: true) diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 1817b6e2..1a25e79e 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -53,6 +53,13 @@ class OpenAiMock < EndpointMock }.to_json end + def stub_raw(chunks) + WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( + status: 200, + body: chunks, + ) + end + def stub_streamed_response(prompt, deltas, tool_call: false) chunks = deltas.each_with_index.map do |_, index| @@ -69,6 +76,8 @@ class OpenAiMock < EndpointMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") .with(body: request_body(prompt, stream: true, tool_call: tool_call)) .to_return(status: 200, body: chunks) + + yield if block_given? end def tool_deltas @@ -168,14 +177,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do end it "will automatically recover from a bad payload" do + called = false + # this should not happen, but lets ensure nothing bad happens # the row with test1 is invalid json raw_data = <<~TEXT.strip d|a|t|a|:| |{|"choices":[{"delta":{"content":"test,"}}]} - data: {"choices":[{"delta":{"content":"test1,"}}] + data: {"choices":[{"delta":{"content":"test|1| |,"}}] - data: {"choices":[{"delta":|{"content":"test2,"}}]} + data: {"choices":[{"delta":|{"content":"test2 ,"}}]} data: {"choices":[{"delta":{"content":"test3,"}}]|} @@ -187,16 +198,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do chunks = raw_data.split("|") open_ai_mock.with_chunk_array_support do - open_ai_mock.stub_streamed_response(compliance.dialect.translate, chunks) do - partials = [] + open_ai_mock.stub_raw(chunks) - endpoint.perform_completion!(compliance.dialect, user) do |partial| - partials << partial - end + partials = [] - expect(partials.join).to eq("test,test1,test2,test3,test4") - end + endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial } + + called = true + expect(partials.join).to eq("test,test2 ,test3,test4") end + expect(called).to be(true) end end @@ -204,6 +215,65 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do it "returns a function invocation" do compliance.streaming_mode_tools(open_ai_mock) end + + it "properly handles spaces in tools payload" do + raw_data = <<~TEXT.strip + data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"google","arguments":""}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\""}}]}}]} + + data: {"ch|oices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "query"}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "\\":\\""}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "Ad"}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "a|b"}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "as"}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": |"| "}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "9"}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "."}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"argume|nts": "1"}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "\\"}"}}]}}]} + + data: {"choices": [{"index": 0, "delta": {"tool_calls": []}}]} + + data: [D|ONE] + TEXT + + chunks = raw_data.split("|") + + open_ai_mock.with_chunk_array_support do + open_ai_mock.stub_raw(chunks) + partials = [] + + endpoint.perform_completion!(compliance.dialect, user) do |partial, x, y| + partials << partial + end + + expect(partials.length).to eq(1) + + function_call = (<<~TXT).strip + + + google + func_id + + Adabas 9.1 + + + + TXT + + expect(partials[0].strip).to eq(function_call) + end + end end end end