Add component for Anthropic (#3564)
Signed-off-by: yaron2 <schneider.yaron@live.com>
This commit is contained in:
parent
69119d6f6c
commit
c53499343a
|
@ -0,0 +1,118 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/dapr/components-contrib/conversation"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/kit/logger"
|
||||
kmeta "github.com/dapr/kit/metadata"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/anthropic"
|
||||
)
|
||||
|
||||
type Anthropic struct {
|
||||
llm *anthropic.LLM
|
||||
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
type AnthropicMetadata struct {
|
||||
Key string `json:"key"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func NewAnthropic(logger logger.Logger) conversation.Conversation {
|
||||
a := &Anthropic{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
const defaultModel = "claude-3-5-sonnet-20240620"
|
||||
|
||||
func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error {
|
||||
m := AnthropicMetadata{}
|
||||
err := kmeta.DecodeMetadata(meta.Properties, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model := defaultModel
|
||||
if m.Model != "" {
|
||||
model = m.Model
|
||||
}
|
||||
|
||||
llm, err := anthropic.New(
|
||||
anthropic.WithModel(model),
|
||||
anthropic.WithToken(m.Key),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.llm = llm
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Anthropic) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
||||
metadataStruct := AnthropicMetadata{}
|
||||
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Anthropic) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
|
||||
messages := make([]llms.MessageContent, 0, len(r.Inputs))
|
||||
|
||||
for _, input := range r.Inputs {
|
||||
role := conversation.ConvertLangchainRole(input.Role)
|
||||
|
||||
messages = append(messages, llms.MessageContent{
|
||||
Role: role,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(input.Message),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
resp, err := a.llm.GenerateContent(ctx, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
|
||||
|
||||
for i := range resp.Choices {
|
||||
outputs = append(outputs, conversation.ConversationResult{
|
||||
Result: resp.Choices[i].Content,
|
||||
Parameters: r.Parameters,
|
||||
})
|
||||
}
|
||||
|
||||
res = &conversation.ConversationResponse{
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (a *Anthropic) Close() error {
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
# yaml-language-server: $schema=../../../component-metadata-schema.json
|
||||
schemaVersion: v1
|
||||
type: conversation
|
||||
name: anthropic
|
||||
version: v1
|
||||
status: alpha
|
||||
title: "Anthropic"
|
||||
urls:
|
||||
- title: Reference
|
||||
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-anthropic/
|
||||
authenticationProfiles:
|
||||
- title: "API Key"
|
||||
description: "Authenticate using an API key"
|
||||
metadata:
|
||||
- name: key
|
||||
type: string
|
||||
required: true
|
||||
sensitive: true
|
||||
description: |
|
||||
API key for Anthropic.
|
||||
example: "**********"
|
||||
default: ""
|
||||
metadata:
|
||||
- name: model
|
||||
required: false
|
||||
description: |
|
||||
The Anthropic LLM to use. Defaults to claude-3-5-sonnet-20240620
|
||||
type: string
|
||||
example: 'claude-3-5-sonnet-20240620'
|
|
@ -53,23 +53,6 @@ func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
|
|||
return b
|
||||
}
|
||||
|
||||
func convertRole(role conversation.Role) llms.ChatMessageType {
|
||||
switch role {
|
||||
case conversation.RoleSystem:
|
||||
return llms.ChatMessageTypeSystem
|
||||
case conversation.RoleUser:
|
||||
return llms.ChatMessageTypeHuman
|
||||
case conversation.RoleAssistant:
|
||||
return llms.ChatMessageTypeAI
|
||||
case conversation.RoleTool:
|
||||
return llms.ChatMessageTypeTool
|
||||
case conversation.RoleFunction:
|
||||
return llms.ChatMessageTypeFunction
|
||||
default:
|
||||
return llms.ChatMessageTypeHuman
|
||||
}
|
||||
}
|
||||
|
||||
func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error {
|
||||
m := AWSBedrockMetadata{}
|
||||
err := kmeta.DecodeMetadata(meta.Properties, &m)
|
||||
|
@ -111,7 +94,7 @@ func (b *AWSBedrock) Converse(ctx context.Context, r *conversation.ConversationR
|
|||
messages := make([]llms.MessageContent, 0, len(r.Inputs))
|
||||
|
||||
for _, input := range r.Inputs {
|
||||
role := convertRole(input.Role)
|
||||
role := conversation.ConvertLangchainRole(input.Role)
|
||||
|
||||
messages = append(messages, llms.MessageContent{
|
||||
Role: role,
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
package bedrock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dapr/components-contrib/conversation"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
)
|
||||
|
||||
func TestConvertRole(t *testing.T) {
|
||||
roles := map[string]string{
|
||||
conversation.RoleSystem: string(llms.ChatMessageTypeSystem),
|
||||
conversation.RoleAssistant: string(llms.ChatMessageTypeAI),
|
||||
conversation.RoleFunction: string(llms.ChatMessageTypeFunction),
|
||||
conversation.RoleUser: string(llms.ChatMessageTypeHuman),
|
||||
conversation.RoleTool: string(llms.ChatMessageTypeTool),
|
||||
}
|
||||
|
||||
for k, v := range roles {
|
||||
r := convertRole(conversation.Role(k))
|
||||
assert.Equal(t, v, string(r))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package conversation
|
||||
|
||||
import "github.com/tmc/langchaingo/llms"
|
||||
|
||||
func ConvertLangchainRole(role Role) llms.ChatMessageType {
|
||||
switch role {
|
||||
case RoleSystem:
|
||||
return llms.ChatMessageTypeSystem
|
||||
case RoleUser:
|
||||
return llms.ChatMessageTypeHuman
|
||||
case RoleAssistant:
|
||||
return llms.ChatMessageTypeAI
|
||||
case RoleTool:
|
||||
return llms.ChatMessageTypeTool
|
||||
case RoleFunction:
|
||||
return llms.ChatMessageTypeFunction
|
||||
default:
|
||||
return llms.ChatMessageTypeHuman
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConvertLangchainRole(t *testing.T) {
|
||||
roles := map[string]string{
|
||||
RoleSystem: string(llms.ChatMessageTypeSystem),
|
||||
RoleAssistant: string(llms.ChatMessageTypeAI),
|
||||
RoleFunction: string(llms.ChatMessageTypeFunction),
|
||||
RoleUser: string(llms.ChatMessageTypeHuman),
|
||||
RoleTool: string(llms.ChatMessageTypeTool),
|
||||
}
|
||||
|
||||
for k, v := range roles {
|
||||
r := ConvertLangchainRole(Role(k))
|
||||
assert.Equal(t, v, string(r))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue