374 lines
10 KiB
Go
374 lines
10 KiB
Go
/*
|
|
Copyright 2023 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
|
|
|
"github.com/dapr/components-contrib/bindings"
|
|
azauth "github.com/dapr/components-contrib/internal/authentication/azure"
|
|
"github.com/dapr/components-contrib/metadata"
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
// List of operations.
|
|
const (
|
|
CompletionOperation bindings.OperationKind = "completion"
|
|
ChatCompletionOperation bindings.OperationKind = "chat-completion"
|
|
GetEmbeddingOperation bindings.OperationKind = "get-embedding"
|
|
|
|
APIKey = "apiKey"
|
|
DeploymentID = "deploymentID"
|
|
Endpoint = "endpoint"
|
|
MessagesKey = "messages"
|
|
Temperature = "temperature"
|
|
MaxTokens = "maxTokens"
|
|
TopP = "topP"
|
|
N = "n"
|
|
Stop = "stop"
|
|
FrequencyPenalty = "frequencyPenalty"
|
|
LogitBias = "logitBias"
|
|
User = "user"
|
|
)
|
|
|
|
// AzOpenAI represents OpenAI output binding.
|
|
type AzOpenAI struct {
|
|
logger logger.Logger
|
|
client *azopenai.Client
|
|
}
|
|
|
|
type openAIMetadata struct {
|
|
// APIKey is the API key for the Azure OpenAI API.
|
|
APIKey string `mapstructure:"apiKey"`
|
|
// Endpoint is the endpoint for the Azure OpenAI API.
|
|
Endpoint string `mapstructure:"endpoint"`
|
|
}
|
|
|
|
// ChatMessages type for chat completion API.
|
|
type ChatMessages struct {
|
|
DeploymentID string `json:"deploymentID"`
|
|
Messages []Message `json:"messages"`
|
|
Temperature float32 `json:"temperature"`
|
|
MaxTokens int32 `json:"maxTokens"`
|
|
TopP float32 `json:"topP"`
|
|
N int32 `json:"n"`
|
|
PresencePenalty float32 `json:"presencePenalty"`
|
|
FrequencyPenalty float32 `json:"frequencyPenalty"`
|
|
Stop []string `json:"stop"`
|
|
}
|
|
|
|
// Message type stores the messages for bot conversation.
|
|
type Message struct {
|
|
Role string
|
|
Message string
|
|
}
|
|
|
|
// Prompt type for completion API.
|
|
type Prompt struct {
|
|
DeploymentID string `json:"deploymentID"`
|
|
Prompt string `json:"prompt"`
|
|
Temperature float32 `json:"temperature"`
|
|
MaxTokens int32 `json:"maxTokens"`
|
|
TopP float32 `json:"topP"`
|
|
N int32 `json:"n"`
|
|
PresencePenalty float32 `json:"presencePenalty"`
|
|
FrequencyPenalty float32 `json:"frequencyPenalty"`
|
|
Stop []string `json:"stop"`
|
|
}
|
|
|
|
type EmbeddingMessage struct {
|
|
DeploymentID string `json:"deploymentID"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
// NewOpenAI returns a new OpenAI output binding.
|
|
func NewOpenAI(logger logger.Logger) bindings.OutputBinding {
|
|
return &AzOpenAI{
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Init initializes the OpenAI binding.
|
|
func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {
|
|
m := openAIMetadata{}
|
|
err := metadata.DecodeMetadata(meta.Properties, &m)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding metadata: %w", err)
|
|
}
|
|
if m.Endpoint == "" {
|
|
return fmt.Errorf("required metadata not set: %s", Endpoint)
|
|
}
|
|
|
|
if m.APIKey != "" {
|
|
// use API key authentication
|
|
var keyCredential azopenai.KeyCredential
|
|
keyCredential, err = azopenai.NewKeyCredential(m.APIKey)
|
|
if err != nil {
|
|
return fmt.Errorf("error getting credentials object: %w", err)
|
|
}
|
|
|
|
p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
|
|
}
|
|
} else {
|
|
// fallback to Azure AD authentication
|
|
settings, innerErr := azauth.NewEnvironmentSettings(meta.Properties)
|
|
if innerErr != nil {
|
|
return fmt.Errorf("error creating environment settings: %w", innerErr)
|
|
}
|
|
|
|
token, innerErr := settings.GetTokenCredential()
|
|
if innerErr != nil {
|
|
return fmt.Errorf("error getting token credential: %w", innerErr)
|
|
}
|
|
|
|
p.client, err = azopenai.NewClient(m.Endpoint, token, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating Azure OpenAI client: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Operations returns list of operations supported by OpenAI binding.
|
|
func (p *AzOpenAI) Operations() []bindings.OperationKind {
|
|
return []bindings.OperationKind{
|
|
ChatCompletionOperation,
|
|
CompletionOperation,
|
|
GetEmbeddingOperation,
|
|
}
|
|
}
|
|
|
|
// Invoke handles all invoke operations.
|
|
func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (resp *bindings.InvokeResponse, err error) {
|
|
if req == nil || len(req.Metadata) == 0 {
|
|
return nil, fmt.Errorf("invalid request: metadata is required")
|
|
}
|
|
|
|
startTime := time.Now().UTC()
|
|
resp = &bindings.InvokeResponse{
|
|
Metadata: map[string]string{
|
|
"operation": string(req.Operation),
|
|
"start-time": startTime.Format(time.RFC3339Nano),
|
|
},
|
|
}
|
|
|
|
switch req.Operation { //nolint:exhaustive
|
|
case CompletionOperation:
|
|
response, err := p.completion(ctx, req.Data, req.Metadata)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error performing completion: %w", err)
|
|
}
|
|
responseAsBytes, _ := json.Marshal(response)
|
|
resp.Data = responseAsBytes
|
|
|
|
case ChatCompletionOperation:
|
|
response, err := p.chatCompletion(ctx, req.Data, req.Metadata)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error performing chat completion: %w", err)
|
|
}
|
|
responseAsBytes, _ := json.Marshal(response)
|
|
resp.Data = responseAsBytes
|
|
|
|
case GetEmbeddingOperation:
|
|
response, err := p.getEmbedding(ctx, req.Data, req.Metadata)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error performing get embedding operation: %w", err)
|
|
}
|
|
responseAsBytes, _ := json.Marshal(response)
|
|
resp.Data = responseAsBytes
|
|
|
|
default:
|
|
return nil, fmt.Errorf(
|
|
"invalid operation type: %s. Expected %s, %s",
|
|
req.Operation, CompletionOperation, ChatCompletionOperation,
|
|
)
|
|
}
|
|
|
|
endTime := time.Now().UTC()
|
|
resp.Metadata["end-time"] = endTime.Format(time.RFC3339Nano)
|
|
resp.Metadata["duration"] = endTime.Sub(startTime).String()
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[string]string) (response []azopenai.Choice, err error) {
|
|
prompt := Prompt{
|
|
Temperature: 1.0,
|
|
TopP: 1.0,
|
|
MaxTokens: 16,
|
|
N: 1,
|
|
PresencePenalty: 0.0,
|
|
FrequencyPenalty: 0.0,
|
|
}
|
|
err = json.Unmarshal(message, &prompt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling the input object: %w", err)
|
|
}
|
|
|
|
if prompt.Prompt == "" {
|
|
return nil, fmt.Errorf("prompt is required for completion operation")
|
|
}
|
|
|
|
if prompt.DeploymentID == "" {
|
|
return nil, fmt.Errorf("required metadata not set: %s", DeploymentID)
|
|
}
|
|
|
|
if len(prompt.Stop) == 0 {
|
|
prompt.Stop = nil
|
|
}
|
|
|
|
resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
|
|
Deployment: prompt.DeploymentID,
|
|
Prompt: []string{prompt.Prompt},
|
|
MaxTokens: &prompt.MaxTokens,
|
|
Temperature: &prompt.Temperature,
|
|
TopP: &prompt.TopP,
|
|
N: &prompt.N,
|
|
Stop: prompt.Stop,
|
|
}, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting completion api: %w", err)
|
|
}
|
|
|
|
// No choices returned
|
|
if len(resp.Completions.Choices) == 0 {
|
|
return []azopenai.Choice{}, nil
|
|
}
|
|
|
|
choices := resp.Completions.Choices
|
|
response = make([]azopenai.Choice, len(choices))
|
|
for i, c := range choices {
|
|
response[i] = c
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, metadata map[string]string) (response []azopenai.ChatChoice, err error) {
|
|
messages := ChatMessages{
|
|
Temperature: 1.0,
|
|
TopP: 1.0,
|
|
N: 1,
|
|
PresencePenalty: 0.0,
|
|
FrequencyPenalty: 0.0,
|
|
}
|
|
err = json.Unmarshal(messageRequest, &messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling the input object: %w", err)
|
|
}
|
|
|
|
if len(messages.Messages) == 0 {
|
|
return nil, fmt.Errorf("messages are required for chat-completion operation")
|
|
}
|
|
|
|
if messages.DeploymentID == "" {
|
|
return nil, fmt.Errorf("required metadata not set: %s", DeploymentID)
|
|
}
|
|
|
|
if len(messages.Stop) == 0 {
|
|
messages.Stop = nil
|
|
}
|
|
|
|
messageReq := make([]azopenai.ChatMessage, len(messages.Messages))
|
|
for i, m := range messages.Messages {
|
|
messageReq[i] = azopenai.ChatMessage{
|
|
Role: to.Ptr(azopenai.ChatRole(m.Role)),
|
|
Content: to.Ptr(m.Message),
|
|
}
|
|
}
|
|
|
|
var maxTokens *int32
|
|
if messages.MaxTokens != 0 {
|
|
maxTokens = &messages.MaxTokens
|
|
}
|
|
|
|
res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
|
|
Deployment: messages.DeploymentID,
|
|
MaxTokens: maxTokens,
|
|
Temperature: &messages.Temperature,
|
|
TopP: &messages.TopP,
|
|
N: &messages.N,
|
|
Messages: messageReq,
|
|
Stop: messages.Stop,
|
|
}, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting chat completion api: %w", err)
|
|
}
|
|
|
|
// No choices returned.
|
|
if len(res.ChatCompletions.Choices) == 0 {
|
|
return []azopenai.ChatChoice{}, nil
|
|
}
|
|
|
|
choices := res.ChatCompletions.Choices
|
|
response = make([]azopenai.ChatChoice, len(choices))
|
|
for i, c := range choices {
|
|
response[i] = c
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func (p *AzOpenAI) getEmbedding(ctx context.Context, messageRequest []byte, metadata map[string]string) (response []float32, err error) {
|
|
message := EmbeddingMessage{}
|
|
err = json.Unmarshal(messageRequest, &message)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling the input object: %w", err)
|
|
}
|
|
|
|
if message.DeploymentID == "" {
|
|
return nil, fmt.Errorf("required metadata not set: %s", DeploymentID)
|
|
}
|
|
|
|
res, err := p.client.GetEmbeddings(ctx, azopenai.EmbeddingsOptions{
|
|
Deployment: message.DeploymentID,
|
|
Input: []string{message.Message},
|
|
}, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting embedding api: %w", err)
|
|
}
|
|
|
|
// No embedding returned.
|
|
if len(res.Data) == 0 {
|
|
return []float32{}, nil
|
|
}
|
|
|
|
response = res.Data[0].Embedding
|
|
return response, nil
|
|
}
|
|
|
|
// Close Az OpenAI instance.
|
|
func (p *AzOpenAI) Close() error {
|
|
p.client = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetComponentMetadata returns the metadata of the component.
|
|
func (p *AzOpenAI) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
|
|
metadataStruct := openAIMetadata{}
|
|
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.BindingType)
|
|
return
|
|
}
|