Conversation API: add cache support, add huggingface+mistral models (#3567)

Signed-off-by: yaron2 <schneider.yaron@live.com>
This commit is contained in:
Yaron Schneider 2024-10-15 21:30:26 -07:00 committed by GitHub
parent 1cbedb3c0e
commit 4ca04dbb61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 449 additions and 99 deletions

View File

@ -28,16 +28,11 @@ import (
)
type Anthropic struct {
llm *anthropic.LLM
llm llms.Model
logger logger.Logger
}
type AnthropicMetadata struct {
Key string `json:"key"`
Model string `json:"model"`
}
func NewAnthropic(logger logger.Logger) conversation.Conversation {
a := &Anthropic{
logger: logger,
@ -49,7 +44,7 @@ func NewAnthropic(logger logger.Logger) conversation.Conversation {
const defaultModel = "claude-3-5-sonnet-20240620"
func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error {
m := AnthropicMetadata{}
m := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &m)
if err != nil {
return err
@ -69,11 +64,21 @@ func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error
}
a.llm = llm
if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.llm)
if cacheErr != nil {
return cacheErr
}
a.llm = cachedModel
}
return nil
}
func (a *Anthropic) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
metadataStruct := AnthropicMetadata{}
metadataStruct := conversation.LangchainMetadata{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
return
}

View File

@ -27,3 +27,9 @@ metadata:
The Anthropic LLM to use. Defaults to claude-3-5-sonnet-20240620
type: string
example: 'claude-3-5-sonnet-20240620'
- name: cacheTTL
required: false
description: |
A time-to-live value for a prompt cache to expire. Uses Golang durations
type: string
example: '10m'

View File

@ -31,7 +31,7 @@ import (
type AWSBedrock struct {
model string
llm *bedrock.LLM
llm llms.Model
logger logger.Logger
}
@ -43,6 +43,7 @@ type AWSBedrockMetadata struct {
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Model string `json:"model"`
CacheTTL string `json:"cacheTTL"`
}
func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
@ -81,6 +82,15 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error
}
b.llm = llm
if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.llm)
if cacheErr != nil {
return cacheErr
}
b.llm = cachedModel
}
return nil
}

View File

@ -24,3 +24,9 @@ metadata:
The LLM to use. Defaults to Bedrock's default provider model from Amazon.
type: string
example: 'amazon.titan-text-express-v1'
- name: cacheTTL
required: false
description: |
A time-to-live value for a prompt cache to expire. Uses Golang durations
type: string
example: '10m'

View File

@ -0,0 +1,129 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package huggingface
import (
"context"
"reflect"
"github.com/dapr/components-contrib/conversation"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
kmeta "github.com/dapr/kit/metadata"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/huggingface"
)
type Huggingface struct {
llm llms.Model
logger logger.Logger
}
func NewHuggingface(logger logger.Logger) conversation.Conversation {
h := &Huggingface{
logger: logger,
}
return h
}
const defaultModel = "meta-llama/Meta-Llama-3-8B"
func (h *Huggingface) Init(ctx context.Context, meta conversation.Metadata) error {
m := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &m)
if err != nil {
return err
}
model := defaultModel
if m.Model != "" {
model = m.Model
}
llm, err := huggingface.New(
huggingface.WithModel(model),
huggingface.WithToken(m.Key),
)
if err != nil {
return err
}
h.llm = llm
if m.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, h.llm)
if cacheErr != nil {
return cacheErr
}
h.llm = cachedModel
}
return nil
}
func (h *Huggingface) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
metadataStruct := conversation.LangchainMetadata{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
return
}
func (h *Huggingface) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
messages := make([]llms.MessageContent, 0, len(r.Inputs))
for _, input := range r.Inputs {
role := conversation.ConvertLangchainRole(input.Role)
messages = append(messages, llms.MessageContent{
Role: role,
Parts: []llms.ContentPart{
llms.TextPart(input.Message),
},
})
}
opts := []llms.CallOption{}
if r.Temperature > 0 {
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
}
resp, err := h.llm.GenerateContent(ctx, messages, opts...)
if err != nil {
return nil, err
}
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
for i := range resp.Choices {
outputs = append(outputs, conversation.ConversationResult{
Result: resp.Choices[i].Content,
Parameters: r.Parameters,
})
}
res = &conversation.ConversationResponse{
Outputs: outputs,
}
return res, nil
}
func (h *Huggingface) Close() error {
return nil
}

View File

@ -0,0 +1,35 @@
# yaml-language-server: $schema=../../../component-metadata-schema.json
schemaVersion: v1
type: conversation
name: huggingface
version: v1
status: alpha
title: "Huggingface"
urls:
- title: Reference
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-huggingface/
authenticationProfiles:
- title: "API Key"
description: "Authenticate using an API key"
metadata:
- name: key
type: string
required: true
sensitive: true
description: |
API key for Huggingface.
example: "**********"
default: ""
metadata:
- name: model
required: false
description: |
The Huggingface LLM to use. Defaults to meta-llama/Meta-Llama-3-8B
type: string
example: 'meta-llama/Meta-Llama-3-8B'
- name: cacheTTL
required: false
description: |
A time-to-live value for a prompt cache to expire. Uses Golang durations
type: string
example: '10m'

View File

@ -20,3 +20,10 @@ import "github.com/dapr/components-contrib/metadata"
type Metadata struct {
metadata.Base `json:",inline"`
}
// LangchainMetadata is a common metadata structure for langchain supported implementations.
type LangchainMetadata struct {
Key string `json:"key"`
Model string `json:"model"`
CacheTTL string `json:"cacheTTL"`
}

View File

@ -0,0 +1,35 @@
# yaml-language-server: $schema=../../../component-metadata-schema.json
schemaVersion: v1
type: conversation
name: mistral
version: v1
status: alpha
title: "Mistral"
urls:
- title: Reference
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-mistral/
authenticationProfiles:
- title: "API Key"
description: "Authenticate using an API key"
metadata:
- name: key
type: string
required: true
sensitive: true
description: |
API key for Mistral.
example: "**********"
default: ""
metadata:
- name: model
required: false
description: |
The Mistral LLM to use. Defaults to open-mistral-7b
type: string
example: 'open-mistral-7b'
- name: cacheTTL
required: false
description: |
A time-to-live value for a prompt cache to expire. Uses Golang durations
type: string
example: '10m'

View File

@ -0,0 +1,128 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mistral
import (
"context"
"reflect"
"github.com/dapr/components-contrib/conversation"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
kmeta "github.com/dapr/kit/metadata"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/mistral"
)
type Mistral struct {
llm llms.Model
logger logger.Logger
}
func NewMistral(logger logger.Logger) conversation.Conversation {
m := &Mistral{
logger: logger,
}
return m
}
const defaultModel = "open-mistral-7b"
func (m *Mistral) Init(ctx context.Context, meta conversation.Metadata) error {
md := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &md)
if err != nil {
return err
}
model := defaultModel
if md.Model != "" {
model = md.Model
}
llm, err := mistral.New(
mistral.WithModel(model),
mistral.WithAPIKey(md.Key),
)
if err != nil {
return err
}
m.llm = llm
if md.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, m.llm)
if cacheErr != nil {
return cacheErr
}
m.llm = cachedModel
}
return nil
}
func (m *Mistral) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
metadataStruct := conversation.LangchainMetadata{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
return
}
func (m *Mistral) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
messages := make([]llms.MessageContent, 0, len(r.Inputs))
for _, input := range r.Inputs {
role := conversation.ConvertLangchainRole(input.Role)
messages = append(messages, llms.MessageContent{
Role: role,
Parts: []llms.ContentPart{
llms.TextPart(input.Message),
},
})
}
opts := []llms.CallOption{}
if r.Temperature > 0 {
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
}
resp, err := m.llm.GenerateContent(ctx, messages, opts...)
if err != nil {
return nil, err
}
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
for i := range resp.Choices {
outputs = append(outputs, conversation.ConversationResult{
Result: resp.Choices[i].Content,
Parameters: r.Parameters,
})
}
res = &conversation.ConversationResponse{
Outputs: outputs,
}
return res, nil
}
func (m *Mistral) Close() error {
return nil
}

View File

@ -27,3 +27,9 @@ metadata:
The OpenAI LLM to use. Defaults to gpt-4o
type: string
example: 'gpt-4-turbo'
- name: cacheTTL
required: false
description: |
A time-to-live value for a prompt cache to expire. Uses Golang durations
type: string
example: '10m'

View File

@ -23,14 +23,12 @@ import (
"github.com/dapr/kit/logger"
kmeta "github.com/dapr/kit/metadata"
openai "github.com/sashabaranov/go-openai"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
)
const defaultModel = "gpt-4o"
type OpenAI struct {
client *openai.Client
model string
llm llms.Model
logger logger.Logger
}
@ -43,83 +41,69 @@ func NewOpenAI(logger logger.Logger) conversation.Conversation {
return o
}
const defaultModel = "gpt-4o"
func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error {
r := &conversation.ConversationRequest{}
err := kmeta.DecodeMetadata(meta.Properties, r)
md := conversation.LangchainMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &md)
if err != nil {
return err
}
o.client = openai.NewClient(r.Key)
o.model = r.Model
if o.model == "" {
o.model = defaultModel
model := defaultModel
if md.Model != "" {
model = md.Model
}
llm, err := openai.New(
openai.WithModel(model),
openai.WithToken(md.Key),
)
if err != nil {
return err
}
o.llm = llm
if md.CacheTTL != "" {
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, o.llm)
if cacheErr != nil {
return cacheErr
}
o.llm = cachedModel
}
return nil
}
func (o *OpenAI) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
metadataStruct := conversation.ConversationRequest{}
metadataStruct := conversation.LangchainMetadata{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
return
}
func convertRole(role conversation.Role) string {
switch role {
case conversation.RoleSystem:
return string(openai.ChatMessageRoleSystem)
case conversation.RoleUser:
return string(openai.ChatMessageRoleUser)
case conversation.RoleAssistant:
return string(openai.ChatMessageRoleAssistant)
case conversation.RoleTool:
return string(openai.ChatMessageRoleTool)
case conversation.RoleFunction:
return string(openai.ChatMessageRoleFunction)
default:
return string(openai.ChatMessageRoleUser)
}
}
func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
// Note: OPENAI does not support load balance
messages := make([]openai.ChatCompletionMessage, 0, len(r.Inputs))
var systemPrompt string
messages := make([]llms.MessageContent, 0, len(r.Inputs))
for _, input := range r.Inputs {
role := convertRole(input.Role)
if role == openai.ChatMessageRoleSystem {
systemPrompt = input.Message
continue
}
role := conversation.ConvertLangchainRole(input.Role)
messages = append(messages, openai.ChatCompletionMessage{
Role: role,
Content: input.Message,
messages = append(messages, llms.MessageContent{
Role: role,
Parts: []llms.ContentPart{
llms.TextPart(input.Message),
},
})
}
// OpenAI needs system prompts to be added last in the array to function properly
if systemPrompt != "" {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: systemPrompt,
})
opts := []llms.CallOption{}
if r.Temperature > 0 {
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
}
req := openai.ChatCompletionRequest{
Model: o.model,
Messages: messages,
Temperature: float32(r.Temperature),
}
// TODO: support ConversationContext
resp, err := o.client.CreateChatCompletion(ctx, req)
resp, err := o.llm.GenerateContent(ctx, messages, opts...)
if err != nil {
o.logger.Error(err)
return nil, err
}
@ -127,14 +111,13 @@ func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationReque
for i := range resp.Choices {
outputs = append(outputs, conversation.ConversationResult{
Result: resp.Choices[i].Message.Content,
Result: resp.Choices[i].Content,
Parameters: r.Parameters,
})
}
res = &conversation.ConversationResponse{
ConversationContext: resp.ID,
Outputs: outputs,
Outputs: outputs,
}
return res, nil

View File

@ -1,26 +0,0 @@
package openai
import (
"testing"
"github.com/dapr/components-contrib/conversation"
"github.com/stretchr/testify/assert"
openai "github.com/sashabaranov/go-openai"
)
func TestConvertRole(t *testing.T) {
roles := map[string]string{
conversation.RoleSystem: string(openai.ChatMessageRoleSystem),
conversation.RoleAssistant: string(openai.ChatMessageRoleAssistant),
conversation.RoleFunction: string(openai.ChatMessageRoleFunction),
conversation.RoleUser: string(openai.ChatMessageRoleUser),
conversation.RoleTool: string(openai.ChatMessageRoleTool),
}
for k, v := range roles {
r := convertRole(conversation.Role(k))
assert.Equal(t, v, r)
}
}

View File

@ -14,9 +14,32 @@ limitations under the License.
*/
package conversation
import "github.com/tmc/langchaingo/llms"
import (
"context"
"fmt"
"time"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/cache"
"github.com/tmc/langchaingo/llms/cache/inmemory"
)
// LangchainTemperature returns a langchain compliant LLM temperature
func LangchainTemperature(temperature float64) llms.CallOption {
return llms.WithTemperature(temperature)
}
// CacheModel creates a prompt query cache with a configured TTL
func CacheModel(ctx context.Context, ttl string, model llms.Model) (llms.Model, error) {
d, err := time.ParseDuration(ttl)
if err != nil {
return model, fmt.Errorf("failed to parse cacheTTL duration: %s", err)
}
mem, err := inmemory.New(ctx, inmemory.WithExpiration(d))
if err != nil {
return model, fmt.Errorf("failed to create llm cache: %s", err)
}
return cache.New(model, mem), nil
}

3
go.mod
View File

@ -102,7 +102,6 @@ require (
github.com/rabbitmq/amqp091-go v1.8.1
github.com/redis/go-redis/v9 v9.2.1
github.com/riferrei/srclient v0.6.0
github.com/sashabaranov/go-openai v1.27.1
github.com/sendgrid/sendgrid-go v3.13.0+incompatible
github.com/sijms/go-ora/v2 v2.7.18
github.com/spf13/cast v1.5.1
@ -155,6 +154,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/Code-Hex/go-generics-cache v1.3.1 // indirect
github.com/DataDog/zstd v1.5.2 // indirect
github.com/OneOfOne/xxhash v1.2.8 // indirect
github.com/RoaringBitmap/roaring v1.1.0 // indirect
@ -224,6 +224,7 @@ require (
github.com/fatih/color v1.17.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gage-technologies/mistral-go v1.0.0 // indirect
github.com/gavv/httpexpect v2.0.0+incompatible // indirect
github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 // indirect
github.com/go-ini/ini v1.67.0 // indirect

6
go.sum
View File

@ -113,6 +113,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I=
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Code-Hex/go-generics-cache v1.3.1 h1:i8rLwyhoyhaerr7JpjtYjJZUcCbWOdiYO3fZXLiEC4g=
github.com/Code-Hex/go-generics-cache v1.3.1/go.mod h1:qxcC9kRVrct9rHeiYpFWSoW1vxyillCVzX13KZG8dl4=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
@ -577,6 +579,8 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4
github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A=
github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8=
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
github.com/getkin/kin-openapi v0.94.0/go.mod h1:LWZfzOd7PRy8GJ1dJ6mCU6tNdSfOwRac1BUPam4aw6Q=
@ -1438,8 +1442,6 @@ github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIH
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
github.com/santhosh-tekuri/jsonschema/v5 v5.0.0 h1:TToq11gyfNlrMFZiYujSekIsPd9AmsA2Bj/iv+s4JHE=
github.com/santhosh-tekuri/jsonschema/v5 v5.0.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0=
github.com/sashabaranov/go-openai v1.27.1 h1:7Nx6db5NXbcoutNmAUQulEQZEpHG/SkzfexP2X5RWMk=
github.com/sashabaranov/go-openai v1.27.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=