dapr-agents/dapr_agents/memory/daprstatestore.py

186 lines
7.4 KiB
Python

from dapr_agents.storage.daprstores.statestore import DaprStateStore
from dapr_agents.types import BaseMessage
from dapr_agents.memory import MemoryBase
from typing import List, Union, Optional, Dict, Any
from pydantic import Field, model_validator
from datetime import datetime
import json
import uuid
import logging
logger = logging.getLogger(__name__)
def generate_numeric_session_id() -> int:
"""
Generates a random numeric session ID by extracting digits from a UUID.
Returns:
int: A numeric session ID.
"""
return int(''.join(filter(str.isdigit, str(uuid.uuid4()))))
class ConversationDaprStateMemory(MemoryBase):
"""
Manages conversation memory stored in a Dapr state store. Each message in the conversation is saved
individually with a unique key and includes a session ID and timestamp for querying and retrieval.
"""
store_name: str = Field(default="statestore", description="The name of the Dapr state store.")
session_id: Optional[Union[str, int]] = Field(default=None, description="Unique identifier for the conversation session.")
address: Optional[str] = Field(default=None, description="The full address of the Dapr sidecar (host:port).")
host: Optional[str] = Field(default=None, description="The host of the Dapr sidecar.")
port: Optional[str] = Field(default=None, description="The port of the Dapr sidecar.")
query_index_name: Optional[str] = Field(default=None, description="The index name for querying state.")
# Private attribute to hold the initialized DaprStateStore
dapr_store: Optional[DaprStateStore] = Field(default=None, init=False, description="Dapr State Store.")
@model_validator(mode="before")
def set_session_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Sets a numeric session ID if none is provided.
Args:
values (Dict[str, Any]): The dictionary of attribute values before initialization.
Returns:
Dict[str, Any]: Updated values including the generated session ID if not provided.
"""
if not values.get("session_id"):
values["session_id"] = generate_numeric_session_id()
return values
def model_post_init(self, __context: Any) -> None:
"""
Initializes the Dapr state store after validation, allowing optional host and port configuration.
"""
self.dapr_store = DaprStateStore(
store_name=self.store_name,
address=self.address,
host=self.host,
port=self.port
)
logger.info(f"ConversationDaprStateMemory initialized with session ID: {self.session_id}")
# Complete post-initialization
super().model_post_init(__context)
def _get_message_key(self, message_id: str) -> str:
"""
Generates a unique key for each message using session_id and message_id.
Args:
message_id (str): A unique identifier for the message.
Returns:
str: A composite key for storing individual messages.
"""
return f"{self.session_id}:{message_id}"
def add_message(self, message: Union[Dict, BaseMessage]):
"""
Adds a single message to the memory and saves it to the Dapr state store.
Args:
message (Union[Dict, BaseMessage]): The message to add to the memory.
"""
if isinstance(message, BaseMessage):
message = message.model_dump()
message_id = str(uuid.uuid4())
message_key = self._get_message_key(message_id)
message.update({
"sessionId": self.session_id,
"createdAt": datetime.now().isoformat() + "Z"
})
logger.info(f"Adding message with key {message_key} to session {self.session_id}")
self.dapr_store.save_state(message_key, json.dumps(message), {"contentType": "application/json"})
def add_messages(self, messages: List[Union[Dict, BaseMessage]]):
"""
Adds multiple messages to the memory and saves each one individually to the Dapr state store.
Args:
messages (List[Union[Dict, BaseMessage]]): A list of messages to add to the memory.
"""
logger.info(f"Adding {len(messages)} messages to session {self.session_id}")
for message in messages:
if isinstance(message, BaseMessage):
message = message.model_dump()
self.add_message(message)
def add_interaction(self, user_message: BaseMessage, assistant_message: BaseMessage):
"""
Adds a user-assistant interaction to the memory storage and saves it to the state store.
Args:
user_message (BaseMessage): The user message.
assistant_message (BaseMessage): The assistant message.
"""
self.add_messages([user_message, assistant_message])
def _decode_message(self, message_data: Union[bytes, str]) -> dict:
"""
Decodes the message data if it's in bytes, otherwise parses it as a JSON string.
Args:
message_data (Union[bytes, str]): The message data to decode.
Returns:
dict: The decoded message as a dictionary.
"""
if isinstance(message_data, bytes):
message_data = message_data.decode("utf-8")
return json.loads(message_data)
def get_messages(self, limit: int = 100) -> List[Dict[str, str]]:
"""
Retrieves messages stored in the state store for the current session_id, with an optional limit.
Args:
limit (int): The maximum number of messages to retrieve. Defaults to 100.
Returns:
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
"""
query = json.dumps({
"filter": {"EQ": {"sessionId": self.session_id}},
"page": {"limit": limit}
})
query_response = self.query_messages(query=query)
messages = [{"content": msg.get("content"), "role": msg.get("role")}
for msg in (self._decode_message(result.value) for result in query_response.results)]
logger.info(f"Retrieved {len(messages)} messages for session {self.session_id}")
return messages
def query_messages(self, query: Optional[str] = json.dumps({})) -> List[Dict[str, str]]:
"""
Queries messages from the state store based on a pre-constructed query string.
Args:
query (Optional[str]): A JSON-formatted query string to be executed.
Returns:
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
"""
logger.debug(f"Executing query for session {self.session_id}: {query}")
states_metadata = {"contentType": "application/json"}
if self.query_index_name:
states_metadata["queryIndexName"] = self.query_index_name
response = self.dapr_store.query_state(query=query, states_metadata=states_metadata)
return response
def reset_memory(self):
"""
Clears all messages stored in the memory and resets the state store for the current session.
"""
query_response = self.query_messages()
keys = [result.key for result in query_response.results]
for key in keys:
self.dapr_store.delete_state(key)
logger.debug(f"Deleted state with key: {key}")
logger.info(f"Memory reset for session {self.session_id} completed. Deleted {len(keys)} messages.")