style: appease linter

Signed-off-by: Samantha Coyle <sam@diagrid.io>
This commit is contained in:
Samantha Coyle 2025-09-08 13:58:29 -05:00
parent 728be553ee
commit 8601054a0b
No known key found for this signature in database
12 changed files with 382 additions and 139 deletions

View File

@ -55,7 +55,6 @@ class DurableAgent(AgenticWorkflow, AgentBase):
description="Metadata about the agent, including name, role, goal, instructions, and topic name.",
)
@model_validator(mode="before")
def set_agent_and_topic_name(cls, values: dict):
# Set name to role if name is not provided
@ -92,10 +91,14 @@ class DurableAgent(AgenticWorkflow, AgentBase):
logger.info(f"Found {len(self.state['instances'])} instances in state")
for instance_id, instance_data in self.state["instances"].items():
stored_workflow_name = instance_data.get("workflow_name")
logger.info(f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}")
logger.info(
f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}"
)
if stored_workflow_name == self._workflow_name:
self.workflow_instance_id = instance_id
logger.info(f"Loaded current workflow instance ID from state: {instance_id}")
logger.info(
f"Loaded current workflow instance ID from state: {instance_id}"
)
break
else:
logger.info("No instances found in state or state is empty")
@ -137,20 +140,28 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# Check for existing incomplete workflows before starting a new one
existing_instance_id = await self._find_incomplete_workflow(input_data)
if existing_instance_id:
logger.info(f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it.")
logger.info(
f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it."
)
logger.info("Monitoring the resumed workflow until completion...")
# Monitor the resumed workflow until completion
try:
result = await self.monitor_workflow_state(existing_instance_id)
if result and result.serialized_output:
logger.info(f"Resumed workflow {existing_instance_id} completed successfully!")
logger.info(
f"Resumed workflow {existing_instance_id} completed successfully!"
)
return result.serialized_output
else:
logger.warning(f"Resumed workflow {existing_instance_id} completed but had no output")
logger.warning(
f"Resumed workflow {existing_instance_id} completed but had no output"
)
return None
except Exception as e:
logger.error(f"Error monitoring resumed workflow {existing_instance_id}: {e}")
logger.error(
f"Error monitoring resumed workflow {existing_instance_id}: {e}"
)
# Fall through to start a new workflow if monitoring fails
# Prepare input payload for workflow
@ -173,7 +184,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
if self.wf_runtime_is_running:
self.stop_runtime()
async def _find_incomplete_workflow(self, input_data: Union[str, Dict[str, Any]]) -> Optional[str]:
async def _find_incomplete_workflow(
self, input_data: Union[str, Dict[str, Any]]
) -> Optional[str]:
"""
Find an existing incomplete workflow instance that should be resumed.
Uses Dapr WorkflowState to determine if workflow is actually complete.
@ -218,24 +231,39 @@ class DurableAgent(AgenticWorkflow, AgentBase):
workflow_state = self.get_workflow_state(instance_id)
if workflow_state:
runtime_status = workflow_state.runtime_status.name
logger.debug(f"Workflow {instance_id} Dapr status: {runtime_status}")
logger.debug(
f"Workflow {instance_id} Dapr status: {runtime_status}"
)
# If workflow is still running, return it for resumption
# Handle case mismatch: Dapr returns 'RUNNING' but enum is 'running'
if runtime_status.upper() in [DaprWorkflowStatus.RUNNING.value.upper(), DaprWorkflowStatus.PENDING.value.upper()]:
if runtime_status.upper() in [
DaprWorkflowStatus.RUNNING.value.upper(),
DaprWorkflowStatus.PENDING.value.upper(),
]:
return instance_id
elif runtime_status.upper() in [DaprWorkflowStatus.UNKNOWN.value.upper()]:
logger.debug(f"Workflow {instance_id} Dapr status is unknown, skipping")
elif runtime_status.upper() in [
DaprWorkflowStatus.UNKNOWN.value.upper()
]:
logger.debug(
f"Workflow {instance_id} Dapr status is unknown, skipping"
)
continue
else:
# This is for COMPLETED, FAILED, TERMINATED
self._mark_workflow_completed(instance_id, runtime_status)
self._mark_workflow_completed(
instance_id, runtime_status
)
else:
logger.debug(f"Workflow {instance_id} no longer exists in Dapr")
logger.debug(
f"Workflow {instance_id} no longer exists in Dapr"
)
# Mark as completed in our state since it's no longer in Dapr
self._mark_workflow_completed(instance_id)
except Exception as e:
logger.warning(f"Could not verify workflow {instance_id} in Dapr: {e}")
logger.warning(
f"Could not verify workflow {instance_id} in Dapr: {e}"
)
return instance_id
logger.debug("No incomplete workflows found")
@ -245,7 +273,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
logger.error(f"Error finding incomplete workflows: {e}")
return None
def _mark_workflow_completed(self, instance_id: str, status: str = "completed") -> None:
def _mark_workflow_completed(
self, instance_id: str, status: str = "completed"
) -> None:
"""
Mark a workflow as completed in our state.
@ -284,26 +314,38 @@ class DurableAgent(AgenticWorkflow, AgentBase):
metadata = message.get("_message_metadata", {}) or {}
# Extract workflow_instance_id from TriggerAction if present from orchestrator
if "workflow_instance_id" in message:
metadata["triggering_workflow_instance_id"] = message["workflow_instance_id"]
metadata["triggering_workflow_instance_id"] = message[
"workflow_instance_id"
]
else:
task = getattr(message, "task", None)
metadata = getattr(message, "_message_metadata", {}) or {}
# Extract workflow_instance_id from TriggerAction if present from orchestrator
if hasattr(message, "workflow_instance_id"):
metadata["triggering_workflow_instance_id"] = getattr(message, "workflow_instance_id")
metadata["triggering_workflow_instance_id"] = getattr(
message, "workflow_instance_id"
)
workflow_instance_id = ctx.instance_id
triggering_workflow_instance_id = metadata.get("triggering_workflow_instance_id")
triggering_workflow_instance_id = metadata.get(
"triggering_workflow_instance_id"
)
source = metadata.get("source")
# Set default source if not provided (for direct run() calls)
if not source:
source = "direct"
print(f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}")
print(f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}")
print(
f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
)
print(
f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}"
)
print(f"DEBUG: Current self.state at start of workflow: {self.state}")
logger.info(f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}")
logger.info(
f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
)
logger.info(f"Current self.state at start of workflow: {self.state}")
# Store the instance ID from workflow context as the source of truth
@ -324,19 +366,29 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"workflow_instance_id": workflow_instance_id,
"triggering_workflow_instance_id": triggering_workflow_instance_id,
"workflow_name": self._workflow_name,
"status": "running" # Mark as running
"status": "running", # Mark as running
}
self.state.setdefault("instances", {})[workflow_instance_id] = instance_entry
self.state.setdefault("instances", {})[
workflow_instance_id
] = instance_entry
logger.info(f"Created new instance entry: {workflow_instance_id}")
print(f"DEBUG: Created new instance entry: {workflow_instance_id}")
# Immediately save state so graceful shutdown can capture this instance
self.save_state()
logger.info(f"Saved state immediately after creating instance entry for {workflow_instance_id}")
print(f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}")
logger.info(
f"Saved state immediately after creating instance entry for {workflow_instance_id}"
)
print(
f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}"
)
else:
logger.info(f"Found existing instance entry for workflow {workflow_instance_id}")
print(f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}")
logger.info(
f"Found existing instance entry for workflow {workflow_instance_id}"
)
print(
f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}"
)
final_message: Optional[Dict[str, Any]] = None
if not ctx.is_replaying:
@ -371,7 +423,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# Step 5: Add the assistant's response message to the chat history
yield ctx.call_activity(
self.append_assistant_message,
input={"instance_id": workflow_instance_id, "message": response_message},
input={
"instance_id": workflow_instance_id,
"message": response_message,
},
)
# Step 6: Handle tool calls response
@ -391,7 +446,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
for tr in tool_results:
yield ctx.call_activity(
self.append_tool_message,
input={"instance_id": workflow_instance_id, "tool_result": tr},
input={
"instance_id": workflow_instance_id,
"tool_result": tr,
},
)
# 🔴 If this was the last turn, stop here—even though there were tool calls
if turn == self.max_iterations:
@ -560,7 +618,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"""
# Construct messages using instance-specific chat history instead of global memory
# This ensures proper message sequence for tool calls
messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(instance_id, task or {})
messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(
instance_id, task or {}
)
user_message = self.get_last_message_if_user(messages)
# Always work with a copy of the user message for safety
@ -582,7 +642,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -623,13 +683,17 @@ class DurableAgent(AgenticWorkflow, AgentBase):
error_type = type(e).__name__
error_msg = str(e)
logger.error(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}")
logger.error(
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
)
logger.error(f"Task: {task}")
logger.error(f"Messages count: {len(messages)}")
logger.error(f"Tools available: {len(self.get_llm_tools())}")
logger.error("Full error details:", exc_info=True)
raise AgentError(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}") from e
raise AgentError(
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
) from e
@task
async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
@ -733,7 +797,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -774,7 +838,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -808,7 +872,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst["output"] = final_output
@ -870,7 +934,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
except Exception as e:
logger.error(f"Error processing broadcast message: {e}", exc_info=True)
def _construct_messages_with_instance_history(self, instance_id: str, input_data: Union[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
def _construct_messages_with_instance_history(
self, instance_id: str, input_data: Union[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Construct messages using instance-specific chat history instead of global memory.
This ensures proper message sequence for tool calls and prevents OpenAI API errors
@ -884,7 +950,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
List of formatted messages with proper sequence
"""
if not self.prompt_template:
raise ValueError("Prompt template must be initialized before constructing messages.")
raise ValueError(
"Prompt template must be initialized before constructing messages."
)
# Get instance-specific chat history instead of global memory
instance_data = self.state.get("instances", {}).get(instance_id, {})
@ -897,10 +965,14 @@ class DurableAgent(AgenticWorkflow, AgentBase):
chat_history.append(msg)
else:
# Convert DurableAgentMessage to dict if needed
chat_history.append(msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg))
chat_history.append(
msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
)
if isinstance(input_data, str):
formatted_messages = self.prompt_template.format_prompt(chat_history=chat_history)
formatted_messages = self.prompt_template.format_prompt(
chat_history=chat_history
)
if isinstance(formatted_messages, list):
user_message = {"role": "user", "content": input_data}
return formatted_messages + [user_message]

View File

@ -230,4 +230,3 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
logger.error("Full error details:", exc_info=True)
raise ValueError(f"OpenAI API error ({error_type}): {error_msg}") from e

View File

@ -95,7 +95,9 @@ class WorkflowContextStorage:
logger.warning(f"⚠️ No context found for instance {instance_id}")
return context
def create_resumed_workflow_context(self, instance_id: str, agent_name: Optional[str] = None) -> Dict[str, Any]:
def create_resumed_workflow_context(
self, instance_id: str, agent_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Create a new trace context for a resumed workflow after app restart.
@ -111,7 +113,9 @@ class WorkflowContextStorage:
"""
try:
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
# Create a new trace for the resumed workflow with proper AGENT span
tracer = trace.get_tracer(__name__)
@ -122,6 +126,7 @@ class WorkflowContextStorage:
with tracer.start_as_current_span(span_name) as span:
# Set AGENT span attributes
from .constants import OPENINFERENCE_SPAN_KIND
span.set_attribute(OPENINFERENCE_SPAN_KIND, "AGENT")
span.set_attribute("workflow.instance_id", instance_id)
span.set_attribute("workflow.resumed", True)
@ -136,23 +141,27 @@ class WorkflowContextStorage:
"tracestate": carrier.get("tracestate"),
"instance_id": instance_id,
"resumed": True,
"debug_info": f"New trace created for resumed workflow {instance_id}"
"debug_info": f"New trace created for resumed workflow {instance_id}",
}
# Store the new context
self.store_context(instance_id, context_data)
logger.info(f"Created new trace context for resumed workflow {instance_id}")
logger.info(
f"Created new trace context for resumed workflow {instance_id}"
)
return context_data
except Exception as e:
logger.error(f"Failed to create resumed workflow context for {instance_id}: {e}")
logger.error(
f"Failed to create resumed workflow context for {instance_id}: {e}"
)
return {
"traceparent": None,
"tracestate": None,
"instance_id": instance_id,
"resumed": True,
"error": str(e)
"error": str(e),
}
def cleanup_context(self, instance_id: str) -> None:

View File

@ -428,15 +428,15 @@ class DaprAgentsInstrumentor(BaseInstrumentor):
# Workflow monitoring wrapper - creates AGENT spans
wrapt.wrap_function_wrapper(
WorkflowApp.__module__,
f'{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}',
WorkflowMonitorWrapper(self._tracer)
f"{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}",
WorkflowMonitorWrapper(self._tracer),
)
# Workflow run wrapper - creates workflow spans
wrapt.wrap_function_wrapper(
WorkflowApp.__module__,
f'{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}',
WorkflowRunWrapper(self._tracer)
f"{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}",
WorkflowRunWrapper(self._tracer),
)
# WorkflowTask call wrapper
@ -444,8 +444,8 @@ class DaprAgentsInstrumentor(BaseInstrumentor):
# and is necessary due to the async nature of the WorkflowTask.__call__ method.
wrapt.wrap_function_wrapper(
WorkflowTask.__module__,
f'{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}',
WorkflowTaskWrapper(self._tracer)
f"{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}",
WorkflowTaskWrapper(self._tracer),
)
except Exception as e:
logger.error(f"Error applying workflow wrappers: {e}", exc_info=True)

View File

@ -341,11 +341,17 @@ class WorkflowMonitorWrapper:
from ..context_storage import store_workflow_context
captured_context = extract_otel_context()
if captured_context.get("traceparent") and captured_context.get("trace_id") != "00000000000000000000000000000000":
if (
captured_context.get("traceparent")
and captured_context.get("trace_id")
!= "00000000000000000000000000000000"
):
logger.debug(
f"Captured traceparent: {captured_context.get('traceparent')}"
)
store_workflow_context("__global_workflow_context__", captured_context)
store_workflow_context(
"__global_workflow_context__", captured_context
)
else:
logger.debug(
f"Invalid or empty trace context captured: {captured_context}"

View File

@ -85,14 +85,14 @@ class WorkflowTaskWrapper:
# Determine task details
logger.debug(f"WorkflowTaskWrapper: instance type = {type(instance)}")
logger.debug(f"WorkflowTaskWrapper: instance attributes = {dir(instance)}")
if hasattr(instance, 'func'):
if hasattr(instance, "func"):
logger.debug(f"WorkflowTaskWrapper: instance.func = {instance.func}")
else:
logger.debug("WorkflowTaskWrapper: instance has no 'func' attribute")
task_name = (
getattr(instance.func, "__name__", "unknown_task")
if hasattr(instance, 'func') and instance.func
if hasattr(instance, "func") and instance.func
else "workflow_task"
)
span_kind = self._determine_span_kind(instance, task_name)
@ -245,7 +245,13 @@ class WorkflowTaskWrapper:
return attributes
def _handle_async_execution(
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict
self,
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
span_name: str,
attributes: dict,
) -> Any:
"""
Handle asynchronous workflow task execution with OpenTelemetry context restoration.
@ -279,29 +285,45 @@ class WorkflowTaskWrapper:
# If no context found for specific instance, try global context as fallback
if otel_context is None:
logger.debug(f"No context found for instance {instance_id}, trying global context")
logger.debug(
f"No context found for instance {instance_id}, trying global context"
)
otel_context = get_workflow_context("__global_workflow_context__")
if otel_context:
# Store the global context with the specific instance ID for future use
from ..context_storage import store_workflow_context
store_workflow_context(instance_id, otel_context)
logger.debug(f"Copied global context to instance {instance_id}")
else:
# If still no context found (e.g., after app restart), create a new one for resumed workflows
logger.debug(f"No context found for instance {instance_id} - creating new context for resumed workflow")
logger.debug(
f"No context found for instance {instance_id} - creating new context for resumed workflow"
)
# Try to get agent name from the task instance
agent_name = None
if hasattr(instance, 'agent') and instance.agent and hasattr(instance.agent, 'name'):
if (
hasattr(instance, "agent")
and instance.agent
and hasattr(instance.agent, "name")
):
agent_name = instance.agent.name
elif hasattr(instance, 'func') and instance.func and hasattr(instance.func, '__self__'):
elif (
hasattr(instance, "func")
and instance.func
and hasattr(instance.func, "__self__")
):
agent_instance = instance.func.__self__
if hasattr(agent_instance, 'name'):
if hasattr(agent_instance, "name"):
agent_name = agent_instance.name
from ..context_storage import _context_storage
otel_context = _context_storage.create_resumed_workflow_context(instance_id, agent_name)
otel_context = _context_storage.create_resumed_workflow_context(
instance_id, agent_name
)
# Create span with restored context if available
from ..context_propagation import create_child_span_with_context
@ -348,7 +370,13 @@ class WorkflowTaskWrapper:
return async_wrapper()
def _handle_sync_execution(
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict
self,
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
span_name: str,
attributes: dict,
) -> Any:
"""
Handle synchronous workflow task execution with OpenTelemetry context restoration.

View File

@ -17,7 +17,6 @@ class ElevenLabsClientConfig(BaseModel):
)
class NVIDIAClientConfig(BaseModel):
base_url: Optional[str] = Field(
"https://integrate.api.nvidia.com/v1", description="Base URL for the NVIDIA API"
@ -27,7 +26,6 @@ class NVIDIAClientConfig(BaseModel):
)
class DaprInferenceClientConfig:
pass
@ -73,7 +71,6 @@ class HFInferenceClientConfig(BaseModel):
)
class OpenAIClientConfig(BaseModel):
base_url: Optional[str] = Field(None, description="Base URL for the OpenAI API")
api_key: Optional[str] = Field(
@ -112,7 +109,6 @@ class AzureOpenAIClientConfig(BaseModel):
)
class OpenAIModelConfig(OpenAIClientConfig):
type: Literal["openai"] = Field(
"openai", description="Type of the model, must always be 'openai'"
@ -170,7 +166,6 @@ class OpenAIParamsBase(BaseModel):
stream: Optional[bool] = Field(False, description="Whether to stream responses")
class OpenAITextCompletionParams(OpenAIParamsBase):
"""
Specific configs for the text completions endpoint.
@ -186,7 +181,6 @@ class OpenAITextCompletionParams(OpenAIParamsBase):
suffix: Optional[str] = Field(None, description="Suffix to append to the prompt")
class OpenAIChatCompletionParams(OpenAIParamsBase):
"""
Specific settings for the Chat Completion endpoint.
@ -222,7 +216,6 @@ class OpenAIChatCompletionParams(OpenAIParamsBase):
)
class HFHubChatCompletionParams(BaseModel):
"""
Specific settings for Hugging Face Hub Chat Completion endpoint.
@ -285,7 +278,6 @@ class HFHubChatCompletionParams(BaseModel):
)
class NVIDIAChatCompletionParams(OpenAIParamsBase):
"""
Specific settings for the Chat Completion endpoint.
@ -311,7 +303,6 @@ class NVIDIAChatCompletionParams(OpenAIParamsBase):
)
class PromptyModelConfig(BaseModel):
api: Literal["chat", "completion"] = Field(
"chat", description="The API to use, either 'chat' or 'completion'"
@ -330,7 +321,6 @@ class PromptyModelConfig(BaseModel):
description="Determines if full response or just the first one is returned",
)
@model_validator(mode="before")
def sync_model_name(cls, values: dict):
"""
@ -405,7 +395,6 @@ class PromptyDefinition(BaseModel):
)
class AudioSpeechRequest(BaseModel):
model: Optional[Literal["tts-1", "tts-1-hd"]] = Field(
"tts-1", description="TTS model to use. Defaults to 'tts-1'."

View File

@ -0,0 +1,121 @@
"""
Reusable signal handling mixin for graceful shutdown across different service types.
"""
import asyncio
import logging
from typing import Optional
from dapr_agents.utils import add_signal_handlers_cross_platform
logger = logging.getLogger(__name__)
class SignalHandlingMixin:
"""
Mixin providing reusable signal handling for graceful shutdown.
This mixin can be used by any class that needs to handle shutdown signals
(SIGINT, SIGTERM) gracefully. It provides a consistent interface for:
- Setting up signal handlers
- Managing shutdown events
- Triggering graceful shutdown logic
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._shutdown_event: Optional[asyncio.Event] = None
self._signal_handlers_setup = False
def setup_signal_handlers(self) -> None:
"""
Set up signal handlers for graceful shutdown.
This method should be called during initialization or startup
to enable graceful shutdown handling.
"""
# Initialize the attribute if it doesn't exist
if not hasattr(self, "_signal_handlers_setup"):
self._signal_handlers_setup = False
if self._signal_handlers_setup:
logger.debug("Signal handlers already set up")
return
# Initialize shutdown event if it doesn't exist
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
self._shutdown_event = asyncio.Event()
# Set up signal handlers
loop = asyncio.get_event_loop()
add_signal_handlers_cross_platform(loop, self._handle_shutdown_signal)
self._signal_handlers_setup = True
logger.debug("Signal handlers set up for graceful shutdown")
def _handle_shutdown_signal(self, sig: int) -> None:
"""
Internal signal handler that triggers graceful shutdown.
Args:
sig: The received signal number
"""
logger.debug(f"Shutdown signal {sig} received. Triggering graceful shutdown...")
# Set the shutdown event
if self._shutdown_event:
self._shutdown_event.set()
# Call the graceful shutdown method if it exists
if hasattr(self, "graceful_shutdown"):
asyncio.create_task(self.graceful_shutdown())
elif hasattr(self, "stop"):
# Fallback to stop() method if graceful_shutdown doesn't exist
asyncio.create_task(self.stop())
else:
logger.warning(
"No graceful shutdown method found. Implement graceful_shutdown() or stop() method."
)
async def graceful_shutdown(self) -> None:
"""
Perform graceful shutdown operations.
This method should be overridden by classes that use this mixin
to implement their specific shutdown logic.
Default implementation calls stop() if it exists.
"""
if hasattr(self, "stop"):
await self.stop()
else:
logger.warning(
"No stop() method found. Override graceful_shutdown() to implement shutdown logic."
)
def is_shutdown_requested(self) -> bool:
"""
Check if a shutdown has been requested.
Returns:
bool: True if shutdown has been requested, False otherwise
"""
return (
hasattr(self, "_shutdown_event")
and self._shutdown_event is not None
and self._shutdown_event.is_set()
)
async def wait_for_shutdown(self, check_interval: float = 1.0) -> None:
"""
Wait for a shutdown signal to be received.
Args:
check_interval: How often to check for shutdown (in seconds)
"""
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
raise RuntimeError(
"Signal handlers not set up. Call setup_signal_handlers() first."
)
while not self._shutdown_event.is_set():
await asyncio.sleep(check_interval)

View File

@ -556,7 +556,10 @@ class WorkflowApp(BaseModel):
f"Output: {json.dumps(state.serialized_output, indent=2)}"
)
elif workflow_status.upper() in (DaprWorkflowStatus.FAILED.value.upper(), "ABORTED"):
elif workflow_status.upper() in (
DaprWorkflowStatus.FAILED.value.upper(),
"ABORTED",
):
# Ensure `failure_details` exists before accessing attributes
error_type = getattr(failure_details, "error_type", "Unknown")
message = getattr(failure_details, "message", "No message provided")
@ -611,7 +614,9 @@ class WorkflowApp(BaseModel):
# Off-load the potentially blocking run_workflow call to a thread.
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
logger.debug(f"Workflow '{workflow}' started with instance ID: {instance_id}")
logger.debug(
f"Workflow '{workflow}' started with instance ID: {instance_id}"
)
# Await the asynchronous monitoring of the workflow state.
state = await self.monitor_workflow_state(instance_id)

View File

@ -138,29 +138,43 @@ class ServiceMixin(SignalHandlingMixin):
# Save state before shutting down to ensure persistence and agent durability to properly rerun after being stoped
try:
if hasattr(self, 'save_state') and hasattr(self, 'state'):
if hasattr(self, "save_state") and hasattr(self, "state"):
# Graceful shutdown compensation: Save incomplete instance if it exists
if hasattr(self, 'workflow_instance_id') and self.workflow_instance_id:
if hasattr(self, "workflow_instance_id") and self.workflow_instance_id:
if self.workflow_instance_id not in self.state.get("instances", {}):
# This instance was never saved, add it as incomplete
from datetime import datetime, timezone
incomplete_entry = {
"messages": [],
"start_time": datetime.now(timezone.utc).isoformat(),
"source": "graceful_shutdown",
"source_workflow_instance_id": None,
"workflow_name": getattr(self, '_workflow_name', 'Unknown'),
"workflow_name": getattr(self, "_workflow_name", "Unknown"),
"dapr_status": DaprWorkflowStatus.PENDING,
"suspended_reason": "app_terminated"
"suspended_reason": "app_terminated",
}
self.state.setdefault("instances", {})[self.workflow_instance_id] = incomplete_entry
logger.info(f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown")
self.state.setdefault("instances", {})[
self.workflow_instance_id
] = incomplete_entry
logger.info(
f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown"
)
else:
# Mark existing instance as suspended due to app termination
if "instances" in self.state and self.workflow_instance_id in self.state["instances"]:
self.state["instances"][self.workflow_instance_id]["dapr_status"] = DaprWorkflowStatus.SUSPENDED
self.state["instances"][self.workflow_instance_id]["suspended_reason"] = "app_terminated"
logger.info(f"Marked instance {self.workflow_instance_id} as suspended due to app termination")
if (
"instances" in self.state
and self.workflow_instance_id in self.state["instances"]
):
self.state["instances"][self.workflow_instance_id][
"dapr_status"
] = DaprWorkflowStatus.SUSPENDED
self.state["instances"][self.workflow_instance_id][
"suspended_reason"
] = "app_terminated"
logger.info(
f"Marked instance {self.workflow_instance_id} as suspended due to app termination"
)
self.save_state()
logger.debug("Workflow state saved successfully.")