refactor dapr memory (#94)

Signed-off-by: yaron2 <schneider.yaron@live.com>
This commit is contained in:
Yaron Schneider 2025-04-22 09:21:08 -07:00 committed by GitHub
parent bd0859d181
commit 356a25f281
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 96 additions and 27 deletions

View File

@ -27,11 +27,9 @@ class ConversationDaprStateMemory(MemoryBase):
store_name: str = Field(default="statestore", description="The name of the Dapr state store.")
session_id: Optional[Union[str, int]] = Field(default=None, description="Unique identifier for the conversation session.")
query_index_name: Optional[str] = Field(default=None, description="The index name for querying state.")
# Private attribute to hold the initialized DaprStateStore
dapr_store: Optional[DaprStateStore] = Field(default=None, init=False, description="Dapr State Store.")
@model_validator(mode="before")
def set_session_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -76,6 +74,7 @@ class ConversationDaprStateMemory(MemoryBase):
Args:
message (Union[Dict, BaseMessage]): The message to add to the memory.
"""
if isinstance(message, BaseMessage):
message = message.model_dump()
@ -86,9 +85,12 @@ class ConversationDaprStateMemory(MemoryBase):
"createdAt": datetime.now().isoformat() + "Z"
})
logger.info(f"Adding message with key {message_key} to session {self.session_id}")
self.dapr_store.save_state(message_key, json.dumps(message), {"contentType": "application/json"})
existing = self.get_messages()
existing.append(message)
logger.debug(f"Adding message with key {message_key} to session {self.session_id}")
self.dapr_store.save_state(self.session_id, json.dumps(existing), {"contentType": "application/json"})
def add_messages(self, messages: List[Union[Dict, BaseMessage]]):
"""
Adds multiple messages to the memory and saves each one individually to the Dapr state store.
@ -136,18 +138,19 @@ class ConversationDaprStateMemory(MemoryBase):
Returns:
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
"""
query = json.dumps({
"filter": {"EQ": {"sessionId": self.session_id}},
"page": {"limit": limit}
})
query_response = self.query_messages(query=query)
messages = [{"content": msg.get("content"), "role": msg.get("role")}
for msg in (self._decode_message(result.value) for result in query_response.results)]
logger.info(f"Retrieved {len(messages)} messages for session {self.session_id}")
return messages
response = self.query_messages(session_id=self.session_id)
if response and response.data:
raw_messages = json.loads(response.data)
if raw_messages:
messages = [{"content": msg.get("content"), "role": msg.get("role")}
for msg in raw_messages]
logger.info(f"Retrieved {len(messages)} messages for session {self.session_id}")
return messages
def query_messages(self, query: Optional[str] = json.dumps({})) -> List[Dict[str, str]]:
return []
def query_messages(self, session_id: str) -> List[Dict[str, str]]:
"""
Queries messages from the state store based on a pre-constructed query string.
@ -157,22 +160,14 @@ class ConversationDaprStateMemory(MemoryBase):
Returns:
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
"""
logger.debug(f"Executing query for session {self.session_id}: {query}")
logger.debug(f"Executing query for session {self.session_id}")
states_metadata = {"contentType": "application/json"}
if self.query_index_name:
states_metadata["queryIndexName"] = self.query_index_name
response = self.dapr_store.query_state(query=query, states_metadata=states_metadata)
response = self.dapr_store.get_state(session_id, state_metadata=states_metadata)
return response
def reset_memory(self):
"""
Clears all messages stored in the memory and resets the state store for the current session.
"""
query_response = self.query_messages()
keys = [result.key for result in query_response.results]
for key in keys:
self.dapr_store.delete_state(key)
logger.debug(f"Deleted state with key: {key}")
logger.info(f"Memory reset for session {self.session_id} completed. Deleted {len(keys)} messages.")
self.dapr_store.delete_state(self.session_id)
logger.info(f"Memory reset for session {self.session_id} completed.")

View File

@ -172,6 +172,68 @@ print(AIAgent.chat_history) # Should be empty now
```
This will show agent interaction history growth and reset.
### Persistent Agent Memory
Dapr Agents allows for agents to retain long-term memory by providing automatic state management of the history. The state can be saved into a wide variety of [Dapr supported state stores](https://docs.dapr.io/reference/components-reference/supported-state-stores/).
To configure persistent agent memory, follow these steps:
1. Set up the state store configuration. Here's an example of working with local Redis:
```yaml
apiVersion: dapr.io/v1alpha1
kind: Component
metadata:
name: historystore
spec:
type: state.redis
version: v1
metadata:
- name: redisHost
value: localhost:6379
- name: redisPassword
value: ""
```
Save the file in a `./components` dir.
2. Enable Dapr memory in code
```python
import asyncio
from weather_tools import tools
from dapr_agents import Agent
from dotenv import load_dotenv
from dapr_agents.memory import ConversationDaprStateMemory
load_dotenv()
AIAgent = Agent(
name="Stevie",
role="Weather Assistant",
goal="Assist Humans with weather related tasks.",
instructions=[
"Get accurate weather information",
"From time to time, you can also jump after answering the weather question."
],
memory=ConversationDaprStateMemory(store_name="historystore", session_id="some-id"),
tools=tools
)
# Wrap your async call
async def main():
await AIAgent.run("What is the weather in Virginia, New York and Washington DC?")
if __name__ == "__main__":
asyncio.run(main())
```
3. Run the agent with Dapr
```bash
dapr run --app-id weatheragent --resources-path ./components -- python weather_agent.py
```
## Available Agent Types
Dapr Agents provides several agent implementations, each designed for different use cases:

View File

@ -0,0 +1,12 @@
apiVersion: dapr.io/v1alpha1
kind: Component
metadata:
name: historystore
spec:
type: state.redis
version: v1
metadata:
- name: redisHost
value: localhost:6379
- name: redisPassword
value: ""