style: appease linter

Signed-off-by: Samantha Coyle <sam@diagrid.io>
This commit is contained in:
Samantha Coyle 2025-09-08 13:58:29 -05:00
parent 728be553ee
commit 8601054a0b
No known key found for this signature in database
12 changed files with 382 additions and 139 deletions

View File

@ -55,7 +55,6 @@ class DurableAgent(AgenticWorkflow, AgentBase):
description="Metadata about the agent, including name, role, goal, instructions, and topic name.", description="Metadata about the agent, including name, role, goal, instructions, and topic name.",
) )
@model_validator(mode="before") @model_validator(mode="before")
def set_agent_and_topic_name(cls, values: dict): def set_agent_and_topic_name(cls, values: dict):
# Set name to role if name is not provided # Set name to role if name is not provided
@ -92,10 +91,14 @@ class DurableAgent(AgenticWorkflow, AgentBase):
logger.info(f"Found {len(self.state['instances'])} instances in state") logger.info(f"Found {len(self.state['instances'])} instances in state")
for instance_id, instance_data in self.state["instances"].items(): for instance_id, instance_data in self.state["instances"].items():
stored_workflow_name = instance_data.get("workflow_name") stored_workflow_name = instance_data.get("workflow_name")
logger.info(f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}") logger.info(
f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}"
)
if stored_workflow_name == self._workflow_name: if stored_workflow_name == self._workflow_name:
self.workflow_instance_id = instance_id self.workflow_instance_id = instance_id
logger.info(f"Loaded current workflow instance ID from state: {instance_id}") logger.info(
f"Loaded current workflow instance ID from state: {instance_id}"
)
break break
else: else:
logger.info("No instances found in state or state is empty") logger.info("No instances found in state or state is empty")
@ -137,20 +140,28 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# Check for existing incomplete workflows before starting a new one # Check for existing incomplete workflows before starting a new one
existing_instance_id = await self._find_incomplete_workflow(input_data) existing_instance_id = await self._find_incomplete_workflow(input_data)
if existing_instance_id: if existing_instance_id:
logger.info(f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it.") logger.info(
f"Found existing incomplete workflow: {existing_instance_id}. The workflow runtime will automatically resume it."
)
logger.info("Monitoring the resumed workflow until completion...") logger.info("Monitoring the resumed workflow until completion...")
# Monitor the resumed workflow until completion # Monitor the resumed workflow until completion
try: try:
result = await self.monitor_workflow_state(existing_instance_id) result = await self.monitor_workflow_state(existing_instance_id)
if result and result.serialized_output: if result and result.serialized_output:
logger.info(f"Resumed workflow {existing_instance_id} completed successfully!") logger.info(
f"Resumed workflow {existing_instance_id} completed successfully!"
)
return result.serialized_output return result.serialized_output
else: else:
logger.warning(f"Resumed workflow {existing_instance_id} completed but had no output") logger.warning(
f"Resumed workflow {existing_instance_id} completed but had no output"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error monitoring resumed workflow {existing_instance_id}: {e}") logger.error(
f"Error monitoring resumed workflow {existing_instance_id}: {e}"
)
# Fall through to start a new workflow if monitoring fails # Fall through to start a new workflow if monitoring fails
# Prepare input payload for workflow # Prepare input payload for workflow
@ -173,7 +184,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
if self.wf_runtime_is_running: if self.wf_runtime_is_running:
self.stop_runtime() self.stop_runtime()
async def _find_incomplete_workflow(self, input_data: Union[str, Dict[str, Any]]) -> Optional[str]: async def _find_incomplete_workflow(
self, input_data: Union[str, Dict[str, Any]]
) -> Optional[str]:
""" """
Find an existing incomplete workflow instance that should be resumed. Find an existing incomplete workflow instance that should be resumed.
Uses Dapr WorkflowState to determine if workflow is actually complete. Uses Dapr WorkflowState to determine if workflow is actually complete.
@ -218,24 +231,39 @@ class DurableAgent(AgenticWorkflow, AgentBase):
workflow_state = self.get_workflow_state(instance_id) workflow_state = self.get_workflow_state(instance_id)
if workflow_state: if workflow_state:
runtime_status = workflow_state.runtime_status.name runtime_status = workflow_state.runtime_status.name
logger.debug(f"Workflow {instance_id} Dapr status: {runtime_status}") logger.debug(
f"Workflow {instance_id} Dapr status: {runtime_status}"
)
# If workflow is still running, return it for resumption # If workflow is still running, return it for resumption
# Handle case mismatch: Dapr returns 'RUNNING' but enum is 'running' # Handle case mismatch: Dapr returns 'RUNNING' but enum is 'running'
if runtime_status.upper() in [DaprWorkflowStatus.RUNNING.value.upper(), DaprWorkflowStatus.PENDING.value.upper()]: if runtime_status.upper() in [
DaprWorkflowStatus.RUNNING.value.upper(),
DaprWorkflowStatus.PENDING.value.upper(),
]:
return instance_id return instance_id
elif runtime_status.upper() in [DaprWorkflowStatus.UNKNOWN.value.upper()]: elif runtime_status.upper() in [
logger.debug(f"Workflow {instance_id} Dapr status is unknown, skipping") DaprWorkflowStatus.UNKNOWN.value.upper()
]:
logger.debug(
f"Workflow {instance_id} Dapr status is unknown, skipping"
)
continue continue
else: else:
# This is for COMPLETED, FAILED, TERMINATED # This is for COMPLETED, FAILED, TERMINATED
self._mark_workflow_completed(instance_id, runtime_status) self._mark_workflow_completed(
instance_id, runtime_status
)
else: else:
logger.debug(f"Workflow {instance_id} no longer exists in Dapr") logger.debug(
f"Workflow {instance_id} no longer exists in Dapr"
)
# Mark as completed in our state since it's no longer in Dapr # Mark as completed in our state since it's no longer in Dapr
self._mark_workflow_completed(instance_id) self._mark_workflow_completed(instance_id)
except Exception as e: except Exception as e:
logger.warning(f"Could not verify workflow {instance_id} in Dapr: {e}") logger.warning(
f"Could not verify workflow {instance_id} in Dapr: {e}"
)
return instance_id return instance_id
logger.debug("No incomplete workflows found") logger.debug("No incomplete workflows found")
@ -245,7 +273,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
logger.error(f"Error finding incomplete workflows: {e}") logger.error(f"Error finding incomplete workflows: {e}")
return None return None
def _mark_workflow_completed(self, instance_id: str, status: str = "completed") -> None: def _mark_workflow_completed(
self, instance_id: str, status: str = "completed"
) -> None:
""" """
Mark a workflow as completed in our state. Mark a workflow as completed in our state.
@ -284,26 +314,38 @@ class DurableAgent(AgenticWorkflow, AgentBase):
metadata = message.get("_message_metadata", {}) or {} metadata = message.get("_message_metadata", {}) or {}
# Extract workflow_instance_id from TriggerAction if present from orchestrator # Extract workflow_instance_id from TriggerAction if present from orchestrator
if "workflow_instance_id" in message: if "workflow_instance_id" in message:
metadata["triggering_workflow_instance_id"] = message["workflow_instance_id"] metadata["triggering_workflow_instance_id"] = message[
"workflow_instance_id"
]
else: else:
task = getattr(message, "task", None) task = getattr(message, "task", None)
metadata = getattr(message, "_message_metadata", {}) or {} metadata = getattr(message, "_message_metadata", {}) or {}
# Extract workflow_instance_id from TriggerAction if present from orchestrator # Extract workflow_instance_id from TriggerAction if present from orchestrator
if hasattr(message, "workflow_instance_id"): if hasattr(message, "workflow_instance_id"):
metadata["triggering_workflow_instance_id"] = getattr(message, "workflow_instance_id") metadata["triggering_workflow_instance_id"] = getattr(
message, "workflow_instance_id"
)
workflow_instance_id = ctx.instance_id workflow_instance_id = ctx.instance_id
triggering_workflow_instance_id = metadata.get("triggering_workflow_instance_id") triggering_workflow_instance_id = metadata.get(
"triggering_workflow_instance_id"
)
source = metadata.get("source") source = metadata.get("source")
# Set default source if not provided (for direct run() calls) # Set default source if not provided (for direct run() calls)
if not source: if not source:
source = "direct" source = "direct"
print(f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}") print(
print(f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}") f"DEBUG: tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
)
print(
f"DEBUG: triggering_workflow_instance_id: {triggering_workflow_instance_id}"
)
print(f"DEBUG: Current self.state at start of workflow: {self.state}") print(f"DEBUG: Current self.state at start of workflow: {self.state}")
logger.info(f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}") logger.info(
f"tool_calling_workflow started with workflow_instance_id: {workflow_instance_id}"
)
logger.info(f"Current self.state at start of workflow: {self.state}") logger.info(f"Current self.state at start of workflow: {self.state}")
# Store the instance ID from workflow context as the source of truth # Store the instance ID from workflow context as the source of truth
@ -324,19 +366,29 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"workflow_instance_id": workflow_instance_id, "workflow_instance_id": workflow_instance_id,
"triggering_workflow_instance_id": triggering_workflow_instance_id, "triggering_workflow_instance_id": triggering_workflow_instance_id,
"workflow_name": self._workflow_name, "workflow_name": self._workflow_name,
"status": "running" # Mark as running "status": "running", # Mark as running
} }
self.state.setdefault("instances", {})[workflow_instance_id] = instance_entry self.state.setdefault("instances", {})[
workflow_instance_id
] = instance_entry
logger.info(f"Created new instance entry: {workflow_instance_id}") logger.info(f"Created new instance entry: {workflow_instance_id}")
print(f"DEBUG: Created new instance entry: {workflow_instance_id}") print(f"DEBUG: Created new instance entry: {workflow_instance_id}")
# Immediately save state so graceful shutdown can capture this instance # Immediately save state so graceful shutdown can capture this instance
self.save_state() self.save_state()
logger.info(f"Saved state immediately after creating instance entry for {workflow_instance_id}") logger.info(
print(f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}") f"Saved state immediately after creating instance entry for {workflow_instance_id}"
)
print(
f"DEBUG: Saved state immediately after creating instance entry for {workflow_instance_id}"
)
else: else:
logger.info(f"Found existing instance entry for workflow {workflow_instance_id}") logger.info(
print(f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}") f"Found existing instance entry for workflow {workflow_instance_id}"
)
print(
f"DEBUG: Found existing instance entry for workflow {workflow_instance_id}"
)
final_message: Optional[Dict[str, Any]] = None final_message: Optional[Dict[str, Any]] = None
if not ctx.is_replaying: if not ctx.is_replaying:
@ -371,7 +423,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
# Step 5: Add the assistant's response message to the chat history # Step 5: Add the assistant's response message to the chat history
yield ctx.call_activity( yield ctx.call_activity(
self.append_assistant_message, self.append_assistant_message,
input={"instance_id": workflow_instance_id, "message": response_message}, input={
"instance_id": workflow_instance_id,
"message": response_message,
},
) )
# Step 6: Handle tool calls response # Step 6: Handle tool calls response
@ -391,7 +446,10 @@ class DurableAgent(AgenticWorkflow, AgentBase):
for tr in tool_results: for tr in tool_results:
yield ctx.call_activity( yield ctx.call_activity(
self.append_tool_message, self.append_tool_message,
input={"instance_id": workflow_instance_id, "tool_result": tr}, input={
"instance_id": workflow_instance_id,
"tool_result": tr,
},
) )
# 🔴 If this was the last turn, stop here—even though there were tool calls # 🔴 If this was the last turn, stop here—even though there were tool calls
if turn == self.max_iterations: if turn == self.max_iterations:
@ -560,7 +618,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
""" """
# Construct messages using instance-specific chat history instead of global memory # Construct messages using instance-specific chat history instead of global memory
# This ensures proper message sequence for tool calls # This ensures proper message sequence for tool calls
messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(instance_id, task or {}) messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history(
instance_id, task or {}
)
user_message = self.get_last_message_if_user(messages) user_message = self.get_last_message_if_user(messages)
# Always work with a copy of the user message for safety # Always work with a copy of the user message for safety
@ -582,7 +642,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input", "source": "user_input",
"workflow_instance_id": instance_id, "workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None, "triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name "workflow_name": self._workflow_name,
} }
inst: dict = self.state["instances"][instance_id] inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json")) inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -623,13 +683,17 @@ class DurableAgent(AgenticWorkflow, AgentBase):
error_type = type(e).__name__ error_type = type(e).__name__
error_msg = str(e) error_msg = str(e)
logger.error(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}") logger.error(
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
)
logger.error(f"Task: {task}") logger.error(f"Task: {task}")
logger.error(f"Messages count: {len(messages)}") logger.error(f"Messages count: {len(messages)}")
logger.error(f"Tools available: {len(self.get_llm_tools())}") logger.error(f"Tools available: {len(self.get_llm_tools())}")
logger.error("Full error details:", exc_info=True) logger.error("Full error details:", exc_info=True)
raise AgentError(f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}") from e raise AgentError(
f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}"
) from e
@task @task
async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: async def run_tool(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
@ -733,7 +797,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input", "source": "user_input",
"workflow_instance_id": instance_id, "workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None, "triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name "workflow_name": self._workflow_name,
} }
inst: dict = self.state["instances"][instance_id] inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json")) inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -774,7 +838,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input", "source": "user_input",
"workflow_instance_id": instance_id, "workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None, "triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name "workflow_name": self._workflow_name,
} }
inst: dict = self.state["instances"][instance_id] inst: dict = self.state["instances"][instance_id]
inst.setdefault("messages", []).append(msg_object.model_dump(mode="json")) inst.setdefault("messages", []).append(msg_object.model_dump(mode="json"))
@ -808,7 +872,7 @@ class DurableAgent(AgenticWorkflow, AgentBase):
"source": "user_input", "source": "user_input",
"workflow_instance_id": instance_id, "workflow_instance_id": instance_id,
"triggering_workflow_instance_id": None, "triggering_workflow_instance_id": None,
"workflow_name": self._workflow_name "workflow_name": self._workflow_name,
} }
inst: dict = self.state["instances"][instance_id] inst: dict = self.state["instances"][instance_id]
inst["output"] = final_output inst["output"] = final_output
@ -870,7 +934,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
except Exception as e: except Exception as e:
logger.error(f"Error processing broadcast message: {e}", exc_info=True) logger.error(f"Error processing broadcast message: {e}", exc_info=True)
def _construct_messages_with_instance_history(self, instance_id: str, input_data: Union[str, Dict[str, Any]]) -> List[Dict[str, Any]]: def _construct_messages_with_instance_history(
self, instance_id: str, input_data: Union[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
""" """
Construct messages using instance-specific chat history instead of global memory. Construct messages using instance-specific chat history instead of global memory.
This ensures proper message sequence for tool calls and prevents OpenAI API errors This ensures proper message sequence for tool calls and prevents OpenAI API errors
@ -884,7 +950,9 @@ class DurableAgent(AgenticWorkflow, AgentBase):
List of formatted messages with proper sequence List of formatted messages with proper sequence
""" """
if not self.prompt_template: if not self.prompt_template:
raise ValueError("Prompt template must be initialized before constructing messages.") raise ValueError(
"Prompt template must be initialized before constructing messages."
)
# Get instance-specific chat history instead of global memory # Get instance-specific chat history instead of global memory
instance_data = self.state.get("instances", {}).get(instance_id, {}) instance_data = self.state.get("instances", {}).get(instance_id, {})
@ -897,10 +965,14 @@ class DurableAgent(AgenticWorkflow, AgentBase):
chat_history.append(msg) chat_history.append(msg)
else: else:
# Convert DurableAgentMessage to dict if needed # Convert DurableAgentMessage to dict if needed
chat_history.append(msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg)) chat_history.append(
msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
)
if isinstance(input_data, str): if isinstance(input_data, str):
formatted_messages = self.prompt_template.format_prompt(chat_history=chat_history) formatted_messages = self.prompt_template.format_prompt(
chat_history=chat_history
)
if isinstance(formatted_messages, list): if isinstance(formatted_messages, list):
user_message = {"role": "user", "content": input_data} user_message = {"role": "user", "content": input_data}
return formatted_messages + [user_message] return formatted_messages + [user_message]

View File

@ -230,4 +230,3 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
logger.error("Full error details:", exc_info=True) logger.error("Full error details:", exc_info=True)
raise ValueError(f"OpenAI API error ({error_type}): {error_msg}") from e raise ValueError(f"OpenAI API error ({error_type}): {error_msg}") from e

View File

@ -95,7 +95,9 @@ class WorkflowContextStorage:
logger.warning(f"⚠️ No context found for instance {instance_id}") logger.warning(f"⚠️ No context found for instance {instance_id}")
return context return context
def create_resumed_workflow_context(self, instance_id: str, agent_name: Optional[str] = None) -> Dict[str, Any]: def create_resumed_workflow_context(
self, instance_id: str, agent_name: Optional[str] = None
) -> Dict[str, Any]:
""" """
Create a new trace context for a resumed workflow after app restart. Create a new trace context for a resumed workflow after app restart.
@ -111,7 +113,9 @@ class WorkflowContextStorage:
""" """
try: try:
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
# Create a new trace for the resumed workflow with proper AGENT span # Create a new trace for the resumed workflow with proper AGENT span
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -122,6 +126,7 @@ class WorkflowContextStorage:
with tracer.start_as_current_span(span_name) as span: with tracer.start_as_current_span(span_name) as span:
# Set AGENT span attributes # Set AGENT span attributes
from .constants import OPENINFERENCE_SPAN_KIND from .constants import OPENINFERENCE_SPAN_KIND
span.set_attribute(OPENINFERENCE_SPAN_KIND, "AGENT") span.set_attribute(OPENINFERENCE_SPAN_KIND, "AGENT")
span.set_attribute("workflow.instance_id", instance_id) span.set_attribute("workflow.instance_id", instance_id)
span.set_attribute("workflow.resumed", True) span.set_attribute("workflow.resumed", True)
@ -136,23 +141,27 @@ class WorkflowContextStorage:
"tracestate": carrier.get("tracestate"), "tracestate": carrier.get("tracestate"),
"instance_id": instance_id, "instance_id": instance_id,
"resumed": True, "resumed": True,
"debug_info": f"New trace created for resumed workflow {instance_id}" "debug_info": f"New trace created for resumed workflow {instance_id}",
} }
# Store the new context # Store the new context
self.store_context(instance_id, context_data) self.store_context(instance_id, context_data)
logger.info(f"Created new trace context for resumed workflow {instance_id}") logger.info(
f"Created new trace context for resumed workflow {instance_id}"
)
return context_data return context_data
except Exception as e: except Exception as e:
logger.error(f"Failed to create resumed workflow context for {instance_id}: {e}") logger.error(
f"Failed to create resumed workflow context for {instance_id}: {e}"
)
return { return {
"traceparent": None, "traceparent": None,
"tracestate": None, "tracestate": None,
"instance_id": instance_id, "instance_id": instance_id,
"resumed": True, "resumed": True,
"error": str(e) "error": str(e),
} }
def cleanup_context(self, instance_id: str) -> None: def cleanup_context(self, instance_id: str) -> None:

View File

@ -428,15 +428,15 @@ class DaprAgentsInstrumentor(BaseInstrumentor):
# Workflow monitoring wrapper - creates AGENT spans # Workflow monitoring wrapper - creates AGENT spans
wrapt.wrap_function_wrapper( wrapt.wrap_function_wrapper(
WorkflowApp.__module__, WorkflowApp.__module__,
f'{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}', f"{WorkflowApp.__name__}.{WorkflowApp.run_and_monitor_workflow_async.__name__}",
WorkflowMonitorWrapper(self._tracer) WorkflowMonitorWrapper(self._tracer),
) )
# Workflow run wrapper - creates workflow spans # Workflow run wrapper - creates workflow spans
wrapt.wrap_function_wrapper( wrapt.wrap_function_wrapper(
WorkflowApp.__module__, WorkflowApp.__module__,
f'{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}', f"{WorkflowApp.__name__}.{WorkflowApp.run_workflow.__name__}",
WorkflowRunWrapper(self._tracer) WorkflowRunWrapper(self._tracer),
) )
# WorkflowTask call wrapper # WorkflowTask call wrapper
@ -444,8 +444,8 @@ class DaprAgentsInstrumentor(BaseInstrumentor):
# and is necessary due to the async nature of the WorkflowTask.__call__ method. # and is necessary due to the async nature of the WorkflowTask.__call__ method.
wrapt.wrap_function_wrapper( wrapt.wrap_function_wrapper(
WorkflowTask.__module__, WorkflowTask.__module__,
f'{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}', f"{WorkflowTask.__name__}.{WorkflowTask.__call__.__name__}",
WorkflowTaskWrapper(self._tracer) WorkflowTaskWrapper(self._tracer),
) )
except Exception as e: except Exception as e:
logger.error(f"Error applying workflow wrappers: {e}", exc_info=True) logger.error(f"Error applying workflow wrappers: {e}", exc_info=True)

View File

@ -341,11 +341,17 @@ class WorkflowMonitorWrapper:
from ..context_storage import store_workflow_context from ..context_storage import store_workflow_context
captured_context = extract_otel_context() captured_context = extract_otel_context()
if captured_context.get("traceparent") and captured_context.get("trace_id") != "00000000000000000000000000000000": if (
captured_context.get("traceparent")
and captured_context.get("trace_id")
!= "00000000000000000000000000000000"
):
logger.debug( logger.debug(
f"Captured traceparent: {captured_context.get('traceparent')}" f"Captured traceparent: {captured_context.get('traceparent')}"
) )
store_workflow_context("__global_workflow_context__", captured_context) store_workflow_context(
"__global_workflow_context__", captured_context
)
else: else:
logger.debug( logger.debug(
f"Invalid or empty trace context captured: {captured_context}" f"Invalid or empty trace context captured: {captured_context}"

View File

@ -85,14 +85,14 @@ class WorkflowTaskWrapper:
# Determine task details # Determine task details
logger.debug(f"WorkflowTaskWrapper: instance type = {type(instance)}") logger.debug(f"WorkflowTaskWrapper: instance type = {type(instance)}")
logger.debug(f"WorkflowTaskWrapper: instance attributes = {dir(instance)}") logger.debug(f"WorkflowTaskWrapper: instance attributes = {dir(instance)}")
if hasattr(instance, 'func'): if hasattr(instance, "func"):
logger.debug(f"WorkflowTaskWrapper: instance.func = {instance.func}") logger.debug(f"WorkflowTaskWrapper: instance.func = {instance.func}")
else: else:
logger.debug("WorkflowTaskWrapper: instance has no 'func' attribute") logger.debug("WorkflowTaskWrapper: instance has no 'func' attribute")
task_name = ( task_name = (
getattr(instance.func, "__name__", "unknown_task") getattr(instance.func, "__name__", "unknown_task")
if hasattr(instance, 'func') and instance.func if hasattr(instance, "func") and instance.func
else "workflow_task" else "workflow_task"
) )
span_kind = self._determine_span_kind(instance, task_name) span_kind = self._determine_span_kind(instance, task_name)
@ -245,7 +245,13 @@ class WorkflowTaskWrapper:
return attributes return attributes
def _handle_async_execution( def _handle_async_execution(
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict self,
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
span_name: str,
attributes: dict,
) -> Any: ) -> Any:
""" """
Handle asynchronous workflow task execution with OpenTelemetry context restoration. Handle asynchronous workflow task execution with OpenTelemetry context restoration.
@ -279,29 +285,45 @@ class WorkflowTaskWrapper:
# If no context found for specific instance, try global context as fallback # If no context found for specific instance, try global context as fallback
if otel_context is None: if otel_context is None:
logger.debug(f"No context found for instance {instance_id}, trying global context") logger.debug(
f"No context found for instance {instance_id}, trying global context"
)
otel_context = get_workflow_context("__global_workflow_context__") otel_context = get_workflow_context("__global_workflow_context__")
if otel_context: if otel_context:
# Store the global context with the specific instance ID for future use # Store the global context with the specific instance ID for future use
from ..context_storage import store_workflow_context from ..context_storage import store_workflow_context
store_workflow_context(instance_id, otel_context) store_workflow_context(instance_id, otel_context)
logger.debug(f"Copied global context to instance {instance_id}") logger.debug(f"Copied global context to instance {instance_id}")
else: else:
# If still no context found (e.g., after app restart), create a new one for resumed workflows # If still no context found (e.g., after app restart), create a new one for resumed workflows
logger.debug(f"No context found for instance {instance_id} - creating new context for resumed workflow") logger.debug(
f"No context found for instance {instance_id} - creating new context for resumed workflow"
)
# Try to get agent name from the task instance # Try to get agent name from the task instance
agent_name = None agent_name = None
if hasattr(instance, 'agent') and instance.agent and hasattr(instance.agent, 'name'): if (
hasattr(instance, "agent")
and instance.agent
and hasattr(instance.agent, "name")
):
agent_name = instance.agent.name agent_name = instance.agent.name
elif hasattr(instance, 'func') and instance.func and hasattr(instance.func, '__self__'): elif (
hasattr(instance, "func")
and instance.func
and hasattr(instance.func, "__self__")
):
agent_instance = instance.func.__self__ agent_instance = instance.func.__self__
if hasattr(agent_instance, 'name'): if hasattr(agent_instance, "name"):
agent_name = agent_instance.name agent_name = agent_instance.name
from ..context_storage import _context_storage from ..context_storage import _context_storage
otel_context = _context_storage.create_resumed_workflow_context(instance_id, agent_name)
otel_context = _context_storage.create_resumed_workflow_context(
instance_id, agent_name
)
# Create span with restored context if available # Create span with restored context if available
from ..context_propagation import create_child_span_with_context from ..context_propagation import create_child_span_with_context
@ -348,7 +370,13 @@ class WorkflowTaskWrapper:
return async_wrapper() return async_wrapper()
def _handle_sync_execution( def _handle_sync_execution(
self, wrapped: Any, instance: Any, args: Any, kwargs: Any, span_name: str, attributes: dict self,
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
span_name: str,
attributes: dict,
) -> Any: ) -> Any:
""" """
Handle synchronous workflow task execution with OpenTelemetry context restoration. Handle synchronous workflow task execution with OpenTelemetry context restoration.

View File

@ -17,7 +17,6 @@ class ElevenLabsClientConfig(BaseModel):
) )
class NVIDIAClientConfig(BaseModel): class NVIDIAClientConfig(BaseModel):
base_url: Optional[str] = Field( base_url: Optional[str] = Field(
"https://integrate.api.nvidia.com/v1", description="Base URL for the NVIDIA API" "https://integrate.api.nvidia.com/v1", description="Base URL for the NVIDIA API"
@ -27,7 +26,6 @@ class NVIDIAClientConfig(BaseModel):
) )
class DaprInferenceClientConfig: class DaprInferenceClientConfig:
pass pass
@ -73,7 +71,6 @@ class HFInferenceClientConfig(BaseModel):
) )
class OpenAIClientConfig(BaseModel): class OpenAIClientConfig(BaseModel):
base_url: Optional[str] = Field(None, description="Base URL for the OpenAI API") base_url: Optional[str] = Field(None, description="Base URL for the OpenAI API")
api_key: Optional[str] = Field( api_key: Optional[str] = Field(
@ -112,7 +109,6 @@ class AzureOpenAIClientConfig(BaseModel):
) )
class OpenAIModelConfig(OpenAIClientConfig): class OpenAIModelConfig(OpenAIClientConfig):
type: Literal["openai"] = Field( type: Literal["openai"] = Field(
"openai", description="Type of the model, must always be 'openai'" "openai", description="Type of the model, must always be 'openai'"
@ -170,7 +166,6 @@ class OpenAIParamsBase(BaseModel):
stream: Optional[bool] = Field(False, description="Whether to stream responses") stream: Optional[bool] = Field(False, description="Whether to stream responses")
class OpenAITextCompletionParams(OpenAIParamsBase): class OpenAITextCompletionParams(OpenAIParamsBase):
""" """
Specific configs for the text completions endpoint. Specific configs for the text completions endpoint.
@ -186,7 +181,6 @@ class OpenAITextCompletionParams(OpenAIParamsBase):
suffix: Optional[str] = Field(None, description="Suffix to append to the prompt") suffix: Optional[str] = Field(None, description="Suffix to append to the prompt")
class OpenAIChatCompletionParams(OpenAIParamsBase): class OpenAIChatCompletionParams(OpenAIParamsBase):
""" """
Specific settings for the Chat Completion endpoint. Specific settings for the Chat Completion endpoint.
@ -222,7 +216,6 @@ class OpenAIChatCompletionParams(OpenAIParamsBase):
) )
class HFHubChatCompletionParams(BaseModel): class HFHubChatCompletionParams(BaseModel):
""" """
Specific settings for Hugging Face Hub Chat Completion endpoint. Specific settings for Hugging Face Hub Chat Completion endpoint.
@ -285,7 +278,6 @@ class HFHubChatCompletionParams(BaseModel):
) )
class NVIDIAChatCompletionParams(OpenAIParamsBase): class NVIDIAChatCompletionParams(OpenAIParamsBase):
""" """
Specific settings for the Chat Completion endpoint. Specific settings for the Chat Completion endpoint.
@ -311,7 +303,6 @@ class NVIDIAChatCompletionParams(OpenAIParamsBase):
) )
class PromptyModelConfig(BaseModel): class PromptyModelConfig(BaseModel):
api: Literal["chat", "completion"] = Field( api: Literal["chat", "completion"] = Field(
"chat", description="The API to use, either 'chat' or 'completion'" "chat", description="The API to use, either 'chat' or 'completion'"
@ -330,7 +321,6 @@ class PromptyModelConfig(BaseModel):
description="Determines if full response or just the first one is returned", description="Determines if full response or just the first one is returned",
) )
@model_validator(mode="before") @model_validator(mode="before")
def sync_model_name(cls, values: dict): def sync_model_name(cls, values: dict):
""" """
@ -405,7 +395,6 @@ class PromptyDefinition(BaseModel):
) )
class AudioSpeechRequest(BaseModel): class AudioSpeechRequest(BaseModel):
model: Optional[Literal["tts-1", "tts-1-hd"]] = Field( model: Optional[Literal["tts-1", "tts-1-hd"]] = Field(
"tts-1", description="TTS model to use. Defaults to 'tts-1'." "tts-1", description="TTS model to use. Defaults to 'tts-1'."

View File

@ -0,0 +1,121 @@
"""
Reusable signal handling mixin for graceful shutdown across different service types.
"""
import asyncio
import logging
from typing import Optional
from dapr_agents.utils import add_signal_handlers_cross_platform
logger = logging.getLogger(__name__)
class SignalHandlingMixin:
"""
Mixin providing reusable signal handling for graceful shutdown.
This mixin can be used by any class that needs to handle shutdown signals
(SIGINT, SIGTERM) gracefully. It provides a consistent interface for:
- Setting up signal handlers
- Managing shutdown events
- Triggering graceful shutdown logic
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._shutdown_event: Optional[asyncio.Event] = None
self._signal_handlers_setup = False
def setup_signal_handlers(self) -> None:
"""
Set up signal handlers for graceful shutdown.
This method should be called during initialization or startup
to enable graceful shutdown handling.
"""
# Initialize the attribute if it doesn't exist
if not hasattr(self, "_signal_handlers_setup"):
self._signal_handlers_setup = False
if self._signal_handlers_setup:
logger.debug("Signal handlers already set up")
return
# Initialize shutdown event if it doesn't exist
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
self._shutdown_event = asyncio.Event()
# Set up signal handlers
loop = asyncio.get_event_loop()
add_signal_handlers_cross_platform(loop, self._handle_shutdown_signal)
self._signal_handlers_setup = True
logger.debug("Signal handlers set up for graceful shutdown")
def _handle_shutdown_signal(self, sig: int) -> None:
"""
Internal signal handler that triggers graceful shutdown.
Args:
sig: The received signal number
"""
logger.debug(f"Shutdown signal {sig} received. Triggering graceful shutdown...")
# Set the shutdown event
if self._shutdown_event:
self._shutdown_event.set()
# Call the graceful shutdown method if it exists
if hasattr(self, "graceful_shutdown"):
asyncio.create_task(self.graceful_shutdown())
elif hasattr(self, "stop"):
# Fallback to stop() method if graceful_shutdown doesn't exist
asyncio.create_task(self.stop())
else:
logger.warning(
"No graceful shutdown method found. Implement graceful_shutdown() or stop() method."
)
async def graceful_shutdown(self) -> None:
"""
Perform graceful shutdown operations.
This method should be overridden by classes that use this mixin
to implement their specific shutdown logic.
Default implementation calls stop() if it exists.
"""
if hasattr(self, "stop"):
await self.stop()
else:
logger.warning(
"No stop() method found. Override graceful_shutdown() to implement shutdown logic."
)
def is_shutdown_requested(self) -> bool:
"""
Check if a shutdown has been requested.
Returns:
bool: True if shutdown has been requested, False otherwise
"""
return (
hasattr(self, "_shutdown_event")
and self._shutdown_event is not None
and self._shutdown_event.is_set()
)
async def wait_for_shutdown(self, check_interval: float = 1.0) -> None:
"""
Wait for a shutdown signal to be received.
Args:
check_interval: How often to check for shutdown (in seconds)
"""
if not hasattr(self, "_shutdown_event") or self._shutdown_event is None:
raise RuntimeError(
"Signal handlers not set up. Call setup_signal_handlers() first."
)
while not self._shutdown_event.is_set():
await asyncio.sleep(check_interval)

View File

@ -556,7 +556,10 @@ class WorkflowApp(BaseModel):
f"Output: {json.dumps(state.serialized_output, indent=2)}" f"Output: {json.dumps(state.serialized_output, indent=2)}"
) )
elif workflow_status.upper() in (DaprWorkflowStatus.FAILED.value.upper(), "ABORTED"): elif workflow_status.upper() in (
DaprWorkflowStatus.FAILED.value.upper(),
"ABORTED",
):
# Ensure `failure_details` exists before accessing attributes # Ensure `failure_details` exists before accessing attributes
error_type = getattr(failure_details, "error_type", "Unknown") error_type = getattr(failure_details, "error_type", "Unknown")
message = getattr(failure_details, "message", "No message provided") message = getattr(failure_details, "message", "No message provided")
@ -611,7 +614,9 @@ class WorkflowApp(BaseModel):
# Off-load the potentially blocking run_workflow call to a thread. # Off-load the potentially blocking run_workflow call to a thread.
instance_id = await asyncio.to_thread(self.run_workflow, workflow, input) instance_id = await asyncio.to_thread(self.run_workflow, workflow, input)
logger.debug(f"Workflow '{workflow}' started with instance ID: {instance_id}") logger.debug(
f"Workflow '{workflow}' started with instance ID: {instance_id}"
)
# Await the asynchronous monitoring of the workflow state. # Await the asynchronous monitoring of the workflow state.
state = await self.monitor_workflow_state(instance_id) state = await self.monitor_workflow_state(instance_id)

View File

@ -138,29 +138,43 @@ class ServiceMixin(SignalHandlingMixin):
# Save state before shutting down to ensure persistence and agent durability to properly rerun after being stoped # Save state before shutting down to ensure persistence and agent durability to properly rerun after being stoped
try: try:
if hasattr(self, 'save_state') and hasattr(self, 'state'): if hasattr(self, "save_state") and hasattr(self, "state"):
# Graceful shutdown compensation: Save incomplete instance if it exists # Graceful shutdown compensation: Save incomplete instance if it exists
if hasattr(self, 'workflow_instance_id') and self.workflow_instance_id: if hasattr(self, "workflow_instance_id") and self.workflow_instance_id:
if self.workflow_instance_id not in self.state.get("instances", {}): if self.workflow_instance_id not in self.state.get("instances", {}):
# This instance was never saved, add it as incomplete # This instance was never saved, add it as incomplete
from datetime import datetime, timezone from datetime import datetime, timezone
incomplete_entry = { incomplete_entry = {
"messages": [], "messages": [],
"start_time": datetime.now(timezone.utc).isoformat(), "start_time": datetime.now(timezone.utc).isoformat(),
"source": "graceful_shutdown", "source": "graceful_shutdown",
"source_workflow_instance_id": None, "source_workflow_instance_id": None,
"workflow_name": getattr(self, '_workflow_name', 'Unknown'), "workflow_name": getattr(self, "_workflow_name", "Unknown"),
"dapr_status": DaprWorkflowStatus.PENDING, "dapr_status": DaprWorkflowStatus.PENDING,
"suspended_reason": "app_terminated" "suspended_reason": "app_terminated",
} }
self.state.setdefault("instances", {})[self.workflow_instance_id] = incomplete_entry self.state.setdefault("instances", {})[
logger.info(f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown") self.workflow_instance_id
] = incomplete_entry
logger.info(
f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown"
)
else: else:
# Mark existing instance as suspended due to app termination # Mark existing instance as suspended due to app termination
if "instances" in self.state and self.workflow_instance_id in self.state["instances"]: if (
self.state["instances"][self.workflow_instance_id]["dapr_status"] = DaprWorkflowStatus.SUSPENDED "instances" in self.state
self.state["instances"][self.workflow_instance_id]["suspended_reason"] = "app_terminated" and self.workflow_instance_id in self.state["instances"]
logger.info(f"Marked instance {self.workflow_instance_id} as suspended due to app termination") ):
self.state["instances"][self.workflow_instance_id][
"dapr_status"
] = DaprWorkflowStatus.SUSPENDED
self.state["instances"][self.workflow_instance_id][
"suspended_reason"
] = "app_terminated"
logger.info(
f"Marked instance {self.workflow_instance_id} as suspended due to app termination"
)
self.save_state() self.save_state()
logger.debug("Workflow state saved successfully.") logger.debug("Workflow state saved successfully.")