mirror of https://github.com/dapr/dapr-agents.git
253 lines
9.2 KiB
Python
253 lines
9.2 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
from typing import Optional, Union
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
state_lock = threading.Lock()
|
|
|
|
|
|
class StateManagementMixin:
|
|
"""
|
|
Mixin providing workflow state initialization, validation, and persistence.
|
|
"""
|
|
|
|
def initialize_state(self) -> None:
|
|
"""
|
|
Initialize workflow state from provided value or storage.
|
|
|
|
Raises:
|
|
RuntimeError: If state initialization or loading from storage fails.
|
|
"""
|
|
try:
|
|
if self.state is None:
|
|
logger.info("No user-provided state. Attempting to load from storage.")
|
|
self.state = self.load_state()
|
|
|
|
if isinstance(self.state, BaseModel):
|
|
logger.debug(
|
|
"User provided a state as a Pydantic model. Converting to dict."
|
|
)
|
|
self.state = self.state.model_dump()
|
|
|
|
if not isinstance(self.state, dict):
|
|
raise TypeError(
|
|
f"Invalid state type: {type(self.state)}. Expected dict or Pydantic model."
|
|
)
|
|
|
|
logger.debug(f"Workflow state initialized with {len(self.state)} key(s).")
|
|
self.save_state()
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error initializing workflow state: {e}") from e
|
|
|
|
def validate_state(self, state_data: dict) -> dict:
|
|
"""
|
|
Validate the workflow state against ``state_format`` if provided.
|
|
|
|
Args:
|
|
state_data: The raw state data to validate.
|
|
|
|
Returns:
|
|
dict: The validated and structured state.
|
|
|
|
Raises:
|
|
ValidationError: If the state data does not conform to the expected schema.
|
|
"""
|
|
try:
|
|
if not self.state_format:
|
|
logger.warning(
|
|
"No schema (state_format) provided; returning state as-is."
|
|
)
|
|
return state_data
|
|
|
|
logger.info("Validating workflow state against schema.")
|
|
validated_state: BaseModel = self.state_format(**state_data)
|
|
return validated_state.model_dump()
|
|
except ValidationError as e:
|
|
raise ValidationError(f"Invalid workflow state: {e.errors()}") from e
|
|
|
|
def load_state(self) -> dict:
|
|
"""
|
|
Load the workflow state from the configured Dapr state store.
|
|
|
|
Returns:
|
|
dict: The loaded and optionally validated state.
|
|
|
|
Raises:
|
|
RuntimeError: If the state store is not properly configured.
|
|
TypeError: If the retrieved state is not a dictionary.
|
|
ValidationError: If state schema validation fails.
|
|
"""
|
|
try:
|
|
if (
|
|
not self._state_store_client
|
|
or not self.state_store_name
|
|
or not self.state_key
|
|
):
|
|
logger.error("State store is not configured. Cannot load state.")
|
|
raise RuntimeError(
|
|
"State store is not configured. Please provide 'state_store_name' and 'state_key'."
|
|
)
|
|
|
|
# For durable agents, always load from database to ensure it's the source of truth
|
|
has_state, state_data = self._state_store_client.try_get_state(
|
|
self.state_key
|
|
)
|
|
if has_state and state_data:
|
|
logger.info(
|
|
f"Existing state found for key '{self.state_key}'. Validating it."
|
|
)
|
|
if not isinstance(state_data, dict):
|
|
raise TypeError(
|
|
f"Invalid state type retrieved: {type(state_data)}. Expected dict."
|
|
)
|
|
|
|
# Set self.state to the loaded data
|
|
if self.state_format:
|
|
loaded_state = self.validate_state(state_data)
|
|
else:
|
|
loaded_state = state_data
|
|
|
|
self.state = loaded_state
|
|
logger.debug(f"Set self.state to loaded data: {self.state}")
|
|
|
|
return loaded_state
|
|
|
|
logger.info(
|
|
f"No existing state found for key '{self.state_key}'. Initializing empty state."
|
|
)
|
|
return {}
|
|
except Exception as e:
|
|
logger.error(f"Failed to load state for key '{self.state_key}': {e}")
|
|
raise RuntimeError(f"Error loading workflow state: {e}") from e
|
|
|
|
def get_local_state_file_path(self) -> str:
|
|
"""
|
|
Return the file path for saving the local state.
|
|
|
|
Returns:
|
|
str: The absolute path to the local state file.
|
|
"""
|
|
directory = self.local_state_path or os.getcwd()
|
|
os.makedirs(directory, exist_ok=True)
|
|
return os.path.join(directory, f"{self.state_key}.json")
|
|
|
|
def save_state_to_disk(
|
|
self, state_data: str, filename: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Safely save the workflow state to a local JSON file.
|
|
|
|
Args:
|
|
state_data: The state data to save (as JSON string or dict).
|
|
filename: Optional filename for the state file.
|
|
|
|
Raises:
|
|
RuntimeError: If saving to disk fails.
|
|
"""
|
|
try:
|
|
save_directory = self.local_state_path or os.getcwd()
|
|
os.makedirs(save_directory, exist_ok=True)
|
|
filename = filename or f"{self.name}_state.json"
|
|
file_path = os.path.join(save_directory, filename)
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
"w", dir=save_directory, delete=False
|
|
) as tmp_file:
|
|
tmp_file.write(state_data)
|
|
temp_path = tmp_file.name
|
|
|
|
with state_lock:
|
|
existing_state = {}
|
|
if os.path.exists(file_path):
|
|
with open(file_path, "r", encoding="utf-8") as file:
|
|
try:
|
|
existing_state = json.load(file)
|
|
except json.JSONDecodeError:
|
|
logger.warning(
|
|
"Existing state file is corrupt or empty. Overwriting."
|
|
)
|
|
|
|
new_state = (
|
|
json.loads(state_data)
|
|
if isinstance(state_data, str)
|
|
else state_data
|
|
)
|
|
merged_state = {**existing_state, **new_state}
|
|
|
|
with open(temp_path, "w", encoding="utf-8") as file:
|
|
json.dump(merged_state, file, indent=4)
|
|
|
|
os.replace(temp_path, file_path)
|
|
|
|
logger.debug(f"Workflow state saved locally at '{file_path}'.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save workflow state to disk: {e}")
|
|
raise RuntimeError(f"Error saving workflow state to disk: {e}")
|
|
|
|
def save_state(
|
|
self,
|
|
state: Optional[Union[dict, BaseModel, str]] = None,
|
|
force_reload: bool = False,
|
|
) -> None:
|
|
"""
|
|
Save the current workflow state to Dapr and optionally to disk.
|
|
|
|
Args:
|
|
state: The new state to save. If not provided, saves the existing state.
|
|
force_reload: If True, reloads the state from the store after saving.
|
|
|
|
Raises:
|
|
RuntimeError: If the state store is not configured.
|
|
TypeError: If the provided state is not a supported type.
|
|
ValueError: If the provided state is a string but not valid JSON.
|
|
"""
|
|
try:
|
|
if (
|
|
not self._state_store_client
|
|
or not self.state_store_name
|
|
or not self.state_key
|
|
):
|
|
logger.error("State store is not configured. Cannot save state.")
|
|
raise RuntimeError(
|
|
"State store is not configured. Please provide 'state_store_name' and 'state_key'."
|
|
)
|
|
|
|
self.state = state or self.state
|
|
if not self.state:
|
|
logger.warning("Skipping state save: Empty state.")
|
|
return
|
|
|
|
if isinstance(self.state, BaseModel):
|
|
state_to_save = self.state.model_dump_json()
|
|
elif isinstance(self.state, dict):
|
|
state_to_save = json.dumps(self.state)
|
|
elif isinstance(self.state, str):
|
|
try:
|
|
json.loads(self.state)
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(f"Invalid JSON string provided as state: {e}")
|
|
state_to_save = self.state
|
|
else:
|
|
raise TypeError(
|
|
f"Invalid state type: {type(self.state)}. Expected dict, BaseModel, or JSON string."
|
|
)
|
|
|
|
self._state_store_client.save_state(self.state_key, state_to_save)
|
|
logger.debug(f"Successfully saved state for key '{self.state_key}'.")
|
|
|
|
if self.save_state_locally:
|
|
self.save_state_to_disk(state_data=state_to_save)
|
|
|
|
if force_reload:
|
|
self.state = self.load_state()
|
|
logger.debug(f"State reloaded after saving for key '{self.state_key}'.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save state for key '{self.state_key}': {e}")
|
|
raise
|