diff --git a/conversation/deepseek/deepseek.go b/conversation/deepseek/deepseek.go index d0f234af1..1ed6c075f 100644 --- a/conversation/deepseek/deepseek.go +++ b/conversation/deepseek/deepseek.go @@ -59,6 +59,10 @@ func (d *Deepseek) Init(ctx context.Context, meta conversation.Metadata) error { model = md.Model } + if md.Endpoint == "" { + md.Endpoint = defaultEndpoint + } + options := []openai.Option{ openai.WithModel(model), openai.WithToken(md.Key), diff --git a/conversation/echo/echo.go b/conversation/echo/echo.go index bb97e567f..79814fa1a 100644 --- a/conversation/echo/echo.go +++ b/conversation/echo/echo.go @@ -18,6 +18,8 @@ import ( "context" "fmt" "reflect" + "strconv" + "strings" "github.com/tmc/langchaingo/llms" @@ -61,67 +63,96 @@ func (e *Echo) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { // Converse returns one output per input message. func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conversation.Response, err error) { - if r.Message == nil { + if r == nil || r.Message == nil { return &conversation.Response{ ConversationContext: r.ConversationContext, Outputs: []conversation.Result{}, }, nil } - outputs := make([]conversation.Result, 0, len(*r.Message)) + // if we get tools, respond with tool calls for each tool + var toolCalls []llms.ToolCall + if r.Tools != nil && len(*r.Tools) > 0 { + // create tool calls for each tool + toolCalls = make([]llms.ToolCall, 0, len(*r.Tools)) + for id, tool := range *r.Tools { + // extract argument names from parameters.properties + if tool.Function == nil || tool.Function.Parameters == nil { + continue // skip if no function or parameters + } + // ensure parameters are a map + m, ok := tool.Function.Parameters.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("tool function parameters must be a map, got %T", tool.Function.Parameters) + } + if len(m) == 0 { + return nil, fmt.Errorf("tool function parameters map cannot be empty for tool ID %d", id) + } + properties, ok := m["properties"] + if !ok { + continue // skip if no properties + } + propMap, ok := properties.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("tool function properties must be a map, got %T", properties) + } + if len(propMap) == 0 { + continue // skip if no properties + } + // get argument names + argNames := make([]string, 0, len(propMap)) + for argName := range propMap { + argNames = append(argNames, argName) + } + toolCalls = append(toolCalls, llms.ToolCall{ + ID: strconv.Itoa(id), + FunctionCall: &llms.FunctionCall{ + Name: tool.Function.Name, + Arguments: strings.Join(argNames, ","), + }, + }) + } + } + // iterate over each message in the request to echo back the content in the response. We respond with the acummulated content of the message parts and tool responses + contentFromMessaged := make([]string, 0, len(*r.Message)) for _, message := range *r.Message { - var content string - var toolCalls []llms.ToolCall - - for i, part := range message.Parts { + for _, part := range message.Parts { switch p := part.(type) { case llms.TextContent: // end with space if not the first part - if i > 0 && content != "" { - content += " " - } - content += p.Text + contentFromMessaged = append(contentFromMessaged, p.Text) case llms.ToolCall: + // in case we added explicit tool calls on the request like on multi-turn conversations. We still append tool calls for each tool defined in the request. toolCalls = append(toolCalls, p) case llms.ToolCallResponse: - content = p.Content - toolCalls = append(toolCalls, llms.ToolCall{ - ID: p.ToolCallID, - Type: "function", - FunctionCall: &llms.FunctionCall{ - Name: p.Name, - Arguments: p.Content, - }, - }) + // show tool responses on the request like on multi-turn conversations + contentFromMessaged = append(contentFromMessaged, fmt.Sprintf("Tool Response for tool ID '%s' with name '%s': %s", p.ToolCallID, p.Name, p.Content)) default: return nil, fmt.Errorf("found invalid content type as input for %v", p) } } + } + choice := conversation.Choice{ + FinishReason: "stop", + Index: 0, + Message: conversation.Message{ + Content: strings.Join(contentFromMessaged, "\n"), + }, + } - choice := conversation.Choice{ - FinishReason: "stop", - Index: 0, - Message: conversation.Message{ - Content: content, - }, - } + if len(toolCalls) > 0 { + choice.Message.ToolCallRequest = &toolCalls + } - if len(toolCalls) > 0 { - choice.Message.ToolCallRequest = &toolCalls - } - - output := conversation.Result{ - StopReason: "stop", - Choices: []conversation.Choice{choice}, - } - - outputs = append(outputs, output) + output := conversation.Result{ + StopReason: "stop", + Choices: []conversation.Choice{choice}, } res = &conversation.Response{ ConversationContext: r.ConversationContext, - Outputs: outputs, + Outputs: []conversation.Result{output}, } return res, nil diff --git a/conversation/echo/echo_test.go b/conversation/echo/echo_test.go index 2eed047db..44150030b 100644 --- a/conversation/echo/echo_test.go +++ b/conversation/echo/echo_test.go @@ -97,19 +97,7 @@ func TestConverse(t *testing.T) { FinishReason: "stop", Index: 0, Message: conversation.Message{ - Content: "first message second message", - }, - }, - }, - }, - { - StopReason: "stop", - Choices: []conversation.Choice{ - { - FinishReason: "stop", - Index: 0, - Message: conversation.Message{ - Content: "third message", + Content: "first message\nsecond message\nthird message", }, }, }, @@ -127,7 +115,7 @@ func TestConverse(t *testing.T) { Message: &tt.inputs, }) require.NoError(t, err) - assert.Len(t, r.Outputs, len(tt.expected.Outputs)) + assert.Len(t, r.Outputs, 1) assert.Equal(t, tt.expected.Outputs, r.Outputs) }) } diff --git a/conversation/mistral/mistral.go b/conversation/mistral/mistral.go index 314452508..a93fcb73d 100644 --- a/conversation/mistral/mistral.go +++ b/conversation/mistral/mistral.go @@ -17,6 +17,7 @@ package mistral import ( "context" "reflect" + "strings" "github.com/dapr/components-contrib/conversation" "github.com/dapr/components-contrib/conversation/langchaingokit" @@ -110,13 +111,34 @@ func CreateToolCallPart(toolCall *llms.ToolCall) llms.ContentPart { // using the human role specifically otherwise mistral will reject the tool response message. // Most LLM providers can handle tool call responses using the tool call response object; // however, mistral requires it as text in conversation history. -func CreateToolResponseMessage(response llms.ToolCallResponse) llms.MessageContent { - return llms.MessageContent{ +func CreateToolResponseMessage(responses ...llms.ContentPart) llms.MessageContent { + msg := llms.MessageContent{ Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextContent{ - Text: "Tool response [ID: " + response.ToolCallID + ", Name: " + response.Name + "]: " + response.Content, - }, - }, } + if len(responses) == 0 { + return msg + } + toolID := "" + name := "" + + mistralContentParts := make([]string, 0, len(responses)) + for _, response := range responses { + if resp, ok := response.(llms.ToolCallResponse); ok { + if toolID == "" { + toolID = resp.ToolCallID + } + if name == "" { + name = resp.Name + } + mistralContentParts = append(mistralContentParts, resp.Content) + } + } + if len(mistralContentParts) > 0 { + msg.Parts = []llms.ContentPart{ + llms.TextContent{ + Text: "Tool response [ID: " + toolID + ", Name: " + name + "]: " + strings.Join(mistralContentParts, "\n"), + }, + } + } + return msg } diff --git a/tests/config/conversation/huggingface/huggingface.yml b/tests/config/conversation/huggingface/huggingface.yml index 4af48ad7c..c4ca9f2fe 100644 --- a/tests/config/conversation/huggingface/huggingface.yml +++ b/tests/config/conversation/huggingface/huggingface.yml @@ -9,4 +9,4 @@ spec: - name: key value: "${{HUGGINGFACE_API_KEY}}" - name: model - value: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" \ No newline at end of file + value: "HuggingFaceTB/SmolLM3-3B" \ No newline at end of file diff --git a/tests/conformance/conversation/conversation.go b/tests/conformance/conversation/conversation.go index 99ff508ea..07a64ef68 100644 --- a/tests/conformance/conversation/conversation.go +++ b/tests/conformance/conversation/conversation.go @@ -46,6 +46,8 @@ func NewTestConfig(componentName string) TestConfig { } func ConformanceTests(t *testing.T, props map[string]string, conv conversation.Conversation, component string) { + var providerStopReasons = []string{"stop", "end_turn", "FinishReasonStop"} + t.Run("init", func(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() @@ -104,7 +106,7 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C assert.Len(t, resp.Outputs, 1) assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content) // anthropic responds with end_turn but other llm providers return with stop - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason)) + assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason)) assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest) }) t.Run("test system message type", func(t *testing.T) { @@ -133,20 +135,11 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C resp, err := conv.Converse(ctx, req) require.NoError(t, err) - // Echo component returns one output per message, other components return one output - if component == "echo" { - assert.Len(t, resp.Outputs, 2) - // Check the last output - system message - assert.NotEmpty(t, resp.Outputs[1].Choices[0].Message.Content) - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[1].StopReason)) - assert.Empty(t, resp.Outputs[1].Choices[0].Message.ToolCallRequest) - } else { - assert.Len(t, resp.Outputs, 1) - assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content) - // anthropic responds with end_turn but other llm providers return with stop - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason)) - assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest) - } + assert.Len(t, resp.Outputs, 1) + assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content) + // anthropic responds with end_turn but other llm providers return with stop + assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason), resp.Outputs[0].StopReason) + assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest) }) t.Run("test assistant message type", func(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 25*time.Second) @@ -234,25 +227,14 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C require.NoError(t, err) // Echo component returns one output per message, other components return one output - if component == "echo" { - assert.Len(t, resp.Outputs, 4) - // Check the last output - human message - assert.NotEmpty(t, resp.Outputs[3].Choices[0].Message.Content) - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[3].StopReason)) - // Check the tool call output - second output - if resp.Outputs[1].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[1].Choices[0].Message.ToolCallRequest) > 0 { - assert.NotEmpty(t, resp.Outputs[1].Choices[0].Message.ToolCallRequest) - require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[1].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments) - } - } else { - assert.Len(t, resp.Outputs, 1) - assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content) - // anthropic responds with end_turn but other llm providers return with stop - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason)) - if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[0].Choices[0].Message.ToolCallRequest) > 0 { - assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest) - require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments) - } + + assert.Len(t, resp.Outputs, 1) + assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content) + // anthropic responds with end_turn but other llm providers return with stop + assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason)) + if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[0].Choices[0].Message.ToolCallRequest) > 0 { + assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest) + require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments) } }) @@ -277,13 +259,13 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C assert.Len(t, resp.Outputs, 1) assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content) // anthropic responds with end_turn but other llm providers return with stop - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason)) + assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason)) if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil { assert.Empty(t, *resp.Outputs[0].Choices[0].Message.ToolCallRequest) } }) - t.Run("test tool message type - confirming active tool calling capability", func(t *testing.T) { + t.Run("test tool message type - confirming active tool calling capability (empty tool choice)", func(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 25*time.Second) defer cancel() @@ -384,10 +366,10 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C require.NoError(t, err2) assert.Len(t, resp2.Outputs, 1) assert.NotEmpty(t, resp2.Outputs[0].Choices[0].Message.Content) - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp2.Outputs[0].StopReason)) + assert.True(t, slices.Contains(providerStopReasons, resp2.Outputs[0].StopReason)) } else { assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content) - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason)) + assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason)) } }) @@ -456,7 +438,7 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C // check if we got a tool call request if found { assert.Equal(t, "retrieve_payment_status", toolCall.FunctionCall.Name) - assert.Contains(t, toolCall.FunctionCall.Arguments, "T1001") + assert.Contains(t, toolCall.FunctionCall.Arguments, "transaction_id") toolResponse := llms.ToolCallResponse{ ToolCallID: toolCall.ID, @@ -508,11 +490,11 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C require.NoError(t, err) assert.Len(t, resp2.Outputs, 1) assert.NotEmpty(t, resp2.Outputs[0].Choices[0].Message.Content) - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp2.Outputs[0].StopReason)) + assert.True(t, slices.Contains(providerStopReasons, resp2.Outputs[0].StopReason)) } else { // it is valid too if no tool call was generated assert.NotEmpty(t, resp1.Outputs[0].Choices[0].Message.Content) - assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp1.Outputs[0].StopReason)) + assert.True(t, slices.Contains(providerStopReasons, resp1.Outputs[0].StopReason)) } }) })