mirror of https://github.com/dapr/dapr-agents.git
296 lines
10 KiB
Python
296 lines
10 KiB
Python
from dapr_agents.types.message import (
|
|
BaseMessage,
|
|
SystemMessage,
|
|
UserMessage,
|
|
AssistantMessage,
|
|
ToolMessage,
|
|
)
|
|
from dapr_agents.prompt.utils.jinja import (
|
|
render_jinja_template,
|
|
extract_jinja_variables,
|
|
)
|
|
from dapr_agents.prompt.utils.fstring import (
|
|
render_fstring_template,
|
|
extract_fstring_variables,
|
|
)
|
|
from typing import Any, Dict, List, Tuple, Union, Optional
|
|
import re
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_FORMATTER_MAPPING = {
|
|
"f-string": render_fstring_template,
|
|
"jinja2": render_jinja_template,
|
|
}
|
|
|
|
DEFAULT_VARIABLE_EXTRACTOR_MAPPING = {
|
|
"f-string": extract_fstring_variables,
|
|
"jinja2": extract_jinja_variables,
|
|
}
|
|
|
|
|
|
class ChatPromptHelper:
|
|
"""
|
|
Utility class for handling various operations on chat prompt messages, such as
|
|
formatting, normalizing, and extracting variables.
|
|
|
|
Attributes:
|
|
_ROLE_MAP (Dict[str, Type[BaseMessage]]): A mapping of role names to message classes.
|
|
"""
|
|
|
|
_ROLE_MAP = {
|
|
"system": SystemMessage,
|
|
"user": UserMessage,
|
|
"assistant": AssistantMessage,
|
|
"tool": ToolMessage,
|
|
}
|
|
|
|
@classmethod
|
|
def normalize_chat_messages(cls, variable_value: Any) -> List[BaseMessage]:
|
|
"""
|
|
Normalize the variable value into a list of BaseMessages, handling strings, dictionaries, and lists.
|
|
|
|
Args:
|
|
variable_value (Any): The value associated with a placeholder variable to normalize.
|
|
|
|
Returns:
|
|
List[BaseMessage]: A list of normalized BaseMessage instances.
|
|
|
|
Raises:
|
|
ValueError: If an unsupported type is encountered within the list or variable.
|
|
"""
|
|
normalized_messages = []
|
|
|
|
def validate_and_create_message(
|
|
role: str, content: str, message_data: dict
|
|
) -> BaseMessage:
|
|
if role not in cls._ROLE_MAP:
|
|
raise ValueError(
|
|
f"Unrecognized role '{role}' in message: {message_data}"
|
|
)
|
|
return cls.create_message(role, content, message_data)
|
|
|
|
if isinstance(variable_value, str):
|
|
normalized_messages.append(cls.create_message("user", variable_value, {}))
|
|
elif isinstance(variable_value, list):
|
|
for item in variable_value:
|
|
if isinstance(item, str):
|
|
normalized_messages.append(cls.create_message("user", item, {}))
|
|
elif isinstance(item, BaseMessage):
|
|
normalized_messages.append(item)
|
|
elif isinstance(item, dict):
|
|
role = item.get("role", "user")
|
|
content = item.get("content", "")
|
|
normalized_messages.append(
|
|
validate_and_create_message(role, content, item)
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported type in list for variable: {type(item)}"
|
|
)
|
|
elif isinstance(variable_value, dict):
|
|
role = variable_value.get("role", "user")
|
|
content = variable_value.get("content", "")
|
|
normalized_messages.append(
|
|
validate_and_create_message(role, content, variable_value)
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported type for variable: {type(variable_value)}")
|
|
|
|
return normalized_messages
|
|
|
|
@classmethod
|
|
def format_message(
|
|
cls,
|
|
message: Union[Tuple[str, str], Dict[str, Any], BaseMessage],
|
|
template_format: str,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
"""
|
|
Format a single message by replacing template variables based on the specified format.
|
|
|
|
Args:
|
|
message (Union[Tuple[str, str], Dict[str, Any], BaseMessage]): The message to format.
|
|
template_format (str): The format for rendering ('f-string' or 'jinja2').
|
|
**kwargs: Variables used to populate placeholders within the message.
|
|
|
|
Returns:
|
|
BaseMessage: The message with variables replaced as per the template format.
|
|
"""
|
|
role, content = cls.extract_role_and_content(message)
|
|
content = cls.format_content(content, template_format=template_format, **kwargs)
|
|
return cls.create_message(role, content, message)
|
|
|
|
@staticmethod
|
|
def format_content(content: str, template_format: str, **kwargs: Any) -> str:
|
|
"""
|
|
Apply template formatting to the content string using the specified format.
|
|
|
|
Args:
|
|
content (str): The content string to format.
|
|
template_format (str): Template format ('f-string' or 'jinja2').
|
|
**kwargs: Variables for populating placeholders within the content.
|
|
|
|
Returns:
|
|
str: The formatted content.
|
|
"""
|
|
formatter = DEFAULT_FORMATTER_MAPPING.get(template_format)
|
|
if not formatter:
|
|
raise ValueError(f"Unsupported template format: {template_format}")
|
|
return formatter(content, **kwargs)
|
|
|
|
@classmethod
|
|
def extract_role_and_content(
|
|
cls, message: Union[Tuple[str, str], Dict[str, Any], BaseMessage]
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Extract role and content from a message.
|
|
|
|
Args:
|
|
message (Union[Tuple[str, str], Dict[str, Any], BaseMessage]): A message object in the form of a tuple, dictionary, or BaseMessage.
|
|
|
|
Returns:
|
|
Tuple[str, str]: Extracted role and content.
|
|
|
|
Raises:
|
|
ValueError: If the message is not in a supported format.
|
|
"""
|
|
if isinstance(message, tuple) and len(message) == 2:
|
|
return message[0], message[1]
|
|
elif isinstance(message, dict):
|
|
return message.get("role"), message.get("content", "")
|
|
elif isinstance(message, BaseMessage):
|
|
return message.role, message.content
|
|
else:
|
|
raise ValueError(
|
|
"Message must be a tuple (role, content), a dict with 'role' and 'content', or a BaseMessage instance."
|
|
)
|
|
|
|
@classmethod
|
|
def create_message(
|
|
cls, role: str, content: str, message_data: Dict[str, Any]
|
|
) -> BaseMessage:
|
|
"""
|
|
Create a BaseMessage instance based on role.
|
|
|
|
Args:
|
|
role (str): Role of the message (system, user, assistant, tool).
|
|
content (str): Message content.
|
|
message_data (Dict[str, Any]): Additional data (e.g., tool_call_id for tool messages).
|
|
|
|
Returns:
|
|
BaseMessage: Formatted message instance.
|
|
|
|
Raises:
|
|
ValueError: If the role is not recognized.
|
|
"""
|
|
if role not in cls._ROLE_MAP:
|
|
raise ValueError(f"Invalid message role: {role}")
|
|
|
|
message_class = cls._ROLE_MAP[role]
|
|
if role == "tool":
|
|
return message_class(
|
|
content=content, tool_call_id=message_data.get("tool_call_id")
|
|
)
|
|
return message_class(content=content)
|
|
|
|
@classmethod
|
|
def get_message_class(cls, role: str) -> Union[BaseMessage, None]:
|
|
"""Get the message class for a given role."""
|
|
role = role.lower()
|
|
return cls._ROLE_MAP.get(role, None)
|
|
|
|
@staticmethod
|
|
def parse_role_content(content: str) -> Tuple[List[str], Optional[str]]:
|
|
"""
|
|
Parse the formatted content into role-based chunks and any remaining plain text.
|
|
- `role_chunks`: List of chunks containing roles and their messages.
|
|
- `plain_text`: Text that does not match any role pattern.
|
|
|
|
Returns:
|
|
Tuple[List[str], Optional[str]]: Role-based chunks and any remaining plain text.
|
|
"""
|
|
# Regex to split on role definitions, capturing role-based content.
|
|
role_pattern = (
|
|
r"(?i)^\s*#?\s*("
|
|
+ "|".join(ChatPromptHelper._ROLE_MAP.keys())
|
|
+ r")\s*:\s*\n"
|
|
)
|
|
|
|
# First, check for role-based patterns. If none, return the entire content as plain_text.
|
|
if not re.search(role_pattern, content, flags=re.MULTILINE):
|
|
return [], content.strip()
|
|
|
|
# Split the content on the role pattern.
|
|
chunks = re.split(role_pattern, content, flags=re.MULTILINE)
|
|
|
|
# Extract plain text that is not part of role-based content, if any.
|
|
plain_text = None
|
|
if (
|
|
chunks
|
|
and chunks[0].strip()
|
|
and chunks[0].lower() not in ChatPromptHelper._ROLE_MAP
|
|
):
|
|
plain_text = chunks.pop(0).strip()
|
|
|
|
# Filter out empty or whitespace-only chunks from role-based content.
|
|
role_chunks = [chunk.strip() for chunk in chunks if chunk.strip()]
|
|
|
|
return role_chunks, plain_text
|
|
|
|
@staticmethod
|
|
def to_message(role: str, content: str) -> BaseMessage:
|
|
"""
|
|
Parse a single chunk of content into a message object.
|
|
"""
|
|
role = role.strip().lower()
|
|
content = content.strip()
|
|
|
|
logger.debug(f"Parsing role: '{role}', content: '{content[:30]}...'")
|
|
|
|
# Map role to a message class
|
|
message_class = ChatPromptHelper.get_message_class(role)
|
|
if not message_class:
|
|
raise ValueError(f"Invalid message role: '{role}'")
|
|
|
|
if not content:
|
|
raise ValueError(f"Content missing for role: '{role}'")
|
|
|
|
return message_class(content=content)
|
|
|
|
@staticmethod
|
|
def parse_as_messages(
|
|
content: str,
|
|
) -> Tuple[Optional[List[BaseMessage]], Optional[str]]:
|
|
"""
|
|
Parse the content into a list of role-based messages and any unstructured plain text.
|
|
|
|
Returns:
|
|
Tuple[List[BaseMessage], Optional[str]]: Parsed messages if role-based chunks are found,
|
|
and any remaining plain text if detected.
|
|
"""
|
|
role_chunks, plain_text = ChatPromptHelper.parse_role_content(content)
|
|
|
|
# If there are no role-based chunks, return the plain_text directly.
|
|
if not role_chunks:
|
|
logger.debug("No role-based content found; returning plain text.")
|
|
return [], plain_text
|
|
|
|
# Parse each role-based chunk to extract messages.
|
|
messages = []
|
|
role = None
|
|
|
|
for chunk in role_chunks:
|
|
if chunk.lower() in ChatPromptHelper._ROLE_MAP.keys():
|
|
role = chunk # Assign role if chunk matches a defined role.
|
|
elif (
|
|
role
|
|
): # If a role is set, treat this chunk as the content for that role.
|
|
messages.append(ChatPromptHelper.to_message(role, chunk))
|
|
role = None
|
|
else:
|
|
raise ValueError(f"Unexpected content without a role: {chunk}")
|
|
|
|
return messages, plain_text
|