Fix lint failure for openai binding (#3000)
Signed-off-by: Shivam Kumar Singh <shivamhere247@gmail.com> Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
608e4cb8a9
commit
9957d6969d
|
@ -26,7 +26,6 @@ import (
|
||||||
"github.com/dapr/components-contrib/bindings"
|
"github.com/dapr/components-contrib/bindings"
|
||||||
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
||||||
"github.com/dapr/components-contrib/metadata"
|
"github.com/dapr/components-contrib/metadata"
|
||||||
"github.com/dapr/kit/config"
|
|
||||||
"github.com/dapr/kit/logger"
|
"github.com/dapr/kit/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,6 +52,7 @@ const (
|
||||||
type AzOpenAI struct {
|
type AzOpenAI struct {
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
client *azopenai.Client
|
client *azopenai.Client
|
||||||
|
deploymentID string
|
||||||
}
|
}
|
||||||
|
|
||||||
type openAIMetadata struct {
|
type openAIMetadata struct {
|
||||||
|
@ -64,15 +64,6 @@ type openAIMetadata struct {
|
||||||
Endpoint string `mapstructure:"endpoint"`
|
Endpoint string `mapstructure:"endpoint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatSettings struct {
|
|
||||||
Temperature float32 `mapstructure:"temperature"`
|
|
||||||
MaxTokens int32 `mapstructure:"maxTokens"`
|
|
||||||
TopP float32 `mapstructure:"topP"`
|
|
||||||
N int32 `mapstructure:"n"`
|
|
||||||
PresencePenalty float32 `mapstructure:"presencePenalty"`
|
|
||||||
FrequencyPenalty float32 `mapstructure:"frequencyPenalty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatMessages type for chat completion API.
|
// ChatMessages type for chat completion API.
|
||||||
type ChatMessages struct {
|
type ChatMessages struct {
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
|
@ -82,6 +73,7 @@ type ChatMessages struct {
|
||||||
N int32 `json:"n"`
|
N int32 `json:"n"`
|
||||||
PresencePenalty float32 `json:"presencePenalty"`
|
PresencePenalty float32 `json:"presencePenalty"`
|
||||||
FrequencyPenalty float32 `json:"frequencyPenalty"`
|
FrequencyPenalty float32 `json:"frequencyPenalty"`
|
||||||
|
Stop []string `json:"stop"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message type stores the messages for bot conversation.
|
// Message type stores the messages for bot conversation.
|
||||||
|
@ -99,6 +91,7 @@ type Prompt struct {
|
||||||
N int32 `json:"n"`
|
N int32 `json:"n"`
|
||||||
PresencePenalty float32 `json:"presencePenalty"`
|
PresencePenalty float32 `json:"presencePenalty"`
|
||||||
FrequencyPenalty float32 `json:"frequencyPenalty"`
|
FrequencyPenalty float32 `json:"frequencyPenalty"`
|
||||||
|
Stop []string `json:"stop"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAI returns a new OpenAI output binding.
|
// NewOpenAI returns a new OpenAI output binding.
|
||||||
|
@ -130,7 +123,7 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
|
||||||
return fmt.Errorf("error getting credentials object: %w", err)
|
return fmt.Errorf("error getting credentials object: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, m.DeploymentID, nil)
|
p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
|
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -146,11 +139,12 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
|
||||||
return fmt.Errorf("error getting token credential: %w", innerErr)
|
return fmt.Errorf("error getting token credential: %w", innerErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.client, err = azopenai.NewClient(m.Endpoint, token, m.DeploymentID, nil)
|
p.client, err = azopenai.NewClient(m.Endpoint, token, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
|
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
p.deploymentID = m.DeploymentID
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -208,10 +202,6 @@ func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (res
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ChatSettings) Decode(in any) error {
|
|
||||||
return config.Decode(in, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) {
|
func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) {
|
||||||
prompt := Prompt{
|
prompt := Prompt{
|
||||||
Temperature: 1.0,
|
Temperature: 1.0,
|
||||||
|
@ -230,12 +220,18 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
|
||||||
return nil, fmt.Errorf("prompt is required for completion operation")
|
return nil, fmt.Errorf("prompt is required for completion operation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(prompt.Stop) == 0 {
|
||||||
|
prompt.Stop = nil
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
|
resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
|
||||||
Prompt: []*string{&prompt.Prompt},
|
DeploymentID: p.deploymentID,
|
||||||
|
Prompt: []string{prompt.Prompt},
|
||||||
MaxTokens: &prompt.MaxTokens,
|
MaxTokens: &prompt.MaxTokens,
|
||||||
Temperature: &prompt.Temperature,
|
Temperature: &prompt.Temperature,
|
||||||
TopP: &prompt.TopP,
|
TopP: &prompt.TopP,
|
||||||
N: &prompt.N,
|
N: &prompt.N,
|
||||||
|
Stop: prompt.Stop,
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting completion api: %w", err)
|
return nil, fmt.Errorf("error getting completion api: %w", err)
|
||||||
|
@ -249,7 +245,7 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
|
||||||
choices := resp.Completions.Choices
|
choices := resp.Completions.Choices
|
||||||
response = make([]azopenai.Choice, len(choices))
|
response = make([]azopenai.Choice, len(choices))
|
||||||
for i, c := range choices {
|
for i, c := range choices {
|
||||||
response[i] = *c
|
response[i] = c
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
|
@ -272,9 +268,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
|
||||||
return nil, fmt.Errorf("messages are required for chat-completion operation")
|
return nil, fmt.Errorf("messages are required for chat-completion operation")
|
||||||
}
|
}
|
||||||
|
|
||||||
messageReq := make([]*azopenai.ChatMessage, len(messages.Messages))
|
if len(messages.Stop) == 0 {
|
||||||
|
messages.Stop = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
messageReq := make([]azopenai.ChatMessage, len(messages.Messages))
|
||||||
for i, m := range messages.Messages {
|
for i, m := range messages.Messages {
|
||||||
messageReq[i] = &azopenai.ChatMessage{
|
messageReq[i] = azopenai.ChatMessage{
|
||||||
Role: to.Ptr(azopenai.ChatRole(m.Role)),
|
Role: to.Ptr(azopenai.ChatRole(m.Role)),
|
||||||
Content: to.Ptr(m.Message),
|
Content: to.Ptr(m.Message),
|
||||||
}
|
}
|
||||||
|
@ -286,11 +286,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
|
res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
|
||||||
|
DeploymentID: p.deploymentID,
|
||||||
MaxTokens: maxTokens,
|
MaxTokens: maxTokens,
|
||||||
Temperature: &messages.Temperature,
|
Temperature: &messages.Temperature,
|
||||||
TopP: &messages.TopP,
|
TopP: &messages.TopP,
|
||||||
N: &messages.N,
|
N: &messages.N,
|
||||||
Messages: messageReq,
|
Messages: messageReq,
|
||||||
|
Stop: messages.Stop,
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting chat completion api: %w", err)
|
return nil, fmt.Errorf("error getting chat completion api: %w", err)
|
||||||
|
@ -304,7 +306,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
|
||||||
choices := res.ChatCompletions.Choices
|
choices := res.ChatCompletions.Choices
|
||||||
response = make([]azopenai.ChatChoice, len(choices))
|
response = make([]azopenai.ChatChoice, len(choices))
|
||||||
for i, c := range choices {
|
for i, c := range choices {
|
||||||
response[i] = *c
|
response[i] = c
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -10,7 +10,7 @@ require (
|
||||||
dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
|
dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5
|
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0
|
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5
|
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.0.1
|
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.0.1
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -422,8 +422,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 h1:8q4SaHjFsClSvuVne0ID/5Ka8
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5 h1:DQCZXtoCPuwBMlAa2aC+B3CfpE6xz2xe1jqdqt8nIJY=
|
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0 h1:lkflJSWI6jicmEBImjpliUOWCr1PdJO/GcZj3bWx19Q=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.0.0-20230705184009-934612c4f2b5/go.mod h1:GQSjs1n073tbMa3e76+STZkyFb+NcEA4N7OB5vNvB3E=
|
github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai v0.1.0/go.mod h1:NwVkXm5Ty88Xd7cx6b53fGNeGG3W3ZDXgOXBNHLUy84=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 h1:OrKZybbyagpgJiREiIVzH5mV/z9oS4rXqdX7i31DSF0=
|
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0 h1:OrKZybbyagpgJiREiIVzH5mV/z9oS4rXqdX7i31DSF0=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0/go.mod h1:p74+tP95m8830ypJk53L93+BEsjTKY4SKQ75J2NmS5U=
|
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v0.5.0/go.mod h1:p74+tP95m8830ypJk53L93+BEsjTKY4SKQ75J2NmS5U=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 h1:qS0Bp4do0cIvnuQgSGeO6ZCu/q/HlRKl4NPfv1eJ2p0=
|
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.5 h1:qS0Bp4do0cIvnuQgSGeO6ZCu/q/HlRKl4NPfv1eJ2p0=
|
||||||
|
|
Loading…
Reference in New Issue