diff --git a/dapr_agents/agents/durableagent/agent.py b/dapr_agents/agents/durableagent/agent.py index bcc5905..dda7482 100644 --- a/dapr_agents/agents/durableagent/agent.py +++ b/dapr_agents/agents/durableagent/agent.py @@ -55,7 +55,6 @@ class DurableAgent(AgenticWorkflow, AgentBase): description="Metadata about the agent, including name, role, goal, instructions, and topic name.", ) - @model_validator(mode="before") def set_agent_and_topic_name(cls, values: dict): # Set name to role if name is not provided @@ -85,17 +84,21 @@ class DurableAgent(AgenticWorkflow, AgentBase): self.state["instances"] = {} if "chat_history" not in self.state: self.state["chat_history"] = [] - + # Load the current workflow instance ID from state if it exists logger.info(f"State after loading: {self.state}") if self.state and "instances" in self.state: logger.info(f"Found {len(self.state['instances'])} instances in state") for instance_id, instance_data in self.state["instances"].items(): stored_workflow_name = instance_data.get("workflow_name") - logger.info(f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}") + logger.info( + f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}" + ) if stored_workflow_name == self._workflow_name: self.workflow_instance_id = instance_id - logger.info(f"Loaded current workflow instance ID from state: {instance_id}") + logger.info( + f"Loaded current workflow instance ID from state: {instance_id}" + ) break else: logger.info("No instances found in state or state is empty") @@ -112,7 +115,7 @@ class DurableAgent(AgenticWorkflow, AgentBase): } self.register_agentic_system() - + # Start the runtime if it's not already running if not self.wf_runtime_is_running: logger.info("Starting workflow runtime...") @@ -137,20 +140,28 @@ class DurableAgent(AgenticWorkflow, AgentBase): # Check for existing incomplete workflows before starting a new one existing_instance_id = await self._find_incomplete_workflow(input_data) if existing_instance_id: - logger.info(f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it.") + logger.info( + f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it." + ) logger.info("Monitoring the resumed workflow until completion...") - + # Monitor the resumed workflow until completion try: result = await self.monitor_workflow_state(existing_instance_id) if result and result.serialized_output: - logger.info(f"Resumed workflow {existing_instance_id} completed successfully!") + logger.info( + f"Resumed workflow {existing_instance_id} completed successfully!" + ) return result.serialized_output else: - logger.warning(f"Resumed workflow {existing_instance_id} completed but had no output") + logger.warning( + f"Resumed workflow {existing_instance_id} completed but had no output" + ) return None except Exception as e: - logger.error(f"Error monitoring resumed workflow {existing_instance_id}: {e}") + logger.error( + f"Error monitoring resumed workflow {existing_instance_id}: {e}" + ) # Fall through to start a new workflow if monitoring fails # Prepare input payload for workflow @@ -173,15 +184,17 @@ class DurableAgent(AgenticWorkflow, AgentBase): if self.wf_runtime_is_running: self.stop_runtime() - async def _find_incomplete_workflow(self, input_data: Union[str, Dict[str, Any]]) -> Optional[str]: + async def _find_incomplete_workflow( + self, input_data: Union[str, Dict[str, Any]] + ) -> Optional[str]: """ Find an existing incomplete workflow instance that should be resumed. Uses Dapr WorkflowState to determine if workflow is actually complete. Only resumes workflows that match the current input. - + Args: input_data: The input for the new workflow request - + Returns: Optional[str]: The instance ID of an incomplete workflow with matching input, or None if none found. """ @@ -191,64 +204,81 @@ class DurableAgent(AgenticWorkflow, AgentBase): if not instances: logger.debug("No instances found in state") return None - + # Normalize input for comparison if isinstance(input_data, dict): current_input = input_data.get("task", str(input_data)) else: current_input = str(input_data) - + for instance_id, instance_data in instances.items(): workflow_name = instance_data.get("workflow_name") end_time = instance_data.get("end_time") stored_input = instance_data.get("input", "") - + # Only consider workflows that match our current workflow name if workflow_name != self._workflow_name: continue - + if end_time is None: # Only consider workflows that match the current input if str(stored_input) != str(current_input): continue - + # Verify the workflow still exists in Dapr and check its actual status # as dapr is our source of truth. try: workflow_state = self.get_workflow_state(instance_id) if workflow_state: runtime_status = workflow_state.runtime_status.name - logger.debug(f"Workflow {instance_id} Dapr status: {runtime_status}") + logger.debug( + f"Workflow {instance_id} Dapr status: {runtime_status}" + ) # If workflow is still running, return it for resumption # Handle case mismatch: Dapr returns 'RUNNING' but enum is 'running' - if runtime_status.upper() in [DaprWorkflowStatus.RUNNING.value.upper(), DaprWorkflowStatus.PENDING.value.upper()]: + if runtime_status.upper() in [ + DaprWorkflowStatus.RUNNING.value.upper(), + DaprWorkflowStatus.PENDING.value.upper(), + ]: return instance_id - elif runtime_status.upper() in [DaprWorkflowStatus.UNKNOWN.value.upper()]: - logger.debug(f"Workflow {instance_id} Dapr status is unknown, skipping") + elif runtime_status.upper() in [ + DaprWorkflowStatus.UNKNOWN.value.upper() + ]: + logger.debug( + f"Workflow {instance_id} Dapr status is unknown, skipping" + ) continue else: # This is for COMPLETED, FAILED, TERMINATED - self._mark_workflow_completed(instance_id, runtime_status) + self._mark_workflow_completed( + instance_id, runtime_status + ) else: - logger.debug(f"Workflow {instance_id} no longer exists in Dapr") + logger.debug( + f"Workflow {instance_id} no longer exists in Dapr" + ) # Mark as completed in our state since it's no longer in Dapr self._mark_workflow_completed(instance_id) except Exception as e: - logger.warning(f"Could not verify workflow {instance_id} in Dapr: {e}") + logger.warning( + f"Could not verify workflow {instance_id} in Dapr: {e}" + ) return instance_id - + logger.debug("No incomplete workflows found") return None - + except Exception as e: logger.error(f"Error finding incomplete workflows: {e}") return None - def _mark_workflow_completed(self, instance_id: str, status: str = "completed") -> None: + def _mark_workflow_completed( + self, instance_id: str, status: str = "completed" + ) -> None: """ Mark a workflow as completed in our state. - + Args: instance_id: The workflow instance ID to mark as completed status: The completion status (default: "completed") @@ -284,33 +314,45 @@ class DurableAgent(AgenticWorkflow, AgentBase): metadata = message.get("_message_metadata", {}) or {} # Extract workflow_instance_id from TriggerAction if present from orchestrator if "workflow_instance_id" in message: - metadata["triggering_workflow_instance_id"] = message["workflow_instance_id"] + metadata["triggering_workflow_instance_id"] = message[ + "workflow_instance_id" + ] else: task = getattr(message, "task", None) metadata = getattr(message, "_message_metadata", {}) or {} # Extract workflow_instance_id from TriggerAction if present from orchestrator if hasattr(message, "workflow_instance_id"): - metadata["triggering_workflow_instance_id"] = getattr(message, "workflow_instance_id") + metadata["triggering_workflow_instance_id"] = getattr( + message, "workflow_instance_id" + ) workflow_instance_id = ctx.instance_id - triggering_workflow_instance_id = metadata.get("triggering_workflow_instance_id") + triggering_workflow_instance_id = metadata.get( + "triggering_workflow_instance_id" + ) source = metadata.get("source") - + # Set default source if not provided (for direct run() calls) if not source: source = "direct" - - print(f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}") - print(f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}") + + print( + f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}" + ) + print( + f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}" + ) print(f"DEBUG: Current self.state at start of workflow: {self.state}") - logger.info(f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}") + logger.info( + f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}" + ) logger.info(f"Current self.state at start of workflow: {self.state}") - + # Store the instance ID from workflow context as the source of truth # This will be used by graceful shutdown logic to save incomplete instances logger.info(f"Workflow context provided instance ID: {workflow_instance_id}") print(f"DEBUG: Workflow context provided instance ID: {workflow_instance_id}") - + # Check if this instance already exists in state (from previous runs) if workflow_instance_id not in self.state.get("instances", {}): # This is a new instance, create a minimal entry @@ -324,19 +366,29 @@ class DurableAgent(AgenticWorkflow, AgentBase): "workflow_instance_id": workflow_instance_id, "triggering_workflow_instance_id": triggering_workflow_instance_id, "workflow_name": self._workflow_name, - "status": "running" # Mark as running + "status": "running", # Mark as running } - self.state.setdefault("instances", {})[workflow_instance_id] = instance_entry + self.state.setdefault("instances", {})[ + workflow_instance_id + ] = instance_entry logger.info(f"Created new instance entry: {workflow_instance_id}") print(f"DEBUG: Created new instance entry: {workflow_instance_id}") - + # Immediately save state so graceful shutdown can capture this instance self.save_state() - logger.info(f"Saved state immediately after creating instance entry for {workflow_instance_id}") - print(f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}") + logger.info( + f"Saved state immediately after creating instance entry for {workflow_instance_id}" + ) + print( + f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}" + ) else: - logger.info(f"Found existing instance entry for workflow {workflow_instance_id}") - print(f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}") + logger.info( + f"Found existing instance entry for workflow {workflow_instance_id}" + ) + print( + f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}" + ) final_message: Optional[Dict[str, Any]] = None if not ctx.is_replaying: @@ -371,7 +423,10 @@ class DurableAgent(AgenticWorkflow, AgentBase): # Step 5: Add the assistant's response message to the chat history yield ctx.call_activity( self.append_assistant_message, - input={"instance_id": workflow_instance_id, "message": response_message}, + input={ + "instance_id": workflow_instance_id, + "message": response_message, + }, ) # Step 6: Handle tool calls response @@ -391,7 +446,10 @@ class DurableAgent(AgenticWorkflow, AgentBase): for tr in tool_results: yield ctx.call_activity( self.append_tool_message, - input={"instance_id": workflow_instance_id, "tool_result": tr}, + input={ + "instance_id": workflow_instance_id, + "tool_result": tr, + }, ) # 🔴 If this was the last turn, stop here—even though there were tool calls if turn == self.max_iterations: @@ -512,7 +570,7 @@ class DurableAgent(AgenticWorkflow, AgentBase): """ # Load state to ensure we have the latest data self.load_state() - + workflow_entry = self.state.get("instances", {}).get(instance_id) if workflow_entry is not None: return { @@ -536,7 +594,7 @@ class DurableAgent(AgenticWorkflow, AgentBase): # } # self.state.setdefault("instances", {})[instance_id] = minimal_entry # self.save_state() - + # return { # "source": minimal_entry.get("source"), # "source_workflow_instance_id": minimal_entry.get("source_workflow_instance_id"), @@ -560,7 +618,9 @@ class DurableAgent(AgenticWorkflow, AgentBase): """ # Construct messages using instance-specific chat history instead of global memory # This ensures proper message sequence for tool calls - messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(instance_id, task or {}) + messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history( + instance_id, task or {} + ) user_message = self.get_last_message_if_user(messages) # Always work with a copy of the user message for safety @@ -582,7 +642,7 @@ class DurableAgent(AgenticWorkflow, AgentBase): "source": "user_input", "workflow_instance_id": instance_id, "triggering_workflow_instance_id": None, - "workflow_name": self._workflow_name + "workflow_name": self._workflow_name, } inst: dict = self.state["instances"][instance_id] inst.setdefault("messages", []).append(msg_object.model_dump(mode="json")) @@ -622,15 +682,19 @@ class DurableAgent(AgenticWorkflow, AgentBase): except Exception as e: error_type = type(e).__name__ error_msg = str(e) - - logger.error(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}") + + logger.error( + f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}" + ) logger.error(f"Task: {task}") logger.error(f"Messages count: {len(messages)}") logger.error(f"Tools available: {len(self.get_llm_tools())}") logger.error("Full error details:", exc_info=True) - raise AgentError(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}") from e - + raise AgentError( + f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}" + ) from e + @task async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: """ @@ -733,7 +797,7 @@ class DurableAgent(AgenticWorkflow, AgentBase): "source": "user_input", "workflow_instance_id": instance_id, "triggering_workflow_instance_id": None, - "workflow_name": self._workflow_name + "workflow_name": self._workflow_name, } inst: dict = self.state["instances"][instance_id] inst.setdefault("messages", []).append(msg_object.model_dump(mode="json")) @@ -774,7 +838,7 @@ class DurableAgent(AgenticWorkflow, AgentBase): "source": "user_input", "workflow_instance_id": instance_id, "triggering_workflow_instance_id": None, - "workflow_name": self._workflow_name + "workflow_name": self._workflow_name, } inst: dict = self.state["instances"][instance_id] inst.setdefault("messages", []).append(msg_object.model_dump(mode="json")) @@ -808,7 +872,7 @@ class DurableAgent(AgenticWorkflow, AgentBase): "source": "user_input", "workflow_instance_id": instance_id, "triggering_workflow_instance_id": None, - "workflow_name": self._workflow_name + "workflow_name": self._workflow_name, } inst: dict = self.state["instances"][instance_id] inst["output"] = final_output @@ -870,26 +934,30 @@ class DurableAgent(AgenticWorkflow, AgentBase): except Exception as e: logger.error(f"Error processing broadcast message: {e}", exc_info=True) - def _construct_messages_with_instance_history(self, instance_id: str, input_data: Union[str, Dict[str, Any]]) -> List[Dict[str, Any]]: + def _construct_messages_with_instance_history( + self, instance_id: str, input_data: Union[str, Dict[str, Any]] + ) -> List[Dict[str, Any]]: """ Construct messages using instance-specific chat history instead of global memory. This ensures proper message sequence for tool calls and prevents OpenAI API errors in the event an app gets terminated or restarts while the workflow is running. - + Args: instance_id: The workflow instance ID input_data: User input, either as a string or dictionary - + Returns: List of formatted messages with proper sequence """ if not self.prompt_template: - raise ValueError("Prompt template must be initialized before constructing messages.") - + raise ValueError( + "Prompt template must be initialized before constructing messages." + ) + # Get instance-specific chat history instead of global memory instance_data = self.state.get("instances", {}).get(instance_id, {}) instance_messages = instance_data.get("messages", []) - + # Convert instance messages to the format expected by prompt template chat_history = [] for msg in instance_messages: @@ -897,10 +965,14 @@ class DurableAgent(AgenticWorkflow, AgentBase): chat_history.append(msg) else: # Convert DurableAgentMessage to dict if needed - chat_history.append(msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg)) - + chat_history.append( + msg.model_dump() if hasattr(msg, "model_dump") else dict(msg) + ) + if isinstance(input_data, str): - formatted_messages = self.prompt_template.format_prompt(chat_history=chat_history) + formatted_messages = self.prompt_template.format_prompt( + chat_history=chat_history + ) if isinstance(formatted_messages, list): user_message = {"role": "user", "content": input_data} return formatted_messages + [user_message] diff --git a/dapr_agents/llm/openai/chat.py b/dapr_agents/llm/openai/chat.py index 3953fa3..ad7fa90 100644 --- a/dapr_agents/llm/openai/chat.py +++ b/dapr_agents/llm/openai/chat.py @@ -225,9 +225,8 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase): except Exception as e: error_type = type(e).__name__ error_msg = str(e) - + logger.error(f"OpenAI ChatCompletion API error: {error_type} - {error_msg}") logger.error("Full error details:", exc_info=True) raise ValueError(f"OpenAI API error ({error_type}): {error_msg}") from e - diff --git a/dapr_agents/observability/context_storage.py b/dapr_agents/observability/context_storage.py index d7ad8ef..edcbf2e 100644 --- a/dapr_agents/observability/context_storage.py +++ b/dapr_agents/observability/context_storage.py @@ -95,33 +95,38 @@ class WorkflowContextStorage: logger.warning(f"⚠️ No context found for instance {instance_id}") return context - def create_resumed_workflow_context(self, instance_id: str, agent_name: Optional[str] = None) -> Dict[str, Any]: + def create_resumed_workflow_context( + self, instance_id: str, agent_name: Optional[str] = None + ) -> Dict[str, Any]: """ Create a new trace context for a resumed workflow after app restart. - + When an app restarts, the in-memory context storage is lost. This method creates a new trace context for resumed workflows so they can still be traced, even though they won't be connected to the original trace. - + Args: instance_id (str): Unique workflow instance ID - + Returns: Dict[str, Any]: New W3C context data for the resumed workflow """ try: from opentelemetry import trace - from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator - + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + # Create a new trace for the resumed workflow with proper AGENT span tracer = trace.get_tracer(__name__) - + # Create AGENT span with proper agent name for resumed workflow agent_display_name = agent_name or "DurableAgent" span_name = f"{agent_display_name}.ToolCallingWorkflow" with tracer.start_as_current_span(span_name) as span: # Set AGENT span attributes from .constants import OPENINFERENCE_SPAN_KIND + span.set_attribute(OPENINFERENCE_SPAN_KIND, "AGENT") span.set_attribute("workflow.instance_id", instance_id) span.set_attribute("workflow.resumed", True) @@ -130,29 +135,33 @@ class WorkflowContextStorage: propagator = TraceContextTextMapPropagator() carrier = {} propagator.inject(carrier) - + context_data = { "traceparent": carrier.get("traceparent"), "tracestate": carrier.get("tracestate"), "instance_id": instance_id, "resumed": True, - "debug_info": f"New trace created for resumed workflow {instance_id}" + "debug_info": f"New trace created for resumed workflow {instance_id}", } - + # Store the new context self.store_context(instance_id, context_data) - logger.info(f"Created new trace context for resumed workflow {instance_id}") - + logger.info( + f"Created new trace context for resumed workflow {instance_id}" + ) + return context_data - + except Exception as e: - logger.error(f"Failed to create resumed workflow context for {instance_id}: {e}") + logger.error( + f"Failed to create resumed workflow context for {instance_id}: {e}" + ) return { "traceparent": None, "tracestate": None, "instance_id": instance_id, "resumed": True, - "error": str(e) + "error": str(e), } def cleanup_context(self, instance_id: str) -> None: diff --git a/dapr_agents/observability/instrumentor.py b/dapr_agents/observability/instrumentor.py index 374fdd8..a7ddf86 100644 --- a/dapr_agents/observability/instrumentor.py +++ b/dapr_agents/observability/instrumentor.py @@ -424,28 +424,28 @@ class DaprAgentsInstrumentor(BaseInstrumentor): from dapr_agents.workflow.base import WorkflowApp from dapr_agents.workflow.task import WorkflowTask import wrapt - + # Workflow monitoring wrapper - creates AGENT spans wrapt.wrap_function_wrapper( WorkflowApp.__module__, - f'{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}', - WorkflowMonitorWrapper(self._tracer) + f"{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}", + WorkflowMonitorWrapper(self._tracer), ) - + # Workflow run wrapper - creates workflow spans wrapt.wrap_function_wrapper( WorkflowApp.__module__, - f'{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}', - WorkflowRunWrapper(self._tracer) + f"{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}", + WorkflowRunWrapper(self._tracer), ) - + # WorkflowTask call wrapper # This ensures child spans (LLM/TOOL) are properly linked to parent AGENT spans, # and is necessary due to the async nature of the WorkflowTask.__call__ method. wrapt.wrap_function_wrapper( WorkflowTask.__module__, - f'{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}', - WorkflowTaskWrapper(self._tracer) + f"{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}", + WorkflowTaskWrapper(self._tracer), ) except Exception as e: logger.error(f"Error applying workflow wrappers: {e}", exc_info=True) diff --git a/dapr_agents/observability/wrappers/workflow.py b/dapr_agents/observability/wrappers/workflow.py index 062deda..d23172a 100644 --- a/dapr_agents/observability/wrappers/workflow.py +++ b/dapr_agents/observability/wrappers/workflow.py @@ -341,11 +341,17 @@ class WorkflowMonitorWrapper: from ..context_storage import store_workflow_context captured_context = extract_otel_context() - if captured_context.get("traceparent") and captured_context.get("trace_id") != "00000000000000000000000000000000": + if ( + captured_context.get("traceparent") + and captured_context.get("trace_id") + != "00000000000000000000000000000000" + ): logger.debug( f"Captured traceparent: {captured_context.get('traceparent')}" ) - store_workflow_context("__global_workflow_context__", captured_context) + store_workflow_context( + "__global_workflow_context__", captured_context + ) else: logger.debug( f"Invalid or empty trace context captured: {captured_context}" diff --git a/dapr_agents/observability/wrappers/workflow_task.py b/dapr_agents/observability/wrappers/workflow_task.py index f48b737..7f65f4c 100644 --- a/dapr_agents/observability/wrappers/workflow_task.py +++ b/dapr_agents/observability/wrappers/workflow_task.py @@ -85,14 +85,14 @@ class WorkflowTaskWrapper: # Determine task details logger.debug(f"WorkflowTaskWrapper: instance type = {type(instance)}") logger.debug(f"WorkflowTaskWrapper: instance attributes = {dir(instance)}") - if hasattr(instance, 'func'): + if hasattr(instance, "func"): logger.debug(f"WorkflowTaskWrapper: instance.func = {instance.func}") else: logger.debug("WorkflowTaskWrapper: instance has no 'func' attribute") - + task_name = ( getattr(instance.func, "__name__", "unknown_task") - if hasattr(instance, 'func') and instance.func + if hasattr(instance, "func") and instance.func else "workflow_task" ) span_kind = self._determine_span_kind(instance, task_name) @@ -245,7 +245,13 @@ class WorkflowTaskWrapper: return attributes def _handle_async_execution( - self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict + self, + wrapped: Any, + instance: Any, + args: Any, + kwargs: Any, + span_name: str, + attributes: dict, ) -> Any: """ Handle asynchronous workflow task execution with OpenTelemetry context restoration. @@ -276,32 +282,48 @@ class WorkflowTaskWrapper: from ..context_storage import get_workflow_context otel_context = get_workflow_context(instance_id) - + # If no context found for specific instance, try global context as fallback if otel_context is None: - logger.debug(f"No context found for instance {instance_id}, trying global context") + logger.debug( + f"No context found for instance {instance_id}, trying global context" + ) otel_context = get_workflow_context("__global_workflow_context__") - + if otel_context: # Store the global context with the specific instance ID for future use from ..context_storage import store_workflow_context + store_workflow_context(instance_id, otel_context) logger.debug(f"Copied global context to instance {instance_id}") else: # If still no context found (e.g., after app restart), create a new one for resumed workflows - logger.debug(f"No context found for instance {instance_id} - creating new context for resumed workflow") - + logger.debug( + f"No context found for instance {instance_id} - creating new context for resumed workflow" + ) + # Try to get agent name from the task instance agent_name = None - if hasattr(instance, 'agent') and instance.agent and hasattr(instance.agent, 'name'): + if ( + hasattr(instance, "agent") + and instance.agent + and hasattr(instance.agent, "name") + ): agent_name = instance.agent.name - elif hasattr(instance, 'func') and instance.func and hasattr(instance.func, '__self__'): + elif ( + hasattr(instance, "func") + and instance.func + and hasattr(instance.func, "__self__") + ): agent_instance = instance.func.__self__ - if hasattr(agent_instance, 'name'): + if hasattr(agent_instance, "name"): agent_name = agent_instance.name - + from ..context_storage import _context_storage - otel_context = _context_storage.create_resumed_workflow_context(instance_id, agent_name) + + otel_context = _context_storage.create_resumed_workflow_context( + instance_id, agent_name + ) # Create span with restored context if available from ..context_propagation import create_child_span_with_context @@ -348,7 +370,13 @@ class WorkflowTaskWrapper: return async_wrapper() def _handle_sync_execution( - self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict + self, + wrapped: Any, + instance: Any, + args: Any, + kwargs: Any, + span_name: str, + attributes: dict, ) -> Any: """ Handle synchronous workflow task execution with OpenTelemetry context restoration. diff --git a/dapr_agents/service/fastapi/base.py b/dapr_agents/service/fastapi/base.py index 980775f..0dbcd7b 100644 --- a/dapr_agents/service/fastapi/base.py +++ b/dapr_agents/service/fastapi/base.py @@ -128,7 +128,7 @@ class FastAPIServerBase(APIServerBase, SignalHandlingMixin): Perform graceful shutdown operations for the FastAPI server. """ await self.stop() - + async def stop(self): """ Stop the FastAPI server gracefully. diff --git a/dapr_agents/types/llm.py b/dapr_agents/types/llm.py index 305cb93..65c0336 100644 --- a/dapr_agents/types/llm.py +++ b/dapr_agents/types/llm.py @@ -17,7 +17,6 @@ class ElevenLabsClientConfig(BaseModel): ) - class NVIDIAClientConfig(BaseModel): base_url: Optional[str] = Field( "https://integrate.api.nvidia.com/v1", description="Base URL for the NVIDIA API" @@ -27,7 +26,6 @@ class NVIDIAClientConfig(BaseModel): ) - class DaprInferenceClientConfig: pass @@ -73,7 +71,6 @@ class HFInferenceClientConfig(BaseModel): ) - class OpenAIClientConfig(BaseModel): base_url: Optional[str] = Field(None, description="Base URL for the OpenAI API") api_key: Optional[str] = Field( @@ -112,7 +109,6 @@ class AzureOpenAIClientConfig(BaseModel): ) - class OpenAIModelConfig(OpenAIClientConfig): type: Literal["openai"] = Field( "openai", description="Type of the model, must always be 'openai'" @@ -170,7 +166,6 @@ class OpenAIParamsBase(BaseModel): stream: Optional[bool] = Field(False, description="Whether to stream responses") - class OpenAITextCompletionParams(OpenAIParamsBase): """ Specific configs for the text completions endpoint. @@ -186,7 +181,6 @@ class OpenAITextCompletionParams(OpenAIParamsBase): suffix: Optional[str] = Field(None, description="Suffix to append to the prompt") - class OpenAIChatCompletionParams(OpenAIParamsBase): """ Specific settings for the Chat Completion endpoint. @@ -222,7 +216,6 @@ class OpenAIChatCompletionParams(OpenAIParamsBase): ) - class HFHubChatCompletionParams(BaseModel): """ Specific settings for Hugging Face Hub Chat Completion endpoint. @@ -285,7 +278,6 @@ class HFHubChatCompletionParams(BaseModel): ) - class NVIDIAChatCompletionParams(OpenAIParamsBase): """ Specific settings for the Chat Completion endpoint. @@ -311,7 +303,6 @@ class NVIDIAChatCompletionParams(OpenAIParamsBase): ) - class PromptyModelConfig(BaseModel): api: Literal["chat", "completion"] = Field( "chat", description="The API to use, either 'chat' or 'completion'" @@ -330,7 +321,6 @@ class PromptyModelConfig(BaseModel): description="Determines if full response or just the first one is returned", ) - @model_validator(mode="before") def sync_model_name(cls, values: dict): """ @@ -405,7 +395,6 @@ class PromptyDefinition(BaseModel): ) - class AudioSpeechRequest(BaseModel): model: Optional[Literal["tts-1", "tts-1-hd"]] = Field( "tts-1", description="TTS model to use. Defaults to 'tts-1'." diff --git a/dapr_agents/utils/signal_mixin.py b/dapr_agents/utils/signal_mixin.py new file mode 100644 index 0000000..11d01a9 --- /dev/null +++ b/dapr_agents/utils/signal_mixin.py @@ -0,0 +1,121 @@ +""" +Reusable signal handling mixin for graceful shutdown across different service types. +""" +import asyncio +import logging +from typing import Optional + +from dapr_agents.utils import add_signal_handlers_cross_platform + +logger = logging.getLogger(__name__) + + +class SignalHandlingMixin: + """ + Mixin providing reusable signal handling for graceful shutdown. + + This mixin can be used by any class that needs to handle shutdown signals + (SIGINT, SIGTERM) gracefully. It provides a consistent interface for: + - Setting up signal handlers + - Managing shutdown events + - Triggering graceful shutdown logic + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._shutdown_event: Optional[asyncio.Event] = None + self._signal_handlers_setup = False + + def setup_signal_handlers(self) -> None: + """ + Set up signal handlers for graceful shutdown. + + This method should be called during initialization or startup + to enable graceful shutdown handling. + """ + # Initialize the attribute if it doesn't exist + if not hasattr(self, "_signal_handlers_setup"): + self._signal_handlers_setup = False + + if self._signal_handlers_setup: + logger.debug("Signal handlers already set up") + return + + # Initialize shutdown event if it doesn't exist + if not hasattr(self, "_shutdown_event") or self._shutdown_event is None: + self._shutdown_event = asyncio.Event() + + # Set up signal handlers + loop = asyncio.get_event_loop() + add_signal_handlers_cross_platform(loop, self._handle_shutdown_signal) + + self._signal_handlers_setup = True + logger.debug("Signal handlers set up for graceful shutdown") + + def _handle_shutdown_signal(self, sig: int) -> None: + """ + Internal signal handler that triggers graceful shutdown. + + Args: + sig: The received signal number + """ + logger.debug(f"Shutdown signal {sig} received. Triggering graceful shutdown...") + + # Set the shutdown event + if self._shutdown_event: + self._shutdown_event.set() + + # Call the graceful shutdown method if it exists + if hasattr(self, "graceful_shutdown"): + asyncio.create_task(self.graceful_shutdown()) + elif hasattr(self, "stop"): + # Fallback to stop() method if graceful_shutdown doesn't exist + asyncio.create_task(self.stop()) + else: + logger.warning( + "No graceful shutdown method found. Implement graceful_shutdown() or stop() method." + ) + + async def graceful_shutdown(self) -> None: + """ + Perform graceful shutdown operations. + + This method should be overridden by classes that use this mixin + to implement their specific shutdown logic. + + Default implementation calls stop() if it exists. + """ + if hasattr(self, "stop"): + await self.stop() + else: + logger.warning( + "No stop() method found. Override graceful_shutdown() to implement shutdown logic." + ) + + def is_shutdown_requested(self) -> bool: + """ + Check if a shutdown has been requested. + + Returns: + bool: True if shutdown has been requested, False otherwise + """ + return ( + hasattr(self, "_shutdown_event") + and self._shutdown_event is not None + and self._shutdown_event.is_set() + ) + + async def wait_for_shutdown(self, check_interval: float = 1.0) -> None: + """ + Wait for a shutdown signal to be received. + + Args: + check_interval: How often to check for shutdown (in seconds) + """ + if not hasattr(self, "_shutdown_event") or self._shutdown_event is None: + raise RuntimeError( + "Signal handlers not set up. Call setup_signal_handlers() first." + ) + + while not self._shutdown_event.is_set(): + await asyncio.sleep(check_interval) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index d3db359..1afddc3 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -556,7 +556,10 @@ class WorkflowApp(BaseModel): f"Output: {json.dumps(state.serialized_output, indent=2)}" ) - elif workflow_status.upper() in (DaprWorkflowStatus.FAILED.value.upper(), "ABORTED"): + elif workflow_status.upper() in ( + DaprWorkflowStatus.FAILED.value.upper(), + "ABORTED", + ): # Ensure `failure_details` exists before accessing attributes error_type = getattr(failure_details, "error_type", "Unknown") message = getattr(failure_details, "message", "No message provided") @@ -611,7 +614,9 @@ class WorkflowApp(BaseModel): # Off-load the potentially blocking run_workflow call to a thread. instance_id = await asyncio.to_thread(self.run_workflow, workflow, input) - logger.debug(f"Workflow '{workflow}' started with instance ID: {instance_id}") + logger.debug( + f"Workflow '{workflow}' started with instance ID: {instance_id}" + ) # Await the asynchronous monitoring of the workflow state. state = await self.monitor_workflow_state(instance_id) diff --git a/dapr_agents/workflow/mixins/service.py b/dapr_agents/workflow/mixins/service.py index 934b86e..467baed 100644 --- a/dapr_agents/workflow/mixins/service.py +++ b/dapr_agents/workflow/mixins/service.py @@ -138,30 +138,44 @@ class ServiceMixin(SignalHandlingMixin): # Save state before shutting down to ensure persistence and agent durability to properly rerun after being stoped try: - if hasattr(self, 'save_state') and hasattr(self, 'state'): + if hasattr(self, "save_state") and hasattr(self, "state"): # Graceful shutdown compensation: Save incomplete instance if it exists - if hasattr(self, 'workflow_instance_id') and self.workflow_instance_id: + if hasattr(self, "workflow_instance_id") and self.workflow_instance_id: if self.workflow_instance_id not in self.state.get("instances", {}): # This instance was never saved, add it as incomplete from datetime import datetime, timezone + incomplete_entry = { "messages": [], "start_time": datetime.now(timezone.utc).isoformat(), "source": "graceful_shutdown", "source_workflow_instance_id": None, - "workflow_name": getattr(self, '_workflow_name', 'Unknown'), + "workflow_name": getattr(self, "_workflow_name", "Unknown"), "dapr_status": DaprWorkflowStatus.PENDING, - "suspended_reason": "app_terminated" + "suspended_reason": "app_terminated", } - self.state.setdefault("instances", {})[self.workflow_instance_id] = incomplete_entry - logger.info(f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown") + self.state.setdefault("instances", {})[ + self.workflow_instance_id + ] = incomplete_entry + logger.info( + f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown" + ) else: # Mark existing instance as suspended due to app termination - if "instances" in self.state and self.workflow_instance_id in self.state["instances"]: - self.state["instances"][self.workflow_instance_id]["dapr_status"] = DaprWorkflowStatus.SUSPENDED - self.state["instances"][self.workflow_instance_id]["suspended_reason"] = "app_terminated" - logger.info(f"Marked instance {self.workflow_instance_id} as suspended due to app termination") - + if ( + "instances" in self.state + and self.workflow_instance_id in self.state["instances"] + ): + self.state["instances"][self.workflow_instance_id][ + "dapr_status" + ] = DaprWorkflowStatus.SUSPENDED + self.state["instances"][self.workflow_instance_id][ + "suspended_reason" + ] = "app_terminated" + logger.info( + f"Marked instance {self.workflow_instance_id} as suspended due to app termination" + ) + self.save_state() logger.debug("Workflow state saved successfully.") except Exception as e: diff --git a/dapr_agents/workflow/mixins/state.py b/dapr_agents/workflow/mixins/state.py index 6314ab4..3ade63f 100644 --- a/dapr_agents/workflow/mixins/state.py +++ b/dapr_agents/workflow/mixins/state.py @@ -106,16 +106,16 @@ class StateManagementMixin: raise TypeError( f"Invalid state type retrieved: {type(state_data)}. Expected dict." ) - + # Set self.state to the loaded data if self.state_format: loaded_state = self.validate_state(state_data) else: loaded_state = state_data - + self.state = loaded_state logger.debug(f"Set self.state to loaded data: {self.state}") - + return loaded_state logger.info(