mirror of https://github.com/dapr/dapr-agents.git
186 lines
7.4 KiB
Python
186 lines
7.4 KiB
Python
from dapr_agents.storage.daprstores.statestore import DaprStateStore
|
|
from dapr_agents.types import BaseMessage
|
|
from dapr_agents.memory import MemoryBase
|
|
from typing import List, Union, Optional, Dict, Any
|
|
from pydantic import Field, model_validator
|
|
from datetime import datetime
|
|
import json
|
|
import uuid
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def generate_numeric_session_id() -> int:
|
|
"""
|
|
Generates a random numeric session ID by extracting digits from a UUID.
|
|
|
|
Returns:
|
|
int: A numeric session ID.
|
|
"""
|
|
return int(''.join(filter(str.isdigit, str(uuid.uuid4()))))
|
|
|
|
class ConversationDaprStateMemory(MemoryBase):
|
|
"""
|
|
Manages conversation memory stored in a Dapr state store. Each message in the conversation is saved
|
|
individually with a unique key and includes a session ID and timestamp for querying and retrieval.
|
|
"""
|
|
|
|
store_name: str = Field(default="statestore", description="The name of the Dapr state store.")
|
|
session_id: Optional[Union[str, int]] = Field(default=None, description="Unique identifier for the conversation session.")
|
|
address: Optional[str] = Field(default=None, description="The full address of the Dapr sidecar (host:port).")
|
|
host: Optional[str] = Field(default=None, description="The host of the Dapr sidecar.")
|
|
port: Optional[str] = Field(default=None, description="The port of the Dapr sidecar.")
|
|
query_index_name: Optional[str] = Field(default=None, description="The index name for querying state.")
|
|
|
|
# Private attribute to hold the initialized DaprStateStore
|
|
dapr_store: Optional[DaprStateStore] = Field(default=None, init=False, description="Dapr State Store.")
|
|
|
|
@model_validator(mode="before")
|
|
def set_session_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Sets a numeric session ID if none is provided.
|
|
|
|
Args:
|
|
values (Dict[str, Any]): The dictionary of attribute values before initialization.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Updated values including the generated session ID if not provided.
|
|
"""
|
|
if not values.get("session_id"):
|
|
values["session_id"] = generate_numeric_session_id()
|
|
return values
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""
|
|
Initializes the Dapr state store after validation, allowing optional host and port configuration.
|
|
"""
|
|
self.dapr_store = DaprStateStore(
|
|
store_name=self.store_name,
|
|
address=self.address,
|
|
host=self.host,
|
|
port=self.port
|
|
)
|
|
logger.info(f"ConversationDaprStateMemory initialized with session ID: {self.session_id}")
|
|
|
|
# Complete post-initialization
|
|
super().model_post_init(__context)
|
|
|
|
def _get_message_key(self, message_id: str) -> str:
|
|
"""
|
|
Generates a unique key for each message using session_id and message_id.
|
|
|
|
Args:
|
|
message_id (str): A unique identifier for the message.
|
|
|
|
Returns:
|
|
str: A composite key for storing individual messages.
|
|
"""
|
|
return f"{self.session_id}:{message_id}"
|
|
|
|
def add_message(self, message: Union[Dict, BaseMessage]):
|
|
"""
|
|
Adds a single message to the memory and saves it to the Dapr state store.
|
|
|
|
Args:
|
|
message (Union[Dict, BaseMessage]): The message to add to the memory.
|
|
"""
|
|
if isinstance(message, BaseMessage):
|
|
message = message.model_dump()
|
|
|
|
message_id = str(uuid.uuid4())
|
|
message_key = self._get_message_key(message_id)
|
|
message.update({
|
|
"sessionId": self.session_id,
|
|
"createdAt": datetime.now().isoformat() + "Z"
|
|
})
|
|
|
|
logger.info(f"Adding message with key {message_key} to session {self.session_id}")
|
|
self.dapr_store.save_state(message_key, json.dumps(message), {"contentType": "application/json"})
|
|
|
|
def add_messages(self, messages: List[Union[Dict, BaseMessage]]):
|
|
"""
|
|
Adds multiple messages to the memory and saves each one individually to the Dapr state store.
|
|
|
|
Args:
|
|
messages (List[Union[Dict, BaseMessage]]): A list of messages to add to the memory.
|
|
"""
|
|
logger.info(f"Adding {len(messages)} messages to session {self.session_id}")
|
|
for message in messages:
|
|
if isinstance(message, BaseMessage):
|
|
message = message.model_dump()
|
|
self.add_message(message)
|
|
|
|
def add_interaction(self, user_message: BaseMessage, assistant_message: BaseMessage):
|
|
"""
|
|
Adds a user-assistant interaction to the memory storage and saves it to the state store.
|
|
|
|
Args:
|
|
user_message (BaseMessage): The user message.
|
|
assistant_message (BaseMessage): The assistant message.
|
|
"""
|
|
self.add_messages([user_message, assistant_message])
|
|
|
|
def _decode_message(self, message_data: Union[bytes, str]) -> dict:
|
|
"""
|
|
Decodes the message data if it's in bytes, otherwise parses it as a JSON string.
|
|
|
|
Args:
|
|
message_data (Union[bytes, str]): The message data to decode.
|
|
|
|
Returns:
|
|
dict: The decoded message as a dictionary.
|
|
"""
|
|
if isinstance(message_data, bytes):
|
|
message_data = message_data.decode("utf-8")
|
|
return json.loads(message_data)
|
|
|
|
def get_messages(self, limit: int = 100) -> List[Dict[str, str]]:
|
|
"""
|
|
Retrieves messages stored in the state store for the current session_id, with an optional limit.
|
|
|
|
Args:
|
|
limit (int): The maximum number of messages to retrieve. Defaults to 100.
|
|
|
|
Returns:
|
|
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
|
|
"""
|
|
query = json.dumps({
|
|
"filter": {"EQ": {"sessionId": self.session_id}},
|
|
"page": {"limit": limit}
|
|
})
|
|
query_response = self.query_messages(query=query)
|
|
messages = [{"content": msg.get("content"), "role": msg.get("role")}
|
|
for msg in (self._decode_message(result.value) for result in query_response.results)]
|
|
|
|
logger.info(f"Retrieved {len(messages)} messages for session {self.session_id}")
|
|
return messages
|
|
|
|
def query_messages(self, query: Optional[str] = json.dumps({})) -> List[Dict[str, str]]:
|
|
"""
|
|
Queries messages from the state store based on a pre-constructed query string.
|
|
|
|
Args:
|
|
query (Optional[str]): A JSON-formatted query string to be executed.
|
|
|
|
Returns:
|
|
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
|
|
"""
|
|
logger.debug(f"Executing query for session {self.session_id}: {query}")
|
|
states_metadata = {"contentType": "application/json"}
|
|
if self.query_index_name:
|
|
states_metadata["queryIndexName"] = self.query_index_name
|
|
|
|
response = self.dapr_store.query_state(query=query, states_metadata=states_metadata)
|
|
return response
|
|
|
|
def reset_memory(self):
|
|
"""
|
|
Clears all messages stored in the memory and resets the state store for the current session.
|
|
"""
|
|
query_response = self.query_messages()
|
|
keys = [result.key for result in query_response.results]
|
|
for key in keys:
|
|
self.dapr_store.delete_state(key)
|
|
logger.debug(f"Deleted state with key: {key}")
|
|
|
|
logger.info(f"Memory reset for session {self.session_id} completed. Deleted {len(keys)} messages.") |