Merge bf1323b6c0 into 20d73459f7
This commit is contained in:
commit
94ba312dd8
|
|
@ -59,6 +59,10 @@ func (d *Deepseek) Init(ctx context.Context, meta conversation.Metadata) error {
|
|||
model = md.Model
|
||||
}
|
||||
|
||||
if md.Endpoint == "" {
|
||||
md.Endpoint = defaultEndpoint
|
||||
}
|
||||
|
||||
options := []openai.Option{
|
||||
openai.WithModel(model),
|
||||
openai.WithToken(md.Key),
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
|
||||
|
|
@ -61,67 +63,96 @@ func (e *Echo) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
|||
|
||||
// Converse returns one output per input message.
|
||||
func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conversation.Response, err error) {
|
||||
if r.Message == nil {
|
||||
if r == nil || r.Message == nil {
|
||||
return &conversation.Response{
|
||||
ConversationContext: r.ConversationContext,
|
||||
Outputs: []conversation.Result{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
outputs := make([]conversation.Result, 0, len(*r.Message))
|
||||
// if we get tools, respond with tool calls for each tool
|
||||
var toolCalls []llms.ToolCall
|
||||
if r.Tools != nil && len(*r.Tools) > 0 {
|
||||
// create tool calls for each tool
|
||||
toolCalls = make([]llms.ToolCall, 0, len(*r.Tools))
|
||||
for id, tool := range *r.Tools {
|
||||
// extract argument names from parameters.properties
|
||||
if tool.Function == nil || tool.Function.Parameters == nil {
|
||||
continue // skip if no function or parameters
|
||||
}
|
||||
// ensure parameters are a map
|
||||
m, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool function parameters must be a map, got %T", tool.Function.Parameters)
|
||||
}
|
||||
if len(m) == 0 {
|
||||
return nil, fmt.Errorf("tool function parameters map cannot be empty for tool ID %d", id)
|
||||
}
|
||||
properties, ok := m["properties"]
|
||||
if !ok {
|
||||
continue // skip if no properties
|
||||
}
|
||||
propMap, ok := properties.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tool function properties must be a map, got %T", properties)
|
||||
}
|
||||
if len(propMap) == 0 {
|
||||
continue // skip if no properties
|
||||
}
|
||||
// get argument names
|
||||
argNames := make([]string, 0, len(propMap))
|
||||
for argName := range propMap {
|
||||
argNames = append(argNames, argName)
|
||||
}
|
||||
toolCalls = append(toolCalls, llms.ToolCall{
|
||||
ID: strconv.Itoa(id),
|
||||
FunctionCall: &llms.FunctionCall{
|
||||
Name: tool.Function.Name,
|
||||
Arguments: strings.Join(argNames, ","),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// iterate over each message in the request to echo back the content in the response. We respond with the acummulated content of the message parts and tool responses
|
||||
contentFromMessaged := make([]string, 0, len(*r.Message))
|
||||
for _, message := range *r.Message {
|
||||
var content string
|
||||
var toolCalls []llms.ToolCall
|
||||
|
||||
for i, part := range message.Parts {
|
||||
for _, part := range message.Parts {
|
||||
switch p := part.(type) {
|
||||
case llms.TextContent:
|
||||
// end with space if not the first part
|
||||
if i > 0 && content != "" {
|
||||
content += " "
|
||||
}
|
||||
content += p.Text
|
||||
contentFromMessaged = append(contentFromMessaged, p.Text)
|
||||
case llms.ToolCall:
|
||||
// in case we added explicit tool calls on the request like on multi-turn conversations. We still append tool calls for each tool defined in the request.
|
||||
toolCalls = append(toolCalls, p)
|
||||
case llms.ToolCallResponse:
|
||||
content = p.Content
|
||||
toolCalls = append(toolCalls, llms.ToolCall{
|
||||
ID: p.ToolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: &llms.FunctionCall{
|
||||
Name: p.Name,
|
||||
Arguments: p.Content,
|
||||
},
|
||||
})
|
||||
// show tool responses on the request like on multi-turn conversations
|
||||
contentFromMessaged = append(contentFromMessaged, fmt.Sprintf("Tool Response for tool ID '%s' with name '%s': %s", p.ToolCallID, p.Name, p.Content))
|
||||
default:
|
||||
return nil, fmt.Errorf("found invalid content type as input for %v", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
choice := conversation.Choice{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Message: conversation.Message{
|
||||
Content: strings.Join(contentFromMessaged, "\n"),
|
||||
},
|
||||
}
|
||||
|
||||
choice := conversation.Choice{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Message: conversation.Message{
|
||||
Content: content,
|
||||
},
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
choice.Message.ToolCallRequest = &toolCalls
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
choice.Message.ToolCallRequest = &toolCalls
|
||||
}
|
||||
|
||||
output := conversation.Result{
|
||||
StopReason: "stop",
|
||||
Choices: []conversation.Choice{choice},
|
||||
}
|
||||
|
||||
outputs = append(outputs, output)
|
||||
output := conversation.Result{
|
||||
StopReason: "stop",
|
||||
Choices: []conversation.Choice{choice},
|
||||
}
|
||||
|
||||
res = &conversation.Response{
|
||||
ConversationContext: r.ConversationContext,
|
||||
Outputs: outputs,
|
||||
Outputs: []conversation.Result{output},
|
||||
}
|
||||
|
||||
return res, nil
|
||||
|
|
|
|||
|
|
@ -97,19 +97,7 @@ func TestConverse(t *testing.T) {
|
|||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Message: conversation.Message{
|
||||
Content: "first message second message",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
StopReason: "stop",
|
||||
Choices: []conversation.Choice{
|
||||
{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Message: conversation.Message{
|
||||
Content: "third message",
|
||||
Content: "first message\nsecond message\nthird message",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -127,7 +115,7 @@ func TestConverse(t *testing.T) {
|
|||
Message: &tt.inputs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, r.Outputs, len(tt.expected.Outputs))
|
||||
assert.Len(t, r.Outputs, 1)
|
||||
assert.Equal(t, tt.expected.Outputs, r.Outputs)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package mistral
|
|||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/dapr/components-contrib/conversation"
|
||||
"github.com/dapr/components-contrib/conversation/langchaingokit"
|
||||
|
|
@ -110,13 +111,34 @@ func CreateToolCallPart(toolCall *llms.ToolCall) llms.ContentPart {
|
|||
// using the human role specifically otherwise mistral will reject the tool response message.
|
||||
// Most LLM providers can handle tool call responses using the tool call response object;
|
||||
// however, mistral requires it as text in conversation history.
|
||||
func CreateToolResponseMessage(response llms.ToolCallResponse) llms.MessageContent {
|
||||
return llms.MessageContent{
|
||||
func CreateToolResponseMessage(responses ...llms.ContentPart) llms.MessageContent {
|
||||
msg := llms.MessageContent{
|
||||
Role: llms.ChatMessageTypeHuman,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextContent{
|
||||
Text: "Tool response [ID: " + response.ToolCallID + ", Name: " + response.Name + "]: " + response.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
if len(responses) == 0 {
|
||||
return msg
|
||||
}
|
||||
toolID := ""
|
||||
name := ""
|
||||
|
||||
mistralContentParts := make([]string, 0, len(responses))
|
||||
for _, response := range responses {
|
||||
if resp, ok := response.(llms.ToolCallResponse); ok {
|
||||
if toolID == "" {
|
||||
toolID = resp.ToolCallID
|
||||
}
|
||||
if name == "" {
|
||||
name = resp.Name
|
||||
}
|
||||
mistralContentParts = append(mistralContentParts, resp.Content)
|
||||
}
|
||||
}
|
||||
if len(mistralContentParts) > 0 {
|
||||
msg.Parts = []llms.ContentPart{
|
||||
llms.TextContent{
|
||||
Text: "Tool response [ID: " + toolID + ", Name: " + name + "]: " + strings.Join(mistralContentParts, "\n"),
|
||||
},
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,4 +9,4 @@ spec:
|
|||
- name: key
|
||||
value: "${{HUGGINGFACE_API_KEY}}"
|
||||
- name: model
|
||||
value: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||
value: "HuggingFaceTB/SmolLM3-3B"
|
||||
|
|
@ -46,6 +46,8 @@ func NewTestConfig(componentName string) TestConfig {
|
|||
}
|
||||
|
||||
func ConformanceTests(t *testing.T, props map[string]string, conv conversation.Conversation, component string) {
|
||||
var providerStopReasons = []string{"stop", "end_turn", "FinishReasonStop"}
|
||||
|
||||
t.Run("init", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
|
@ -104,7 +106,7 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
|
|||
assert.Len(t, resp.Outputs, 1)
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
|
||||
// anthropic responds with end_turn but other llm providers return with stop
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason))
|
||||
assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
|
||||
})
|
||||
t.Run("test system message type", func(t *testing.T) {
|
||||
|
|
@ -133,20 +135,11 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
|
|||
resp, err := conv.Converse(ctx, req)
|
||||
|
||||
require.NoError(t, err)
|
||||
// Echo component returns one output per message, other components return one output
|
||||
if component == "echo" {
|
||||
assert.Len(t, resp.Outputs, 2)
|
||||
// Check the last output - system message
|
||||
assert.NotEmpty(t, resp.Outputs[1].Choices[0].Message.Content)
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[1].StopReason))
|
||||
assert.Empty(t, resp.Outputs[1].Choices[0].Message.ToolCallRequest)
|
||||
} else {
|
||||
assert.Len(t, resp.Outputs, 1)
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
|
||||
// anthropic responds with end_turn but other llm providers return with stop
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
|
||||
assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
|
||||
}
|
||||
assert.Len(t, resp.Outputs, 1)
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
|
||||
// anthropic responds with end_turn but other llm providers return with stop
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason), resp.Outputs[0].StopReason)
|
||||
assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
|
||||
})
|
||||
t.Run("test assistant message type", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 25*time.Second)
|
||||
|
|
@ -234,25 +227,14 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
|
|||
|
||||
require.NoError(t, err)
|
||||
// Echo component returns one output per message, other components return one output
|
||||
if component == "echo" {
|
||||
assert.Len(t, resp.Outputs, 4)
|
||||
// Check the last output - human message
|
||||
assert.NotEmpty(t, resp.Outputs[3].Choices[0].Message.Content)
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[3].StopReason))
|
||||
// Check the tool call output - second output
|
||||
if resp.Outputs[1].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[1].Choices[0].Message.ToolCallRequest) > 0 {
|
||||
assert.NotEmpty(t, resp.Outputs[1].Choices[0].Message.ToolCallRequest)
|
||||
require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[1].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments)
|
||||
}
|
||||
} else {
|
||||
assert.Len(t, resp.Outputs, 1)
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
|
||||
// anthropic responds with end_turn but other llm providers return with stop
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
|
||||
if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[0].Choices[0].Message.ToolCallRequest) > 0 {
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
|
||||
require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments)
|
||||
}
|
||||
|
||||
assert.Len(t, resp.Outputs, 1)
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
|
||||
// anthropic responds with end_turn but other llm providers return with stop
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason))
|
||||
if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[0].Choices[0].Message.ToolCallRequest) > 0 {
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
|
||||
require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
@ -277,13 +259,13 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
|
|||
assert.Len(t, resp.Outputs, 1)
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
|
||||
// anthropic responds with end_turn but other llm providers return with stop
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason))
|
||||
if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil {
|
||||
assert.Empty(t, *resp.Outputs[0].Choices[0].Message.ToolCallRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test tool message type - confirming active tool calling capability", func(t *testing.T) {
|
||||
t.Run("test tool message type - confirming active tool calling capability (empty tool choice)", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 25*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
|
@ -384,10 +366,10 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
|
|||
require.NoError(t, err2)
|
||||
assert.Len(t, resp2.Outputs, 1)
|
||||
assert.NotEmpty(t, resp2.Outputs[0].Choices[0].Message.Content)
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp2.Outputs[0].StopReason))
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp2.Outputs[0].StopReason))
|
||||
} else {
|
||||
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp.Outputs[0].StopReason))
|
||||
}
|
||||
})
|
||||
|
||||
|
|
@ -456,7 +438,7 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
|
|||
// check if we got a tool call request
|
||||
if found {
|
||||
assert.Equal(t, "retrieve_payment_status", toolCall.FunctionCall.Name)
|
||||
assert.Contains(t, toolCall.FunctionCall.Arguments, "T1001")
|
||||
assert.Contains(t, toolCall.FunctionCall.Arguments, "transaction_id")
|
||||
|
||||
toolResponse := llms.ToolCallResponse{
|
||||
ToolCallID: toolCall.ID,
|
||||
|
|
@ -508,11 +490,11 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
|
|||
require.NoError(t, err)
|
||||
assert.Len(t, resp2.Outputs, 1)
|
||||
assert.NotEmpty(t, resp2.Outputs[0].Choices[0].Message.Content)
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp2.Outputs[0].StopReason))
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp2.Outputs[0].StopReason))
|
||||
} else {
|
||||
// it is valid too if no tool call was generated
|
||||
assert.NotEmpty(t, resp1.Outputs[0].Choices[0].Message.Content)
|
||||
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp1.Outputs[0].StopReason))
|
||||
assert.True(t, slices.Contains(providerStopReasons, resp1.Outputs[0].StopReason))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in New Issue