Conversation API: add cache support, add huggingface+mistral models (#3567)
Signed-off-by: yaron2 <schneider.yaron@live.com>
This commit is contained in:
parent
1cbedb3c0e
commit
4ca04dbb61
|
@ -28,16 +28,11 @@ import (
|
|||
)
|
||||
|
||||
type Anthropic struct {
|
||||
llm *anthropic.LLM
|
||||
llm llms.Model
|
||||
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
type AnthropicMetadata struct {
|
||||
Key string `json:"key"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func NewAnthropic(logger logger.Logger) conversation.Conversation {
|
||||
a := &Anthropic{
|
||||
logger: logger,
|
||||
|
@ -49,7 +44,7 @@ func NewAnthropic(logger logger.Logger) conversation.Conversation {
|
|||
const defaultModel = "claude-3-5-sonnet-20240620"
|
||||
|
||||
func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error {
|
||||
m := AnthropicMetadata{}
|
||||
m := conversation.LangchainMetadata{}
|
||||
err := kmeta.DecodeMetadata(meta.Properties, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -69,11 +64,21 @@ func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error
|
|||
}
|
||||
|
||||
a.llm = llm
|
||||
|
||||
if m.CacheTTL != "" {
|
||||
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.llm)
|
||||
if cacheErr != nil {
|
||||
return cacheErr
|
||||
}
|
||||
|
||||
a.llm = cachedModel
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Anthropic) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
||||
metadataStruct := AnthropicMetadata{}
|
||||
metadataStruct := conversation.LangchainMetadata{}
|
||||
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -27,3 +27,9 @@ metadata:
|
|||
The Anthropic LLM to use. Defaults to claude-3-5-sonnet-20240620
|
||||
type: string
|
||||
example: 'claude-3-5-sonnet-20240620'
|
||||
- name: cacheTTL
|
||||
required: false
|
||||
description: |
|
||||
A time-to-live value for a prompt cache to expire. Uses Golang durations
|
||||
type: string
|
||||
example: '10m'
|
||||
|
|
|
@ -31,7 +31,7 @@ import (
|
|||
|
||||
type AWSBedrock struct {
|
||||
model string
|
||||
llm *bedrock.LLM
|
||||
llm llms.Model
|
||||
|
||||
logger logger.Logger
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ type AWSBedrockMetadata struct {
|
|||
SecretKey string `json:"secretKey"`
|
||||
SessionToken string `json:"sessionToken"`
|
||||
Model string `json:"model"`
|
||||
CacheTTL string `json:"cacheTTL"`
|
||||
}
|
||||
|
||||
func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
|
||||
|
@ -81,6 +82,15 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error
|
|||
}
|
||||
|
||||
b.llm = llm
|
||||
|
||||
if m.CacheTTL != "" {
|
||||
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.llm)
|
||||
if cacheErr != nil {
|
||||
return cacheErr
|
||||
}
|
||||
|
||||
b.llm = cachedModel
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -24,3 +24,9 @@ metadata:
|
|||
The LLM to use. Defaults to Bedrock's default provider model from Amazon.
|
||||
type: string
|
||||
example: 'amazon.titan-text-express-v1'
|
||||
- name: cacheTTL
|
||||
required: false
|
||||
description: |
|
||||
A time-to-live value for a prompt cache to expire. Uses Golang durations
|
||||
type: string
|
||||
example: '10m'
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/dapr/components-contrib/conversation"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/kit/logger"
|
||||
kmeta "github.com/dapr/kit/metadata"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/huggingface"
|
||||
)
|
||||
|
||||
type Huggingface struct {
|
||||
llm llms.Model
|
||||
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewHuggingface(logger logger.Logger) conversation.Conversation {
|
||||
h := &Huggingface{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
const defaultModel = "meta-llama/Meta-Llama-3-8B"
|
||||
|
||||
func (h *Huggingface) Init(ctx context.Context, meta conversation.Metadata) error {
|
||||
m := conversation.LangchainMetadata{}
|
||||
err := kmeta.DecodeMetadata(meta.Properties, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model := defaultModel
|
||||
if m.Model != "" {
|
||||
model = m.Model
|
||||
}
|
||||
|
||||
llm, err := huggingface.New(
|
||||
huggingface.WithModel(model),
|
||||
huggingface.WithToken(m.Key),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.llm = llm
|
||||
|
||||
if m.CacheTTL != "" {
|
||||
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, h.llm)
|
||||
if cacheErr != nil {
|
||||
return cacheErr
|
||||
}
|
||||
|
||||
h.llm = cachedModel
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Huggingface) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
||||
metadataStruct := conversation.LangchainMetadata{}
|
||||
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Huggingface) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
|
||||
messages := make([]llms.MessageContent, 0, len(r.Inputs))
|
||||
|
||||
for _, input := range r.Inputs {
|
||||
role := conversation.ConvertLangchainRole(input.Role)
|
||||
|
||||
messages = append(messages, llms.MessageContent{
|
||||
Role: role,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(input.Message),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
opts := []llms.CallOption{}
|
||||
|
||||
if r.Temperature > 0 {
|
||||
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
|
||||
}
|
||||
|
||||
resp, err := h.llm.GenerateContent(ctx, messages, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
|
||||
|
||||
for i := range resp.Choices {
|
||||
outputs = append(outputs, conversation.ConversationResult{
|
||||
Result: resp.Choices[i].Content,
|
||||
Parameters: r.Parameters,
|
||||
})
|
||||
}
|
||||
|
||||
res = &conversation.ConversationResponse{
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (h *Huggingface) Close() error {
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
# yaml-language-server: $schema=../../../component-metadata-schema.json
|
||||
schemaVersion: v1
|
||||
type: conversation
|
||||
name: huggingface
|
||||
version: v1
|
||||
status: alpha
|
||||
title: "Huggingface"
|
||||
urls:
|
||||
- title: Reference
|
||||
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-huggingface/
|
||||
authenticationProfiles:
|
||||
- title: "API Key"
|
||||
description: "Authenticate using an API key"
|
||||
metadata:
|
||||
- name: key
|
||||
type: string
|
||||
required: true
|
||||
sensitive: true
|
||||
description: |
|
||||
API key for Huggingface.
|
||||
example: "**********"
|
||||
default: ""
|
||||
metadata:
|
||||
- name: model
|
||||
required: false
|
||||
description: |
|
||||
The Huggingface LLM to use. Defaults to meta-llama/Meta-Llama-3-8B
|
||||
type: string
|
||||
example: 'meta-llama/Meta-Llama-3-8B'
|
||||
- name: cacheTTL
|
||||
required: false
|
||||
description: |
|
||||
A time-to-live value for a prompt cache to expire. Uses Golang durations
|
||||
type: string
|
||||
example: '10m'
|
|
@ -20,3 +20,10 @@ import "github.com/dapr/components-contrib/metadata"
|
|||
type Metadata struct {
|
||||
metadata.Base `json:",inline"`
|
||||
}
|
||||
|
||||
// LangchainMetadata is a common metadata structure for langchain supported implementations.
|
||||
type LangchainMetadata struct {
|
||||
Key string `json:"key"`
|
||||
Model string `json:"model"`
|
||||
CacheTTL string `json:"cacheTTL"`
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# yaml-language-server: $schema=../../../component-metadata-schema.json
|
||||
schemaVersion: v1
|
||||
type: conversation
|
||||
name: mistral
|
||||
version: v1
|
||||
status: alpha
|
||||
title: "Mistral"
|
||||
urls:
|
||||
- title: Reference
|
||||
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-mistral/
|
||||
authenticationProfiles:
|
||||
- title: "API Key"
|
||||
description: "Authenticate using an API key"
|
||||
metadata:
|
||||
- name: key
|
||||
type: string
|
||||
required: true
|
||||
sensitive: true
|
||||
description: |
|
||||
API key for Mistral.
|
||||
example: "**********"
|
||||
default: ""
|
||||
metadata:
|
||||
- name: model
|
||||
required: false
|
||||
description: |
|
||||
The Mistral LLM to use. Defaults to open-mistral-7b
|
||||
type: string
|
||||
example: 'open-mistral-7b'
|
||||
- name: cacheTTL
|
||||
required: false
|
||||
description: |
|
||||
A time-to-live value for a prompt cache to expire. Uses Golang durations
|
||||
type: string
|
||||
example: '10m'
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
Copyright 2024 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/dapr/components-contrib/conversation"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/kit/logger"
|
||||
kmeta "github.com/dapr/kit/metadata"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/mistral"
|
||||
)
|
||||
|
||||
type Mistral struct {
|
||||
llm llms.Model
|
||||
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewMistral(logger logger.Logger) conversation.Conversation {
|
||||
m := &Mistral{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
const defaultModel = "open-mistral-7b"
|
||||
|
||||
func (m *Mistral) Init(ctx context.Context, meta conversation.Metadata) error {
|
||||
md := conversation.LangchainMetadata{}
|
||||
err := kmeta.DecodeMetadata(meta.Properties, &md)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model := defaultModel
|
||||
if md.Model != "" {
|
||||
model = md.Model
|
||||
}
|
||||
|
||||
llm, err := mistral.New(
|
||||
mistral.WithModel(model),
|
||||
mistral.WithAPIKey(md.Key),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.llm = llm
|
||||
|
||||
if md.CacheTTL != "" {
|
||||
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, m.llm)
|
||||
if cacheErr != nil {
|
||||
return cacheErr
|
||||
}
|
||||
|
||||
m.llm = cachedModel
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mistral) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
||||
metadataStruct := conversation.LangchainMetadata{}
|
||||
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Mistral) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
|
||||
messages := make([]llms.MessageContent, 0, len(r.Inputs))
|
||||
|
||||
for _, input := range r.Inputs {
|
||||
role := conversation.ConvertLangchainRole(input.Role)
|
||||
|
||||
messages = append(messages, llms.MessageContent{
|
||||
Role: role,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(input.Message),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
opts := []llms.CallOption{}
|
||||
|
||||
if r.Temperature > 0 {
|
||||
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
|
||||
}
|
||||
|
||||
resp, err := m.llm.GenerateContent(ctx, messages, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
|
||||
|
||||
for i := range resp.Choices {
|
||||
outputs = append(outputs, conversation.ConversationResult{
|
||||
Result: resp.Choices[i].Content,
|
||||
Parameters: r.Parameters,
|
||||
})
|
||||
}
|
||||
|
||||
res = &conversation.ConversationResponse{
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (m *Mistral) Close() error {
|
||||
return nil
|
||||
}
|
|
@ -27,3 +27,9 @@ metadata:
|
|||
The OpenAI LLM to use. Defaults to gpt-4o
|
||||
type: string
|
||||
example: 'gpt-4-turbo'
|
||||
- name: cacheTTL
|
||||
required: false
|
||||
description: |
|
||||
A time-to-live value for a prompt cache to expire. Uses Golang durations
|
||||
type: string
|
||||
example: '10m'
|
||||
|
|
|
@ -23,14 +23,12 @@ import (
|
|||
"github.com/dapr/kit/logger"
|
||||
kmeta "github.com/dapr/kit/metadata"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/openai"
|
||||
)
|
||||
|
||||
const defaultModel = "gpt-4o"
|
||||
|
||||
type OpenAI struct {
|
||||
client *openai.Client
|
||||
model string
|
||||
llm llms.Model
|
||||
|
||||
logger logger.Logger
|
||||
}
|
||||
|
@ -43,83 +41,69 @@ func NewOpenAI(logger logger.Logger) conversation.Conversation {
|
|||
return o
|
||||
}
|
||||
|
||||
const defaultModel = "gpt-4o"
|
||||
|
||||
func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error {
|
||||
r := &conversation.ConversationRequest{}
|
||||
err := kmeta.DecodeMetadata(meta.Properties, r)
|
||||
md := conversation.LangchainMetadata{}
|
||||
err := kmeta.DecodeMetadata(meta.Properties, &md)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
o.client = openai.NewClient(r.Key)
|
||||
o.model = r.Model
|
||||
|
||||
if o.model == "" {
|
||||
o.model = defaultModel
|
||||
model := defaultModel
|
||||
if md.Model != "" {
|
||||
model = md.Model
|
||||
}
|
||||
|
||||
llm, err := openai.New(
|
||||
openai.WithModel(model),
|
||||
openai.WithToken(md.Key),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
o.llm = llm
|
||||
|
||||
if md.CacheTTL != "" {
|
||||
cachedModel, cacheErr := conversation.CacheModel(ctx, md.CacheTTL, o.llm)
|
||||
if cacheErr != nil {
|
||||
return cacheErr
|
||||
}
|
||||
|
||||
o.llm = cachedModel
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpenAI) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
||||
metadataStruct := conversation.ConversationRequest{}
|
||||
metadataStruct := conversation.LangchainMetadata{}
|
||||
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
|
||||
return
|
||||
}
|
||||
|
||||
func convertRole(role conversation.Role) string {
|
||||
switch role {
|
||||
case conversation.RoleSystem:
|
||||
return string(openai.ChatMessageRoleSystem)
|
||||
case conversation.RoleUser:
|
||||
return string(openai.ChatMessageRoleUser)
|
||||
case conversation.RoleAssistant:
|
||||
return string(openai.ChatMessageRoleAssistant)
|
||||
case conversation.RoleTool:
|
||||
return string(openai.ChatMessageRoleTool)
|
||||
case conversation.RoleFunction:
|
||||
return string(openai.ChatMessageRoleFunction)
|
||||
default:
|
||||
return string(openai.ChatMessageRoleUser)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
|
||||
// Note: OPENAI does not support load balance
|
||||
messages := make([]openai.ChatCompletionMessage, 0, len(r.Inputs))
|
||||
|
||||
var systemPrompt string
|
||||
messages := make([]llms.MessageContent, 0, len(r.Inputs))
|
||||
|
||||
for _, input := range r.Inputs {
|
||||
role := convertRole(input.Role)
|
||||
if role == openai.ChatMessageRoleSystem {
|
||||
systemPrompt = input.Message
|
||||
continue
|
||||
}
|
||||
role := conversation.ConvertLangchainRole(input.Role)
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: input.Message,
|
||||
messages = append(messages, llms.MessageContent{
|
||||
Role: role,
|
||||
Parts: []llms.ContentPart{
|
||||
llms.TextPart(input.Message),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// OpenAI needs system prompts to be added last in the array to function properly
|
||||
if systemPrompt != "" {
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: systemPrompt,
|
||||
})
|
||||
opts := []llms.CallOption{}
|
||||
|
||||
if r.Temperature > 0 {
|
||||
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
|
||||
}
|
||||
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: o.model,
|
||||
Messages: messages,
|
||||
Temperature: float32(r.Temperature),
|
||||
}
|
||||
|
||||
// TODO: support ConversationContext
|
||||
resp, err := o.client.CreateChatCompletion(ctx, req)
|
||||
resp, err := o.llm.GenerateContent(ctx, messages, opts...)
|
||||
if err != nil {
|
||||
o.logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -127,14 +111,13 @@ func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationReque
|
|||
|
||||
for i := range resp.Choices {
|
||||
outputs = append(outputs, conversation.ConversationResult{
|
||||
Result: resp.Choices[i].Message.Content,
|
||||
Result: resp.Choices[i].Content,
|
||||
Parameters: r.Parameters,
|
||||
})
|
||||
}
|
||||
|
||||
res = &conversation.ConversationResponse{
|
||||
ConversationContext: resp.ID,
|
||||
Outputs: outputs,
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dapr/components-contrib/conversation"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func TestConvertRole(t *testing.T) {
|
||||
roles := map[string]string{
|
||||
conversation.RoleSystem: string(openai.ChatMessageRoleSystem),
|
||||
conversation.RoleAssistant: string(openai.ChatMessageRoleAssistant),
|
||||
conversation.RoleFunction: string(openai.ChatMessageRoleFunction),
|
||||
conversation.RoleUser: string(openai.ChatMessageRoleUser),
|
||||
conversation.RoleTool: string(openai.ChatMessageRoleTool),
|
||||
}
|
||||
|
||||
for k, v := range roles {
|
||||
r := convertRole(conversation.Role(k))
|
||||
assert.Equal(t, v, r)
|
||||
}
|
||||
}
|
|
@ -14,9 +14,32 @@ limitations under the License.
|
|||
*/
|
||||
package conversation
|
||||
|
||||
import "github.com/tmc/langchaingo/llms"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/cache"
|
||||
"github.com/tmc/langchaingo/llms/cache/inmemory"
|
||||
)
|
||||
|
||||
// LangchainTemperature returns a langchain compliant LLM temperature
|
||||
func LangchainTemperature(temperature float64) llms.CallOption {
|
||||
return llms.WithTemperature(temperature)
|
||||
}
|
||||
|
||||
// CacheModel creates a prompt query cache with a configured TTL
|
||||
func CacheModel(ctx context.Context, ttl string, model llms.Model) (llms.Model, error) {
|
||||
d, err := time.ParseDuration(ttl)
|
||||
if err != nil {
|
||||
return model, fmt.Errorf("failed to parse cacheTTL duration: %s", err)
|
||||
}
|
||||
|
||||
mem, err := inmemory.New(ctx, inmemory.WithExpiration(d))
|
||||
if err != nil {
|
||||
return model, fmt.Errorf("failed to create llm cache: %s", err)
|
||||
}
|
||||
|
||||
return cache.New(model, mem), nil
|
||||
}
|
3
go.mod
3
go.mod
|
@ -102,7 +102,6 @@ require (
|
|||
github.com/rabbitmq/amqp091-go v1.8.1
|
||||
github.com/redis/go-redis/v9 v9.2.1
|
||||
github.com/riferrei/srclient v0.6.0
|
||||
github.com/sashabaranov/go-openai v1.27.1
|
||||
github.com/sendgrid/sendgrid-go v3.13.0+incompatible
|
||||
github.com/sijms/go-ora/v2 v2.7.18
|
||||
github.com/spf13/cast v1.5.1
|
||||
|
@ -155,6 +154,7 @@ require (
|
|||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
|
||||
github.com/Code-Hex/go-generics-cache v1.3.1 // indirect
|
||||
github.com/DataDog/zstd v1.5.2 // indirect
|
||||
github.com/OneOfOne/xxhash v1.2.8 // indirect
|
||||
github.com/RoaringBitmap/roaring v1.1.0 // indirect
|
||||
|
@ -224,6 +224,7 @@ require (
|
|||
github.com/fatih/color v1.17.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gage-technologies/mistral-go v1.0.0 // indirect
|
||||
github.com/gavv/httpexpect v2.0.0+incompatible // indirect
|
||||
github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 // indirect
|
||||
github.com/go-ini/ini v1.67.0 // indirect
|
||||
|
|
6
go.sum
6
go.sum
|
@ -113,6 +113,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
|
|||
github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I=
|
||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/Code-Hex/go-generics-cache v1.3.1 h1:i8rLwyhoyhaerr7JpjtYjJZUcCbWOdiYO3fZXLiEC4g=
|
||||
github.com/Code-Hex/go-generics-cache v1.3.1/go.mod h1:qxcC9kRVrct9rHeiYpFWSoW1vxyillCVzX13KZG8dl4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
|
||||
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
||||
|
@ -577,6 +579,8 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4
|
|||
github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A=
|
||||
github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
|
||||
github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8=
|
||||
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
|
||||
github.com/getkin/kin-openapi v0.94.0/go.mod h1:LWZfzOd7PRy8GJ1dJ6mCU6tNdSfOwRac1BUPam4aw6Q=
|
||||
|
@ -1438,8 +1442,6 @@ github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIH
|
|||
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
|
||||
github.com/santhosh-tekuri/jsonschema/v5 v5.0.0 h1:TToq11gyfNlrMFZiYujSekIsPd9AmsA2Bj/iv+s4JHE=
|
||||
github.com/santhosh-tekuri/jsonschema/v5 v5.0.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0=
|
||||
github.com/sashabaranov/go-openai v1.27.1 h1:7Nx6db5NXbcoutNmAUQulEQZEpHG/SkzfexP2X5RWMk=
|
||||
github.com/sashabaranov/go-openai v1.27.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I=
|
||||
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
|
||||
|
|
Loading…
Reference in New Issue