mirror of https://github.com/dapr/dapr-agents.git
style: appease linter
Signed-off-by: Samantha Coyle <sam@diagrid.io>
This commit is contained in:
parent
728be553ee
commit
8601054a0b
|
|
@ -55,7 +55,6 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
description="Metadata about the agent, including name, role, goal, instructions, and topic name.",
|
||||
)
|
||||
|
||||
|
||||
@model_validator(mode="before")
|
||||
def set_agent_and_topic_name(cls, values: dict):
|
||||
# Set name to role if name is not provided
|
||||
|
|
@ -85,17 +84,21 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
self.state["instances"] = {}
|
||||
if "chat_history" not in self.state:
|
||||
self.state["chat_history"] = []
|
||||
|
||||
|
||||
# Load the current workflow instance ID from state if it exists
|
||||
logger.info(f"State after loading: {self.state}")
|
||||
if self.state and "instances" in self.state:
|
||||
logger.info(f"Found {len(self.state['instances'])} instances in state")
|
||||
for instance_id, instance_data in self.state["instances"].items():
|
||||
stored_workflow_name = instance_data.get("workflow_name")
|
||||
logger.info(f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}")
|
||||
logger.info(
|
||||
f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}"
|
||||
)
|
||||
if stored_workflow_name == self._workflow_name:
|
||||
self.workflow_instance_id = instance_id
|
||||
logger.info(f"Loaded current workflow instance ID from state: {instance_id}")
|
||||
logger.info(
|
||||
f"Loaded current workflow instance ID from state: {instance_id}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info("No instances found in state or state is empty")
|
||||
|
|
@ -112,7 +115,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
}
|
||||
|
||||
self.register_agentic_system()
|
||||
|
||||
|
||||
# Start the runtime if it's not already running
|
||||
if not self.wf_runtime_is_running:
|
||||
logger.info("Starting workflow runtime...")
|
||||
|
|
@ -137,20 +140,28 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
# Check for existing incomplete workflows before starting a new one
|
||||
existing_instance_id = await self._find_incomplete_workflow(input_data)
|
||||
if existing_instance_id:
|
||||
logger.info(f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it.")
|
||||
logger.info(
|
||||
f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it."
|
||||
)
|
||||
logger.info("Monitoring the resumed workflow until completion...")
|
||||
|
||||
|
||||
# Monitor the resumed workflow until completion
|
||||
try:
|
||||
result = await self.monitor_workflow_state(existing_instance_id)
|
||||
if result and result.serialized_output:
|
||||
logger.info(f"Resumed workflow {existing_instance_id} completed successfully!")
|
||||
logger.info(
|
||||
f"Resumed workflow {existing_instance_id} completed successfully!"
|
||||
)
|
||||
return result.serialized_output
|
||||
else:
|
||||
logger.warning(f"Resumed workflow {existing_instance_id} completed but had no output")
|
||||
logger.warning(
|
||||
f"Resumed workflow {existing_instance_id} completed but had no output"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring resumed workflow {existing_instance_id}: {e}")
|
||||
logger.error(
|
||||
f"Error monitoring resumed workflow {existing_instance_id}: {e}"
|
||||
)
|
||||
# Fall through to start a new workflow if monitoring fails
|
||||
|
||||
# Prepare input payload for workflow
|
||||
|
|
@ -173,15 +184,17 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
if self.wf_runtime_is_running:
|
||||
self.stop_runtime()
|
||||
|
||||
async def _find_incomplete_workflow(self, input_data: Union[str, Dict[str, Any]]) -> Optional[str]:
|
||||
async def _find_incomplete_workflow(
|
||||
self, input_data: Union[str, Dict[str, Any]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Find an existing incomplete workflow instance that should be resumed.
|
||||
Uses Dapr WorkflowState to determine if workflow is actually complete.
|
||||
Only resumes workflows that match the current input.
|
||||
|
||||
|
||||
Args:
|
||||
input_data: The input for the new workflow request
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[str]: The instance ID of an incomplete workflow with matching input, or None if none found.
|
||||
"""
|
||||
|
|
@ -191,64 +204,81 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
if not instances:
|
||||
logger.debug("No instances found in state")
|
||||
return None
|
||||
|
||||
|
||||
# Normalize input for comparison
|
||||
if isinstance(input_data, dict):
|
||||
current_input = input_data.get("task", str(input_data))
|
||||
else:
|
||||
current_input = str(input_data)
|
||||
|
||||
|
||||
for instance_id, instance_data in instances.items():
|
||||
workflow_name = instance_data.get("workflow_name")
|
||||
end_time = instance_data.get("end_time")
|
||||
stored_input = instance_data.get("input", "")
|
||||
|
||||
|
||||
# Only consider workflows that match our current workflow name
|
||||
if workflow_name != self._workflow_name:
|
||||
continue
|
||||
|
||||
|
||||
if end_time is None:
|
||||
# Only consider workflows that match the current input
|
||||
if str(stored_input) != str(current_input):
|
||||
continue
|
||||
|
||||
|
||||
# Verify the workflow still exists in Dapr and check its actual status
|
||||
# as dapr is our source of truth.
|
||||
try:
|
||||
workflow_state = self.get_workflow_state(instance_id)
|
||||
if workflow_state:
|
||||
runtime_status = workflow_state.runtime_status.name
|
||||
logger.debug(f"Workflow {instance_id} Dapr status: {runtime_status}")
|
||||
logger.debug(
|
||||
f"Workflow {instance_id} Dapr status: {runtime_status}"
|
||||
)
|
||||
|
||||
# If workflow is still running, return it for resumption
|
||||
# Handle case mismatch: Dapr returns 'RUNNING' but enum is 'running'
|
||||
if runtime_status.upper() in [DaprWorkflowStatus.RUNNING.value.upper(), DaprWorkflowStatus.PENDING.value.upper()]:
|
||||
if runtime_status.upper() in [
|
||||
DaprWorkflowStatus.RUNNING.value.upper(),
|
||||
DaprWorkflowStatus.PENDING.value.upper(),
|
||||
]:
|
||||
return instance_id
|
||||
elif runtime_status.upper() in [DaprWorkflowStatus.UNKNOWN.value.upper()]:
|
||||
logger.debug(f"Workflow {instance_id} Dapr status is unknown, skipping")
|
||||
elif runtime_status.upper() in [
|
||||
DaprWorkflowStatus.UNKNOWN.value.upper()
|
||||
]:
|
||||
logger.debug(
|
||||
f"Workflow {instance_id} Dapr status is unknown, skipping"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# This is for COMPLETED, FAILED, TERMINATED
|
||||
self._mark_workflow_completed(instance_id, runtime_status)
|
||||
self._mark_workflow_completed(
|
||||
instance_id, runtime_status
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Workflow {instance_id} no longer exists in Dapr")
|
||||
logger.debug(
|
||||
f"Workflow {instance_id} no longer exists in Dapr"
|
||||
)
|
||||
# Mark as completed in our state since it's no longer in Dapr
|
||||
self._mark_workflow_completed(instance_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not verify workflow {instance_id} in Dapr: {e}")
|
||||
logger.warning(
|
||||
f"Could not verify workflow {instance_id} in Dapr: {e}"
|
||||
)
|
||||
return instance_id
|
||||
|
||||
|
||||
logger.debug("No incomplete workflows found")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding incomplete workflows: {e}")
|
||||
return None
|
||||
|
||||
def _mark_workflow_completed(self, instance_id: str, status: str = "completed") -> None:
|
||||
def _mark_workflow_completed(
|
||||
self, instance_id: str, status: str = "completed"
|
||||
) -> None:
|
||||
"""
|
||||
Mark a workflow as completed in our state.
|
||||
|
||||
|
||||
Args:
|
||||
instance_id: The workflow instance ID to mark as completed
|
||||
status: The completion status (default: "completed")
|
||||
|
|
@ -284,33 +314,45 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
metadata = message.get("_message_metadata", {}) or {}
|
||||
# Extract workflow_instance_id from TriggerAction if present from orchestrator
|
||||
if "workflow_instance_id" in message:
|
||||
metadata["triggering_workflow_instance_id"] = message["workflow_instance_id"]
|
||||
metadata["triggering_workflow_instance_id"] = message[
|
||||
"workflow_instance_id"
|
||||
]
|
||||
else:
|
||||
task = getattr(message, "task", None)
|
||||
metadata = getattr(message, "_message_metadata", {}) or {}
|
||||
# Extract workflow_instance_id from TriggerAction if present from orchestrator
|
||||
if hasattr(message, "workflow_instance_id"):
|
||||
metadata["triggering_workflow_instance_id"] = getattr(message, "workflow_instance_id")
|
||||
metadata["triggering_workflow_instance_id"] = getattr(
|
||||
message, "workflow_instance_id"
|
||||
)
|
||||
|
||||
workflow_instance_id = ctx.instance_id
|
||||
triggering_workflow_instance_id = metadata.get("triggering_workflow_instance_id")
|
||||
triggering_workflow_instance_id = metadata.get(
|
||||
"triggering_workflow_instance_id"
|
||||
)
|
||||
source = metadata.get("source")
|
||||
|
||||
|
||||
# Set default source if not provided (for direct run() calls)
|
||||
if not source:
|
||||
source = "direct"
|
||||
|
||||
print(f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}")
|
||||
print(f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}")
|
||||
|
||||
print(
|
||||
f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
|
||||
)
|
||||
print(
|
||||
f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}"
|
||||
)
|
||||
print(f"DEBUG: Current self.state at start of workflow: {self.state}")
|
||||
logger.info(f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}")
|
||||
logger.info(
|
||||
f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
|
||||
)
|
||||
logger.info(f"Current self.state at start of workflow: {self.state}")
|
||||
|
||||
|
||||
# Store the instance ID from workflow context as the source of truth
|
||||
# This will be used by graceful shutdown logic to save incomplete instances
|
||||
logger.info(f"Workflow context provided instance ID: {workflow_instance_id}")
|
||||
print(f"DEBUG: Workflow context provided instance ID: {workflow_instance_id}")
|
||||
|
||||
|
||||
# Check if this instance already exists in state (from previous runs)
|
||||
if workflow_instance_id not in self.state.get("instances", {}):
|
||||
# This is a new instance, create a minimal entry
|
||||
|
|
@ -324,19 +366,29 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
"workflow_instance_id": workflow_instance_id,
|
||||
"triggering_workflow_instance_id": triggering_workflow_instance_id,
|
||||
"workflow_name": self._workflow_name,
|
||||
"status": "running" # Mark as running
|
||||
"status": "running", # Mark as running
|
||||
}
|
||||
self.state.setdefault("instances", {})[workflow_instance_id] = instance_entry
|
||||
self.state.setdefault("instances", {})[
|
||||
workflow_instance_id
|
||||
] = instance_entry
|
||||
logger.info(f"Created new instance entry: {workflow_instance_id}")
|
||||
print(f"DEBUG: Created new instance entry: {workflow_instance_id}")
|
||||
|
||||
|
||||
# Immediately save state so graceful shutdown can capture this instance
|
||||
self.save_state()
|
||||
logger.info(f"Saved state immediately after creating instance entry for {workflow_instance_id}")
|
||||
print(f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}")
|
||||
logger.info(
|
||||
f"Saved state immediately after creating instance entry for {workflow_instance_id}"
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Found existing instance entry for workflow {workflow_instance_id}")
|
||||
print(f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}")
|
||||
logger.info(
|
||||
f"Found existing instance entry for workflow {workflow_instance_id}"
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}"
|
||||
)
|
||||
final_message: Optional[Dict[str, Any]] = None
|
||||
|
||||
if not ctx.is_replaying:
|
||||
|
|
@ -371,7 +423,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
# Step 5: Add the assistant's response message to the chat history
|
||||
yield ctx.call_activity(
|
||||
self.append_assistant_message,
|
||||
input={"instance_id": workflow_instance_id, "message": response_message},
|
||||
input={
|
||||
"instance_id": workflow_instance_id,
|
||||
"message": response_message,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 6: Handle tool calls response
|
||||
|
|
@ -391,7 +446,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
for tr in tool_results:
|
||||
yield ctx.call_activity(
|
||||
self.append_tool_message,
|
||||
input={"instance_id": workflow_instance_id, "tool_result": tr},
|
||||
input={
|
||||
"instance_id": workflow_instance_id,
|
||||
"tool_result": tr,
|
||||
},
|
||||
)
|
||||
# 🔴 If this was the last turn, stop here—even though there were tool calls
|
||||
if turn == self.max_iterations:
|
||||
|
|
@ -512,7 +570,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
"""
|
||||
# Load state to ensure we have the latest data
|
||||
self.load_state()
|
||||
|
||||
|
||||
workflow_entry = self.state.get("instances", {}).get(instance_id)
|
||||
if workflow_entry is not None:
|
||||
return {
|
||||
|
|
@ -536,7 +594,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
# }
|
||||
# self.state.setdefault("instances", {})[instance_id] = minimal_entry
|
||||
# self.save_state()
|
||||
|
||||
|
||||
# return {
|
||||
# "source": minimal_entry.get("source"),
|
||||
# "source_workflow_instance_id": minimal_entry.get("source_workflow_instance_id"),
|
||||
|
|
@ -560,7 +618,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
"""
|
||||
# Construct messages using instance-specific chat history instead of global memory
|
||||
# This ensures proper message sequence for tool calls
|
||||
messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(instance_id, task or {})
|
||||
messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(
|
||||
instance_id, task or {}
|
||||
)
|
||||
user_message = self.get_last_message_if_user(messages)
|
||||
|
||||
# Always work with a copy of the user message for safety
|
||||
|
|
@ -582,7 +642,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
"source": "user_input",
|
||||
"workflow_instance_id": instance_id,
|
||||
"triggering_workflow_instance_id": None,
|
||||
"workflow_name": self._workflow_name
|
||||
"workflow_name": self._workflow_name,
|
||||
}
|
||||
inst: dict = self.state["instances"][instance_id]
|
||||
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
|
||||
|
|
@ -622,15 +682,19 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)
|
||||
|
||||
logger.error(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}")
|
||||
|
||||
logger.error(
|
||||
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
|
||||
)
|
||||
logger.error(f"Task: {task}")
|
||||
logger.error(f"Messages count: {len(messages)}")
|
||||
logger.error(f"Tools available: {len(self.get_llm_tools())}")
|
||||
logger.error("Full error details:", exc_info=True)
|
||||
|
||||
raise AgentError(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}") from e
|
||||
|
||||
raise AgentError(
|
||||
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
|
||||
) from e
|
||||
|
||||
@task
|
||||
async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -733,7 +797,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
"source": "user_input",
|
||||
"workflow_instance_id": instance_id,
|
||||
"triggering_workflow_instance_id": None,
|
||||
"workflow_name": self._workflow_name
|
||||
"workflow_name": self._workflow_name,
|
||||
}
|
||||
inst: dict = self.state["instances"][instance_id]
|
||||
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
|
||||
|
|
@ -774,7 +838,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
"source": "user_input",
|
||||
"workflow_instance_id": instance_id,
|
||||
"triggering_workflow_instance_id": None,
|
||||
"workflow_name": self._workflow_name
|
||||
"workflow_name": self._workflow_name,
|
||||
}
|
||||
inst: dict = self.state["instances"][instance_id]
|
||||
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
|
||||
|
|
@ -808,7 +872,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
"source": "user_input",
|
||||
"workflow_instance_id": instance_id,
|
||||
"triggering_workflow_instance_id": None,
|
||||
"workflow_name": self._workflow_name
|
||||
"workflow_name": self._workflow_name,
|
||||
}
|
||||
inst: dict = self.state["instances"][instance_id]
|
||||
inst["output"] = final_output
|
||||
|
|
@ -870,26 +934,30 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
except Exception as e:
|
||||
logger.error(f"Error processing broadcast message: {e}", exc_info=True)
|
||||
|
||||
def _construct_messages_with_instance_history(self, instance_id: str, input_data: Union[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def _construct_messages_with_instance_history(
|
||||
self, instance_id: str, input_data: Union[str, Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Construct messages using instance-specific chat history instead of global memory.
|
||||
This ensures proper message sequence for tool calls and prevents OpenAI API errors
|
||||
in the event an app gets terminated or restarts while the workflow is running.
|
||||
|
||||
|
||||
Args:
|
||||
instance_id: The workflow instance ID
|
||||
input_data: User input, either as a string or dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
List of formatted messages with proper sequence
|
||||
"""
|
||||
if not self.prompt_template:
|
||||
raise ValueError("Prompt template must be initialized before constructing messages.")
|
||||
|
||||
raise ValueError(
|
||||
"Prompt template must be initialized before constructing messages."
|
||||
)
|
||||
|
||||
# Get instance-specific chat history instead of global memory
|
||||
instance_data = self.state.get("instances", {}).get(instance_id, {})
|
||||
instance_messages = instance_data.get("messages", [])
|
||||
|
||||
|
||||
# Convert instance messages to the format expected by prompt template
|
||||
chat_history = []
|
||||
for msg in instance_messages:
|
||||
|
|
@ -897,10 +965,14 @@ class DurableAgent(AgenticWorkflow, AgentBase):
|
|||
chat_history.append(msg)
|
||||
else:
|
||||
# Convert DurableAgentMessage to dict if needed
|
||||
chat_history.append(msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg))
|
||||
|
||||
chat_history.append(
|
||||
msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
|
||||
)
|
||||
|
||||
if isinstance(input_data, str):
|
||||
formatted_messages = self.prompt_template.format_prompt(chat_history=chat_history)
|
||||
formatted_messages = self.prompt_template.format_prompt(
|
||||
chat_history=chat_history
|
||||
)
|
||||
if isinstance(formatted_messages, list):
|
||||
user_message = {"role": "user", "content": input_data}
|
||||
return formatted_messages + [user_message]
|
||||
|
|
|
|||
|
|
@ -225,9 +225,8 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)
|
||||
|
||||
|
||||
logger.error(f"OpenAI ChatCompletion API error: {error_type} - {error_msg}")
|
||||
logger.error("Full error details:", exc_info=True)
|
||||
|
||||
raise ValueError(f"OpenAI API error ({error_type}): {error_msg}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -95,33 +95,38 @@ class WorkflowContextStorage:
|
|||
logger.warning(f"⚠️ No context found for instance {instance_id}")
|
||||
return context
|
||||
|
||||
def create_resumed_workflow_context(self, instance_id: str, agent_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
def create_resumed_workflow_context(
|
||||
self, instance_id: str, agent_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new trace context for a resumed workflow after app restart.
|
||||
|
||||
|
||||
When an app restarts, the in-memory context storage is lost. This method
|
||||
creates a new trace context for resumed workflows so they can still be
|
||||
traced, even though they won't be connected to the original trace.
|
||||
|
||||
|
||||
Args:
|
||||
instance_id (str): Unique workflow instance ID
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: New W3C context data for the resumed workflow
|
||||
"""
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from opentelemetry.trace.propagation.tracecontext import (
|
||||
TraceContextTextMapPropagator,
|
||||
)
|
||||
|
||||
# Create a new trace for the resumed workflow with proper AGENT span
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
# Create AGENT span with proper agent name for resumed workflow
|
||||
agent_display_name = agent_name or "DurableAgent"
|
||||
span_name = f"{agent_display_name}.ToolCallingWorkflow"
|
||||
with tracer.start_as_current_span(span_name) as span:
|
||||
# Set AGENT span attributes
|
||||
from .constants import OPENINFERENCE_SPAN_KIND
|
||||
|
||||
span.set_attribute(OPENINFERENCE_SPAN_KIND, "AGENT")
|
||||
span.set_attribute("workflow.instance_id", instance_id)
|
||||
span.set_attribute("workflow.resumed", True)
|
||||
|
|
@ -130,29 +135,33 @@ class WorkflowContextStorage:
|
|||
propagator = TraceContextTextMapPropagator()
|
||||
carrier = {}
|
||||
propagator.inject(carrier)
|
||||
|
||||
|
||||
context_data = {
|
||||
"traceparent": carrier.get("traceparent"),
|
||||
"tracestate": carrier.get("tracestate"),
|
||||
"instance_id": instance_id,
|
||||
"resumed": True,
|
||||
"debug_info": f"New trace created for resumed workflow {instance_id}"
|
||||
"debug_info": f"New trace created for resumed workflow {instance_id}",
|
||||
}
|
||||
|
||||
|
||||
# Store the new context
|
||||
self.store_context(instance_id, context_data)
|
||||
logger.info(f"Created new trace context for resumed workflow {instance_id}")
|
||||
|
||||
logger.info(
|
||||
f"Created new trace context for resumed workflow {instance_id}"
|
||||
)
|
||||
|
||||
return context_data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create resumed workflow context for {instance_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to create resumed workflow context for {instance_id}: {e}"
|
||||
)
|
||||
return {
|
||||
"traceparent": None,
|
||||
"tracestate": None,
|
||||
"instance_id": instance_id,
|
||||
"resumed": True,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def cleanup_context(self, instance_id: str) -> None:
|
||||
|
|
|
|||
|
|
@ -424,28 +424,28 @@ class DaprAgentsInstrumentor(BaseInstrumentor):
|
|||
from dapr_agents.workflow.base import WorkflowApp
|
||||
from dapr_agents.workflow.task import WorkflowTask
|
||||
import wrapt
|
||||
|
||||
|
||||
# Workflow monitoring wrapper - creates AGENT spans
|
||||
wrapt.wrap_function_wrapper(
|
||||
WorkflowApp.__module__,
|
||||
f'{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}',
|
||||
WorkflowMonitorWrapper(self._tracer)
|
||||
f"{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}",
|
||||
WorkflowMonitorWrapper(self._tracer),
|
||||
)
|
||||
|
||||
|
||||
# Workflow run wrapper - creates workflow spans
|
||||
wrapt.wrap_function_wrapper(
|
||||
WorkflowApp.__module__,
|
||||
f'{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}',
|
||||
WorkflowRunWrapper(self._tracer)
|
||||
f"{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}",
|
||||
WorkflowRunWrapper(self._tracer),
|
||||
)
|
||||
|
||||
|
||||
# WorkflowTask call wrapper
|
||||
# This ensures child spans (LLM/TOOL) are properly linked to parent AGENT spans,
|
||||
# and is necessary due to the async nature of the WorkflowTask.__call__ method.
|
||||
wrapt.wrap_function_wrapper(
|
||||
WorkflowTask.__module__,
|
||||
f'{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}',
|
||||
WorkflowTaskWrapper(self._tracer)
|
||||
f"{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}",
|
||||
WorkflowTaskWrapper(self._tracer),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying workflow wrappers: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -341,11 +341,17 @@ class WorkflowMonitorWrapper:
|
|||
from ..context_storage import store_workflow_context
|
||||
|
||||
captured_context = extract_otel_context()
|
||||
if captured_context.get("traceparent") and captured_context.get("trace_id") != "00000000000000000000000000000000":
|
||||
if (
|
||||
captured_context.get("traceparent")
|
||||
and captured_context.get("trace_id")
|
||||
!= "00000000000000000000000000000000"
|
||||
):
|
||||
logger.debug(
|
||||
f"Captured traceparent: {captured_context.get('traceparent')}"
|
||||
)
|
||||
store_workflow_context("__global_workflow_context__", captured_context)
|
||||
store_workflow_context(
|
||||
"__global_workflow_context__", captured_context
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Invalid or empty trace context captured: {captured_context}"
|
||||
|
|
|
|||
|
|
@ -85,14 +85,14 @@ class WorkflowTaskWrapper:
|
|||
# Determine task details
|
||||
logger.debug(f"WorkflowTaskWrapper: instance type = {type(instance)}")
|
||||
logger.debug(f"WorkflowTaskWrapper: instance attributes = {dir(instance)}")
|
||||
if hasattr(instance, 'func'):
|
||||
if hasattr(instance, "func"):
|
||||
logger.debug(f"WorkflowTaskWrapper: instance.func = {instance.func}")
|
||||
else:
|
||||
logger.debug("WorkflowTaskWrapper: instance has no 'func' attribute")
|
||||
|
||||
|
||||
task_name = (
|
||||
getattr(instance.func, "__name__", "unknown_task")
|
||||
if hasattr(instance, 'func') and instance.func
|
||||
if hasattr(instance, "func") and instance.func
|
||||
else "workflow_task"
|
||||
)
|
||||
span_kind = self._determine_span_kind(instance, task_name)
|
||||
|
|
@ -245,7 +245,13 @@ class WorkflowTaskWrapper:
|
|||
return attributes
|
||||
|
||||
def _handle_async_execution(
|
||||
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict
|
||||
self,
|
||||
wrapped: Any,
|
||||
instance: Any,
|
||||
args: Any,
|
||||
kwargs: Any,
|
||||
span_name: str,
|
||||
attributes: dict,
|
||||
) -> Any:
|
||||
"""
|
||||
Handle asynchronous workflow task execution with OpenTelemetry context restoration.
|
||||
|
|
@ -276,32 +282,48 @@ class WorkflowTaskWrapper:
|
|||
from ..context_storage import get_workflow_context
|
||||
|
||||
otel_context = get_workflow_context(instance_id)
|
||||
|
||||
|
||||
# If no context found for specific instance, try global context as fallback
|
||||
if otel_context is None:
|
||||
logger.debug(f"No context found for instance {instance_id}, trying global context")
|
||||
logger.debug(
|
||||
f"No context found for instance {instance_id}, trying global context"
|
||||
)
|
||||
otel_context = get_workflow_context("__global_workflow_context__")
|
||||
|
||||
|
||||
if otel_context:
|
||||
# Store the global context with the specific instance ID for future use
|
||||
from ..context_storage import store_workflow_context
|
||||
|
||||
store_workflow_context(instance_id, otel_context)
|
||||
logger.debug(f"Copied global context to instance {instance_id}")
|
||||
else:
|
||||
# If still no context found (e.g., after app restart), create a new one for resumed workflows
|
||||
logger.debug(f"No context found for instance {instance_id} - creating new context for resumed workflow")
|
||||
|
||||
logger.debug(
|
||||
f"No context found for instance {instance_id} - creating new context for resumed workflow"
|
||||
)
|
||||
|
||||
# Try to get agent name from the task instance
|
||||
agent_name = None
|
||||
if hasattr(instance, 'agent') and instance.agent and hasattr(instance.agent, 'name'):
|
||||
if (
|
||||
hasattr(instance, "agent")
|
||||
and instance.agent
|
||||
and hasattr(instance.agent, "name")
|
||||
):
|
||||
agent_name = instance.agent.name
|
||||
elif hasattr(instance, 'func') and instance.func and hasattr(instance.func, '__self__'):
|
||||
elif (
|
||||
hasattr(instance, "func")
|
||||
and instance.func
|
||||
and hasattr(instance.func, "__self__")
|
||||
):
|
||||
agent_instance = instance.func.__self__
|
||||
if hasattr(agent_instance, 'name'):
|
||||
if hasattr(agent_instance, "name"):
|
||||
agent_name = agent_instance.name
|
||||
|
||||
|
||||
from ..context_storage import _context_storage
|
||||
otel_context = _context_storage.create_resumed_workflow_context(instance_id, agent_name)
|
||||
|
||||
otel_context = _context_storage.create_resumed_workflow_context(
|
||||
instance_id, agent_name
|
||||
)
|
||||
|
||||
# Create span with restored context if available
|
||||
from ..context_propagation import create_child_span_with_context
|
||||
|
|
@ -348,7 +370,13 @@ class WorkflowTaskWrapper:
|
|||
return async_wrapper()
|
||||
|
||||
def _handle_sync_execution(
|
||||
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict
|
||||
self,
|
||||
wrapped: Any,
|
||||
instance: Any,
|
||||
args: Any,
|
||||
kwargs: Any,
|
||||
span_name: str,
|
||||
attributes: dict,
|
||||
) -> Any:
|
||||
"""
|
||||
Handle synchronous workflow task execution with OpenTelemetry context restoration.
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class FastAPIServerBase(APIServerBase, SignalHandlingMixin):
|
|||
Perform graceful shutdown operations for the FastAPI server.
|
||||
"""
|
||||
await self.stop()
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop the FastAPI server gracefully.
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class ElevenLabsClientConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class NVIDIAClientConfig(BaseModel):
|
||||
base_url: Optional[str] = Field(
|
||||
"https://integrate.api.nvidia.com/v1", description="Base URL for the NVIDIA API"
|
||||
|
|
@ -27,7 +26,6 @@ class NVIDIAClientConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class DaprInferenceClientConfig:
|
||||
pass
|
||||
|
||||
|
|
@ -73,7 +71,6 @@ class HFInferenceClientConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class OpenAIClientConfig(BaseModel):
|
||||
base_url: Optional[str] = Field(None, description="Base URL for the OpenAI API")
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
@ -112,7 +109,6 @@ class AzureOpenAIClientConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class OpenAIModelConfig(OpenAIClientConfig):
|
||||
type: Literal["openai"] = Field(
|
||||
"openai", description="Type of the model, must always be 'openai'"
|
||||
|
|
@ -170,7 +166,6 @@ class OpenAIParamsBase(BaseModel):
|
|||
stream: Optional[bool] = Field(False, description="Whether to stream responses")
|
||||
|
||||
|
||||
|
||||
class OpenAITextCompletionParams(OpenAIParamsBase):
|
||||
"""
|
||||
Specific configs for the text completions endpoint.
|
||||
|
|
@ -186,7 +181,6 @@ class OpenAITextCompletionParams(OpenAIParamsBase):
|
|||
suffix: Optional[str] = Field(None, description="Suffix to append to the prompt")
|
||||
|
||||
|
||||
|
||||
class OpenAIChatCompletionParams(OpenAIParamsBase):
|
||||
"""
|
||||
Specific settings for the Chat Completion endpoint.
|
||||
|
|
@ -222,7 +216,6 @@ class OpenAIChatCompletionParams(OpenAIParamsBase):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class HFHubChatCompletionParams(BaseModel):
|
||||
"""
|
||||
Specific settings for Hugging Face Hub Chat Completion endpoint.
|
||||
|
|
@ -285,7 +278,6 @@ class HFHubChatCompletionParams(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class NVIDIAChatCompletionParams(OpenAIParamsBase):
|
||||
"""
|
||||
Specific settings for the Chat Completion endpoint.
|
||||
|
|
@ -311,7 +303,6 @@ class NVIDIAChatCompletionParams(OpenAIParamsBase):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class PromptyModelConfig(BaseModel):
|
||||
api: Literal["chat", "completion"] = Field(
|
||||
"chat", description="The API to use, either 'chat' or 'completion'"
|
||||
|
|
@ -330,7 +321,6 @@ class PromptyModelConfig(BaseModel):
|
|||
description="Determines if full response or just the first one is returned",
|
||||
)
|
||||
|
||||
|
||||
@model_validator(mode="before")
|
||||
def sync_model_name(cls, values: dict):
|
||||
"""
|
||||
|
|
@ -405,7 +395,6 @@ class PromptyDefinition(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
|
||||
class AudioSpeechRequest(BaseModel):
|
||||
model: Optional[Literal["tts-1", "tts-1-hd"]] = Field(
|
||||
"tts-1", description="TTS model to use. Defaults to 'tts-1'."
|
||||
|
|
|
|||
|
|
@ -0,0 +1,121 @@
|
|||
"""
|
||||
Reusable signal handling mixin for graceful shutdown across different service types.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dapr_agents.utils import add_signal_handlers_cross_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SignalHandlingMixin:
|
||||
"""
|
||||
Mixin providing reusable signal handling for graceful shutdown.
|
||||
|
||||
This mixin can be used by any class that needs to handle shutdown signals
|
||||
(SIGINT, SIGTERM) gracefully. It provides a consistent interface for:
|
||||
- Setting up signal handlers
|
||||
- Managing shutdown events
|
||||
- Triggering graceful shutdown logic
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._shutdown_event: Optional[asyncio.Event] = None
|
||||
self._signal_handlers_setup = False
|
||||
|
||||
def setup_signal_handlers(self) -> None:
|
||||
"""
|
||||
Set up signal handlers for graceful shutdown.
|
||||
|
||||
This method should be called during initialization or startup
|
||||
to enable graceful shutdown handling.
|
||||
"""
|
||||
# Initialize the attribute if it doesn't exist
|
||||
if not hasattr(self, "_signal_handlers_setup"):
|
||||
self._signal_handlers_setup = False
|
||||
|
||||
if self._signal_handlers_setup:
|
||||
logger.debug("Signal handlers already set up")
|
||||
return
|
||||
|
||||
# Initialize shutdown event if it doesn't exist
|
||||
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
# Set up signal handlers
|
||||
loop = asyncio.get_event_loop()
|
||||
add_signal_handlers_cross_platform(loop, self._handle_shutdown_signal)
|
||||
|
||||
self._signal_handlers_setup = True
|
||||
logger.debug("Signal handlers set up for graceful shutdown")
|
||||
|
||||
def _handle_shutdown_signal(self, sig: int) -> None:
|
||||
"""
|
||||
Internal signal handler that triggers graceful shutdown.
|
||||
|
||||
Args:
|
||||
sig: The received signal number
|
||||
"""
|
||||
logger.debug(f"Shutdown signal {sig} received. Triggering graceful shutdown...")
|
||||
|
||||
# Set the shutdown event
|
||||
if self._shutdown_event:
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Call the graceful shutdown method if it exists
|
||||
if hasattr(self, "graceful_shutdown"):
|
||||
asyncio.create_task(self.graceful_shutdown())
|
||||
elif hasattr(self, "stop"):
|
||||
# Fallback to stop() method if graceful_shutdown doesn't exist
|
||||
asyncio.create_task(self.stop())
|
||||
else:
|
||||
logger.warning(
|
||||
"No graceful shutdown method found. Implement graceful_shutdown() or stop() method."
|
||||
)
|
||||
|
||||
async def graceful_shutdown(self) -> None:
|
||||
"""
|
||||
Perform graceful shutdown operations.
|
||||
|
||||
This method should be overridden by classes that use this mixin
|
||||
to implement their specific shutdown logic.
|
||||
|
||||
Default implementation calls stop() if it exists.
|
||||
"""
|
||||
if hasattr(self, "stop"):
|
||||
await self.stop()
|
||||
else:
|
||||
logger.warning(
|
||||
"No stop() method found. Override graceful_shutdown() to implement shutdown logic."
|
||||
)
|
||||
|
||||
def is_shutdown_requested(self) -> bool:
|
||||
"""
|
||||
Check if a shutdown has been requested.
|
||||
|
||||
Returns:
|
||||
bool: True if shutdown has been requested, False otherwise
|
||||
"""
|
||||
return (
|
||||
hasattr(self, "_shutdown_event")
|
||||
and self._shutdown_event is not None
|
||||
and self._shutdown_event.is_set()
|
||||
)
|
||||
|
||||
async def wait_for_shutdown(self, check_interval: float = 1.0) -> None:
|
||||
"""
|
||||
Wait for a shutdown signal to be received.
|
||||
|
||||
Args:
|
||||
check_interval: How often to check for shutdown (in seconds)
|
||||
"""
|
||||
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
|
||||
raise RuntimeError(
|
||||
"Signal handlers not set up. Call setup_signal_handlers() first."
|
||||
)
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
await asyncio.sleep(check_interval)
|
||||
|
|
@ -556,7 +556,10 @@ class WorkflowApp(BaseModel):
|
|||
f"Output: {json.dumps(state.serialized_output, indent=2)}"
|
||||
)
|
||||
|
||||
elif workflow_status.upper() in (DaprWorkflowStatus.FAILED.value.upper(), "ABORTED"):
|
||||
elif workflow_status.upper() in (
|
||||
DaprWorkflowStatus.FAILED.value.upper(),
|
||||
"ABORTED",
|
||||
):
|
||||
# Ensure `failure_details` exists before accessing attributes
|
||||
error_type = getattr(failure_details, "error_type", "Unknown")
|
||||
message = getattr(failure_details, "message", "No message provided")
|
||||
|
|
@ -611,7 +614,9 @@ class WorkflowApp(BaseModel):
|
|||
# Off-load the potentially blocking run_workflow call to a thread.
|
||||
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
|
||||
|
||||
logger.debug(f"Workflow '{workflow}' started with instance ID: {instance_id}")
|
||||
logger.debug(
|
||||
f"Workflow '{workflow}' started with instance ID: {instance_id}"
|
||||
)
|
||||
|
||||
# Await the asynchronous monitoring of the workflow state.
|
||||
state = await self.monitor_workflow_state(instance_id)
|
||||
|
|
|
|||
|
|
@ -138,30 +138,44 @@ class ServiceMixin(SignalHandlingMixin):
|
|||
|
||||
# Save state before shutting down to ensure persistence and agent durability to properly rerun after being stoped
|
||||
try:
|
||||
if hasattr(self, 'save_state') and hasattr(self, 'state'):
|
||||
if hasattr(self, "save_state") and hasattr(self, "state"):
|
||||
# Graceful shutdown compensation: Save incomplete instance if it exists
|
||||
if hasattr(self, 'workflow_instance_id') and self.workflow_instance_id:
|
||||
if hasattr(self, "workflow_instance_id") and self.workflow_instance_id:
|
||||
if self.workflow_instance_id not in self.state.get("instances", {}):
|
||||
# This instance was never saved, add it as incomplete
|
||||
from datetime import datetime, timezone
|
||||
|
||||
incomplete_entry = {
|
||||
"messages": [],
|
||||
"start_time": datetime.now(timezone.utc).isoformat(),
|
||||
"source": "graceful_shutdown",
|
||||
"source_workflow_instance_id": None,
|
||||
"workflow_name": getattr(self, '_workflow_name', 'Unknown'),
|
||||
"workflow_name": getattr(self, "_workflow_name", "Unknown"),
|
||||
"dapr_status": DaprWorkflowStatus.PENDING,
|
||||
"suspended_reason": "app_terminated"
|
||||
"suspended_reason": "app_terminated",
|
||||
}
|
||||
self.state.setdefault("instances", {})[self.workflow_instance_id] = incomplete_entry
|
||||
logger.info(f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown")
|
||||
self.state.setdefault("instances", {})[
|
||||
self.workflow_instance_id
|
||||
] = incomplete_entry
|
||||
logger.info(
|
||||
f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown"
|
||||
)
|
||||
else:
|
||||
# Mark existing instance as suspended due to app termination
|
||||
if "instances" in self.state and self.workflow_instance_id in self.state["instances"]:
|
||||
self.state["instances"][self.workflow_instance_id]["dapr_status"] = DaprWorkflowStatus.SUSPENDED
|
||||
self.state["instances"][self.workflow_instance_id]["suspended_reason"] = "app_terminated"
|
||||
logger.info(f"Marked instance {self.workflow_instance_id} as suspended due to app termination")
|
||||
|
||||
if (
|
||||
"instances" in self.state
|
||||
and self.workflow_instance_id in self.state["instances"]
|
||||
):
|
||||
self.state["instances"][self.workflow_instance_id][
|
||||
"dapr_status"
|
||||
] = DaprWorkflowStatus.SUSPENDED
|
||||
self.state["instances"][self.workflow_instance_id][
|
||||
"suspended_reason"
|
||||
] = "app_terminated"
|
||||
logger.info(
|
||||
f"Marked instance {self.workflow_instance_id} as suspended due to app termination"
|
||||
)
|
||||
|
||||
self.save_state()
|
||||
logger.debug("Workflow state saved successfully.")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -106,16 +106,16 @@ class StateManagementMixin:
|
|||
raise TypeError(
|
||||
f"Invalid state type retrieved: {type(state_data)}. Expected dict."
|
||||
)
|
||||
|
||||
|
||||
# Set self.state to the loaded data
|
||||
if self.state_format:
|
||||
loaded_state = self.validate_state(state_data)
|
||||
else:
|
||||
loaded_state = state_data
|
||||
|
||||
|
||||
self.state = loaded_state
|
||||
logger.debug(f"Set self.state to loaded data: {self.state}")
|
||||
|
||||
|
||||
return loaded_state
|
||||
|
||||
logger.info(
|
||||
|
|
|
|||
Loading…
Reference in New Issue