mirror of https://github.com/dapr/dapr-agents.git
refactor dapr memory (#94)
Signed-off-by: yaron2 <schneider.yaron@live.com>
This commit is contained in:
parent
bd0859d181
commit
356a25f281
|
|
@ -27,11 +27,9 @@ class ConversationDaprStateMemory(MemoryBase):
|
|||
|
||||
store_name: str = Field(default="statestore", description="The name of the Dapr state store.")
|
||||
session_id: Optional[Union[str, int]] = Field(default=None, description="Unique identifier for the conversation session.")
|
||||
query_index_name: Optional[str] = Field(default=None, description="The index name for querying state.")
|
||||
|
||||
# Private attribute to hold the initialized DaprStateStore
|
||||
dapr_store: Optional[DaprStateStore] = Field(default=None, init=False, description="Dapr State Store.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def set_session_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -76,6 +74,7 @@ class ConversationDaprStateMemory(MemoryBase):
|
|||
Args:
|
||||
message (Union[Dict, BaseMessage]): The message to add to the memory.
|
||||
"""
|
||||
|
||||
if isinstance(message, BaseMessage):
|
||||
message = message.model_dump()
|
||||
|
||||
|
|
@ -86,9 +85,12 @@ class ConversationDaprStateMemory(MemoryBase):
|
|||
"createdAt": datetime.now().isoformat() + "Z"
|
||||
})
|
||||
|
||||
logger.info(f"Adding message with key {message_key} to session {self.session_id}")
|
||||
self.dapr_store.save_state(message_key, json.dumps(message), {"contentType": "application/json"})
|
||||
existing = self.get_messages()
|
||||
existing.append(message)
|
||||
|
||||
logger.debug(f"Adding message with key {message_key} to session {self.session_id}")
|
||||
self.dapr_store.save_state(self.session_id, json.dumps(existing), {"contentType": "application/json"})
|
||||
|
||||
def add_messages(self, messages: List[Union[Dict, BaseMessage]]):
|
||||
"""
|
||||
Adds multiple messages to the memory and saves each one individually to the Dapr state store.
|
||||
|
|
@ -136,18 +138,19 @@ class ConversationDaprStateMemory(MemoryBase):
|
|||
Returns:
|
||||
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
|
||||
"""
|
||||
query = json.dumps({
|
||||
"filter": {"EQ": {"sessionId": self.session_id}},
|
||||
"page": {"limit": limit}
|
||||
})
|
||||
query_response = self.query_messages(query=query)
|
||||
messages = [{"content": msg.get("content"), "role": msg.get("role")}
|
||||
for msg in (self._decode_message(result.value) for result in query_response.results)]
|
||||
|
||||
logger.info(f"Retrieved {len(messages)} messages for session {self.session_id}")
|
||||
return messages
|
||||
response = self.query_messages(session_id=self.session_id)
|
||||
if response and response.data:
|
||||
raw_messages = json.loads(response.data)
|
||||
if raw_messages:
|
||||
messages = [{"content": msg.get("content"), "role": msg.get("role")}
|
||||
for msg in raw_messages]
|
||||
|
||||
logger.info(f"Retrieved {len(messages)} messages for session {self.session_id}")
|
||||
return messages
|
||||
|
||||
def query_messages(self, query: Optional[str] = json.dumps({})) -> List[Dict[str, str]]:
|
||||
return []
|
||||
|
||||
def query_messages(self, session_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Queries messages from the state store based on a pre-constructed query string.
|
||||
|
||||
|
|
@ -157,22 +160,14 @@ class ConversationDaprStateMemory(MemoryBase):
|
|||
Returns:
|
||||
List[Dict[str, str]]: A list containing the 'content' and 'role' fields of the messages.
|
||||
"""
|
||||
logger.debug(f"Executing query for session {self.session_id}: {query}")
|
||||
logger.debug(f"Executing query for session {self.session_id}")
|
||||
states_metadata = {"contentType": "application/json"}
|
||||
if self.query_index_name:
|
||||
states_metadata["queryIndexName"] = self.query_index_name
|
||||
|
||||
response = self.dapr_store.query_state(query=query, states_metadata=states_metadata)
|
||||
response = self.dapr_store.get_state(session_id, state_metadata=states_metadata)
|
||||
return response
|
||||
|
||||
def reset_memory(self):
|
||||
"""
|
||||
Clears all messages stored in the memory and resets the state store for the current session.
|
||||
"""
|
||||
query_response = self.query_messages()
|
||||
keys = [result.key for result in query_response.results]
|
||||
for key in keys:
|
||||
self.dapr_store.delete_state(key)
|
||||
logger.debug(f"Deleted state with key: {key}")
|
||||
|
||||
logger.info(f"Memory reset for session {self.session_id} completed. Deleted {len(keys)} messages.")
|
||||
self.dapr_store.delete_state(self.session_id)
|
||||
logger.info(f"Memory reset for session {self.session_id} completed.")
|
||||
|
|
@ -172,6 +172,68 @@ print(AIAgent.chat_history) # Should be empty now
|
|||
```
|
||||
This will show agent interaction history growth and reset.
|
||||
|
||||
### Persistent Agent Memory
|
||||
|
||||
Dapr Agents allows for agents to retain long-term memory by providing automatic state management of the history. The state can be saved into a wide variety of [Dapr supported state stores](https://docs.dapr.io/reference/components-reference/supported-state-stores/).
|
||||
|
||||
To configure persistent agent memory, follow these steps:
|
||||
|
||||
1. Set up the state store configuration. Here's an example of working with local Redis:
|
||||
|
||||
```yaml
|
||||
apiVersion: dapr.io/v1alpha1
|
||||
kind: Component
|
||||
metadata:
|
||||
name: historystore
|
||||
spec:
|
||||
type: state.redis
|
||||
version: v1
|
||||
metadata:
|
||||
- name: redisHost
|
||||
value: localhost:6379
|
||||
- name: redisPassword
|
||||
value: ""
|
||||
```
|
||||
|
||||
Save the file in a `./components` dir.
|
||||
|
||||
2. Enable Dapr memory in code
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from weather_tools import tools
|
||||
from dapr_agents import Agent
|
||||
from dotenv import load_dotenv
|
||||
from dapr_agents.memory import ConversationDaprStateMemory
|
||||
|
||||
load_dotenv()
|
||||
|
||||
AIAgent = Agent(
|
||||
name="Stevie",
|
||||
role="Weather Assistant",
|
||||
goal="Assist Humans with weather related tasks.",
|
||||
instructions=[
|
||||
"Get accurate weather information",
|
||||
"From time to time, you can also jump after answering the weather question."
|
||||
],
|
||||
memory=ConversationDaprStateMemory(store_name="historystore", session_id="some-id"),
|
||||
tools=tools
|
||||
)
|
||||
|
||||
# Wrap your async call
|
||||
async def main():
|
||||
await AIAgent.run("What is the weather in Virginia, New York and Washington DC?")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
3. Run the agent with Dapr
|
||||
|
||||
```bash
|
||||
dapr run --app-id weatheragent --resources-path ./components -- python weather_agent.py
|
||||
```
|
||||
|
||||
## Available Agent Types
|
||||
|
||||
Dapr Agents provides several agent implementations, each designed for different use cases:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
apiVersion: dapr.io/v1alpha1
|
||||
kind: Component
|
||||
metadata:
|
||||
name: historystore
|
||||
spec:
|
||||
type: state.redis
|
||||
version: v1
|
||||
metadata:
|
||||
- name: redisHost
|
||||
value: localhost:6379
|
||||
- name: redisPassword
|
||||
value: ""
|
||||
Loading…
Reference in New Issue