dapr-agents/dapr_agents/llm/openai/utils.py

267 lines
8.9 KiB
Python

import dataclasses
import logging
from typing import Any, Callable, Dict, Iterator, Optional
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from dapr_agents.types.message import (
AssistantMessage,
FunctionCall,
LLMChatCandidate,
LLMChatCandidateChunk,
LLMChatResponseChunk,
LLMChatResponse,
ToolCall,
ToolCallChunk,
)
logger = logging.getLogger(__name__)
# Helper function to handle metadata extraction
def _get_packet_metadata(
pkt: Dict[str, Any], enrich_metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Extract metadata from OpenAI packet and merge with enrich_metadata.
Args:
pkt (Dict[str, Any]): The OpenAI packet from which to extract metadata.
enrich_metadata (Optional[Dict[str, Any]]): Additional metadata to merge with the extracted metadata.
Returns:
Dict[str, Any]: The merged metadata dictionary.
"""
try:
return {
"id": pkt.get("id"),
"created": pkt.get("created"),
"model": pkt.get("model"),
"object": pkt.get("object"),
"service_tier": pkt.get("service_tier"),
"system_fingerprint": pkt.get("system_fingerprint"),
**(enrich_metadata or {}),
}
except Exception as e:
logger.error(f"Failed to parse packet: {e}", exc_info=True)
return {}
# Helper function to process each choice delta (content, function call, tool call, finish reason)
def _process_choice_delta(
choice: Dict[str, Any],
overall_meta: Dict[str, Any],
on_chunk: Optional[Callable],
first_chunk_flag: bool,
) -> Iterator[LLMChatResponseChunk]:
"""
Process each choice delta and yield corresponding chunks.
Args:
choice (Dict[str, Any]): The choice delta from OpenAI response.
overall_meta (Dict[str, Any]): Overall metadata to include in chunks.
on_chunk (Optional[Callable]): Callback for each chunk.
first_chunk_flag (bool): Flag indicating if this is the first chunk.
Yields:
LLMChatResponseChunk: The processed chunk with content, function call, tool calls,
"""
# Make an immutable snapshot for this single chunk
meta = {**overall_meta}
# mark first_chunk exactly once
if first_chunk_flag and "first_chunk" not in meta:
meta["first_chunk"] = True
# Extract initial properties from choice
delta: dict = choice.get("delta", {})
idx = choice.get("index")
finish_reason = choice.get("finish_reason", None)
logprobs = choice.get("logprobs", None)
# Set additional metadata
if finish_reason in ("stop", "tool_calls"):
meta["last_chunk"] = True
# Process content delta
content = delta.get("content", None)
function_call = delta.get("function_call", None)
refusal = delta.get("refusal", None)
role = delta.get("role", None)
# Process tool calls
chunk_tool_calls = [ToolCallChunk(**tc) for tc in (delta.get("tool_calls") or [])]
# Initialize LLMChatResponseChunk
response_chunk = LLMChatResponseChunk(
result=LLMChatCandidateChunk(
content=content,
function_call=function_call,
refusal=refusal,
role=role,
tool_calls=chunk_tool_calls,
finish_reason=finish_reason,
index=idx,
logprobs=logprobs,
),
metadata=meta,
)
# Process chunk with on_chunk callback
if on_chunk:
on_chunk(response_chunk)
# Yield LLMChatResponseChunk
yield response_chunk
# Main function to process OpenAI streaming response
def process_openai_stream(
raw_stream: Iterator[ChatCompletionChunk],
*,
enrich_metadata: Optional[Dict[str, Any]] = None,
on_chunk: Optional[Callable],
) -> Iterator[LLMChatCandidateChunk]:
"""
Normalize OpenAI streaming chat into LLMChatCandidateChunk objects,
accumulating buffers per choice and yielding both partial and final chunks.
Args:
raw_stream: Iterator from client.chat.completions.create(..., stream=True)
enrich_metadata: Extra key/value pairs to merge into each chunk.metadata
on_chunk: Callback fired on every partial delta (token, function, tool)
Yields:
LLMChatCandidateChunk for every partial and final piece, in stream order
"""
enrich_metadata = enrich_metadata or {}
overall_meta: Dict[str, Any] = {}
# Track if we are in the first chunk
first_chunk_flag = True
for packet in raw_stream:
# Convert Pydantic / OpenAIObject → plain dict
if hasattr(packet, "model_dump"):
pkt = packet.model_dump()
elif hasattr(packet, "to_dict"):
pkt = packet.to_dict()
elif dataclasses.is_dataclass(packet):
pkt = dataclasses.asdict(packet)
else:
raise TypeError(f"Cannot serialize packet of type {type(packet)}")
# Capture overall metadata from the packet
overall_meta = _get_packet_metadata(pkt, enrich_metadata)
# Process each choice in this packet
if choices := pkt.get("choices"):
if len(choices) == 0:
logger.warning("Received empty 'choices' in OpenAI packet, skipping.")
continue
# Process the first choice in the packet
choice = choices[0]
yield from _process_choice_delta(
choice, overall_meta, on_chunk, first_chunk_flag
)
# Set first_chunk_flag to False after processing the first choice
first_chunk_flag = False
else:
logger.warning(f" Yielding packet without 'choices': {pkt}")
# Initialize default LLMChatResponseChunk
final_response_chunk = LLMChatResponseChunk(metadata=overall_meta)
yield final_response_chunk
def process_openai_chat_response(openai_response: ChatCompletion) -> LLMChatResponse:
"""
Convert an OpenAI ChatCompletion into our unified LLMChatResponse.
This function:
- Safely extracts each choice (skipping malformed ones)
- Builds an AssistantMessage with content/refusal/tool_calls/function_call
- Wraps into LLMChatCandidate (including index & logprobs)
- Collects provider metadata
Args:
openai_response: A Pydantic ChatCompletion from the OpenAI SDK.
Returns:
LLMChatResponse: Contains a list of candidates and a metadata dict.
"""
# 1) Turn into plain dict
try:
if hasattr(openai_response, "model_dump"):
resp = openai_response.model_dump()
elif hasattr(openai_response, "to_dict"):
resp = openai_response.to_dict()
elif dataclasses.is_dataclass(openai_response):
resp = dataclasses.asdict(openai_response)
else:
resp = dict(openai_response)
except Exception:
logger.exception("Failed to serialize OpenAI chat response")
resp = {}
candidates = []
for choice in resp.get("choices", []):
if "message" not in choice:
logger.warning(f"Skipping choice missing 'message': {choice}")
continue
msg = choice["message"]
# 2) Build tool_calls list if present
tool_calls = None
if msg.get("tool_calls"):
tool_calls = []
for tc in msg["tool_calls"]:
try:
tool_calls.append(
ToolCall(
id=tc["id"],
type=tc["type"],
function=FunctionCall(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
),
)
)
except Exception as e:
logger.warning(f"Invalid tool_call entry {tc}: {e}")
# 3) Build function_call if present
function_call = None
if fc := msg.get("function_call"):
function_call = FunctionCall(
name=fc.get("name", ""),
arguments=fc.get("arguments", ""),
)
# 4) Assemble AssistantMessage
assistant_message = AssistantMessage(
content=msg.get("content"),
refusal=msg.get("refusal"),
tool_calls=tool_calls,
function_call=function_call,
)
# 5) Build candidate, including index & logprobs
candidate = LLMChatCandidate(
message=assistant_message,
finish_reason=choice.get("finish_reason"),
index=choice.get("index"),
logprobs=choice.get("logprobs"),
)
candidates.append(candidate)
# 6) Metadata: include provider tag
metadata: Dict[str, Any] = {
"provider": "openai",
"id": resp.get("id"),
"model": resp.get("model"),
"object": resp.get("object"),
"usage": resp.get("usage"),
"created": resp.get("created"),
}
return LLMChatResponse(results=candidates, metadata=metadata)