style: appease linter

Signed-off-by: Samantha Coyle <sam@diagrid.io>
This commit is contained in:
Samantha Coyle 2025-09-08 13:58:29 -05:00
parent 728be553ee
commit 8601054a0b
No known key found for this signature in database
12 changed files with 382 additions and 139 deletions

View File

@ -55,7 +55,6 @@ class DurableAgent(AgenticWorkflow, AgentBase):
description="Metadata about the agent, including name, role, goal, instructions, and topic name.",
)
@model_validator(mode="before")
def set_agent_and_topic_name(cls, values: dict):
# Set name to role if name is not provided
@ -85,17 +84,21 @@ class DurableAgent(AgenticWorkflow, AgentBase):
self.state["instances"] = {}
if "chat_history" not in self.state:
self.state["chat_history"] = []
# Load the current workflow instance ID from state if it exists
logger.info(f"State after loading: {self.state}")
if self.state and "instances" in self.state:
logger.info(f"Found {len(self.state['instances'])} instances in state")
for instance_id, instance_data in self.state["instances"].items():
stored_workflow_name = instance_data.get("workflow_name")
logger.info(f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}")
logger.info(
f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}"
)
if stored_workflow_name == self._workflow_name:
self.workflow_instance_id = instance_id
logger.info(f"Loaded current workflow instance ID from state: {instance_id}")
logger.info(
f"Loaded current workflow instance ID from state: {instance_id}"
)
break
else:
logger.info("No instances found in state or state is empty")
@ -112,7 +115,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
}
self.register_agentic_system()
# Start the runtime if it's not already running
if not self.wf_runtime_is_running:
logger.info("Starting workflow runtime...")
@ -137,20 +140,28 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# Check for existing incomplete workflows before starting a new one
existing_instance_id = await self._find_incomplete_workflow(input_data)
if existing_instance_id:
logger.info(f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it.")
logger.info(
f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it."
)
logger.info("Monitoring the resumed workflow until completion...")
# Monitor the resumed workflow until completion
try:
result = await self.monitor_workflow_state(existing_instance_id)
if result and result.serialized_output:
logger.info(f"Resumed workflow {existing_instance_id} completed successfully!")
logger.info(
f"Resumed workflow {existing_instance_id} completed successfully!"
)
return result.serialized_output
else:
logger.warning(f"Resumed workflow {existing_instance_id} completed but had no output")
logger.warning(
f"Resumed workflow {existing_instance_id} completed but had no output"
)
return None
except Exception as e:
logger.error(f"Error monitoring resumed workflow {existing_instance_id}: {e}")
logger.error(
f"Error monitoring resumed workflow {existing_instance_id}: {e}"
)
# Fall through to start a new workflow if monitoring fails
# Prepare input payload for workflow
@ -173,15 +184,17 @@ class DurableAgent(AgenticWorkflow, AgentBase):
if self.wf_runtime_is_running:
self.stop_runtime()
async def _find_incomplete_workflow(self, input_data: Union[str, Dict[str, Any]]) -> Optional[str]:
async def _find_incomplete_workflow(
self, input_data: Union[str, Dict[str, Any]]
) -> Optional[str]:
"""
Find an existing incomplete workflow instance that should be resumed.
Uses Dapr WorkflowState to determine if workflow is actually complete.
Only resumes workflows that match the current input.
Args:
input_data: The input for the new workflow request
Returns:
Optional[str]: The instance ID of an incomplete workflow with matching input, or None if none found.
"""
@ -191,64 +204,81 @@ class DurableAgent(AgenticWorkflow, AgentBase):
if not instances:
logger.debug("No instances found in state")
return None
# Normalize input for comparison
if isinstance(input_data, dict):
current_input = input_data.get("task", str(input_data))
else:
current_input = str(input_data)
for instance_id, instance_data in instances.items():
workflow_name = instance_data.get("workflow_name")
end_time = instance_data.get("end_time")
stored_input = instance_data.get("input", "")
# Only consider workflows that match our current workflow name
if workflow_name != self._workflow_name:
continue
if end_time is None:
# Only consider workflows that match the current input
if str(stored_input) != str(current_input):
continue
# Verify the workflow still exists in Dapr and check its actual status
# as dapr is our source of truth.
try:
workflow_state = self.get_workflow_state(instance_id)
if workflow_state:
runtime_status = workflow_state.runtime_status.name
logger.debug(f"Workflow {instance_id} Dapr status: {runtime_status}")
logger.debug(
f"Workflow {instance_id} Dapr status: {runtime_status}"
)
# If workflow is still running, return it for resumption
# Handle case mismatch: Dapr returns 'RUNNING' but enum is 'running'
if runtime_status.upper() in [DaprWorkflowStatus.RUNNING.value.upper(), DaprWorkflowStatus.PENDING.value.upper()]:
if runtime_status.upper() in [
DaprWorkflowStatus.RUNNING.value.upper(),
DaprWorkflowStatus.PENDING.value.upper(),
]:
return instance_id
elif runtime_status.upper() in [DaprWorkflowStatus.UNKNOWN.value.upper()]:
logger.debug(f"Workflow {instance_id} Dapr status is unknown, skipping")
elif runtime_status.upper() in [
DaprWorkflowStatus.UNKNOWN.value.upper()
]:
logger.debug(
f"Workflow {instance_id} Dapr status is unknown, skipping"
)
continue
else:
# This is for COMPLETED, FAILED, TERMINATED
self._mark_workflow_completed(instance_id, runtime_status)
self._mark_workflow_completed(
instance_id, runtime_status
)
else:
logger.debug(f"Workflow {instance_id} no longer exists in Dapr")
logger.debug(
f"Workflow {instance_id} no longer exists in Dapr"
)
# Mark as completed in our state since it's no longer in Dapr
self._mark_workflow_completed(instance_id)
except Exception as e:
logger.warning(f"Could not verify workflow {instance_id} in Dapr: {e}")
logger.warning(
f"Could not verify workflow {instance_id} in Dapr: {e}"
)
return instance_id
logger.debug("No incomplete workflows found")
return None
except Exception as e:
logger.error(f"Error finding incomplete workflows: {e}")
return None
def _mark_workflow_completed(self, instance_id: str, status: str = "completed") -> None:
def _mark_workflow_completed(
self, instance_id: str, status: str = "completed"
) -> None:
"""
Mark a workflow as completed in our state.
Args:
instance_id: The workflow instance ID to mark as completed
status: The completion status (default: "completed")
@ -284,33 +314,45 @@ class DurableAgent(AgenticWorkflow, AgentBase):
metadata = message.get("_message_metadata", {}) or {}
# Extract workflow_instance_id from TriggerAction if present from orchestrator
if "workflow_instance_id" in message:
metadata["triggering_workflow_instance_id"] = message["workflow_instance_id"]
metadata["triggering_workflow_instance_id"] = message[
"workflow_instance_id"
]
else:
task = getattr(message, "task", None)
metadata = getattr(message, "_message_metadata", {}) or {}
# Extract workflow_instance_id from TriggerAction if present from orchestrator
if hasattr(message, "workflow_instance_id"):
metadata["triggering_workflow_instance_id"] = getattr(message, "workflow_instance_id")
metadata["triggering_workflow_instance_id"] = getattr(
message, "workflow_instance_id"
)
workflow_instance_id = ctx.instance_id
triggering_workflow_instance_id = metadata.get("triggering_workflow_instance_id")
triggering_workflow_instance_id = metadata.get(
"triggering_workflow_instance_id"
)
source = metadata.get("source")
# Set default source if not provided (for direct run() calls)
if not source:
source = "direct"
print(f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}")
print(f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}")
print(
f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
)
print(
f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}"
)
print(f"DEBUG: Current self.state at start of workflow: {self.state}")
logger.info(f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}")
logger.info(
f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
)
logger.info(f"Current self.state at start of workflow: {self.state}")
# Store the instance ID from workflow context as the source of truth
# This will be used by graceful shutdown logic to save incomplete instances
logger.info(f"Workflow context provided instance ID: {workflow_instance_id}")
print(f"DEBUG: Workflow context provided instance ID: {workflow_instance_id}")
# Check if this instance already exists in state (from previous runs)
if workflow_instance_id not in self.state.get("instances", {}):
# This is a new instance, create a minimal entry
@ -324,19 +366,29 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"workflow_instance_id": workflow_instance_id,
"triggering_workflow_instance_id": triggering_workflow_instance_id,
"workflow_name": self._workflow_name,
"status": "running" # Mark as running
"status": "running", # Mark as running
}
self.state.setdefault("instances", {})[workflow_instance_id] = instance_entry
self.state.setdefault("instances", {})[
workflow_instance_id
] = instance_entry
logger.info(f"Created new instance entry: {workflow_instance_id}")
print(f"DEBUG: Created new instance entry: {workflow_instance_id}")
# Immediately save state so graceful shutdown can capture this instance
self.save_state()
logger.info(f"Saved state immediately after creating instance entry for {workflow_instance_id}")
print(f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}")
logger.info(
f"Saved state immediately after creating instance entry for {workflow_instance_id}"
)
print(
f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}"
)
else:
logger.info(f"Found existing instance entry for workflow {workflow_instance_id}")
print(f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}")
logger.info(
f"Found existing instance entry for workflow {workflow_instance_id}"
)
print(
f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}"
)
final_message: Optional[Dict[str, Any]] = None
if not ctx.is_replaying:
@ -371,7 +423,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# Step 5: Add the assistant's response message to the chat history
yield ctx.call_activity(
self.append_assistant_message,
input={"instance_id": workflow_instance_id, "message": response_message},
input={
"instance_id": workflow_instance_id,
"message": response_message,
},
)
# Step 6: Handle tool calls response
@ -391,7 +446,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
for tr in tool_results:
yield ctx.call_activity(
self.append_tool_message,
input={"instance_id": workflow_instance_id, "tool_result": tr},
input={
"instance_id": workflow_instance_id,
"tool_result": tr,
},
)
# 🔴 If this was the last turn, stop here—even though there were tool calls
if turn == self.max_iterations:
@ -512,7 +570,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"""
# Load state to ensure we have the latest data
self.load_state()
workflow_entry = self.state.get("instances", {}).get(instance_id)
if workflow_entry is not None:
return {
@ -536,7 +594,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# }
# self.state.setdefault("instances", {})[instance_id] = minimal_entry
# self.save_state()
# return {
# "source": minimal_entry.get("source"),
# "source_workflow_instance_id": minimal_entry.get("source_workflow_instance_id"),
@ -560,7 +618,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"""
# Construct messages using instance-specific chat history instead of global memory
# This ensures proper message sequence for tool calls
messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(instance_id, task or {})
messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(
instance_id, task or {}
)
user_message = self.get_last_message_if_user(messages)
# Always work with a copy of the user message for safety
@ -582,7 +642,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -622,15 +682,19 @@ class DurableAgent(AgenticWorkflow, AgentBase):
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logger.error(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}")
logger.error(
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
)
logger.error(f"Task: {task}")
logger.error(f"Messages count: {len(messages)}")
logger.error(f"Tools available: {len(self.get_llm_tools())}")
logger.error("Full error details:", exc_info=True)
raise AgentError(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}") from e
raise AgentError(
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
) from e
@task
async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -733,7 +797,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -774,7 +838,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -808,7 +872,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input",
"workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name
"workflow_name": self._workflow_name,
}
inst: dict = self.state["instances"][instance_id]
inst["output"] = final_output
@ -870,26 +934,30 @@ class DurableAgent(AgenticWorkflow, AgentBase):
except Exception as e:
logger.error(f"Error processing broadcast message: {e}", exc_info=True)
def _construct_messages_with_instance_history(self, instance_id: str, input_data: Union[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
def _construct_messages_with_instance_history(
self, instance_id: str, input_data: Union[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Construct messages using instance-specific chat history instead of global memory.
This ensures proper message sequence for tool calls and prevents OpenAI API errors
in the event an app gets terminated or restarts while the workflow is running.
Args:
instance_id: The workflow instance ID
input_data: User input, either as a string or dictionary
Returns:
List of formatted messages with proper sequence
"""
if not self.prompt_template:
raise ValueError("Prompt template must be initialized before constructing messages.")
raise ValueError(
"Prompt template must be initialized before constructing messages."
)
# Get instance-specific chat history instead of global memory
instance_data = self.state.get("instances", {}).get(instance_id, {})
instance_messages = instance_data.get("messages", [])
# Convert instance messages to the format expected by prompt template
chat_history = []
for msg in instance_messages:
@ -897,10 +965,14 @@ class DurableAgent(AgenticWorkflow, AgentBase):
chat_history.append(msg)
else:
# Convert DurableAgentMessage to dict if needed
chat_history.append(msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg))
chat_history.append(
msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
)
if isinstance(input_data, str):
formatted_messages = self.prompt_template.format_prompt(chat_history=chat_history)
formatted_messages = self.prompt_template.format_prompt(
chat_history=chat_history
)
if isinstance(formatted_messages, list):
user_message = {"role": "user", "content": input_data}
return formatted_messages + [user_message]

View File

@ -225,9 +225,8 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logger.error(f"OpenAI ChatCompletion API error: {error_type} - {error_msg}")
logger.error("Full error details:", exc_info=True)
raise ValueError(f"OpenAI API error ({error_type}): {error_msg}") from e

View File

@ -95,33 +95,38 @@ class WorkflowContextStorage:
logger.warning(f"⚠️ No context found for instance {instance_id}")
return context
def create_resumed_workflow_context(self, instance_id: str, agent_name: Optional[str] = None) -> Dict[str, Any]:
def create_resumed_workflow_context(
self, instance_id: str, agent_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Create a new trace context for a resumed workflow after app restart.
When an app restarts, the in-memory context storage is lost. This method
creates a new trace context for resumed workflows so they can still be
traced, even though they won't be connected to the original trace.
Args:
instance_id (str): Unique workflow instance ID
Returns:
Dict[str, Any]: New W3C context data for the resumed workflow
"""
try:
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
# Create a new trace for the resumed workflow with proper AGENT span
tracer = trace.get_tracer(__name__)
# Create AGENT span with proper agent name for resumed workflow
agent_display_name = agent_name or "DurableAgent"
span_name = f"{agent_display_name}.ToolCallingWorkflow"
with tracer.start_as_current_span(span_name) as span:
# Set AGENT span attributes
from .constants import OPENINFERENCE_SPAN_KIND
span.set_attribute(OPENINFERENCE_SPAN_KIND, "AGENT")
span.set_attribute("workflow.instance_id", instance_id)
span.set_attribute("workflow.resumed", True)
@ -130,29 +135,33 @@ class WorkflowContextStorage:
propagator = TraceContextTextMapPropagator()
carrier = {}
propagator.inject(carrier)
context_data = {
"traceparent": carrier.get("traceparent"),
"tracestate": carrier.get("tracestate"),
"instance_id": instance_id,
"resumed": True,
"debug_info": f"New trace created for resumed workflow {instance_id}"
"debug_info": f"New trace created for resumed workflow {instance_id}",
}
# Store the new context
self.store_context(instance_id, context_data)
logger.info(f"Created new trace context for resumed workflow {instance_id}")
logger.info(
f"Created new trace context for resumed workflow {instance_id}"
)
return context_data
except Exception as e:
logger.error(f"Failed to create resumed workflow context for {instance_id}: {e}")
logger.error(
f"Failed to create resumed workflow context for {instance_id}: {e}"
)
return {
"traceparent": None,
"tracestate": None,
"instance_id": instance_id,
"resumed": True,
"error": str(e)
"error": str(e),
}
def cleanup_context(self, instance_id: str) -> None:

View File

@ -424,28 +424,28 @@ class DaprAgentsInstrumentor(BaseInstrumentor):
from dapr_agents.workflow.base import WorkflowApp
from dapr_agents.workflow.task import WorkflowTask
import wrapt
# Workflow monitoring wrapper - creates AGENT spans
wrapt.wrap_function_wrapper(
WorkflowApp.__module__,
f'{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}',
WorkflowMonitorWrapper(self._tracer)
f"{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}",
WorkflowMonitorWrapper(self._tracer),
)
# Workflow run wrapper - creates workflow spans
wrapt.wrap_function_wrapper(
WorkflowApp.__module__,
f'{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}',
WorkflowRunWrapper(self._tracer)
f"{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}",
WorkflowRunWrapper(self._tracer),
)
# WorkflowTask call wrapper
# This ensures child spans (LLM/TOOL) are properly linked to parent AGENT spans,
# and is necessary due to the async nature of the WorkflowTask.__call__ method.
wrapt.wrap_function_wrapper(
WorkflowTask.__module__,
f'{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}',
WorkflowTaskWrapper(self._tracer)
f"{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}",
WorkflowTaskWrapper(self._tracer),
)
except Exception as e:
logger.error(f"Error applying workflow wrappers: {e}", exc_info=True)

View File

@ -341,11 +341,17 @@ class WorkflowMonitorWrapper:
from ..context_storage import store_workflow_context
captured_context = extract_otel_context()
if captured_context.get("traceparent") and captured_context.get("trace_id") != "00000000000000000000000000000000":
if (
captured_context.get("traceparent")
and captured_context.get("trace_id")
!= "00000000000000000000000000000000"
):
logger.debug(
f"Captured traceparent: {captured_context.get('traceparent')}"
)
store_workflow_context("__global_workflow_context__", captured_context)
store_workflow_context(
"__global_workflow_context__", captured_context
)
else:
logger.debug(
f"Invalid or empty trace context captured: {captured_context}"

View File

@ -85,14 +85,14 @@ class WorkflowTaskWrapper:
# Determine task details
logger.debug(f"WorkflowTaskWrapper: instance type = {type(instance)}")
logger.debug(f"WorkflowTaskWrapper: instance attributes = {dir(instance)}")
if hasattr(instance, 'func'):
if hasattr(instance, "func"):
logger.debug(f"WorkflowTaskWrapper: instance.func = {instance.func}")
else:
logger.debug("WorkflowTaskWrapper: instance has no 'func' attribute")
task_name = (
getattr(instance.func, "__name__", "unknown_task")
if hasattr(instance, 'func') and instance.func
if hasattr(instance, "func") and instance.func
else "workflow_task"
)
span_kind = self._determine_span_kind(instance, task_name)
@ -245,7 +245,13 @@ class WorkflowTaskWrapper:
return attributes
def _handle_async_execution(
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict
self,
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
span_name: str,
attributes: dict,
) -> Any:
"""
Handle asynchronous workflow task execution with OpenTelemetry context restoration.
@ -276,32 +282,48 @@ class WorkflowTaskWrapper:
from ..context_storage import get_workflow_context
otel_context = get_workflow_context(instance_id)
# If no context found for specific instance, try global context as fallback
if otel_context is None:
logger.debug(f"No context found for instance {instance_id}, trying global context")
logger.debug(
f"No context found for instance {instance_id}, trying global context"
)
otel_context = get_workflow_context("__global_workflow_context__")
if otel_context:
# Store the global context with the specific instance ID for future use
from ..context_storage import store_workflow_context
store_workflow_context(instance_id, otel_context)
logger.debug(f"Copied global context to instance {instance_id}")
else:
# If still no context found (e.g., after app restart), create a new one for resumed workflows
logger.debug(f"No context found for instance {instance_id} - creating new context for resumed workflow")
logger.debug(
f"No context found for instance {instance_id} - creating new context for resumed workflow"
)
# Try to get agent name from the task instance
agent_name = None
if hasattr(instance, 'agent') and instance.agent and hasattr(instance.agent, 'name'):
if (
hasattr(instance, "agent")
and instance.agent
and hasattr(instance.agent, "name")
):
agent_name = instance.agent.name
elif hasattr(instance, 'func') and instance.func and hasattr(instance.func, '__self__'):
elif (
hasattr(instance, "func")
and instance.func
and hasattr(instance.func, "__self__")
):
agent_instance = instance.func.__self__
if hasattr(agent_instance, 'name'):
if hasattr(agent_instance, "name"):
agent_name = agent_instance.name
from ..context_storage import _context_storage
otel_context = _context_storage.create_resumed_workflow_context(instance_id, agent_name)
otel_context = _context_storage.create_resumed_workflow_context(
instance_id, agent_name
)
# Create span with restored context if available
from ..context_propagation import create_child_span_with_context
@ -348,7 +370,13 @@ class WorkflowTaskWrapper:
return async_wrapper()
def _handle_sync_execution(
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict
self,
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
span_name: str,
attributes: dict,
) -> Any:
"""
Handle synchronous workflow task execution with OpenTelemetry context restoration.

View File

@ -128,7 +128,7 @@ class FastAPIServerBase(APIServerBase, SignalHandlingMixin):
Perform graceful shutdown operations for the FastAPI server.
"""
await self.stop()
async def stop(self):
"""
Stop the FastAPI server gracefully.

View File

@ -17,7 +17,6 @@ class ElevenLabsClientConfig(BaseModel):
)
class NVIDIAClientConfig(BaseModel):
base_url: Optional[str] = Field(
"https://integrate.api.nvidia.com/v1", description="Base URL for the NVIDIA API"
@ -27,7 +26,6 @@ class NVIDIAClientConfig(BaseModel):
)
class DaprInferenceClientConfig:
pass
@ -73,7 +71,6 @@ class HFInferenceClientConfig(BaseModel):
)
class OpenAIClientConfig(BaseModel):
base_url: Optional[str] = Field(None, description="Base URL for the OpenAI API")
api_key: Optional[str] = Field(
@ -112,7 +109,6 @@ class AzureOpenAIClientConfig(BaseModel):
)
class OpenAIModelConfig(OpenAIClientConfig):
type: Literal["openai"] = Field(
"openai", description="Type of the model, must always be 'openai'"
@ -170,7 +166,6 @@ class OpenAIParamsBase(BaseModel):
stream: Optional[bool] = Field(False, description="Whether to stream responses")
class OpenAITextCompletionParams(OpenAIParamsBase):
"""
Specific configs for the text completions endpoint.
@ -186,7 +181,6 @@ class OpenAITextCompletionParams(OpenAIParamsBase):
suffix: Optional[str] = Field(None, description="Suffix to append to the prompt")
class OpenAIChatCompletionParams(OpenAIParamsBase):
"""
Specific settings for the Chat Completion endpoint.
@ -222,7 +216,6 @@ class OpenAIChatCompletionParams(OpenAIParamsBase):
)
class HFHubChatCompletionParams(BaseModel):
"""
Specific settings for Hugging Face Hub Chat Completion endpoint.
@ -285,7 +278,6 @@ class HFHubChatCompletionParams(BaseModel):
)
class NVIDIAChatCompletionParams(OpenAIParamsBase):
"""
Specific settings for the Chat Completion endpoint.
@ -311,7 +303,6 @@ class NVIDIAChatCompletionParams(OpenAIParamsBase):
)
class PromptyModelConfig(BaseModel):
api: Literal["chat", "completion"] = Field(
"chat", description="The API to use, either 'chat' or 'completion'"
@ -330,7 +321,6 @@ class PromptyModelConfig(BaseModel):
description="Determines if full response or just the first one is returned",
)
@model_validator(mode="before")
def sync_model_name(cls, values: dict):
"""
@ -405,7 +395,6 @@ class PromptyDefinition(BaseModel):
)
class AudioSpeechRequest(BaseModel):
model: Optional[Literal["tts-1", "tts-1-hd"]] = Field(
"tts-1", description="TTS model to use. Defaults to 'tts-1'."

View File

@ -0,0 +1,121 @@
"""
Reusable signal handling mixin for graceful shutdown across different service types.
"""
import asyncio
import logging
from typing import Optional
from dapr_agents.utils import add_signal_handlers_cross_platform
logger = logging.getLogger(__name__)
class SignalHandlingMixin:
"""
Mixin providing reusable signal handling for graceful shutdown.
This mixin can be used by any class that needs to handle shutdown signals
(SIGINT, SIGTERM) gracefully. It provides a consistent interface for:
- Setting up signal handlers
- Managing shutdown events
- Triggering graceful shutdown logic
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._shutdown_event: Optional[asyncio.Event] = None
self._signal_handlers_setup = False
def setup_signal_handlers(self) -> None:
"""
Set up signal handlers for graceful shutdown.
This method should be called during initialization or startup
to enable graceful shutdown handling.
"""
# Initialize the attribute if it doesn't exist
if not hasattr(self, "_signal_handlers_setup"):
self._signal_handlers_setup = False
if self._signal_handlers_setup:
logger.debug("Signal handlers already set up")
return
# Initialize shutdown event if it doesn't exist
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
self._shutdown_event = asyncio.Event()
# Set up signal handlers
loop = asyncio.get_event_loop()
add_signal_handlers_cross_platform(loop, self._handle_shutdown_signal)
self._signal_handlers_setup = True
logger.debug("Signal handlers set up for graceful shutdown")
def _handle_shutdown_signal(self, sig: int) -> None:
"""
Internal signal handler that triggers graceful shutdown.
Args:
sig: The received signal number
"""
logger.debug(f"Shutdown signal {sig} received. Triggering graceful shutdown...")
# Set the shutdown event
if self._shutdown_event:
self._shutdown_event.set()
# Call the graceful shutdown method if it exists
if hasattr(self, "graceful_shutdown"):
asyncio.create_task(self.graceful_shutdown())
elif hasattr(self, "stop"):
# Fallback to stop() method if graceful_shutdown doesn't exist
asyncio.create_task(self.stop())
else:
logger.warning(
"No graceful shutdown method found. Implement graceful_shutdown() or stop() method."
)
async def graceful_shutdown(self) -> None:
"""
Perform graceful shutdown operations.
This method should be overridden by classes that use this mixin
to implement their specific shutdown logic.
Default implementation calls stop() if it exists.
"""
if hasattr(self, "stop"):
await self.stop()
else:
logger.warning(
"No stop() method found. Override graceful_shutdown() to implement shutdown logic."
)
def is_shutdown_requested(self) -> bool:
"""
Check if a shutdown has been requested.
Returns:
bool: True if shutdown has been requested, False otherwise
"""
return (
hasattr(self, "_shutdown_event")
and self._shutdown_event is not None
and self._shutdown_event.is_set()
)
async def wait_for_shutdown(self, check_interval: float = 1.0) -> None:
"""
Wait for a shutdown signal to be received.
Args:
check_interval: How often to check for shutdown (in seconds)
"""
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
raise RuntimeError(
"Signal handlers not set up. Call setup_signal_handlers() first."
)
while not self._shutdown_event.is_set():
await asyncio.sleep(check_interval)

View File

@ -556,7 +556,10 @@ class WorkflowApp(BaseModel):
f"Output: {json.dumps(state.serialized_output, indent=2)}"
)
elif workflow_status.upper() in (DaprWorkflowStatus.FAILED.value.upper(), "ABORTED"):
elif workflow_status.upper() in (
DaprWorkflowStatus.FAILED.value.upper(),
"ABORTED",
):
# Ensure `failure_details` exists before accessing attributes
error_type = getattr(failure_details, "error_type", "Unknown")
message = getattr(failure_details, "message", "No message provided")
@ -611,7 +614,9 @@ class WorkflowApp(BaseModel):
# Off-load the potentially blocking run_workflow call to a thread.
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
logger.debug(f"Workflow '{workflow}' started with instance ID: {instance_id}")
logger.debug(
f"Workflow '{workflow}' started with instance ID: {instance_id}"
)
# Await the asynchronous monitoring of the workflow state.
state = await self.monitor_workflow_state(instance_id)

View File

@ -138,30 +138,44 @@ class ServiceMixin(SignalHandlingMixin):
# Save state before shutting down to ensure persistence and agent durability to properly rerun after being stoped
try:
if hasattr(self, 'save_state') and hasattr(self, 'state'):
if hasattr(self, "save_state") and hasattr(self, "state"):
# Graceful shutdown compensation: Save incomplete instance if it exists
if hasattr(self, 'workflow_instance_id') and self.workflow_instance_id:
if hasattr(self, "workflow_instance_id") and self.workflow_instance_id:
if self.workflow_instance_id not in self.state.get("instances", {}):
# This instance was never saved, add it as incomplete
from datetime import datetime, timezone
incomplete_entry = {
"messages": [],
"start_time": datetime.now(timezone.utc).isoformat(),
"source": "graceful_shutdown",
"source_workflow_instance_id": None,
"workflow_name": getattr(self, '_workflow_name', 'Unknown'),
"workflow_name": getattr(self, "_workflow_name", "Unknown"),
"dapr_status": DaprWorkflowStatus.PENDING,
"suspended_reason": "app_terminated"
"suspended_reason": "app_terminated",
}
self.state.setdefault("instances", {})[self.workflow_instance_id] = incomplete_entry
logger.info(f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown")
self.state.setdefault("instances", {})[
self.workflow_instance_id
] = incomplete_entry
logger.info(
f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown"
)
else:
# Mark existing instance as suspended due to app termination
if "instances" in self.state and self.workflow_instance_id in self.state["instances"]:
self.state["instances"][self.workflow_instance_id]["dapr_status"] = DaprWorkflowStatus.SUSPENDED
self.state["instances"][self.workflow_instance_id]["suspended_reason"] = "app_terminated"
logger.info(f"Marked instance {self.workflow_instance_id} as suspended due to app termination")
if (
"instances" in self.state
and self.workflow_instance_id in self.state["instances"]
):
self.state["instances"][self.workflow_instance_id][
"dapr_status"
] = DaprWorkflowStatus.SUSPENDED
self.state["instances"][self.workflow_instance_id][
"suspended_reason"
] = "app_terminated"
logger.info(
f"Marked instance {self.workflow_instance_id} as suspended due to app termination"
)
self.save_state()
logger.debug("Workflow state saved successfully.")
except Exception as e:

View File

@ -106,16 +106,16 @@ class StateManagementMixin:
raise TypeError(
f"Invalid state type retrieved: {type(state_data)}. Expected dict."
)
# Set self.state to the loaded data
if self.state_format:
loaded_state = self.validate_state(state_data)
else:
loaded_state = state_data
self.state = loaded_state
logger.debug(f"Set self.state to loaded data: {self.state}")
return loaded_state
logger.info(