mirror of https://github.com/dapr/dapr-agents.git
Compare commits
27 Commits
Author | SHA1 | Date |
---|---|---|
|
c2eff2b971 | |
|
c4b1f7c441 | |
|
f87e27f450 | |
|
1e5275834d | |
|
b7b4a9891e | |
|
2fd44b3ecc | |
|
2757aab5b6 | |
|
d86a4c5a70 | |
|
83fc449e39 | |
|
94bf5d2a38 | |
|
8741289e7d | |
|
41faa4f5b7 | |
|
76ad962b69 | |
|
28ac198055 | |
|
e27f5befb0 | |
|
889b7bf7ef | |
|
4dce1c0300 | |
|
53c1c9ffde | |
|
6f20c0d9a0 | |
|
6823cd633d | |
|
a878e76ec1 | |
|
75274ac607 | |
|
f129754486 | |
|
c31e985d81 | |
|
f9eb48c02c | |
|
6f0cfc8818 | |
|
fd28b02935 |
|
@ -0,0 +1,65 @@
|
|||
name: Lint and Build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- feature/*
|
||||
- feat/*
|
||||
- bugfix/*
|
||||
- hotfix/*
|
||||
- fix/*
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- feature/*
|
||||
- release-*
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel tox
|
||||
- name: Run Autoformatter
|
||||
run: |
|
||||
tox -e ruff
|
||||
statusResult=$(git status -u --porcelain)
|
||||
if [ -z $statusResult ]
|
||||
then
|
||||
exit 0
|
||||
else
|
||||
echo "Source files are not formatted correctly. Run 'tox -e ruff' to autoformat."
|
||||
exit 1
|
||||
fi
|
||||
- name: Run Linter
|
||||
run: |
|
||||
tox -e flake8
|
||||
|
||||
build:
|
||||
needs: lint
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_ver: ["3.10", "3.11", "3.12", "3.13"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python_ver }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_ver }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel tox
|
||||
- name: Check Typing
|
||||
run: |
|
||||
tox -e type
|
|
@ -4,12 +4,12 @@ on:
|
|||
branches:
|
||||
- main
|
||||
paths:
|
||||
- docs
|
||||
- docs/**
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- docs
|
||||
- docs/**
|
||||
workflow_dispatch:
|
||||
permissions:
|
||||
contents: write
|
||||
|
@ -18,7 +18,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
name: Review changed files
|
||||
outputs:
|
||||
docs_any_changed: NaN
|
||||
docs_any_changed: ${{ steps.changed-files.outputs.docs_any_changed }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Get changed files
|
||||
|
@ -42,10 +42,16 @@ jobs:
|
|||
- name: Remove plugins from mkdocs configuration
|
||||
run: |
|
||||
sed -i '/^plugins:/,/^[^ ]/d' mkdocs.yml
|
||||
- name: Run MkDocs build
|
||||
uses: Kjuly/mkdocs-page-builder@main
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
pip install mkdocs-material
|
||||
pip install .[recommended,git,imaging]
|
||||
pip install mkdocs-jupyter
|
||||
- name: Validate build
|
||||
run: mkdocs build
|
||||
|
||||
deploy:
|
||||
if: github.ref == 'refs/heads/main'
|
||||
runs-on: ubuntu-latest
|
||||
needs: documentation_validation
|
||||
steps:
|
||||
|
|
|
@ -165,3 +165,7 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea
|
||||
|
||||
.ruff_cache/
|
||||
|
||||
quickstarts/05-multi-agent-workflow-dapr-workflows/services/**/*_state.json
|
|
@ -1,86 +0,0 @@
|
|||
# Code of Conduct
|
||||
|
||||
We are committed to fostering a welcoming, inclusive, and respectful environment for everyone involved in this project. This Code of Conduct outlines the expected behaviors within our community and the steps for reporting unacceptable actions. By participating, you agree to uphold these standards, helping to create a positive and collaborative space.
|
||||
|
||||
---
|
||||
|
||||
## Our Pledge
|
||||
|
||||
As members, contributors, and leaders of this community, we pledge to:
|
||||
|
||||
* Ensure participation in our project is free from harassment, discrimination, or exclusion.
|
||||
* Treat everyone with respect and empathy, regardless of factors such as age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity or expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual orientation.
|
||||
* Act in ways that contribute to a safe, welcoming, and supportive environment for all participants.
|
||||
|
||||
---
|
||||
|
||||
## Our Standards
|
||||
|
||||
We strive to create an environment where all members can thrive. Examples of positive behaviors include:
|
||||
|
||||
* Showing kindness, empathy, and consideration for others.
|
||||
* Being respectful of differing opinions, experiences, and perspectives.
|
||||
* Providing constructive feedback in a supportive manner.
|
||||
* Taking responsibility for mistakes, apologizing when necessary, and learning from experiences.
|
||||
* Prioritizing the success and well-being of the entire community over individual gains.
|
||||
|
||||
The following behaviors are considered unacceptable:
|
||||
|
||||
* Using sexualized language or imagery, or engaging in inappropriate sexual attention or advances.
|
||||
* Making insulting, derogatory, or inflammatory comments, including trolling or personal attacks.
|
||||
* Engaging in harassment, whether public or private.
|
||||
* Publishing private or sensitive information about others without explicit consent.
|
||||
* Engaging in behavior that disrupts discussions, events, or contributions in a negative way.
|
||||
* Any conduct that could reasonably be deemed unprofessional or harmful to others.
|
||||
|
||||
---
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies to all areas of interaction within the community, including but not limited to:
|
||||
|
||||
* Discussions on forums, repositories, or other official communication channels.
|
||||
* Contributions made to the project, such as code, documentation, or issues.
|
||||
* Public representation of the community, such as through official social media accounts or at events.
|
||||
|
||||
It also applies to actions outside these spaces if they negatively impact the health, safety, or inclusivity of the community.
|
||||
|
||||
---
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for ensuring that this Code of Conduct is upheld. They may take appropriate and fair corrective actions in response to any behavior that violates these standards, including:
|
||||
|
||||
* Removing, editing, or rejecting comments, commits, issues, or other contributions not aligned with the Code of Conduct.
|
||||
* Temporarily or permanently banning individuals for repeated or severe violations.
|
||||
|
||||
Leaders will always strive to communicate their decisions clearly and fairly.
|
||||
|
||||
---
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
If you experience or witness unacceptable behavior, please report it to the project's owner [Roberto Rodriguez](https://www.linkedin.com/in/cyb3rward0g/). Your report will be handled with sensitivity, and we will respect your privacy and confidentiality while addressing the issue.
|
||||
|
||||
When reporting, please include:
|
||||
|
||||
* A description of the incident.
|
||||
* When and where it occurred.
|
||||
* Any additional context or supporting evidence, if available.
|
||||
|
||||
---
|
||||
|
||||
## Enforcement Process
|
||||
|
||||
We encourage resolving issues through dialogue when possible, but community leaders will intervene when necessary. Actions may include warnings, temporary bans, or permanent removal from the community, depending on the severity of the behavior.
|
||||
|
||||
---
|
||||
|
||||
## Attribution
|
||||
This Code of Conduct is inspired by the [Contributor Covenant, version 2.0](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html) and has drawn inspiration from open source community guidelines by Microsoft, Mozilla, and others.
|
||||
|
||||
For further context on best practices for open source codes of conduct, see the [Contributor Covenant FAQ](https://www.contributor-covenant.org/faq).
|
||||
|
||||
---
|
||||
|
||||
Thank you for helping to create a positive environment! ❤️
|
12
README.md
12
README.md
|
@ -1,5 +1,13 @@
|
|||
# Dapr Agents: A Framework for Agentic AI Systems
|
||||
|
||||
[](https://pypi.org/project/dapr-agents/)
|
||||
[](https://pypi.org/project/dapr-agents/)
|
||||
[](https://github.com/dapr/dapr-agents/actions/workflows/build.yaml)
|
||||
[](https://github.com/dapr/dapr-agents/blob/main/LICENSE)
|
||||
[](http://bit.ly/dapr-discord)
|
||||
[](https://youtube.com/@daprdev)
|
||||
[](https://twitter.com/daprdev)
|
||||
|
||||
Dapr Agents is a developer framework designed to build production-grade resilient AI agent systems that operate at scale. Built on top of the battle-tested Dapr project, it enables software developers to create AI agents that reason, act, and collaborate using Large Language Models (LLMs), while leveraging built-in observability and stateful workflow execution to guarantee agentic workflows complete successfully, no matter how complex.
|
||||
|
||||

|
||||
|
@ -60,8 +68,8 @@ As a part of **CNCF**, Dapr Agents is vendor-neutral, eliminating concerns about
|
|||
|
||||
Here are some of the major features we're working on for the current quarter:
|
||||
|
||||
### Q2 2024
|
||||
- **MCP Support** - Integration with Anthropic's MCP platform ([#50](https://github.com/dapr/dapr-agents/issues/50))
|
||||
### Q2 2025
|
||||
- **MCP Support** - Integration with Anthropic's MCP platform ([#50](https://github.com/dapr/dapr-agents/issues/50) ✅ )
|
||||
- **Agent Interaction Tracing** - Enhanced observability of agent interactions with LLMs and tools ([#79](https://github.com/dapr/dapr-agents/issues/79))
|
||||
- **Streaming LLM Output** - Real-time streaming capabilities for LLM responses ([#80](https://github.com/dapr/dapr-agents/issues/80))
|
||||
- **HTTP Endpoint Tools** - Support for using Dapr's HTTP endpoint capabilities for tool calling ([#81](https://github.com/dapr/dapr-agents/issues/81))
|
||||
|
|
|
@ -4,9 +4,10 @@ from datetime import datetime
|
|||
import requests
|
||||
import time
|
||||
|
||||
|
||||
class WeatherForecast(AgentTool):
|
||||
name: str = 'WeatherForecast'
|
||||
description: str = 'A tool for retrieving the weather/temperature for a given city.'
|
||||
name: str = "WeatherForecast"
|
||||
description: str = "A tool for retrieving the weather/temperature for a given city."
|
||||
|
||||
# Default user agent
|
||||
user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15"
|
||||
|
@ -23,7 +24,9 @@ class WeatherForecast(AgentTool):
|
|||
f"No data found during {stage}. URL: {url}. Response: {response.text}"
|
||||
)
|
||||
|
||||
def _run(self, city: str, state: Optional[str] = None, country: Optional[str] = "usa") -> dict:
|
||||
def _run(
|
||||
self, city: str, state: Optional[str] = None, country: Optional[str] = "usa"
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieves weather data by first fetching geocode data for the city and then fetching weather data.
|
||||
|
||||
|
@ -35,12 +38,12 @@ class WeatherForecast(AgentTool):
|
|||
Returns:
|
||||
dict: A dictionary containing the city, state, country, and current temperature.
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": self.user_agent
|
||||
}
|
||||
headers = {"User-Agent": self.user_agent}
|
||||
|
||||
# Construct the geocode URL, conditionally including the state if it's provided
|
||||
geocode_url = f"https://nominatim.openstreetmap.org/search?city={city}&country={country}"
|
||||
geocode_url = (
|
||||
f"https://nominatim.openstreetmap.org/search?city={city}&country={country}"
|
||||
)
|
||||
if state:
|
||||
geocode_url += f"&state={state}"
|
||||
geocode_url += "&limit=1&format=jsonv2"
|
||||
|
@ -81,7 +84,7 @@ class WeatherForecast(AgentTool):
|
|||
"state": state,
|
||||
"country": country,
|
||||
"temperature": today_forecast["temperature"],
|
||||
"unit": "Fahrenheit"
|
||||
"unit": "Fahrenheit",
|
||||
}
|
||||
|
||||
else:
|
||||
|
@ -91,8 +94,12 @@ class WeatherForecast(AgentTool):
|
|||
self.handle_error(weather_response, met_no_url, "Met.no weather lookup")
|
||||
|
||||
weather_data = weather_response.json()
|
||||
temperature_unit = weather_data["properties"]["meta"]["units"]["air_temperature"]
|
||||
today_forecast = weather_data["properties"]["timeseries"][0]["data"]["instant"]["details"]["air_temperature"]
|
||||
temperature_unit = weather_data["properties"]["meta"]["units"][
|
||||
"air_temperature"
|
||||
]
|
||||
today_forecast = weather_data["properties"]["timeseries"][0]["data"][
|
||||
"instant"
|
||||
]["details"]["air_temperature"]
|
||||
|
||||
# Return the weather data along with the city, state, and country
|
||||
return {
|
||||
|
@ -100,12 +107,15 @@ class WeatherForecast(AgentTool):
|
|||
"state": state,
|
||||
"country": country,
|
||||
"temperature": today_forecast,
|
||||
"unit": temperature_unit
|
||||
"unit": temperature_unit,
|
||||
}
|
||||
|
||||
|
||||
class HistoricalWeather(AgentTool):
|
||||
name: str = 'HistoricalWeather'
|
||||
description: str = 'A tool for retrieving historical weather data (temperature) for a given city.'
|
||||
name: str = "HistoricalWeather"
|
||||
description: str = (
|
||||
"A tool for retrieving historical weather data (temperature) for a given city."
|
||||
)
|
||||
|
||||
# Default user agent
|
||||
user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15"
|
||||
|
@ -122,7 +132,14 @@ class HistoricalWeather(AgentTool):
|
|||
f"No data found during {stage}. URL: {url}. Response: {response.text}"
|
||||
)
|
||||
|
||||
def _run(self, city: str, state: Optional[str] = None, country: Optional[str] = "usa", start_date: Optional[str] = None, end_date: Optional[str] = None) -> dict:
|
||||
def _run(
|
||||
self,
|
||||
city: str,
|
||||
state: Optional[str] = None,
|
||||
country: Optional[str] = "usa",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieves historical weather data for the city by first fetching geocode data and then historical weather data.
|
||||
|
||||
|
@ -136,20 +153,27 @@ class HistoricalWeather(AgentTool):
|
|||
Returns:
|
||||
dict: A dictionary containing the city, state, country, and historical temperature data.
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": self.user_agent
|
||||
}
|
||||
headers = {"User-Agent": self.user_agent}
|
||||
|
||||
# Validate dates
|
||||
current_date = datetime.now().strftime('%Y-%m-%d')
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
if start_date >= current_date or end_date >= current_date:
|
||||
raise ValueError("Both start_date and end_date must be earlier than the current date.")
|
||||
raise ValueError(
|
||||
"Both start_date and end_date must be earlier than the current date."
|
||||
)
|
||||
|
||||
if (datetime.strptime(end_date, "%Y-%m-%d") - datetime.strptime(start_date, "%Y-%m-%d")).days > 30:
|
||||
raise ValueError("The time span between start_date and end_date cannot exceed 30 days.")
|
||||
if (
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
- datetime.strptime(start_date, "%Y-%m-%d")
|
||||
).days > 30:
|
||||
raise ValueError(
|
||||
"The time span between start_date and end_date cannot exceed 30 days."
|
||||
)
|
||||
|
||||
# Construct the geocode URL, conditionally including the state if it's provided
|
||||
geocode_url = f"https://nominatim.openstreetmap.org/search?city={city}&country={country}"
|
||||
geocode_url = (
|
||||
f"https://nominatim.openstreetmap.org/search?city={city}&country={country}"
|
||||
)
|
||||
if state:
|
||||
geocode_url += f"&state={state}"
|
||||
geocode_url += "&limit=1&format=jsonv2"
|
||||
|
@ -167,7 +191,9 @@ class HistoricalWeather(AgentTool):
|
|||
# Historical weather request
|
||||
historical_weather_url = f"https://archive-api.open-meteo.com/v1/archive?latitude={lat}&longitude={lon}&start_date={start_date}&end_date={end_date}&hourly=temperature_2m"
|
||||
weather_response = requests.get(historical_weather_url, headers=headers)
|
||||
self.handle_error(weather_response, historical_weather_url, "historical weather lookup")
|
||||
self.handle_error(
|
||||
weather_response, historical_weather_url, "historical weather lookup"
|
||||
)
|
||||
|
||||
weather_data = weather_response.json()
|
||||
|
||||
|
@ -177,7 +203,9 @@ class HistoricalWeather(AgentTool):
|
|||
temperature_unit = weather_data["hourly_units"]["temperature_2m"]
|
||||
|
||||
# Combine timestamps and temperatures into a dictionary
|
||||
temperature_data = {timestamps[i]: temperatures[i] for i in range(len(timestamps))}
|
||||
temperature_data = {
|
||||
timestamps[i]: temperatures[i] for i in range(len(timestamps))
|
||||
}
|
||||
|
||||
# Return the structured weather data along with the city, state, country
|
||||
return {
|
||||
|
@ -187,5 +215,5 @@ class HistoricalWeather(AgentTool):
|
|||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"temperature_data": temperature_data,
|
||||
"unit": temperature_unit
|
||||
"unit": temperature_unit,
|
||||
}
|
|
@ -0,0 +1,501 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "39c2dcc0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Executor: LocalCodeExecutor Basic Examples\n",
|
||||
"\n",
|
||||
"This notebook shows how to execute Python and shell snippets in **isolated, cached virtual environments**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c4ff4b2b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install Required Libraries\n",
|
||||
"Before starting, ensure the required libraries are installed:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5b41a66a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install dapr-agents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a9c01be3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "508fd446",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"from dapr_agents.executors.local import LocalCodeExecutor\n",
|
||||
"from dapr_agents.types.executor import CodeSnippet, ExecutionRequest\n",
|
||||
"from rich.console import Console\n",
|
||||
"from rich.ansi import AnsiDecoder\n",
|
||||
"import shutil"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "27594072",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"logging.basicConfig(level=logging.INFO)\n",
|
||||
"\n",
|
||||
"executor = LocalCodeExecutor()\n",
|
||||
"console = Console()\n",
|
||||
"decoder = AnsiDecoder()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d663475",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running a basic Python Code Snippet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ba45ddc8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
|
||||
"INFO:dapr_agents.executors.local:Created a new virtual environment\n",
|
||||
"INFO:dapr_agents.executors.local:Installing print, rich\n",
|
||||
"INFO:dapr_agents.executors.local:Snippet 1 finished in 2.442s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">Hello executor!</span>\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u001b[1;32mHello executor!\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"code = \"\"\"\n",
|
||||
"from rich import print\n",
|
||||
"print(\"[bold green]Hello executor![/bold green]\")\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"request = ExecutionRequest(snippets=[\n",
|
||||
" CodeSnippet(language='python', code=code, timeout=10)\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"results = await executor.execute(request)\n",
|
||||
"results[0] # raw result\n",
|
||||
"\n",
|
||||
"# pretty‑print with Rich\n",
|
||||
"console.print(*decoder.decode(results[0].output))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d28c7531",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run a Shell Snipper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4ea89b85",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
|
||||
"INFO:dapr_agents.executors.local:Snippet 1 finished in 0.019s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[ExecutionResult(status='success', output='4\\n', exit_code=0)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"shell_request = ExecutionRequest(snippets=[\n",
|
||||
" CodeSnippet(language='sh', code='echo $((2+2))', timeout=5)\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"await executor.execute(shell_request)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "da281b6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Reuse the cached virtual environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "3e9e7e9b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
|
||||
"INFO:dapr_agents.executors.local:Reusing cached virtual environment.\n",
|
||||
"INFO:dapr_agents.executors.local:Installing print, rich\n",
|
||||
"INFO:dapr_agents.executors.local:Snippet 1 finished in 0.297s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[ExecutionResult(status='success', output='\\x1b[1;32mHello executor!\\x1b[0m\\n', exit_code=0)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Re‑running the same Python request will reuse the cached venv, so it is faster\n",
|
||||
"await executor.execute(request)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "14dc3e4c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Inject Helper Functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "82f9a168",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
|
||||
"INFO:dapr_agents.executors.local:Created a new virtual environment\n",
|
||||
"INFO:dapr_agents.executors.local:Snippet 1 finished in 1.408s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[ExecutionResult(status='success', output='42\\n', exit_code=0)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def fancy_sum(a: int, b: int) -> int:\n",
|
||||
" return a + b\n",
|
||||
"\n",
|
||||
"executor.user_functions.append(fancy_sum)\n",
|
||||
"\n",
|
||||
"helper_request = ExecutionRequest(snippets=[\n",
|
||||
" CodeSnippet(language='python', code='print(fancy_sum(40, 2))', timeout=5)\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"await executor.execute(helper_request)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25f9718c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Clean Up"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "b09059f1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Cache directory removed ✅\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"shutil.rmtree(executor.cache_dir, ignore_errors=True)\n",
|
||||
"print(\"Cache directory removed ✅\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c93cdef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Package-manager detection & automatic bootstrap"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "8691f3e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dapr_agents.executors.utils import package_manager as pm\n",
|
||||
"import pathlib, tempfile"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e9e08d81",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create a throw-away project"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "4c7dd9c3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tmp project: /var/folders/9z/8xhqw8x1611fcbhzl339yrs40000gn/T/tmpmssk0m2b\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tmp_proj = pathlib.Path(tempfile.mkdtemp())\n",
|
||||
"(tmp_proj / \"requirements.txt\").write_text(\"rich==13.7.0\\n\")\n",
|
||||
"print(\"tmp project:\", tmp_proj)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "03558a95",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Show what the helper detects"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "3b5acbfb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"detect_package_managers -> [<PackageManagerType.PIP: 'pip'>]\n",
|
||||
"get_install_command -> pip install -r requirements.txt\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"detect_package_managers ->\",\n",
|
||||
" [m.name for m in pm.detect_package_managers(tmp_proj)])\n",
|
||||
"print(\"get_install_command ->\",\n",
|
||||
" pm.get_install_command(tmp_proj))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "42f1ae7c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Point the executor at that directory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "81e53cf4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from contextlib import contextmanager, ExitStack\n",
|
||||
"\n",
|
||||
"@contextmanager\n",
|
||||
"def chdir(path):\n",
|
||||
" \"\"\"\n",
|
||||
" Temporarily change the process CWD to *path*.\n",
|
||||
"\n",
|
||||
" Works on every CPython ≥ 3.6 (and PyPy) and restores the old directory\n",
|
||||
" even if an exception is raised inside the block.\n",
|
||||
" \"\"\"\n",
|
||||
" old_cwd = os.getcwd()\n",
|
||||
" os.chdir(path)\n",
|
||||
" try:\n",
|
||||
" yield\n",
|
||||
" finally:\n",
|
||||
" os.chdir(old_cwd)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "fb2f5052",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:dapr_agents.executors.local:bootstrapping python project with 'pip install -r requirements.txt'\n",
|
||||
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
|
||||
"INFO:dapr_agents.executors.local:Created a new virtual environment\n",
|
||||
"INFO:dapr_agents.executors.local:Snippet 1 finished in 1.433s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">bootstrap OK\n",
|
||||
"\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"bootstrap OK\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with ExitStack() as stack:\n",
|
||||
" # keep a directory handle open (optional but handy if you’ll delete tmp_proj later)\n",
|
||||
" stack.enter_context(os.scandir(tmp_proj))\n",
|
||||
"\n",
|
||||
" # <-- our portable replacement for contextlib.chdir()\n",
|
||||
" stack.enter_context(chdir(tmp_proj))\n",
|
||||
"\n",
|
||||
" # run a trivial snippet; executor will bootstrap because it now “sees”\n",
|
||||
" # requirements.txt in the current working directory\n",
|
||||
" out = await executor.execute(\n",
|
||||
" ExecutionRequest(snippets=[\n",
|
||||
" CodeSnippet(language=\"python\", code=\"print('bootstrap OK')\", timeout=5)\n",
|
||||
" ])\n",
|
||||
" )\n",
|
||||
" console.print(out[0].output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45de2386",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Clean Up the throw-away project "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "0c7aa010",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Cache directory removed ✅\n",
|
||||
"temporary project removed ✅\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"shutil.rmtree(executor.cache_dir, ignore_errors=True)\n",
|
||||
"print(\"Cache directory removed ✅\")\n",
|
||||
"shutil.rmtree(tmp_proj, ignore_errors=True)\n",
|
||||
"print(\"temporary project removed ✅\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "36ea4010",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -12,11 +12,11 @@ from tools import mcp
|
|||
# Logging Configuration
|
||||
# ─────────────────────────────────────────────
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("mcp-server")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Starlette App Factory
|
||||
# ─────────────────────────────────────────────
|
||||
|
@ -29,27 +29,44 @@ def create_starlette_app():
|
|||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
logger.info("🔌 SSE connection established")
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
):
|
||||
logger.debug("Starting MCP server run loop over SSE")
|
||||
await mcp._mcp_server.run(read_stream, write_stream, mcp._mcp_server.create_initialization_options())
|
||||
await mcp._mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp._mcp_server.create_initialization_options(),
|
||||
)
|
||||
logger.debug("MCP run loop completed")
|
||||
|
||||
return Starlette(
|
||||
debug=False,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message)
|
||||
]
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# CLI Entrypoint
|
||||
# ─────────────────────────────────────────────
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run an MCP tool server.")
|
||||
parser.add_argument("--server_type", choices=["stdio", "sse"], default="stdio", help="Transport to use")
|
||||
parser.add_argument("--host", default="127.0.0.1", help="Host to bind to (SSE only)")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to bind to (SSE only)")
|
||||
parser.add_argument(
|
||||
"--server_type",
|
||||
choices=["stdio", "sse"],
|
||||
default="stdio",
|
||||
help="Transport to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default="127.0.0.1", help="Host to bind to (SSE only)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to bind to (SSE only)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"🚀 Starting MCP server in {args.server_type.upper()} mode")
|
||||
|
@ -61,5 +78,6 @@ def main():
|
|||
logger.info(f"🌐 Running SSE server on {args.host}:{args.port}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -3,12 +3,14 @@ import random
|
|||
|
||||
mcp = FastMCP("TestServer")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_weather(location: str) -> str:
|
||||
"""Get weather information for a specific location."""
|
||||
temperature = random.randint(60, 80)
|
||||
return f"{location}: {temperature}F."
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def jump(distance: str) -> str:
|
||||
"""Simulate a jump of a given distance."""
|
||||
|
|
|
@ -5,6 +5,7 @@ from dotenv import load_dotenv
|
|||
from dapr_agents import AssistantAgent
|
||||
from dapr_agents.tool.mcp import MCPClient
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Load MCP tools from server (stdio or sse)
|
||||
|
@ -38,6 +39,7 @@ async def main():
|
|||
except Exception as e:
|
||||
logging.exception("Error starting weather agent service", exc_info=e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
|
|
@ -23,7 +23,7 @@ if __name__ == "__main__":
|
|||
print(f"Request failed: {e}")
|
||||
|
||||
attempt += 1
|
||||
print(f"Waiting 5s seconds before next health checkattempt...")
|
||||
print("Waiting 5s seconds before next health checkattempt...")
|
||||
time.sleep(5)
|
||||
|
||||
if not healthy:
|
||||
|
@ -48,10 +48,10 @@ if __name__ == "__main__":
|
|||
print(f"Request failed: {e}")
|
||||
|
||||
attempt += 1
|
||||
print(f"Waiting 1s seconds before next attempt...")
|
||||
print("Waiting 1s seconds before next attempt...")
|
||||
time.sleep(1)
|
||||
|
||||
print(f"Maximum attempts (10) reached without success.")
|
||||
print("Maximum attempts (10) reached without success.")
|
||||
|
||||
print("Failed to get successful response")
|
||||
sys.exit(1)
|
|
@ -12,11 +12,11 @@ from tools import mcp
|
|||
# Logging Configuration
|
||||
# ─────────────────────────────────────────────
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("mcp-server")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Starlette App Factory
|
||||
# ─────────────────────────────────────────────
|
||||
|
@ -29,27 +29,44 @@ def create_starlette_app():
|
|||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
logger.info("🔌 SSE connection established")
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
):
|
||||
logger.debug("Starting MCP server run loop over SSE")
|
||||
await mcp._mcp_server.run(read_stream, write_stream, mcp._mcp_server.create_initialization_options())
|
||||
await mcp._mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp._mcp_server.create_initialization_options(),
|
||||
)
|
||||
logger.debug("MCP run loop completed")
|
||||
|
||||
return Starlette(
|
||||
debug=False,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message)
|
||||
]
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# CLI Entrypoint
|
||||
# ─────────────────────────────────────────────
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run an MCP tool server.")
|
||||
parser.add_argument("--server_type", choices=["stdio", "sse"], default="stdio", help="Transport to use")
|
||||
parser.add_argument("--host", default="127.0.0.1", help="Host to bind to (SSE only)")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to bind to (SSE only)")
|
||||
parser.add_argument(
|
||||
"--server_type",
|
||||
choices=["stdio", "sse"],
|
||||
default="stdio",
|
||||
help="Transport to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default="127.0.0.1", help="Host to bind to (SSE only)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to bind to (SSE only)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"🚀 Starting MCP server in {args.server_type.upper()} mode")
|
||||
|
@ -61,5 +78,6 @@ def main():
|
|||
logger.info(f"🌐 Running SSE server on {args.host}:{args.port}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -3,12 +3,14 @@ import random
|
|||
|
||||
mcp = FastMCP("TestServer")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_weather(location: str) -> str:
|
||||
"""Get weather information for a specific location."""
|
||||
temperature = random.randint(60, 80)
|
||||
return f"{location}: {temperature}F."
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def jump(distance: str) -> str:
|
||||
"""Simulate a jump of a given distance."""
|
||||
|
|
|
@ -3,39 +3,46 @@ import dapr.ext.workflow as wf
|
|||
|
||||
wfr = wf.WorkflowRuntime()
|
||||
|
||||
@wfr.workflow(name='random_workflow')
|
||||
|
||||
@wfr.workflow(name="random_workflow")
|
||||
def task_chain_workflow(ctx: wf.DaprWorkflowContext, x: int):
|
||||
result1 = yield ctx.call_activity(step1, input=x)
|
||||
result2 = yield ctx.call_activity(step2, input=result1)
|
||||
result3 = yield ctx.call_activity(step3, input=result2)
|
||||
return [result1, result2, result3]
|
||||
|
||||
|
||||
@wfr.activity
|
||||
def step1(ctx, activity_input):
|
||||
print(f'Step 1: Received input: {activity_input}.')
|
||||
print(f"Step 1: Received input: {activity_input}.")
|
||||
# Do some work
|
||||
return activity_input + 1
|
||||
|
||||
|
||||
@wfr.activity
|
||||
def step2(ctx, activity_input):
|
||||
print(f'Step 2: Received input: {activity_input}.')
|
||||
print(f"Step 2: Received input: {activity_input}.")
|
||||
# Do some work
|
||||
return activity_input * 2
|
||||
|
||||
|
||||
@wfr.activity
|
||||
def step3(ctx, activity_input):
|
||||
print(f'Step 3: Received input: {activity_input}.')
|
||||
print(f"Step 3: Received input: {activity_input}.")
|
||||
# Do some work
|
||||
return activity_input ^ 2
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
wfr.start()
|
||||
sleep(5) # wait for workflow runtime to start
|
||||
|
||||
wf_client = wf.DaprWorkflowClient()
|
||||
instance_id = wf_client.schedule_new_workflow(workflow=task_chain_workflow, input=10)
|
||||
print(f'Workflow started. Instance ID: {instance_id}')
|
||||
instance_id = wf_client.schedule_new_workflow(
|
||||
workflow=task_chain_workflow, input=10
|
||||
)
|
||||
print(f"Workflow started. Instance ID: {instance_id}")
|
||||
state = wf_client.wait_for_workflow_completion(instance_id)
|
||||
print(f'Workflow completed! Status: {state.runtime_status}')
|
||||
print(f"Workflow completed! Status: {state.runtime_status}")
|
||||
|
||||
wfr.shutdown()
|
|
@ -1,34 +1,39 @@
|
|||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
|
||||
@workflow(name='random_workflow')
|
||||
|
||||
@workflow(name="random_workflow")
|
||||
def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
|
||||
result1 = yield ctx.call_activity(step1, input=input)
|
||||
result2 = yield ctx.call_activity(step2, input=result1)
|
||||
result3 = yield ctx.call_activity(step3, input=result2)
|
||||
return [result1, result2, result3]
|
||||
|
||||
|
||||
@task
|
||||
def step1(activity_input):
|
||||
print(f'Step 1: Received input: {activity_input}.')
|
||||
print(f"Step 1: Received input: {activity_input}.")
|
||||
# Do some work
|
||||
return activity_input + 1
|
||||
|
||||
|
||||
@task
|
||||
def step2(activity_input):
|
||||
print(f'Step 2: Received input: {activity_input}.')
|
||||
print(f"Step 2: Received input: {activity_input}.")
|
||||
# Do some work
|
||||
return activity_input * 2
|
||||
|
||||
|
||||
@task
|
||||
def step3(activity_input):
|
||||
print(f'Step 3: Received input: {activity_input}.')
|
||||
print(f"Step 3: Received input: {activity_input}.")
|
||||
# Do some work
|
||||
return activity_input ^ 2
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
wfapp = WorkflowApp()
|
||||
|
|
|
@ -2,7 +2,8 @@ import asyncio
|
|||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
|
||||
|
||||
@workflow(name="random_workflow")
|
||||
def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
|
||||
|
@ -11,30 +12,32 @@ def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
|
|||
result3 = yield ctx.call_activity(step3, input=result2)
|
||||
return [result1, result2, result3]
|
||||
|
||||
|
||||
@task
|
||||
def step1(activity_input: int) -> int:
|
||||
print(f"Step 1: Received input: {activity_input}.")
|
||||
return activity_input + 1
|
||||
|
||||
|
||||
@task
|
||||
def step2(activity_input: int) -> int:
|
||||
print(f"Step 2: Received input: {activity_input}.")
|
||||
return activity_input * 2
|
||||
|
||||
|
||||
@task
|
||||
def step3(activity_input: int) -> int:
|
||||
print(f"Step 3: Received input: {activity_input}.")
|
||||
return activity_input ^ 2
|
||||
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
result = await wfapp.run_and_monitor_workflow_async(
|
||||
task_chain_workflow,
|
||||
input=10
|
||||
)
|
||||
result = await wfapp.run_and_monitor_workflow_async(task_chain_workflow, input=10)
|
||||
print(f"Results: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -1,27 +1,35 @@
|
|||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
|
||||
|
||||
# Define Workflow logic
|
||||
@workflow(name='lotr_workflow')
|
||||
@workflow(name="lotr_workflow")
|
||||
def task_chain_workflow(ctx: DaprWorkflowContext):
|
||||
result1 = yield ctx.call_activity(get_character)
|
||||
result2 = yield ctx.call_activity(get_line, input={"character": result1})
|
||||
return result2
|
||||
|
||||
@task(description="""
|
||||
|
||||
@task(
|
||||
description="""
|
||||
Pick a random character from The Lord of the Rings\n
|
||||
and respond with the character's name ONLY
|
||||
""")
|
||||
"""
|
||||
)
|
||||
def get_character() -> str:
|
||||
pass
|
||||
|
||||
@task(description="What is a famous line by {character}",)
|
||||
|
||||
@task(
|
||||
description="What is a famous line by {character}",
|
||||
)
|
||||
def get_line(character: str) -> str:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Load environment variables
|
||||
|
|
|
@ -2,27 +2,35 @@ import asyncio
|
|||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# Define Workflow logic
|
||||
@workflow(name='lotr_workflow')
|
||||
@workflow(name="lotr_workflow")
|
||||
def task_chain_workflow(ctx: DaprWorkflowContext):
|
||||
result1 = yield ctx.call_activity(get_character)
|
||||
result2 = yield ctx.call_activity(get_line, input={"character": result1})
|
||||
return result2
|
||||
|
||||
@task(description="""
|
||||
|
||||
@task(
|
||||
description="""
|
||||
Pick a random character from The Lord of the Rings\n
|
||||
and respond with the character's name ONLY
|
||||
""")
|
||||
"""
|
||||
)
|
||||
def get_character() -> str:
|
||||
pass
|
||||
|
||||
@task(description="What is a famous line by {character}",)
|
||||
|
||||
@task(
|
||||
description="What is a famous line by {character}",
|
||||
)
|
||||
def get_line(character: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
@ -36,5 +44,6 @@ async def main():
|
|||
result = await wfapp.run_and_monitor_workflow_async(task_chain_workflow)
|
||||
print(f"Results: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -1,33 +1,34 @@
|
|||
import logging
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
@workflow
|
||||
def question(ctx: DaprWorkflowContext, input: int):
|
||||
step1 = yield ctx.call_activity(ask, input=input)
|
||||
return step1
|
||||
|
||||
|
||||
class Dog(BaseModel):
|
||||
name: str
|
||||
bio: str
|
||||
breed: str
|
||||
|
||||
|
||||
@task("Who was {name}?")
|
||||
def ask(name: str) -> Dog:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
results = wfapp.run_and_monitor_workflow_sync(
|
||||
workflow=question,
|
||||
input="Scooby Doo"
|
||||
)
|
||||
results = wfapp.run_and_monitor_workflow_sync(workflow=question, input="Scooby Doo")
|
||||
|
||||
print(results)
|
|
@ -2,24 +2,28 @@ import asyncio
|
|||
import logging
|
||||
|
||||
from dapr_agents.workflow import WorkflowApp, workflow, task
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
@workflow
|
||||
def question(ctx: DaprWorkflowContext, input: int):
|
||||
step1 = yield ctx.call_activity(ask, input=input)
|
||||
return step1
|
||||
|
||||
|
||||
class Dog(BaseModel):
|
||||
name: str
|
||||
bio: str
|
||||
breed: str
|
||||
|
||||
|
||||
@task("Who was {name}?")
|
||||
def ask(name: str) -> Dog:
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
@ -31,10 +35,10 @@ async def main():
|
|||
|
||||
# Run workflow
|
||||
result = await wfapp.run_and_monitor_workflow_async(
|
||||
workflow=question,
|
||||
input="Scooby Doo"
|
||||
workflow=question, input="Scooby Doo"
|
||||
)
|
||||
print(f"Results: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -9,55 +9,54 @@ load_dotenv()
|
|||
# Initialize Workflow Instance
|
||||
wfr = wf.WorkflowRuntime()
|
||||
|
||||
|
||||
# Define Workflow logic
|
||||
@wfr.workflow(name='lotr_workflow')
|
||||
@wfr.workflow(name="lotr_workflow")
|
||||
def task_chain_workflow(ctx: wf.DaprWorkflowContext):
|
||||
result1 = yield ctx.call_activity(get_character)
|
||||
result2 = yield ctx.call_activity(get_line, input=result1)
|
||||
return result2
|
||||
|
||||
|
||||
# Activity 1
|
||||
@wfr.activity(name='step1')
|
||||
@wfr.activity(name="step1")
|
||||
def get_character(ctx):
|
||||
client = OpenAI()
|
||||
response = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Pick a random character from The Lord of the Rings and respond with the character name only"
|
||||
"content": "Pick a random character from The Lord of the Rings and respond with the character name only",
|
||||
}
|
||||
],
|
||||
model = 'gpt-4o'
|
||||
model="gpt-4o",
|
||||
)
|
||||
character = response.choices[0].message.content
|
||||
print(f"Character: {character}")
|
||||
return character
|
||||
|
||||
|
||||
# Activity 2
|
||||
@wfr.activity(name='step2')
|
||||
@wfr.activity(name="step2")
|
||||
def get_line(ctx, character: str):
|
||||
client = OpenAI()
|
||||
response = client.chat.completions.create(
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"What is a famous line by {character}"
|
||||
}
|
||||
],
|
||||
model = 'gpt-4o'
|
||||
messages=[{"role": "user", "content": f"What is a famous line by {character}"}],
|
||||
model="gpt-4o",
|
||||
)
|
||||
line = response.choices[0].message.content
|
||||
print(f"Line: {line}")
|
||||
return line
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
wfr.start()
|
||||
sleep(5) # wait for workflow runtime to start
|
||||
|
||||
wf_client = wf.DaprWorkflowClient()
|
||||
instance_id = wf_client.schedule_new_workflow(workflow=task_chain_workflow)
|
||||
print(f'Workflow started. Instance ID: {instance_id}')
|
||||
print(f"Workflow started. Instance ID: {instance_id}")
|
||||
state = wf_client.wait_for_workflow_completion(instance_id)
|
||||
print(f'Workflow completed! Status: {state.runtime_status}')
|
||||
print(f"Workflow completed! Status: {state.runtime_status}")
|
||||
|
||||
wfr.shutdown()
|
|
@ -1,5 +1,5 @@
|
|||
from dapr_agents.document.reader.pdf.pypdf import PyPDFReader
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
from dapr_agents import WorkflowApp
|
||||
from urllib.parse import urlparse, unquote
|
||||
from dotenv import load_dotenv
|
||||
|
@ -22,16 +22,19 @@ load_dotenv()
|
|||
# Initialize the WorkflowApp
|
||||
wfapp = WorkflowApp()
|
||||
|
||||
|
||||
# Define structured output models
|
||||
class SpeakerEntry(BaseModel):
|
||||
name: str
|
||||
text: str
|
||||
|
||||
|
||||
class PodcastDialogue(BaseModel):
|
||||
participants: List[SpeakerEntry]
|
||||
|
||||
|
||||
# Define Workflow logic
|
||||
@wfapp.workflow(name='doc2podcast')
|
||||
@wfapp.workflow(name="doc2podcast")
|
||||
def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
|
||||
# Extract pre-validated input
|
||||
podcast_name = input["podcast_name"]
|
||||
|
@ -44,10 +47,13 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
|
|||
audio_model = input["audio_model"]
|
||||
|
||||
# Step 1: Assign voices to the team
|
||||
team_config = yield ctx.call_activity(assign_podcast_voices, input={
|
||||
team_config = yield ctx.call_activity(
|
||||
assign_podcast_voices,
|
||||
input={
|
||||
"host_config": host_config,
|
||||
"participant_configs": participant_configs,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Read PDF and get documents
|
||||
file_path = yield ctx.call_activity(download_pdf, input=file_input)
|
||||
|
@ -67,7 +73,9 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
|
|||
"context": accumulated_context,
|
||||
"participants": [p["name"] for p in team_config["participants"]],
|
||||
}
|
||||
generated_prompt = yield ctx.call_activity(generate_prompt, input=document_with_context)
|
||||
generated_prompt = yield ctx.call_activity(
|
||||
generate_prompt, input=document_with_context
|
||||
)
|
||||
|
||||
# Use the prompt to generate the structured dialogue
|
||||
prompt_parameters = {
|
||||
|
@ -76,7 +84,9 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
|
|||
"prompt": generated_prompt,
|
||||
"max_rounds": max_rounds,
|
||||
}
|
||||
dialogue_entry = yield ctx.call_activity(generate_transcript, input=prompt_parameters)
|
||||
dialogue_entry = yield ctx.call_activity(
|
||||
generate_transcript, input=prompt_parameters
|
||||
)
|
||||
|
||||
# Update context and transcript parts
|
||||
conversations = dialogue_entry["participants"]
|
||||
|
@ -85,18 +95,30 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
|
|||
transcript_parts.append(participant)
|
||||
|
||||
# Step 4: Write the final transcript to a file
|
||||
yield ctx.call_activity(write_transcript_to_file, input={"podcast_dialogue": transcript_parts, "output_path": output_transcript_path})
|
||||
yield ctx.call_activity(
|
||||
write_transcript_to_file,
|
||||
input={
|
||||
"podcast_dialogue": transcript_parts,
|
||||
"output_path": output_transcript_path,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 5: Convert transcript to audio using team_config
|
||||
yield ctx.call_activity(convert_transcript_to_audio, input={
|
||||
yield ctx.call_activity(
|
||||
convert_transcript_to_audio,
|
||||
input={
|
||||
"transcript_parts": transcript_parts,
|
||||
"output_path": output_audio_path,
|
||||
"voices": team_config,
|
||||
"model": audio_model,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@wfapp.task
|
||||
def assign_podcast_voices(host_config: Dict[str, Any], participant_configs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def assign_podcast_voices(
|
||||
host_config: Dict[str, Any], participant_configs: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Assign voices to the podcast host and participants.
|
||||
|
||||
|
@ -112,7 +134,9 @@ def assign_podcast_voices(host_config: Dict[str, Any], participant_configs: List
|
|||
|
||||
# Assign voice to the host if not already specified
|
||||
if "voice" not in host_config:
|
||||
host_config["voice"] = next(voice for voice in allowed_voices if voice not in assigned_voices)
|
||||
host_config["voice"] = next(
|
||||
voice for voice in allowed_voices if voice not in assigned_voices
|
||||
)
|
||||
assigned_voices.add(host_config["voice"])
|
||||
|
||||
# Assign voices to participants, ensuring no duplicates
|
||||
|
@ -131,6 +155,7 @@ def assign_podcast_voices(host_config: Dict[str, Any], participant_configs: List
|
|||
"participants": updated_participants,
|
||||
}
|
||||
|
||||
|
||||
@wfapp.task
|
||||
def download_pdf(pdf_url: str, local_directory: str = ".") -> str:
|
||||
"""
|
||||
|
@ -163,6 +188,7 @@ def download_pdf(pdf_url: str, local_directory: str = ".") -> str:
|
|||
logger.error(f"Error downloading PDF: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@wfapp.task
|
||||
def read_pdf(file_path: str) -> List[dict]:
|
||||
"""
|
||||
|
@ -176,8 +202,15 @@ def read_pdf(file_path: str) -> List[dict]:
|
|||
logger.error(f"Error reading document: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@wfapp.task
|
||||
def generate_prompt(text: str, iteration_index: int, total_iterations: int, context: str, participants: List[str]) -> str:
|
||||
def generate_prompt(
|
||||
text: str,
|
||||
iteration_index: int,
|
||||
total_iterations: int,
|
||||
context: str,
|
||||
participants: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt dynamically for the chunk.
|
||||
"""
|
||||
|
@ -189,7 +222,7 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
|
|||
"""
|
||||
|
||||
if participants:
|
||||
participant_names = ', '.join(participants)
|
||||
participant_names = ", ".join(participants)
|
||||
instructions += f"\nPARTICIPANTS: {participant_names}"
|
||||
else:
|
||||
instructions += "\nPARTICIPANTS: None (Host-only conversation)"
|
||||
|
@ -214,7 +247,7 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
|
|||
- Follow up on the previous discussion points and introduce the next topic naturally.
|
||||
"""
|
||||
|
||||
instructions += f"""
|
||||
instructions += """
|
||||
TASK:
|
||||
- Use the provided TEXT to guide this part of the conversation.
|
||||
- Alternate between speakers, ensuring a natural conversational flow.
|
||||
|
@ -222,7 +255,9 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
|
|||
"""
|
||||
return f"{instructions}\nTEXT:\n{text.strip()}"
|
||||
|
||||
@wfapp.task("""
|
||||
|
||||
@wfapp.task(
|
||||
"""
|
||||
Generate a structured podcast dialogue based on the context and text provided.
|
||||
The podcast is titled '{podcast_name}' and is hosted by {host_name}.
|
||||
If participants are available, each speaker is limited to a maximum of {max_rounds} turns per iteration.
|
||||
|
@ -231,26 +266,39 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
|
|||
If participants are not available, the host drives the conversation alone.
|
||||
Keep the dialogue concise and ensure a natural conversational flow.
|
||||
{prompt}
|
||||
""")
|
||||
def generate_transcript(podcast_name: str, host_name: str, prompt: str, max_rounds: int) -> PodcastDialogue:
|
||||
"""
|
||||
)
|
||||
def generate_transcript(
|
||||
podcast_name: str, host_name: str, prompt: str, max_rounds: int
|
||||
) -> PodcastDialogue:
|
||||
pass
|
||||
|
||||
|
||||
@wfapp.task
|
||||
def write_transcript_to_file(podcast_dialogue: List[Dict[str, Any]], output_path: str) -> None:
|
||||
def write_transcript_to_file(
|
||||
podcast_dialogue: List[Dict[str, Any]], output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Write the final structured transcript to a file.
|
||||
"""
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as file:
|
||||
import json
|
||||
|
||||
json.dump(podcast_dialogue, file, ensure_ascii=False, indent=4)
|
||||
logger.info(f"Podcast dialogue successfully written to {output_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing podcast dialogue to file: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@wfapp.task
|
||||
def convert_transcript_to_audio(transcript_parts: List[Dict[str, Any]], output_path: str, voices: Dict[str, Any], model: str = "tts-1") -> None:
|
||||
def convert_transcript_to_audio(
|
||||
transcript_parts: List[Dict[str, Any]],
|
||||
output_path: str,
|
||||
voices: Dict[str, Any],
|
||||
model: str = "tts-1",
|
||||
) -> None:
|
||||
"""
|
||||
Converts a transcript into a single audio file using the OpenAI Audio Client and pydub for concatenation.
|
||||
|
||||
|
@ -271,24 +319,30 @@ def convert_transcript_to_audio(transcript_parts: List[Dict[str, Any]], output_p
|
|||
for part in transcript_parts:
|
||||
speaker_name = part["name"]
|
||||
speaker_text = part["text"]
|
||||
assigned_voice = voice_mapping.get(speaker_name, "alloy") # Default to "alloy" if not found
|
||||
assigned_voice = voice_mapping.get(
|
||||
speaker_name, "alloy"
|
||||
) # Default to "alloy" if not found
|
||||
|
||||
# Log assigned voice for debugging
|
||||
logger.info(f"Generating audio for {speaker_name} using voice '{assigned_voice}'.")
|
||||
logger.info(
|
||||
f"Generating audio for {speaker_name} using voice '{assigned_voice}'."
|
||||
)
|
||||
|
||||
# Create TTS request
|
||||
tts_request = AudioSpeechRequest(
|
||||
model=model,
|
||||
input=speaker_text,
|
||||
voice=assigned_voice,
|
||||
response_format="mp3"
|
||||
response_format="mp3",
|
||||
)
|
||||
|
||||
# Generate the audio
|
||||
audio_bytes = client.create_speech(request=tts_request)
|
||||
|
||||
# Create an AudioSegment from the audio bytes
|
||||
audio_chunk = AudioSegment.from_file(io.BytesIO(audio_bytes), format=tts_request.response_format)
|
||||
audio_chunk = AudioSegment.from_file(
|
||||
io.BytesIO(audio_bytes), format=tts_request.response_format
|
||||
)
|
||||
|
||||
# Append the audio to the combined segment
|
||||
combined_audio += audio_chunk + AudioSegment.silent(duration=300)
|
||||
|
@ -301,17 +355,18 @@ def convert_transcript_to_audio(transcript_parts: List[Dict[str, Any]], output_p
|
|||
logger.error(f"Error during audio generation: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import json
|
||||
import yaml
|
||||
|
||||
def load_config(file_path: str) -> dict:
|
||||
"""Load configuration from a JSON or YAML file."""
|
||||
with open(file_path, 'r') as file:
|
||||
if file_path.endswith('.yaml') or file_path.endswith('.yml'):
|
||||
with open(file_path, "r") as file:
|
||||
if file_path.endswith(".yaml") or file_path.endswith(".yml"):
|
||||
return yaml.safe_load(file)
|
||||
elif file_path.endswith('.json'):
|
||||
elif file_path.endswith(".json"):
|
||||
return json.load(file)
|
||||
else:
|
||||
raise ValueError("Unsupported file format. Use JSON or YAML.")
|
||||
|
@ -323,11 +378,21 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--podcast_name", type=str, help="Name of the podcast.")
|
||||
parser.add_argument("--host_name", type=str, help="Name of the host.")
|
||||
parser.add_argument("--host_voice", type=str, help="Voice for the host.")
|
||||
parser.add_argument("--participants", type=str, nargs='+', help="List of participant names.")
|
||||
parser.add_argument("--max_rounds", type=int, default=4, help="Number of turns per round.")
|
||||
parser.add_argument("--output_transcript_path", type=str, help="Path to save the output transcript.")
|
||||
parser.add_argument("--output_audio_path", type=str, help="Path to save the final audio file.")
|
||||
parser.add_argument("--audio_model", type=str, default="tts-1", help="Audio model for TTS.")
|
||||
parser.add_argument(
|
||||
"--participants", type=str, nargs="+", help="List of participant names."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_rounds", type=int, default=4, help="Number of turns per round."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_transcript_path", type=str, help="Path to save the output transcript."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_audio_path", type=str, help="Path to save the final audio file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_model", type=str, default="tts-1", help="Audio model for TTS."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -337,15 +402,18 @@ if __name__ == '__main__':
|
|||
# Merge CLI and Config inputs
|
||||
user_input = {
|
||||
"pdf_url": args.pdf_url or config.get("pdf_url"),
|
||||
"podcast_name": args.podcast_name or config.get("podcast_name", "Default Podcast"),
|
||||
"podcast_name": args.podcast_name
|
||||
or config.get("podcast_name", "Default Podcast"),
|
||||
"host": {
|
||||
"name": args.host_name or config.get("host", {}).get("name", "Host"),
|
||||
"voice": args.host_voice or config.get("host", {}).get("voice", "alloy"),
|
||||
},
|
||||
"participants": config.get("participants", []),
|
||||
"max_rounds": args.max_rounds or config.get("max_rounds", 4),
|
||||
"output_transcript_path": args.output_transcript_path or config.get("output_transcript_path", "podcast_dialogue.json"),
|
||||
"output_audio_path": args.output_audio_path or config.get("output_audio_path", "final_podcast.mp3"),
|
||||
"output_transcript_path": args.output_transcript_path
|
||||
or config.get("output_transcript_path", "podcast_dialogue.json"),
|
||||
"output_audio_path": args.output_audio_path
|
||||
or config.get("output_audio_path", "final_podcast.mp3"),
|
||||
"audio_model": args.audio_model or config.get("audio_model", "tts-1"),
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dapr_agents import OpenAIChatClient, NVIDIAChatClient
|
||||
from dapr_agents.types import DaprWorkflowContext
|
||||
from dapr.ext.workflow import DaprWorkflowContext
|
||||
from dapr_agents.workflow import WorkflowApp, task, workflow
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
@ -8,8 +8,7 @@ import logging
|
|||
load_dotenv()
|
||||
|
||||
nvidia_llm = NVIDIAChatClient(
|
||||
model="meta/llama-3.1-8b-instruct",
|
||||
api_key=os.getenv("NVIDIA_API_KEY")
|
||||
model="meta/llama-3.1-8b-instruct", api_key=os.getenv("NVIDIA_API_KEY")
|
||||
)
|
||||
|
||||
oai_llm = OpenAIChatClient(
|
||||
|
@ -22,7 +21,7 @@ azoai_llm = OpenAIChatClient(
|
|||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
azure_deployment="gpt-4o-mini",
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
azure_api_version="2024-12-01-preview"
|
||||
azure_api_version="2024-12-01-preview",
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,20 +35,29 @@ def test_workflow(ctx: DaprWorkflowContext):
|
|||
nvidia_results = yield ctx.call_activity(invoke_nvidia, input=azoai_results)
|
||||
return nvidia_results
|
||||
|
||||
@task(description="What is the name of the capital of {country}?. Reply with just the name.", llm=oai_llm)
|
||||
|
||||
@task(
|
||||
description="What is the name of the capital of {country}?. Reply with just the name.",
|
||||
llm=oai_llm,
|
||||
)
|
||||
def invoke_oai(country: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
@task(description="What is a famous thing about {capital}?", llm=azoai_llm)
|
||||
def invoke_azoai(capital: str) -> str:
|
||||
pass
|
||||
|
||||
@task(description="Context: {context}. From the previous context. Pick one thing to do.", llm=nvidia_llm)
|
||||
|
||||
@task(
|
||||
description="Context: {context}. From the previous context. Pick one thing to do.",
|
||||
llm=nvidia_llm,
|
||||
)
|
||||
def invoke_nvidia(context: str) -> str:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
wfapp = WorkflowApp()
|
||||
|
|
|
@ -31,12 +31,9 @@ def sub(a: float, b: float) -> float:
|
|||
|
||||
|
||||
async def main():
|
||||
|
||||
|
||||
calculator_agent = Agent(
|
||||
name="MathematicsAgent",
|
||||
role="Calculator Assistant",
|
||||
|
||||
goal="Assist Humans with calculation tasks.",
|
||||
instructions=[
|
||||
"Get accurate calculation results",
|
||||
|
|
|
@ -7,6 +7,7 @@ from dapr.clients import DaprClient
|
|||
# Default Pub/Sub component
|
||||
PUBSUB_NAME = "pubsub"
|
||||
|
||||
|
||||
def main(orchestrator_topic, max_attempts=10, retry_delay=1):
|
||||
"""
|
||||
Publishes a task to a specified Dapr Pub/Sub topic with retries.
|
||||
|
@ -26,7 +27,9 @@ def main(orchestrator_topic, max_attempts=10, retry_delay=1):
|
|||
|
||||
while attempt <= max_attempts:
|
||||
try:
|
||||
print(f"📢 Attempt {attempt}: Publishing to topic '{orchestrator_topic}'...")
|
||||
print(
|
||||
f"📢 Attempt {attempt}: Publishing to topic '{orchestrator_topic}'..."
|
||||
)
|
||||
|
||||
with DaprClient() as client:
|
||||
client.publish_event(
|
||||
|
@ -36,7 +39,7 @@ def main(orchestrator_topic, max_attempts=10, retry_delay=1):
|
|||
data_content_type="application/json",
|
||||
publish_metadata={
|
||||
"cloudevent.type": "TriggerAction",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
print(f"✅ Successfully published request to '{orchestrator_topic}'")
|
||||
|
@ -52,8 +55,8 @@ def main(orchestrator_topic, max_attempts=10, retry_delay=1):
|
|||
print(f"❌ Maximum attempts ({max_attempts}) reached without success.")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
orchestrator_topic = 'LLMOrchestrator'
|
||||
if __name__ == "__main__":
|
||||
orchestrator_topic = "LLMOrchestrator"
|
||||
|
||||
main(orchestrator_topic)
|
|
@ -6,10 +6,8 @@ import logging
|
|||
|
||||
async def main():
|
||||
try:
|
||||
|
||||
workflow_service = LLMOrchestrator(
|
||||
name="LLMOrchestrator",
|
||||
|
||||
message_bus_name="pubsub",
|
||||
state_store_name="workflowstatestore",
|
||||
state_key="workflow_state",
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,8 +16,8 @@ async def main():
|
|||
"Be swift, silent, and precise, moving effortlessly across any terrain.",
|
||||
"Use superior vision and heightened senses to scout ahead and detect threats.",
|
||||
"Excel in ranged combat, delivering pinpoint arrow strikes from great distances.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
]
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose Agent as an Actor over a Service
|
||||
|
@ -32,6 +33,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,8 +16,8 @@ async def main():
|
|||
"Endure hardships and temptations, staying true to the mission even when faced with doubt.",
|
||||
"Seek guidance and trust allies, but bear the ultimate burden alone when necessary.",
|
||||
"Move carefully through enemy-infested lands, avoiding unnecessary risks.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
]
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose Agent as an Actor over a Service
|
||||
|
@ -25,13 +26,14 @@ async def main():
|
|||
message_bus_name="messagepubsub",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
service_port=8001
|
||||
service_port=8001,
|
||||
)
|
||||
|
||||
await hobbit_service.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,8 +16,8 @@ async def main():
|
|||
"Provide strategic counsel, always considering the long-term consequences of actions.",
|
||||
"Use magic sparingly, applying it when necessary to guide or protect.",
|
||||
"Encourage allies to find strength within themselves rather than relying solely on your power.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
]
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose Agent as an Actor over a Service
|
||||
|
@ -25,13 +26,14 @@ async def main():
|
|||
message_bus_name="messagepubsub",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
service_port=8002
|
||||
service_port=8002,
|
||||
)
|
||||
|
||||
await wizard_service.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
agentic_orchestrator = LLMOrchestrator(
|
||||
|
@ -12,13 +13,14 @@ async def main():
|
|||
state_key="workflow_state",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
max_iterations=25
|
||||
max_iterations=25,
|
||||
).as_service(port=8004)
|
||||
|
||||
await agentic_orchestrator.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
random_workflow_service = RandomOrchestrator(
|
||||
|
@ -12,13 +13,14 @@ async def main():
|
|||
state_key="workflow_state",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
max_iterations=3
|
||||
max_iterations=3,
|
||||
).as_service(port=8004)
|
||||
|
||||
await random_workflow_service.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
roundrobin_workflow_service = RoundRobinOrchestrator(
|
||||
|
@ -12,13 +13,14 @@ async def main():
|
|||
state_key="workflow_state",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
max_iterations=3
|
||||
max_iterations=3,
|
||||
).as_service(port=8004)
|
||||
|
||||
await roundrobin_workflow_service.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,7 +16,7 @@ async def main():
|
|||
"Be strong-willed, fiercely loyal, and protective of companions.",
|
||||
"Excel in close combat and battlefield tactics, favoring axes and brute strength.",
|
||||
"Navigate caves, tunnels, and ancient stonework with expert knowledge.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
message_bus_name="messagepubsub",
|
||||
state_store_name="agenticworkflowstate",
|
||||
|
@ -28,6 +29,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Eagle Agent
|
||||
|
@ -16,7 +17,7 @@ async def main():
|
|||
"Provide swift and strategic transport for those on critical journeys.",
|
||||
"Offer aerial insights, spotting dangers, tracking movements, and scouting strategic locations.",
|
||||
"Speak with wisdom and authority, as one of the ancient and noble Great Eagles.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
message_bus_name="messagepubsub",
|
||||
state_store_name="agenticworkflowstate",
|
||||
|
@ -29,6 +30,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,7 +16,7 @@ async def main():
|
|||
"Be swift, silent, and precise, moving effortlessly across any terrain.",
|
||||
"Use superior vision and heightened senses to scout ahead and detect threats.",
|
||||
"Excel in ranged combat, delivering pinpoint arrow strikes from great distances.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
message_bus_name="messagepubsub",
|
||||
state_store_name="agenticworkflowstate",
|
||||
|
@ -28,6 +29,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,7 +16,7 @@ async def main():
|
|||
"Endure hardships and temptations, staying true to the mission even when faced with doubt.",
|
||||
"Seek guidance and trust allies, but bear the ultimate burden alone when necessary.",
|
||||
"Move carefully through enemy-infested lands, avoiding unnecessary risks.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
message_bus_name="messagepubsub",
|
||||
state_store_name="agenticworkflowstate",
|
||||
|
@ -28,6 +29,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
agentic_orchestrator = LLMOrchestrator(
|
||||
|
@ -12,13 +13,14 @@ async def main():
|
|||
state_key="workflow_state",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
max_iterations=3
|
||||
max_iterations=3,
|
||||
).as_service(port=8004)
|
||||
|
||||
await agentic_orchestrator.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,7 +16,7 @@ async def main():
|
|||
"Lead by example, inspiring courage and loyalty in allies.",
|
||||
"Navigate wilderness with expert tracking and survival skills.",
|
||||
"Master both swordplay and battlefield strategy, excelling in one-on-one combat and large-scale warfare.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
message_bus_name="messagepubsub",
|
||||
state_store_name="agenticworkflowstate",
|
||||
|
@ -28,6 +29,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Define Agent
|
||||
|
@ -15,7 +16,7 @@ async def main():
|
|||
"Provide strategic counsel, always considering the long-term consequences of actions.",
|
||||
"Use magic sparingly, applying it when necessary to guide or protect.",
|
||||
"Encourage allies to find strength within themselves rather than relying solely on your power.",
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task."
|
||||
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
|
||||
],
|
||||
message_bus_name="messagepubsub",
|
||||
state_store_name="agenticworkflowstate",
|
||||
|
@ -28,6 +29,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
random_workflow_service = RandomOrchestrator(
|
||||
|
@ -12,13 +13,14 @@ async def main():
|
|||
state_key="workflow_state",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
max_iterations=3
|
||||
max_iterations=3,
|
||||
).as_service(port=8004)
|
||||
|
||||
await random_workflow_service.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
roundrobin_workflow_service = RoundRobinOrchestrator(
|
||||
|
@ -12,13 +13,14 @@ async def main():
|
|||
state_key="workflow_state",
|
||||
agents_registry_store_name="agentsregistrystore",
|
||||
agents_registry_key="agents_registry",
|
||||
max_iterations=3
|
||||
max_iterations=3,
|
||||
).as_service(port=8004)
|
||||
|
||||
await roundrobin_workflow_service.start()
|
||||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from dotenv import load_dotenv
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Create the Weather Agent using those tools
|
||||
|
@ -27,6 +28,7 @@ async def main():
|
|||
except Exception as e:
|
||||
print(f"Error starting service: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ if __name__ == "__main__":
|
|||
print(f"Request failed: {e}")
|
||||
|
||||
attempt += 1
|
||||
print(f"Waiting 5s seconds before next health checkattempt...")
|
||||
print("Waiting 5s seconds before next health checkattempt...")
|
||||
time.sleep(5)
|
||||
|
||||
if not healthy:
|
||||
|
@ -48,10 +48,10 @@ if __name__ == "__main__":
|
|||
print(f"Request failed: {e}")
|
||||
|
||||
attempt += 1
|
||||
print(f"Waiting 1s seconds before next attempt...")
|
||||
print("Waiting 1s seconds before next attempt...")
|
||||
time.sleep(1)
|
||||
|
||||
print(f"Maximum attempts (10) reached without success.")
|
||||
print("Maximum attempts (10) reached without success.")
|
||||
|
||||
print("Failed to get successful response")
|
||||
sys.exit(1)
|
|
@ -1,12 +1,25 @@
|
|||
from dapr_agents.agent import Agent, AgentActor, ReActAgent, ToolCallAgent, OpenAPIReActAgent
|
||||
from dapr_agents.llm.openai import OpenAIChatClient, OpenAIAudioClient, OpenAIEmbeddingClient
|
||||
from dapr_agents.agent import (
|
||||
Agent,
|
||||
AgentActor,
|
||||
ReActAgent,
|
||||
ToolCallAgent,
|
||||
OpenAPIReActAgent,
|
||||
)
|
||||
from dapr_agents.llm.openai import (
|
||||
OpenAIChatClient,
|
||||
OpenAIAudioClient,
|
||||
OpenAIEmbeddingClient,
|
||||
)
|
||||
from dapr_agents.llm.huggingface import HFHubChatClient
|
||||
from dapr_agents.llm.nvidia import NVIDIAChatClient, NVIDIAEmbeddingClient
|
||||
from dapr_agents.llm.elevenlabs import ElevenLabsSpeechClient
|
||||
from dapr_agents.tool import AgentTool, tool
|
||||
from dapr_agents.workflow import (
|
||||
WorkflowApp, AgenticWorkflow,
|
||||
LLMOrchestrator, RandomOrchestrator, RoundRobinOrchestrator,
|
||||
AssistantAgent
|
||||
WorkflowApp,
|
||||
AgenticWorkflow,
|
||||
LLMOrchestrator,
|
||||
RandomOrchestrator,
|
||||
RoundRobinOrchestrator,
|
||||
AssistantAgent,
|
||||
)
|
||||
from dapr_agents.executors import LocalCodeExecutor, DockerCodeExecutor
|
|
@ -1,11 +1,16 @@
|
|||
import logging
|
||||
from dapr_agents.agent.actor.schemas import AgentTaskResponse, TriggerAction, BroadcastMessage
|
||||
from dapr_agents.agent.actor.schemas import (
|
||||
AgentTaskResponse,
|
||||
TriggerAction,
|
||||
BroadcastMessage,
|
||||
)
|
||||
from dapr_agents.agent.actor.service import AgentActorService
|
||||
from dapr_agents.types.agent import AgentActorMessage
|
||||
from dapr_agents.workflow.messaging.decorator import message_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentActor(AgentActorService):
|
||||
"""
|
||||
A Pydantic-based class for managing services and exposing FastAPI routes with Dapr pub/sub and actor support.
|
||||
|
@ -35,10 +40,16 @@ class AgentActor(AgentActorService):
|
|||
response = await self.invoke_task(task)
|
||||
|
||||
# Check if the response exists
|
||||
content = response.body.decode() if response and response.body else "Task completed but no response generated."
|
||||
content = (
|
||||
response.body.decode()
|
||||
if response and response.body
|
||||
else "Task completed but no response generated."
|
||||
)
|
||||
|
||||
# Broadcast result
|
||||
response_message = BroadcastMessage(name=self.agent.name, role="user", content=content)
|
||||
response_message = BroadcastMessage(
|
||||
name=self.agent.name, role="user", content=content
|
||||
)
|
||||
await self.broadcast_message(message=response_message)
|
||||
|
||||
# Update response
|
||||
|
@ -60,22 +71,30 @@ class AgentActor(AgentActorService):
|
|||
metadata = message.pop("_message_metadata", {})
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
logger.warning(f"{getattr(self, 'name', 'agent')} received a broadcast with invalid metadata. Ignoring.")
|
||||
logger.warning(
|
||||
f"{getattr(self, 'name', 'agent')} received a broadcast with invalid metadata. Ignoring."
|
||||
)
|
||||
return
|
||||
|
||||
source = metadata.get("source", "unknown_source")
|
||||
message_type = metadata.get("type", "unknown_type")
|
||||
message_content = message.get("content", "No content")
|
||||
|
||||
logger.info(f"{self.agent.name} received broadcast message of type '{message_type}' from '{source}'.")
|
||||
logger.info(
|
||||
f"{self.agent.name} received broadcast message of type '{message_type}' from '{source}'."
|
||||
)
|
||||
|
||||
# Ignore messages sent by this agent
|
||||
if source == self.agent.name:
|
||||
logger.info(f"{self.agent.name} ignored its own broadcast message of type '{message_type}'.")
|
||||
logger.info(
|
||||
f"{self.agent.name} ignored its own broadcast message of type '{message_type}'."
|
||||
)
|
||||
return
|
||||
|
||||
# Log and process the valid broadcast message
|
||||
logger.debug(f"{self.agent.name} is processing broadcast message of type '{message_type}' from '{source}'.")
|
||||
logger.debug(
|
||||
f"{self.agent.name} is processing broadcast message of type '{message_type}' from '{source}'."
|
||||
)
|
||||
logger.debug(f"Message content: {message_content}")
|
||||
|
||||
# Add the message to the agent's memory
|
||||
|
|
|
@ -16,6 +16,7 @@ from pydantic import ValidationError
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentActorBase(Actor, AgentActorInterface):
|
||||
"""Base class for all agent actors, including task execution and agent state management."""
|
||||
|
||||
|
@ -30,13 +31,17 @@ class AgentActorBase(Actor, AgentActorInterface):
|
|||
Called when the actor is activated. Initializes the agent's state if not present.
|
||||
"""
|
||||
logger.info(f"Activating actor with ID: {self.actor_id}")
|
||||
has_state, state_data = await self._state_manager.try_get_state(self.agent_state_key)
|
||||
has_state, state_data = await self._state_manager.try_get_state(
|
||||
self.agent_state_key
|
||||
)
|
||||
|
||||
if not has_state:
|
||||
# Initialize state with default values if it doesn't exist
|
||||
logger.info(f"Initializing state for {self.actor_id}")
|
||||
self.state = AgentActorState(overall_status=AgentStatus.IDLE)
|
||||
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump())
|
||||
await self._state_manager.set_state(
|
||||
self.agent_state_key, self.state.model_dump()
|
||||
)
|
||||
await self._state_manager.save_state()
|
||||
else:
|
||||
# Load existing state
|
||||
|
@ -48,14 +53,18 @@ class AgentActorBase(Actor, AgentActorInterface):
|
|||
"""
|
||||
Called when the actor is deactivated.
|
||||
"""
|
||||
logger.info(f"Deactivate {self.__class__.__name__} actor with ID: {self.actor_id}.")
|
||||
logger.info(
|
||||
f"Deactivate {self.__class__.__name__} actor with ID: {self.actor_id}."
|
||||
)
|
||||
|
||||
async def set_status(self, status: AgentStatus) -> None:
|
||||
"""
|
||||
Sets the current operational status of the agent and saves the state.
|
||||
"""
|
||||
self.state.overall_status = status
|
||||
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump())
|
||||
await self._state_manager.set_state(
|
||||
self.agent_state_key, self.state.model_dump()
|
||||
)
|
||||
await self._state_manager.save_state()
|
||||
|
||||
async def invoke_task(self, task: Optional[str] = None) -> str:
|
||||
|
@ -76,7 +85,9 @@ class AgentActorBase(Actor, AgentActorInterface):
|
|||
# Look for the last message in the conversation history
|
||||
last_message = messages[-1]
|
||||
default_task = last_message.get("content")
|
||||
logger.debug(f"Default task entry input derived from last message: {default_task}")
|
||||
logger.debug(
|
||||
f"Default task entry input derived from last message: {default_task}"
|
||||
)
|
||||
|
||||
# Prepare the input for task entry
|
||||
task_entry_input = task or default_task or "Triggered without a specific task"
|
||||
|
@ -93,7 +104,9 @@ class AgentActorBase(Actor, AgentActorInterface):
|
|||
self.state.task_history.append(task_entry)
|
||||
|
||||
# Save initial task state with IN_PROGRESS status
|
||||
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump())
|
||||
await self._state_manager.set_state(
|
||||
self.agent_state_key, self.state.model_dump()
|
||||
)
|
||||
await self._state_manager.save_state()
|
||||
|
||||
try:
|
||||
|
@ -120,7 +133,9 @@ class AgentActorBase(Actor, AgentActorInterface):
|
|||
|
||||
finally:
|
||||
# Ensure the final state of the task is saved
|
||||
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump())
|
||||
await self._state_manager.set_state(
|
||||
self.agent_state_key, self.state.model_dump()
|
||||
)
|
||||
await self._state_manager.save_state()
|
||||
# Revert the agent's status to idle
|
||||
await self.set_status(AgentStatus.IDLE)
|
||||
|
@ -141,7 +156,9 @@ class AgentActorBase(Actor, AgentActorInterface):
|
|||
self.state.message_count += 1
|
||||
|
||||
# Save state back to Dapr
|
||||
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump())
|
||||
await self._state_manager.set_state(
|
||||
self.agent_state_key, self.state.model_dump()
|
||||
)
|
||||
await self._state_manager.save_state()
|
||||
|
||||
async def get_messages(self) -> List[dict]:
|
||||
|
@ -149,7 +166,9 @@ class AgentActorBase(Actor, AgentActorInterface):
|
|||
Retrieves the messages from the actor's state, validates it using Pydantic,
|
||||
and returns a list of dictionaries if valid.
|
||||
"""
|
||||
has_state, state_data = await self._state_manager.try_get_state(self.agent_state_key)
|
||||
has_state, state_data = await self._state_manager.try_get_state(
|
||||
self.agent_state_key
|
||||
)
|
||||
|
||||
if has_state:
|
||||
try:
|
||||
|
|
|
@ -3,9 +3,10 @@ from typing import List, Optional, Union
|
|||
from dapr.actor import ActorInterface, actormethod
|
||||
from dapr_agents.types.agent import AgentActorMessage, AgentStatus
|
||||
|
||||
|
||||
class AgentActorInterface(ActorInterface):
|
||||
@abstractmethod
|
||||
@actormethod(name='InvokeTask')
|
||||
@actormethod(name="InvokeTask")
|
||||
async def invoke_task(self, task: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke a task and returns the result as a string.
|
||||
|
@ -13,7 +14,7 @@ class AgentActorInterface(ActorInterface):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
@actormethod(name='AddMessage')
|
||||
@actormethod(name="AddMessage")
|
||||
async def add_message(self, message: Union[AgentActorMessage, dict]) -> None:
|
||||
"""
|
||||
Adds a message to the conversation history in the actor's state.
|
||||
|
@ -21,7 +22,7 @@ class AgentActorInterface(ActorInterface):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
@actormethod(name='GetMessages')
|
||||
@actormethod(name="GetMessages")
|
||||
async def get_messages(self) -> List[dict]:
|
||||
"""
|
||||
Retrieves the conversation history from the actor's state.
|
||||
|
@ -29,7 +30,7 @@ class AgentActorInterface(ActorInterface):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
@actormethod(name='SetStatus')
|
||||
@actormethod(name="SetStatus")
|
||||
async def set_status(self, status: AgentStatus) -> None:
|
||||
"""
|
||||
Sets the current operational status of the agent.
|
||||
|
|
|
@ -2,19 +2,31 @@ from typing import Optional
|
|||
from pydantic import BaseModel, Field
|
||||
from dapr_agents.types.message import BaseMessage
|
||||
|
||||
|
||||
class AgentTaskResponse(BaseMessage):
|
||||
"""
|
||||
Represents a response message from an agent after completing a task.
|
||||
"""
|
||||
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available")
|
||||
|
||||
workflow_instance_id: Optional[str] = Field(
|
||||
default=None, description="Dapr workflow instance id from source if available"
|
||||
)
|
||||
|
||||
|
||||
class TriggerAction(BaseModel):
|
||||
"""
|
||||
Represents a message used to trigger an agent's activity within the workflow.
|
||||
"""
|
||||
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.")
|
||||
|
||||
task: Optional[str] = Field(
|
||||
None,
|
||||
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
|
||||
)
|
||||
iteration: Optional[int] = Field(0, description="")
|
||||
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available")
|
||||
workflow_instance_id: Optional[str] = Field(
|
||||
default=None, description="Dapr workflow instance id from source if available"
|
||||
)
|
||||
|
||||
|
||||
class BroadcastMessage(BaseMessage):
|
||||
"""
|
||||
|
|
|
@ -18,7 +18,10 @@ from dapr.actor.runtime.config import (
|
|||
)
|
||||
from dapr.actor.runtime.runtime import ActorRuntime
|
||||
from dapr.clients import DaprClient
|
||||
from dapr.clients.grpc._request import TransactionOperationType, TransactionalStateOperation
|
||||
from dapr.clients.grpc._request import (
|
||||
TransactionOperationType,
|
||||
TransactionalStateOperation,
|
||||
)
|
||||
from dapr.clients.grpc._response import StateResponse
|
||||
from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
|
||||
from dapr.ext.fastapi import DaprActor
|
||||
|
@ -34,22 +37,56 @@ from dapr_agents.workflow.messaging.routing import MessageRoutingMixin
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
||||
agent: AgentBase
|
||||
name: Optional[str] = Field(default=None, description="Name of the agent actor, derived from the agent if not provided.")
|
||||
agent_topic_name: Optional[str] = Field(None, description="The topic name dedicated to this specific agent, derived from the agent's name if not provided.")
|
||||
broadcast_topic_name: str = Field("beacon_channel", description="The default topic used for broadcasting messages to all agents.")
|
||||
agents_registry_store_name: str = Field(..., description="The name of the Dapr state store component used to store and share agent metadata centrally.")
|
||||
agents_registry_key: str = Field(default="agents_registry", description="Dapr state store key for agentic workflow state.")
|
||||
service_port: Optional[int] = Field(default=None, description="The port number to run the API server on.")
|
||||
service_host: Optional[str] = Field(default="0.0.0.0", description="Host address for the API server.")
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of the agent actor, derived from the agent if not provided.",
|
||||
)
|
||||
agent_topic_name: Optional[str] = Field(
|
||||
None,
|
||||
description="The topic name dedicated to this specific agent, derived from the agent's name if not provided.",
|
||||
)
|
||||
broadcast_topic_name: str = Field(
|
||||
"beacon_channel",
|
||||
description="The default topic used for broadcasting messages to all agents.",
|
||||
)
|
||||
agents_registry_store_name: str = Field(
|
||||
...,
|
||||
description="The name of the Dapr state store component used to store and share agent metadata centrally.",
|
||||
)
|
||||
agents_registry_key: str = Field(
|
||||
default="agents_registry",
|
||||
description="Dapr state store key for agentic workflow state.",
|
||||
)
|
||||
service_port: Optional[int] = Field(
|
||||
default=None, description="The port number to run the API server on."
|
||||
)
|
||||
service_host: Optional[str] = Field(
|
||||
default="0.0.0.0", description="Host address for the API server."
|
||||
)
|
||||
|
||||
# Fields initialized in model_post_init
|
||||
actor: Optional[DaprActor] = Field(default=None, init=False, description="DaprActor for actor lifecycle support.")
|
||||
actor_name: Optional[str] = Field(default=None, init=False, description="Actor name")
|
||||
actor_proxy: Optional[ActorProxy] = Field(default=None, init=False, description="Proxy for invoking methods on the agent's actor.")
|
||||
actor_class: Optional[type] = Field(default=None, init=False, description="Dynamically created actor class for the agent")
|
||||
agent_metadata: Optional[dict] = Field(default=None, init=False, description="Agent's metadata")
|
||||
actor: Optional[DaprActor] = Field(
|
||||
default=None, init=False, description="DaprActor for actor lifecycle support."
|
||||
)
|
||||
actor_name: Optional[str] = Field(
|
||||
default=None, init=False, description="Actor name"
|
||||
)
|
||||
actor_proxy: Optional[ActorProxy] = Field(
|
||||
default=None,
|
||||
init=False,
|
||||
description="Proxy for invoking methods on the agent's actor.",
|
||||
)
|
||||
actor_class: Optional[type] = Field(
|
||||
default=None,
|
||||
init=False,
|
||||
description="Dynamically created actor class for the agent",
|
||||
)
|
||||
agent_metadata: Optional[dict] = Field(
|
||||
default=None, init=False, description="Agent's metadata"
|
||||
)
|
||||
|
||||
# Private internal attributes (not schema/validated)
|
||||
_http_server: Optional[Any] = PrivateAttr(default=None)
|
||||
|
@ -57,7 +94,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
_dapr_client: Optional[DaprClient] = PrivateAttr(default=None)
|
||||
_is_running: bool = PrivateAttr(default=False)
|
||||
_subscriptions: Dict[str, Callable] = PrivateAttr(default_factory=dict)
|
||||
_topic_handlers: Dict[Tuple[str, str], Dict[Type[BaseModel], Callable]] = PrivateAttr(default_factory=dict)
|
||||
_topic_handlers: Dict[
|
||||
Tuple[str, str], Dict[Type[BaseModel], Callable]
|
||||
] = PrivateAttr(default_factory=dict)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
@ -80,10 +119,16 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
actor_class_name = f"{self.agent.name}Actor"
|
||||
|
||||
# Create the actor class dynamically using the 'type' function
|
||||
self.actor_class = type(actor_class_name, (AgentActorBase,), {
|
||||
'__init__': lambda self, ctx, actor_id: AgentActorBase.__init__(self, ctx, actor_id),
|
||||
'agent': self.agent
|
||||
})
|
||||
self.actor_class = type(
|
||||
actor_class_name,
|
||||
(AgentActorBase,),
|
||||
{
|
||||
"__init__": lambda self, ctx, actor_id: AgentActorBase.__init__(
|
||||
self, ctx, actor_id
|
||||
),
|
||||
"agent": self.agent,
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare agent metadata
|
||||
self.agent_metadata = {
|
||||
|
@ -92,12 +137,14 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
"goal": self.agent.goal,
|
||||
"topic_name": self.agent_topic_name,
|
||||
"pubsub_name": self.message_bus_name,
|
||||
"orchestrator": False
|
||||
"orchestrator": False,
|
||||
}
|
||||
|
||||
# Proxy for actor methods
|
||||
self.actor_name = self.actor_class.__name__
|
||||
self.actor_proxy = ActorProxy.create(self.actor_name, ActorId(self.agent.name), AgentActorInterface)
|
||||
self.actor_proxy = ActorProxy.create(
|
||||
self.actor_name, ActorId(self.agent.name), AgentActorInterface
|
||||
)
|
||||
|
||||
# Initialize Sync Dapr Client
|
||||
self._dapr_client = DaprClient()
|
||||
|
@ -106,7 +153,7 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
self._http_server: FastAPIServerBase = FastAPIServerBase(
|
||||
service_name=self.agent.name,
|
||||
service_port=self.service_port,
|
||||
service_host=self.service_host
|
||||
service_host=self.service_host,
|
||||
)
|
||||
self._http_server.app.router.lifespan_context = self.lifespan
|
||||
|
||||
|
@ -133,15 +180,18 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
async def lifespan(self, app: FastAPI):
|
||||
# Register actor
|
||||
actor_runtime_config = ActorRuntimeConfig()
|
||||
actor_runtime_config.update_actor_type_configs([
|
||||
actor_runtime_config.update_actor_type_configs(
|
||||
[
|
||||
ActorTypeConfig(
|
||||
actor_type=self.actor_class.__name__,
|
||||
actor_idle_timeout=timedelta(hours=1),
|
||||
actor_scan_interval=timedelta(seconds=30),
|
||||
drain_ongoing_call_timeout=timedelta(minutes=1),
|
||||
drain_rebalanced_actors=True,
|
||||
reentrancy=ActorReentrancyConfig(enabled=True))
|
||||
])
|
||||
reentrancy=ActorReentrancyConfig(enabled=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
ActorRuntime.set_actor_config(actor_runtime_config)
|
||||
|
||||
await self.actor.register_actor(self.actor_class)
|
||||
|
@ -158,7 +208,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
|
||||
async def start(self):
|
||||
if self._is_running:
|
||||
logger.warning("Service is already running. Ignoring duplicate start request.")
|
||||
logger.warning(
|
||||
"Service is already running. Ignoring duplicate start request."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Starting Agent Actor Service...")
|
||||
|
@ -176,7 +228,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
|
||||
for (pubsub_name, topic_name), close_fn in self._subscriptions.items():
|
||||
try:
|
||||
logger.info(f"Unsubscribing from pubsub '{pubsub_name}' topic '{topic_name}'")
|
||||
logger.info(
|
||||
f"Unsubscribing from pubsub '{pubsub_name}' topic '{topic_name}'"
|
||||
)
|
||||
close_fn()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unsubscribe from topic '{topic_name}': {e}")
|
||||
|
@ -197,15 +251,21 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
Optional[dict]: The data stored under the specified key if found; otherwise, None.
|
||||
"""
|
||||
try:
|
||||
response: StateResponse = self._dapr_client.get_state(store_name=store_name, key=key)
|
||||
response: StateResponse = self._dapr_client.get_state(
|
||||
store_name=store_name, key=key
|
||||
)
|
||||
data = response.data
|
||||
|
||||
return json.loads(data) if data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error retrieving data for key '{key}' from store '{store_name}'")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Error retrieving data for key '{key}' from store '{store_name}'"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_agents_metadata(self, exclude_self: bool = True, exclude_orchestrator: bool = False) -> dict:
|
||||
def get_agents_metadata(
|
||||
self, exclude_self: bool = True, exclude_orchestrator: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieves metadata for all registered agents while ensuring orchestrators do not interact with other orchestrators.
|
||||
|
||||
|
@ -221,17 +281,28 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
"""
|
||||
try:
|
||||
# Fetch agent metadata from the registry
|
||||
agents_metadata = self.get_data_from_store(self.agents_registry_store_name, self.agents_registry_key) or {}
|
||||
agents_metadata = (
|
||||
self.get_data_from_store(
|
||||
self.agents_registry_store_name, self.agents_registry_key
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
if agents_metadata:
|
||||
logger.info(f"Agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'.")
|
||||
logger.info(
|
||||
f"Agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'."
|
||||
)
|
||||
|
||||
# Filter based on exclusion rules
|
||||
filtered_metadata = {
|
||||
name: metadata
|
||||
for name, metadata in agents_metadata.items()
|
||||
if not (exclude_self and name == self.agent.name) # Exclude self if requested
|
||||
and not (exclude_orchestrator and metadata.get("orchestrator", False)) # Exclude all orchestrators if exclude_orchestrator=True
|
||||
if not (
|
||||
exclude_self and name == self.agent.name
|
||||
) # Exclude self if requested
|
||||
and not (
|
||||
exclude_orchestrator and metadata.get("orchestrator", False)
|
||||
) # Exclude all orchestrators if exclude_orchestrator=True
|
||||
}
|
||||
|
||||
if not filtered_metadata:
|
||||
|
@ -239,7 +310,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
|
||||
return filtered_metadata
|
||||
|
||||
logger.info(f"No agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'.")
|
||||
logger.info(
|
||||
f"No agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'."
|
||||
)
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve agents metadata: {e}", exc_info=True)
|
||||
|
@ -255,14 +328,20 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
store_name=self.agents_registry_store_name,
|
||||
store_key=self.agents_registry_key,
|
||||
agent_name=self.name,
|
||||
agent_metadata=self.agent_metadata
|
||||
agent_metadata=self.agent_metadata,
|
||||
)
|
||||
logger.info(
|
||||
f"{self.name} registered its metadata under key '{self.agents_registry_key}'"
|
||||
)
|
||||
logger.info(f"{self.name} registered its metadata under key '{self.agents_registry_key}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register metadata for agent {self.agent.name}: {e}")
|
||||
logger.error(
|
||||
f"Failed to register metadata for agent {self.agent.name}: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def register_agent(self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict) -> None:
|
||||
def register_agent(
|
||||
self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict
|
||||
) -> None:
|
||||
"""
|
||||
Merges the existing data with the new data and updates the store.
|
||||
|
||||
|
@ -274,7 +353,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
# retry the entire operation up to ten times sleeping 1 second between each attempt
|
||||
for attempt in range(1, 11):
|
||||
try:
|
||||
response: StateResponse = self._dapr_client.get_state(store_name=store_name, key=store_key)
|
||||
response: StateResponse = self._dapr_client.get_state(
|
||||
store_name=store_name, key=store_key
|
||||
)
|
||||
if not response.etag:
|
||||
# if there is no etag the following transaction won't work as expected
|
||||
# so we need to save an empty object with a strong consistency to force the etag to be created
|
||||
|
@ -283,7 +364,10 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
key=store_key,
|
||||
value=json.dumps({}),
|
||||
state_metadata={"contentType": "application/json"},
|
||||
options=StateOptions(concurrency=Concurrency.first_write, consistency=Consistency.strong)
|
||||
options=StateOptions(
|
||||
concurrency=Concurrency.first_write,
|
||||
consistency=Consistency.strong,
|
||||
),
|
||||
)
|
||||
# raise an exception to retry the entire operation
|
||||
raise Exception(f"No etag found for key: {store_key}")
|
||||
|
@ -303,19 +387,21 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
key=store_key,
|
||||
data=json.dumps(merged_data),
|
||||
etag=response.etag,
|
||||
operation_type=TransactionOperationType.upsert
|
||||
operation_type=TransactionOperationType.upsert,
|
||||
)
|
||||
],
|
||||
transactional_metadata={"contentType": "application/json"}
|
||||
transactional_metadata={"contentType": "application/json"},
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error on transaction attempt: {attempt}: {e}")
|
||||
logger.debug(f"Sleeping for 1 second before retrying transaction...")
|
||||
logger.debug("Sleeping for 1 second before retrying transaction...")
|
||||
time.sleep(1)
|
||||
raise Exception(f"Failed to update state store key: {store_key} after 10 attempts.")
|
||||
raise Exception(
|
||||
f"Failed to update state store key: {store_key} after 10 attempts."
|
||||
)
|
||||
|
||||
async def invoke_task(self, task: Optional[str]) -> Response:
|
||||
"""
|
||||
|
@ -332,7 +418,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
return Response(content=response, status_code=status.HTTP_200_OK)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run task for {self.actor_name}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error invoking task: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error invoking task: {str(e)}"
|
||||
)
|
||||
|
||||
async def add_message(self, message: AgentActorMessage) -> None:
|
||||
"""
|
||||
|
@ -349,12 +437,21 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
"""
|
||||
try:
|
||||
messages = await self.actor_proxy.GetMessages()
|
||||
return JSONResponse(content=jsonable_encoder(messages), status_code=status.HTTP_200_OK)
|
||||
return JSONResponse(
|
||||
content=jsonable_encoder(messages), status_code=status.HTTP_200_OK
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve messages for {self.actor_name}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error retrieving messages: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error retrieving messages: {str(e)}"
|
||||
)
|
||||
|
||||
async def broadcast_message(self, message: Union[BaseModel, dict], exclude_orchestrator: bool = False, **kwargs) -> None:
|
||||
async def broadcast_message(
|
||||
self,
|
||||
message: Union[BaseModel, dict],
|
||||
exclude_orchestrator: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Sends a message to all agents (or only to non-orchestrator agents if exclude_orchestrator=True).
|
||||
|
||||
|
@ -365,7 +462,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
"""
|
||||
try:
|
||||
# Retrieve agents metadata while respecting the exclude_orchestrator flag
|
||||
agents_metadata = self.get_agents_metadata(exclude_orchestrator=exclude_orchestrator)
|
||||
agents_metadata = self.get_agents_metadata(
|
||||
exclude_orchestrator=exclude_orchestrator
|
||||
)
|
||||
|
||||
if not agents_metadata:
|
||||
logger.warning("No agents available for broadcast.")
|
||||
|
@ -385,7 +484,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to broadcast message: {e}", exc_info=True)
|
||||
|
||||
async def send_message_to_agent(self, name: str, message: Union[BaseModel, dict], **kwargs) -> None:
|
||||
async def send_message_to_agent(
|
||||
self, name: str, message: Union[BaseModel, dict], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Sends a message to a specific agent.
|
||||
|
||||
|
@ -398,7 +499,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
agents_metadata = self.get_agents_metadata()
|
||||
|
||||
if name not in agents_metadata:
|
||||
logger.warning(f"Target '{name}' is not registered as an agent. Skipping message send.")
|
||||
logger.warning(
|
||||
f"Target '{name}' is not registered as an agent. Skipping message send."
|
||||
)
|
||||
return # Do not raise an error—just warn and move on.
|
||||
|
||||
agent_metadata = agents_metadata[name]
|
||||
|
@ -414,4 +517,6 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
|
|||
|
||||
logger.debug(f"{self.name} sent message to agent '{name}'.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to agent '{name}': {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to send message to agent '{name}': {e}", exc_info=True
|
||||
)
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from dapr_agents.memory import MemoryBase, ConversationListMemory, ConversationVectorMemory
|
||||
from dapr_agents.memory import (
|
||||
MemoryBase,
|
||||
ConversationListMemory,
|
||||
ConversationVectorMemory,
|
||||
)
|
||||
from dapr_agents.agent.utils.text_printer import ColorTextFormatter
|
||||
from dapr_agents.types import MessageContent, MessagePlaceHolder
|
||||
from dapr_agents.tool.executor import AgentToolExecutor
|
||||
|
@ -14,26 +18,59 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentBase(BaseModel, ABC):
|
||||
"""
|
||||
Base class for agents that interact with language models and manage tools for task execution.
|
||||
"""
|
||||
|
||||
name: Optional[str] = Field(default=None, description="The agent's name, defaulting to the role if not provided.")
|
||||
role: Optional[str] = Field(default="Assistant", description="The agent's role in the interaction (e.g., 'Weather Expert').")
|
||||
goal: Optional[str] = Field(default="Help humans", description="The agent's main objective (e.g., 'Provide Weather information').")
|
||||
instructions: Optional[List[str]] = Field(default=None, description="Instructions guiding the agent's tasks.")
|
||||
system_prompt: Optional[str] = Field(default=None, description="A custom system prompt, overriding name, role, goal, and instructions.")
|
||||
llm: LLMClientBase = Field(default_factory=OpenAIChatClient, description="Language model client for generating responses.")
|
||||
prompt_template: Optional[PromptTemplateBase] = Field(default=None, description="The prompt template for the agent.")
|
||||
tools: List[Union[AgentTool, Callable]] = Field(default_factory=list, description="Tools available for the agent to assist with tasks.")
|
||||
max_iterations: int = Field(default=10, description="Max iterations for conversation cycles.")
|
||||
memory: MemoryBase = Field(default_factory=ConversationListMemory, description="Handles conversation history and context storage.")
|
||||
template_format: Literal["f-string", "jinja2"] = Field(default="jinja2", description="The format used for rendering the prompt template.")
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The agent's name, defaulting to the role if not provided.",
|
||||
)
|
||||
role: Optional[str] = Field(
|
||||
default="Assistant",
|
||||
description="The agent's role in the interaction (e.g., 'Weather Expert').",
|
||||
)
|
||||
goal: Optional[str] = Field(
|
||||
default="Help humans",
|
||||
description="The agent's main objective (e.g., 'Provide Weather information').",
|
||||
)
|
||||
instructions: Optional[List[str]] = Field(
|
||||
default=None, description="Instructions guiding the agent's tasks."
|
||||
)
|
||||
system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A custom system prompt, overriding name, role, goal, and instructions.",
|
||||
)
|
||||
llm: LLMClientBase = Field(
|
||||
default_factory=OpenAIChatClient,
|
||||
description="Language model client for generating responses.",
|
||||
)
|
||||
prompt_template: Optional[PromptTemplateBase] = Field(
|
||||
default=None, description="The prompt template for the agent."
|
||||
)
|
||||
tools: List[Union[AgentTool, Callable]] = Field(
|
||||
default_factory=list,
|
||||
description="Tools available for the agent to assist with tasks.",
|
||||
)
|
||||
max_iterations: int = Field(
|
||||
default=10, description="Max iterations for conversation cycles."
|
||||
)
|
||||
memory: MemoryBase = Field(
|
||||
default_factory=ConversationListMemory,
|
||||
description="Handles conversation history and context storage.",
|
||||
)
|
||||
template_format: Literal["f-string", "jinja2"] = Field(
|
||||
default="jinja2",
|
||||
description="The format used for rendering the prompt template.",
|
||||
)
|
||||
|
||||
# Private attributes
|
||||
_tool_executor: AgentToolExecutor = PrivateAttr()
|
||||
_text_formatter: ColorTextFormatter = PrivateAttr(default_factory=ColorTextFormatter)
|
||||
_text_formatter: ColorTextFormatter = PrivateAttr(
|
||||
default_factory=ColorTextFormatter
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
@ -97,7 +134,9 @@ class AgentBase(BaseModel, ABC):
|
|||
|
||||
# If the agent's prompt_template is provided, use it and skip further configuration
|
||||
if self.prompt_template:
|
||||
logger.info("Using the provided agent prompt_template. Skipping system prompt construction.")
|
||||
logger.info(
|
||||
"Using the provided agent prompt_template. Skipping system prompt construction."
|
||||
)
|
||||
self.llm.prompt_template = self.prompt_template
|
||||
|
||||
# If the LLM client already has a prompt template, sync it and prefill/validate as needed
|
||||
|
@ -145,27 +184,43 @@ class AgentBase(BaseModel, ABC):
|
|||
prefill_data["instructions"] = "\n".join(self.instructions)
|
||||
|
||||
# Collect attributes set but not in input_variables for informational logging
|
||||
set_attributes = {"name": self.name, "role": self.role, "goal": self.goal, "instructions": self.instructions}
|
||||
set_attributes = {
|
||||
"name": self.name,
|
||||
"role": self.role,
|
||||
"goal": self.goal,
|
||||
"instructions": self.instructions,
|
||||
}
|
||||
|
||||
# Use Pydantic's model_fields_set to detect if attributes were user-set
|
||||
user_set_attributes = {attr for attr in set_attributes if attr in self.model_fields_set}
|
||||
user_set_attributes = {
|
||||
attr for attr in set_attributes if attr in self.model_fields_set
|
||||
}
|
||||
|
||||
ignored_attributes = [
|
||||
attr for attr in set_attributes
|
||||
if attr not in self.prompt_template.input_variables and set_attributes[attr] is not None and attr in user_set_attributes
|
||||
attr
|
||||
for attr in set_attributes
|
||||
if attr not in self.prompt_template.input_variables
|
||||
and set_attributes[attr] is not None
|
||||
and attr in user_set_attributes
|
||||
]
|
||||
|
||||
# Apply pre-filled data only for attributes that are in input_variables
|
||||
if prefill_data:
|
||||
self.prompt_template = self.prompt_template.pre_fill_variables(**prefill_data)
|
||||
logger.info(f"Pre-filled prompt template with attributes: {list(prefill_data.keys())}")
|
||||
self.prompt_template = self.prompt_template.pre_fill_variables(
|
||||
**prefill_data
|
||||
)
|
||||
logger.info(
|
||||
f"Pre-filled prompt template with attributes: {list(prefill_data.keys())}"
|
||||
)
|
||||
elif ignored_attributes:
|
||||
raise ValueError(
|
||||
f"The following agent attributes were explicitly set by the user but are not considered by the prompt template: {', '.join(ignored_attributes)}. "
|
||||
"Please ensure that these attributes are included in the prompt template's input variables if they are needed."
|
||||
)
|
||||
else:
|
||||
logger.info("No agent attributes were pre-filled, as the template did not require any.")
|
||||
logger.info(
|
||||
"No agent attributes were pre-filled, as the template did not require any."
|
||||
)
|
||||
|
||||
def construct_system_prompt(self) -> str:
|
||||
"""
|
||||
|
@ -206,13 +261,15 @@ class AgentBase(BaseModel, ABC):
|
|||
# Create the template with placeholders for system message and chat history
|
||||
return ChatPromptTemplate.from_messages(
|
||||
messages=[
|
||||
('system', system_prompt),
|
||||
MessagePlaceHolder(variable_name="chat_history")
|
||||
("system", system_prompt),
|
||||
MessagePlaceHolder(variable_name="chat_history"),
|
||||
],
|
||||
template_format=self.template_format
|
||||
template_format=self.template_format,
|
||||
)
|
||||
|
||||
def construct_messages(self, input_data: Union[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def construct_messages(
|
||||
self, input_data: Union[str, Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Constructs and formats initial messages based on input type, pre-filling chat history as needed.
|
||||
|
||||
|
@ -255,7 +312,9 @@ class AgentBase(BaseModel, ABC):
|
|||
chat_history = self.chat_history
|
||||
return chat_history[-1] if chat_history else None
|
||||
|
||||
def get_last_user_message(self, messages: List[Dict[str, Any]]) -> Optional[MessageContent]:
|
||||
def get_last_user_message(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> Optional[MessageContent]:
|
||||
"""
|
||||
Retrieves the last user message in a list of messages.
|
||||
|
||||
|
@ -286,7 +345,9 @@ class AgentBase(BaseModel, ABC):
|
|||
- This method does not affect the `chat_history` which is dynamically updated.
|
||||
"""
|
||||
if not self.prompt_template:
|
||||
raise ValueError("Prompt template must be initialized before pre-filling variables.")
|
||||
raise ValueError(
|
||||
"Prompt template must be initialized before pre-filling variables."
|
||||
)
|
||||
|
||||
self.prompt_template = self.prompt_template.pre_fill_variables(**kwargs)
|
||||
logger.debug(f"Pre-filled prompt template with variables: {kwargs.keys()}")
|
|
@ -8,15 +8,18 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAPIReActAgent(ReActAgent):
|
||||
"""
|
||||
Extends ReActAgent with OpenAPI handling capabilities, including tools for managing API calls.
|
||||
"""
|
||||
|
||||
role: str = Field(default="OpenAPI Expert", description="The agent's role in the interaction.")
|
||||
role: str = Field(
|
||||
default="OpenAPI Expert", description="The agent's role in the interaction."
|
||||
)
|
||||
goal: str = Field(
|
||||
default="Help users work with OpenAPI specifications and API integrations.",
|
||||
description="The main objective of the agent."
|
||||
description="The main objective of the agent.",
|
||||
)
|
||||
instructions: List[str] = Field(
|
||||
default=[
|
||||
|
@ -25,15 +28,23 @@ class OpenAPIReActAgent(ReActAgent):
|
|||
"You must first help users explore potential APIs by analyzing OpenAPI definitions, then assist in making authenticated API requests.",
|
||||
"Ensure that all API calls are executed with the correct parameters, authentication, and methods.",
|
||||
"Your responses should be concise, clear, and focus on guiding the user through the steps of working with APIs, including retrieving API definitions, understanding endpoint parameters, and handling errors.",
|
||||
"You only respond to questions directly related to your role."
|
||||
"You only respond to questions directly related to your role.",
|
||||
],
|
||||
description="Instructions to guide the agent's behavior."
|
||||
description="Instructions to guide the agent's behavior.",
|
||||
)
|
||||
spec_parser: OpenAPISpecParser = Field(
|
||||
..., description="Parser for handling OpenAPI specifications."
|
||||
)
|
||||
api_vector_store: VectorStoreBase = Field(
|
||||
..., description="Vector store for storing API definitions."
|
||||
)
|
||||
auth_header: Optional[Dict] = Field(
|
||||
None, description="Authentication headers for executing API calls."
|
||||
)
|
||||
spec_parser: OpenAPISpecParser = Field(..., description="Parser for handling OpenAPI specifications.")
|
||||
api_vector_store: VectorStoreBase = Field(..., description="Vector store for storing API definitions.")
|
||||
auth_header: Optional[Dict] = Field(None, description="Authentication headers for executing API calls.")
|
||||
|
||||
tool_vector_store: Optional[VectorToolStore] = Field(default=None, init=False, description="Internal vector store for OpenAPI tools.")
|
||||
tool_vector_store: Optional[VectorToolStore] = Field(
|
||||
default=None, init=False, description="Internal vector store for OpenAPI tools."
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
@ -52,9 +63,10 @@ class OpenAPIReActAgent(ReActAgent):
|
|||
|
||||
# Generate OpenAPI-specific tools
|
||||
from .tools import generate_api_call_executor, generate_get_openapi_definition
|
||||
|
||||
openapi_tools = [
|
||||
generate_get_openapi_definition(self.tool_vector_store),
|
||||
generate_api_call_executor(self.spec_parser, self.auth_header)
|
||||
generate_api_call_executor(self.spec_parser, self.auth_header),
|
||||
]
|
||||
|
||||
# Extend tools with OpenAPI tools
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import json, logging, requests
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
|
@ -42,7 +44,10 @@ def _fmt_candidate(doc: str, meta: Dict[str, Any]) -> str:
|
|||
|
||||
class GetDefinitionInput(BaseModel):
|
||||
"""Free-form query describing *one* desired operation (e.g. "multiply two numbers")."""
|
||||
user_input: str = Field(..., description="Natural-language description of ONE desired API operation.")
|
||||
|
||||
user_input: str = Field(
|
||||
..., description="Natural-language description of ONE desired API operation."
|
||||
)
|
||||
|
||||
|
||||
def generate_get_openapi_definition(store: VectorToolStore):
|
||||
|
@ -65,17 +70,29 @@ def generate_get_openapi_definition(store: VectorToolStore):
|
|||
|
||||
|
||||
class OpenAPIExecutorInput(BaseModel):
|
||||
path_template: str = Field(..., description="Path template, may contain `{placeholder}` segments.")
|
||||
path_template: str = Field(
|
||||
..., description="Path template, may contain `{placeholder}` segments."
|
||||
)
|
||||
method: str = Field(..., description="HTTP verb, upper‑case.")
|
||||
path_params: Dict[str, Any] = Field(default_factory=dict, description="Replacements for path placeholders.")
|
||||
data: Dict[str, Any] = Field(default_factory=dict, description="JSON body for POST/PUT/PATCH.")
|
||||
headers: Optional[Dict[str, Any]] = Field(default=None, description="Extra request headers.")
|
||||
params: Optional[Dict[str, Any]] = Field(default=None, description="Query params (?key=value).")
|
||||
path_params: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Replacements for path placeholders."
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="JSON body for POST/PUT/PATCH."
|
||||
)
|
||||
headers: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Extra request headers."
|
||||
)
|
||||
params: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Query params (?key=value)."
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Dict[str, str]] = None):
|
||||
def generate_api_call_executor(
|
||||
spec: OpenAPISpecParser, auth_header: Optional[Dict[str, str]] = None
|
||||
):
|
||||
base_url = spec.spec.servers[0].url # assumes at least one server entry
|
||||
|
||||
@tool(args_model=OpenAPIExecutorInput)
|
||||
|
@ -106,8 +123,10 @@ def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Di
|
|||
final_headers.update(headers)
|
||||
|
||||
# redact auth key in debug logs
|
||||
safe_hdrs = {k: ("***" if "auth" in k.lower() or "key" in k.lower() else v)
|
||||
for k, v in final_headers.items()}
|
||||
safe_hdrs = {
|
||||
k: ("***" if "auth" in k.lower() or "key" in k.lower() else v)
|
||||
for k, v in final_headers.items()
|
||||
}
|
||||
|
||||
# Only convert data to JSON if we're doing a request that requires a body
|
||||
# and there's actually data to send
|
||||
|
@ -116,9 +135,14 @@ def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Di
|
|||
body = json.dumps(data)
|
||||
|
||||
# Add more detailed logging similar to old implementation
|
||||
logger.debug("→ %s %s | headers=%s params=%s data=%s",
|
||||
method, url, safe_hdrs, params,
|
||||
"***" if body else None)
|
||||
logger.debug(
|
||||
"→ %s %s | headers=%s params=%s data=%s",
|
||||
method,
|
||||
url,
|
||||
safe_hdrs,
|
||||
params,
|
||||
"***" if body else None,
|
||||
)
|
||||
|
||||
# For debugging purposes, similar to the old implementation
|
||||
print(f"Base Url: {base_url}")
|
||||
|
@ -126,8 +150,9 @@ def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Di
|
|||
print(f"Requested Method: {method}")
|
||||
print(f"Requested Parameters: {params}")
|
||||
|
||||
resp = requests.request(method, url, headers=final_headers,
|
||||
params=params, data=body, **req_kwargs)
|
||||
resp = requests.request(
|
||||
method, url, headers=final_headers, params=params, data=body, **req_kwargs
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
|
|
@ -14,15 +14,25 @@ from dapr_agents.types import AgentError, AssistantMessage, ChatCompletion
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReActAgent(AgentBase):
|
||||
"""
|
||||
Agent implementing the ReAct (Reasoning-Action) framework for dynamic, few-shot problem-solving by leveraging
|
||||
contextual reasoning, actions, and observations in a conversation flow.
|
||||
"""
|
||||
|
||||
stop_at_token: List[str] = Field(default=["\nObservation:"], description="Token(s) signaling the LLM to stop generation.")
|
||||
tools: List[Union[AgentTool, Callable]] = Field(default_factory=list, description="Tools available for the agent, including final_answer.")
|
||||
template_format: Literal["f-string", "jinja2"] = Field(default="jinja2", description="The format used for rendering the prompt template.")
|
||||
stop_at_token: List[str] = Field(
|
||||
default=["\nObservation:"],
|
||||
description="Token(s) signaling the LLM to stop generation.",
|
||||
)
|
||||
tools: List[Union[AgentTool, Callable]] = Field(
|
||||
default_factory=list,
|
||||
description="Tools available for the agent, including final_answer.",
|
||||
)
|
||||
template_format: Literal["f-string", "jinja2"] = Field(
|
||||
default="jinja2",
|
||||
description="The format used for rendering the prompt template.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
@ -51,11 +61,16 @@ class ReActAgent(AgentBase):
|
|||
# Tools section with schema details
|
||||
tools_section = "## Tools\nYou have access ONLY to the following tools:\n"
|
||||
for tool in self.tools:
|
||||
tools_section += f"{tool.name}: {tool.description}. Args schema: {tool.args_schema}\n"
|
||||
prompt_parts.append(tools_section.rstrip()) # Trim any trailing newlines from tools_section
|
||||
tools_section += (
|
||||
f"{tool.name}: {tool.description}. Args schema: {tool.args_schema}\n"
|
||||
)
|
||||
prompt_parts.append(
|
||||
tools_section.rstrip()
|
||||
) # Trim any trailing newlines from tools_section
|
||||
|
||||
# Additional Guidelines
|
||||
additional_guidelines = textwrap.dedent("""
|
||||
additional_guidelines = textwrap.dedent(
|
||||
"""
|
||||
If you think about using tool, it must use the correct tool JSON blob format as shown below:
|
||||
```
|
||||
{
|
||||
|
@ -63,11 +78,13 @@ class ReActAgent(AgentBase):
|
|||
"arguments": $INPUT
|
||||
}
|
||||
```
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
prompt_parts.append(additional_guidelines)
|
||||
|
||||
# ReAct specific guidelines
|
||||
react_guidelines = textwrap.dedent("""
|
||||
react_guidelines = textwrap.dedent(
|
||||
"""
|
||||
## ReAct Format
|
||||
Thought: Reflect on the current state of the conversation or task. If additional information is needed, determine if using a tool is necessary. When a tool is required, briefly explain why it is needed for the specific step at hand, and immediately follow this with an `Action:` statement to address that specific requirement. Avoid combining multiple tool requests in a single `Thought`. If no tools are needed, proceed directly to an `Answer:` statement.
|
||||
Action:
|
||||
|
@ -105,7 +122,8 @@ class ReActAgent(AgentBase):
|
|||
|
||||
## Chat History
|
||||
The chat history is provided to avoid repeating information and to ensure accurate references when summarizing past interactions.
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
prompt_parts.append(react_guidelines)
|
||||
|
||||
return "\n\n".join(prompt_parts)
|
||||
|
@ -123,7 +141,9 @@ class ReActAgent(AgentBase):
|
|||
Raises:
|
||||
AgentError: If LLM fails or tool execution encounters issues.
|
||||
"""
|
||||
logger.debug(f"Agent run started with input: {input_data or 'Using memory context'}")
|
||||
logger.debug(
|
||||
f"Agent run started with input: {input_data or 'Using memory context'}"
|
||||
)
|
||||
|
||||
# Format messages; construct_messages already includes chat history.
|
||||
messages = self.construct_messages(input_data or {})
|
||||
|
@ -161,8 +181,12 @@ class ReActAgent(AgentBase):
|
|||
break
|
||||
else:
|
||||
# Append react_loop to the last message if no user message is found
|
||||
logger.warning("No user message found in the current messages; appending react_loop to the last message.")
|
||||
iteration_messages[-1]["content"] += f"\n{react_loop}" # Append react_loop to the last message
|
||||
logger.warning(
|
||||
"No user message found in the current messages; appending react_loop to the last message."
|
||||
)
|
||||
iteration_messages[-1][
|
||||
"content"
|
||||
] += f"\n{react_loop}" # Append react_loop to the last message
|
||||
|
||||
try:
|
||||
response: ChatCompletion = self.llm.generate(
|
||||
|
@ -179,13 +203,17 @@ class ReActAgent(AgentBase):
|
|||
assistant_final = AssistantMessage(final_answer)
|
||||
self.memory.add_message(assistant_final)
|
||||
self.text_formatter.print_separator()
|
||||
self.text_formatter.print_message(assistant_final, include_separator=False)
|
||||
self.text_formatter.print_message(
|
||||
assistant_final, include_separator=False
|
||||
)
|
||||
logger.info("Agent provided a direct final answer.")
|
||||
return final_answer
|
||||
|
||||
# If there's no action, update the loop and continue reasoning
|
||||
if not action:
|
||||
logger.info("No action specified; continuing with further reasoning.")
|
||||
logger.info(
|
||||
"No action specified; continuing with further reasoning."
|
||||
)
|
||||
react_loop += f"Thought:{thought_action}\n"
|
||||
continue # Proceed to the next iteration
|
||||
|
||||
|
@ -211,8 +239,9 @@ class ReActAgent(AgentBase):
|
|||
|
||||
logger.info("Max iterations reached. Agent has stopped.")
|
||||
|
||||
|
||||
def parse_response(self, response: ChatCompletion) -> Tuple[str, Optional[dict], Optional[str]]:
|
||||
def parse_response(
|
||||
self, response: ChatCompletion
|
||||
) -> Tuple[str, Optional[dict], Optional[str]]:
|
||||
"""
|
||||
Parses a ReAct-style LLM response into a Thought, optional Action (JSON blob), and optional Final Answer.
|
||||
|
||||
|
@ -225,16 +254,18 @@ class ReActAgent(AgentBase):
|
|||
- Parsed Action dictionary, if present.
|
||||
- Final Answer string, if present.
|
||||
"""
|
||||
pattern = r'\{(?:[^{}]|(?R))*\}' # Recursive pattern to match nested JSON blobs
|
||||
pattern = r"\{(?:[^{}]|(?R))*\}" # Recursive pattern to match nested JSON blobs
|
||||
content = response.get_content()
|
||||
|
||||
# Compile reusable regex patterns
|
||||
action_split_regex = regex.compile(r'action:\s*', flags=regex.IGNORECASE)
|
||||
final_answer_regex = regex.compile(r'answer:\s*(.*)', flags=regex.IGNORECASE | regex.DOTALL)
|
||||
thought_label_regex = regex.compile(r'thought:\s*', flags=regex.IGNORECASE)
|
||||
action_split_regex = regex.compile(r"action:\s*", flags=regex.IGNORECASE)
|
||||
final_answer_regex = regex.compile(
|
||||
r"answer:\s*(.*)", flags=regex.IGNORECASE | regex.DOTALL
|
||||
)
|
||||
thought_label_regex = regex.compile(r"thought:\s*", flags=regex.IGNORECASE)
|
||||
|
||||
# Strip leading "Thought:" labels (they get repeated a lot)
|
||||
content = thought_label_regex.sub('', content).strip()
|
||||
content = thought_label_regex.sub("", content).strip()
|
||||
|
||||
# Check if there's a final answer present
|
||||
if final_match := final_answer_regex.search(content):
|
||||
|
@ -247,22 +278,32 @@ class ReActAgent(AgentBase):
|
|||
thought_part, action_block = action_split_regex.split(content, 1)
|
||||
thought_part = thought_part.strip()
|
||||
logger.debug(f"[parse_response] Thought extracted: {thought_part}")
|
||||
logger.debug(f"[parse_response] Action block to parse: {action_block.strip()}")
|
||||
logger.debug(
|
||||
f"[parse_response] Action block to parse: {action_block.strip()}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[parse_response] No action or answer found. Returning content as Thought: {content}")
|
||||
logger.debug(
|
||||
f"[parse_response] No action or answer found. Returning content as Thought: {content}"
|
||||
)
|
||||
return content, None, None
|
||||
|
||||
# Attempt to extract the first valid JSON blob from the action block
|
||||
for match in regex.finditer(pattern, action_block, flags=regex.DOTALL):
|
||||
try:
|
||||
action_dict = json.loads(match.group())
|
||||
logger.debug(f"[parse_response] Successfully parsed action: {action_dict}")
|
||||
logger.debug(
|
||||
f"[parse_response] Successfully parsed action: {action_dict}"
|
||||
)
|
||||
return thought_part, action_dict, None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(f"[parse_response] Failed to parse action JSON blob: {match.group()} — {e}")
|
||||
logger.debug(
|
||||
f"[parse_response] Failed to parse action JSON blob: {match.group()} — {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug(f"[parse_response] No valid action JSON found. Returning Thought only.")
|
||||
logger.debug(
|
||||
"[parse_response] No valid action JSON found. Returning Thought only."
|
||||
)
|
||||
return thought_part, None, None
|
||||
|
||||
async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
|
||||
|
|
|
@ -6,14 +6,20 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolCallAgent(AgentBase):
|
||||
"""
|
||||
Agent that manages tool calls and conversations using a language model.
|
||||
It integrates tools and processes them based on user inputs and task orchestration.
|
||||
"""
|
||||
|
||||
tool_history: List[ToolMessage] = Field(default_factory=list, description="Executed tool calls during the conversation.")
|
||||
tool_choice: Optional[str] = Field(default=None, description="Strategy for selecting tools ('auto', 'required', 'none'). Defaults to 'auto' if tools are provided.")
|
||||
tool_history: List[ToolMessage] = Field(
|
||||
default_factory=list, description="Executed tool calls during the conversation."
|
||||
)
|
||||
tool_choice: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Strategy for selecting tools ('auto', 'required', 'none'). Defaults to 'auto' if tools are provided.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
@ -22,7 +28,7 @@ class ToolCallAgent(AgentBase):
|
|||
Initialize the agent's settings, such as tool choice and parent setup.
|
||||
Sets the tool choice strategy based on provided tools.
|
||||
"""
|
||||
self.tool_choice = self.tool_choice or ('auto' if self.tools else None)
|
||||
self.tool_choice = self.tool_choice or ("auto" if self.tools else None)
|
||||
|
||||
# Proceed with base model setup
|
||||
super().model_post_init(__context)
|
||||
|
@ -40,7 +46,9 @@ class ToolCallAgent(AgentBase):
|
|||
Raises:
|
||||
AgentError: If user input is invalid or tool execution fails.
|
||||
"""
|
||||
logger.debug(f"Agent run started with input: {input_data if input_data else 'Using memory context'}")
|
||||
logger.debug(
|
||||
f"Agent run started with input: {input_data if input_data else 'Using memory context'}"
|
||||
)
|
||||
|
||||
# Format messages; construct_messages already includes chat history.
|
||||
messages = self.construct_messages(input_data or {})
|
||||
|
@ -70,9 +78,15 @@ class ToolCallAgent(AgentBase):
|
|||
for tool in tool_calls:
|
||||
function_name = tool.function.name
|
||||
try:
|
||||
logger.info(f"Executing {function_name} with arguments {tool.function.arguments}")
|
||||
result = await self.tool_executor.run_tool(function_name, **tool.function.arguments_dict)
|
||||
tool_message = ToolMessage(tool_call_id=tool.id, name=function_name, content=str(result))
|
||||
logger.info(
|
||||
f"Executing {function_name} with arguments {tool.function.arguments}"
|
||||
)
|
||||
result = await self.tool_executor.run_tool(
|
||||
function_name, **tool.function.arguments_dict
|
||||
)
|
||||
tool_message = ToolMessage(
|
||||
tool_call_id=tool.id, name=function_name, content=str(result)
|
||||
)
|
||||
self.text_formatter.print_message(tool_message)
|
||||
self.tool_history.append(tool_message)
|
||||
except Exception as e:
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .otel import DaprAgentsOTel
|
|
@ -0,0 +1,144 @@
|
|||
from logging import Logger
|
||||
from typing import Union
|
||||
|
||||
from opentelemetry._logs import set_logger_provider
|
||||
from opentelemetry.metrics import set_meter_provider
|
||||
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
|
||||
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource, SERVICE_NAME
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.trace import set_tracer_provider
|
||||
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
|
||||
|
||||
class DaprAgentsOTel:
|
||||
"""
|
||||
OpenTelemetry configuration for Dapr agents.
|
||||
"""
|
||||
|
||||
def __init__(self, service_name: str = "", otlp_endpoint: str = ""):
|
||||
# Configure OpenTelemetry
|
||||
self.service_name = service_name
|
||||
self.otlp_endpoint = otlp_endpoint
|
||||
|
||||
self.setup_resources()
|
||||
|
||||
def setup_resources(self):
|
||||
"""
|
||||
Set up the resource for OpenTelemetry.
|
||||
"""
|
||||
|
||||
self._resource = Resource.create(
|
||||
attributes={
|
||||
SERVICE_NAME: str(self.service_name),
|
||||
}
|
||||
)
|
||||
|
||||
def create_and_instrument_meter_provider(
|
||||
self,
|
||||
otlp_endpoint: str = "",
|
||||
) -> MeterProvider:
|
||||
"""
|
||||
Returns a `MeterProvider` that is configured to export metrics using the `PeriodicExportingMetricReader`
|
||||
which means that metrics are exported periodically in the background. The interval can be set by
|
||||
the environment variable `OTEL_METRIC_EXPORT_INTERVAL`. The default value is 60000ms (1 minute).
|
||||
|
||||
Also sets the global OpenTelemetry meter provider to the returned meter provider.
|
||||
"""
|
||||
|
||||
# Ensure the endpoint is set correctly
|
||||
endpoint = self._endpoint_validator(
|
||||
endpoint=self.otlp_endpoint if otlp_endpoint == "" else otlp_endpoint,
|
||||
telemetry_type="metrics",
|
||||
)
|
||||
|
||||
metric_exporter = OTLPMetricExporter(endpoint=str(endpoint))
|
||||
metric_reader = PeriodicExportingMetricReader(metric_exporter)
|
||||
meter_provider = MeterProvider(
|
||||
resource=self._resource, metric_readers=[metric_reader]
|
||||
)
|
||||
set_meter_provider(meter_provider)
|
||||
return meter_provider
|
||||
|
||||
def create_and_instrument_tracer_provider(
|
||||
self,
|
||||
otlp_endpoint: str = "",
|
||||
) -> TracerProvider:
|
||||
"""
|
||||
Returns a `TracerProvider` that is configured to export traces using the `BatchSpanProcessor`
|
||||
which means that traces are exported in batches. The batch size can be set by
|
||||
the environment variable `OTEL_TRACES_EXPORT_BATCH_SIZE`. The default value is 512.
|
||||
Also sets the global OpenTelemetry tracer provider to the returned tracer provider.
|
||||
"""
|
||||
|
||||
# Ensure the endpoint is set correctly
|
||||
endpoint = self._endpoint_validator(
|
||||
endpoint=self.otlp_endpoint if otlp_endpoint == "" else otlp_endpoint,
|
||||
telemetry_type="traces",
|
||||
)
|
||||
|
||||
trace_exporter = OTLPSpanExporter(endpoint=str(endpoint))
|
||||
tracer_processor = BatchSpanProcessor(trace_exporter)
|
||||
tracer_provider = TracerProvider(resource=self._resource)
|
||||
tracer_provider.add_span_processor(tracer_processor)
|
||||
set_tracer_provider(tracer_provider)
|
||||
return tracer_provider
|
||||
|
||||
def create_and_instrument_logging_provider(
|
||||
self,
|
||||
logger: Logger,
|
||||
otlp_endpoint: str = "",
|
||||
) -> LoggerProvider:
|
||||
"""
|
||||
Returns a `LoggingProvider` that is configured to export logs using the `BatchLogProcessor`
|
||||
which means that logs are exported in batches. The batch size can be set by
|
||||
the environment variable `OTEL_LOGS_EXPORT_BATCH_SIZE`. The default value is 512.
|
||||
Also sets the global OpenTelemetry logging provider to the returned logging provider.
|
||||
"""
|
||||
|
||||
# Ensure the endpoint is set correctly
|
||||
endpoint = self._endpoint_validator(
|
||||
endpoint=self.otlp_endpoint if otlp_endpoint == "" else otlp_endpoint,
|
||||
telemetry_type="logs",
|
||||
)
|
||||
|
||||
log_exporter = OTLPLogExporter(endpoint=str(endpoint))
|
||||
logging_provider = LoggerProvider(resource=self._resource)
|
||||
logging_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
|
||||
set_logger_provider(logging_provider)
|
||||
|
||||
handler = LoggingHandler(logger_provider=logging_provider)
|
||||
logger.addHandler(handler)
|
||||
return logging_provider
|
||||
|
||||
def _endpoint_validator(
|
||||
self,
|
||||
endpoint: str,
|
||||
telemetry_type: str,
|
||||
) -> Union[str | Exception]:
|
||||
"""
|
||||
Validates the endpoint and method.
|
||||
"""
|
||||
|
||||
if endpoint == "":
|
||||
raise ValueError(
|
||||
"OTLP endpoint must be set either in the environment variable OTEL_EXPORTER_OTLP_ENDPOINT or in the constructor."
|
||||
)
|
||||
if endpoint.startswith("https://"):
|
||||
raise NotImplementedError(
|
||||
"OTLP over HTTPS is not supported. Please use HTTP."
|
||||
)
|
||||
|
||||
endpoint = (
|
||||
endpoint
|
||||
if endpoint.endswith(f"/v1/{telemetry_type}")
|
||||
else f"{endpoint}/v1/{telemetry_type}"
|
||||
)
|
||||
endpoint = endpoint if endpoint.startswith("http://") else f"http://{endpoint}"
|
||||
|
||||
return endpoint
|
|
@ -1,7 +1,8 @@
|
|||
import requests
|
||||
import os
|
||||
|
||||
def construct_auth_headers(auth_url, grant_type='client_credentials', **kwargs):
|
||||
|
||||
def construct_auth_headers(auth_url, grant_type="client_credentials", **kwargs):
|
||||
"""
|
||||
Construct authorization headers for API requests.
|
||||
|
||||
|
@ -14,15 +15,19 @@ def construct_auth_headers(auth_url, grant_type='client_credentials', **kwargs):
|
|||
|
||||
# Define default parameters based on the grant_type
|
||||
data = {
|
||||
'grant_type': grant_type,
|
||||
"grant_type": grant_type,
|
||||
}
|
||||
|
||||
# Defaults for client_credentials grant type
|
||||
if grant_type == 'client_credentials':
|
||||
data.update({
|
||||
'client_id': kwargs.get('client_id', os.getenv('CLIENT_ID')),
|
||||
'client_secret': kwargs.get('client_secret', os.getenv('CLIENT_SECRET')),
|
||||
})
|
||||
if grant_type == "client_credentials":
|
||||
data.update(
|
||||
{
|
||||
"client_id": kwargs.get("client_id", os.getenv("CLIENT_ID")),
|
||||
"client_secret": kwargs.get(
|
||||
"client_secret", os.getenv("CLIENT_SECRET")
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Add any additional data passed in kwargs
|
||||
data.update(kwargs)
|
||||
|
@ -37,7 +42,7 @@ def construct_auth_headers(auth_url, grant_type='client_credentials', **kwargs):
|
|||
auth_response_data = auth_response.json()
|
||||
|
||||
# Extract the access token
|
||||
access_token = auth_response_data.get('access_token')
|
||||
access_token = auth_response_data.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
raise ValueError("No access token found in the response")
|
||||
|
|
|
@ -8,16 +8,18 @@ from dapr_agents.memory import MemoryBase
|
|||
from dapr_agents.tool import AgentTool
|
||||
from typing import Optional, List, Union, Type, TypeVar
|
||||
|
||||
T = TypeVar('T', ToolCallAgent, ReActAgent, OpenAPIReActAgent)
|
||||
T = TypeVar("T", ToolCallAgent, ReActAgent, OpenAPIReActAgent)
|
||||
|
||||
|
||||
class AgentFactory:
|
||||
"""
|
||||
Returns agent classes based on the provided pattern.
|
||||
"""
|
||||
|
||||
AGENT_PATTERNS = {
|
||||
"react": ReActAgent,
|
||||
"toolcalling": ToolCallAgent,
|
||||
"openapireact": OpenAPIReActAgent
|
||||
"openapireact": OpenAPIReActAgent,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -54,7 +56,7 @@ class Agent(AgentBase):
|
|||
llm: Optional[LLMClientBase] = None,
|
||||
memory: Optional[MemoryBase] = None,
|
||||
tools: Optional[List[AgentTool]] = [],
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[ToolCallAgent, ReActAgent, OpenAPIReActAgent]:
|
||||
"""
|
||||
Creates and returns an instance of the selected agent class.
|
||||
|
@ -77,11 +79,21 @@ class Agent(AgentBase):
|
|||
memory = memory or ConversationListMemory()
|
||||
|
||||
if pattern == "openapireact":
|
||||
kwargs.update({
|
||||
"spec_parser": kwargs.get('spec_parser', OpenAPISpecParser()),
|
||||
"auth_header": kwargs.get('auth_header', {})
|
||||
})
|
||||
kwargs.update(
|
||||
{
|
||||
"spec_parser": kwargs.get("spec_parser", OpenAPISpecParser()),
|
||||
"auth_header": kwargs.get("auth_header", {}),
|
||||
}
|
||||
)
|
||||
|
||||
instance = super().__new__(agent_class)
|
||||
agent_class.__init__(instance, role=role, name=name, llm=llm, memory=memory, tools=tools, **kwargs)
|
||||
agent_class.__init__(
|
||||
instance,
|
||||
role=role,
|
||||
name=name,
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return instance
|
|
@ -2,6 +2,7 @@ from dapr_agents.types import BaseMessage
|
|||
from typing import List
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def messages_to_string(messages: List[BaseMessage]) -> str:
|
||||
"""
|
||||
Converts messages into a single string with roles and content.
|
||||
|
|
|
@ -4,14 +4,15 @@ from colorama import Style
|
|||
|
||||
# Define your custom colors as a dictionary
|
||||
COLORS = {
|
||||
"dapr_agents_teal": '\033[38;2;147;191;183m',
|
||||
"dapr_agents_mustard": '\033[38;2;242;182;128m',
|
||||
"dapr_agents_red": '\033[38;2;217;95;118m',
|
||||
"dapr_agents_pink": '\033[38;2;191;69;126m',
|
||||
"dapr_agents_purple": '\033[38;2;146;94;130m',
|
||||
"reset": Style.RESET_ALL
|
||||
"dapr_agents_teal": "\033[38;2;147;191;183m",
|
||||
"dapr_agents_mustard": "\033[38;2;242;182;128m",
|
||||
"dapr_agents_red": "\033[38;2;217;95;118m",
|
||||
"dapr_agents_pink": "\033[38;2;191;69;126m",
|
||||
"dapr_agents_purple": "\033[38;2;146;94;130m",
|
||||
"reset": Style.RESET_ALL,
|
||||
}
|
||||
|
||||
|
||||
class ColorTextFormatter:
|
||||
"""
|
||||
A flexible text formatter class to print colored text dynamically.
|
||||
|
@ -56,7 +57,7 @@ class ColorTextFormatter:
|
|||
formatted_line = self.format_text(line, color)
|
||||
print(formatted_line, end="\n" if i < len(lines) - 1 else "")
|
||||
|
||||
print(COLORS['reset']) # Ensure terminal color is reset at the end
|
||||
print(COLORS["reset"]) # Ensure terminal color is reset at the end
|
||||
|
||||
def print_separator(self):
|
||||
"""
|
||||
|
@ -65,7 +66,11 @@ class ColorTextFormatter:
|
|||
separator = "-" * 80
|
||||
self.print_colored_text([(f"\n{separator}\n", "reset")])
|
||||
|
||||
def print_message(self, message: Union[BaseMessage, Dict[str, Any]], include_separator: bool = True):
|
||||
def print_message(
|
||||
self,
|
||||
message: Union[BaseMessage, Dict[str, Any]],
|
||||
include_separator: bool = True,
|
||||
):
|
||||
"""
|
||||
Prints messages with colored formatting based on the role and message content.
|
||||
|
||||
|
@ -91,7 +96,7 @@ class ColorTextFormatter:
|
|||
"user": "dapr_agents_mustard",
|
||||
"assistant": "dapr_agents_teal",
|
||||
"tool_calls": "dapr_agents_red",
|
||||
"tool": "dapr_agents_pink"
|
||||
"tool": "dapr_agents_pink",
|
||||
}
|
||||
|
||||
# Handle tool calls
|
||||
|
@ -103,7 +108,10 @@ class ColorTextFormatter:
|
|||
tool_id = tool_call["id"]
|
||||
tool_call_text = [
|
||||
(f"{formatted_role}:\n", color_map["tool_calls"]),
|
||||
(f"Function name: {function_name} (Call Id: {tool_id})\n", color_map["tool_calls"]),
|
||||
(
|
||||
f"Function name: {function_name} (Call Id: {tool_id})\n",
|
||||
color_map["tool_calls"],
|
||||
),
|
||||
(f"Arguments: {arguments}", color_map["tool_calls"]),
|
||||
]
|
||||
self.print_colored_text(tool_call_text)
|
||||
|
@ -142,7 +150,7 @@ class ColorTextFormatter:
|
|||
color_map = {
|
||||
"Thought": "dapr_agents_red",
|
||||
"Action": "dapr_agents_pink",
|
||||
"Observation": "dapr_agents_purple"
|
||||
"Observation": "dapr_agents_purple",
|
||||
}
|
||||
|
||||
# Get the color for the part type, defaulting to reset if not found
|
||||
|
|
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class EmbedderBase(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for Embedders.
|
||||
|
|
|
@ -7,6 +7,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
|
||||
"""
|
||||
NVIDIA-based embedder for generating text embeddings with support for indexing (passage) and querying.
|
||||
|
@ -17,10 +18,16 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
|
|||
normalize (bool): Whether to normalize embeddings. Defaults to True.
|
||||
"""
|
||||
|
||||
chunk_size: int = Field(default=1000, description="Batch size for embedding requests.")
|
||||
normalize: bool = Field(default=True, description="Whether to normalize embeddings.")
|
||||
chunk_size: int = Field(
|
||||
default=1000, description="Batch size for embedding requests."
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True, description="Whether to normalize embeddings."
|
||||
)
|
||||
|
||||
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
||||
def embed(
|
||||
self, input: Union[str, List[str]]
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Embeds input text(s) for indexing with default input_type set to 'passage'.
|
||||
|
||||
|
@ -37,7 +44,9 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
|
|||
"""
|
||||
return self._generate_embeddings(input, input_type="passage")
|
||||
|
||||
def embed_query(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
||||
def embed_query(
|
||||
self, input: Union[str, List[str]]
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Embeds input text(s) for querying with input_type set to 'query'.
|
||||
|
||||
|
@ -54,7 +63,9 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
|
|||
"""
|
||||
return self._generate_embeddings(input, input_type="query")
|
||||
|
||||
def _generate_embeddings(self, input: Union[str, List[str]], input_type: str) -> Union[List[float], List[List[float]]]:
|
||||
def _generate_embeddings(
|
||||
self, input: Union[str, List[str]], input_type: str
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Helper function to generate embeddings for given input text(s) with specified input_type.
|
||||
|
||||
|
@ -82,7 +93,8 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
|
|||
# Normalize embeddings if required
|
||||
if self.normalize:
|
||||
normalized_embeddings = [
|
||||
(embedding / np.linalg.norm(embedding)).tolist() for embedding in chunk_embeddings
|
||||
(embedding / np.linalg.norm(embedding)).tolist()
|
||||
for embedding in chunk_embeddings
|
||||
]
|
||||
else:
|
||||
normalized_embeddings = chunk_embeddings
|
||||
|
@ -90,7 +102,9 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
|
|||
# Return a single embedding if the input was a single string; otherwise, return a list
|
||||
return normalized_embeddings[0] if single_input else normalized_embeddings
|
||||
|
||||
def __call__(self, input: Union[str, List[str]], query: bool = False) -> Union[List[float], List[List[float]]]:
|
||||
def __call__(
|
||||
self, input: Union[str, List[str]], query: bool = False
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Allows the instance to be called directly to embed text(s).
|
||||
|
||||
|
|
|
@ -7,17 +7,28 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
|
||||
"""
|
||||
OpenAI-based embedder for generating text embeddings with handling for long inputs.
|
||||
Inherits functionality from OpenAIEmbeddingClient for API interactions.
|
||||
"""
|
||||
|
||||
max_tokens: int = Field(default=8191, description="Maximum tokens allowed per input.")
|
||||
chunk_size: int = Field(default=1000, description="Batch size for embedding requests.")
|
||||
normalize: bool = Field(default=True, description="Whether to normalize embeddings.")
|
||||
encoding_name: Optional[str] = Field(default=None, description="Token encoding name (if provided).")
|
||||
encoder: Optional[Any] = Field(default=None, init=False, description="TikToken Encoder")
|
||||
max_tokens: int = Field(
|
||||
default=8191, description="Maximum tokens allowed per input."
|
||||
)
|
||||
chunk_size: int = Field(
|
||||
default=1000, description="Batch size for embedding requests."
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True, description="Whether to normalize embeddings."
|
||||
)
|
||||
encoding_name: Optional[str] = Field(
|
||||
default=None, description="Token encoding name (if provided)."
|
||||
)
|
||||
encoder: Optional[Any] = Field(
|
||||
default=None, init=False, description="TikToken Encoder"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
@ -59,9 +70,13 @@ class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
|
|||
|
||||
def _chunk_tokens(self, tokens: List[int], chunk_length: int) -> List[List[int]]:
|
||||
"""Splits tokens into chunks of the specified length."""
|
||||
return [tokens[i:i + chunk_length] for i in range(0, len(tokens), chunk_length)]
|
||||
return [
|
||||
tokens[i : i + chunk_length] for i in range(0, len(tokens), chunk_length)
|
||||
]
|
||||
|
||||
def _process_embeddings(self, embeddings: List[List[float]], weights: List[int]) -> List[float]:
|
||||
def _process_embeddings(
|
||||
self, embeddings: List[List[float]], weights: List[int]
|
||||
) -> List[float]:
|
||||
"""Combines embeddings using weighted averaging."""
|
||||
weighted_avg = np.average(embeddings, axis=0, weights=weights)
|
||||
if self.normalize:
|
||||
|
@ -69,7 +84,9 @@ class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
|
|||
return (weighted_avg / norm).tolist()
|
||||
return weighted_avg.tolist()
|
||||
|
||||
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
||||
def embed(
|
||||
self, input: Union[str, List[str]]
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Embeds input text(s) with support for both single and multiple inputs, handling long texts via chunking and batching.
|
||||
|
||||
|
@ -133,13 +150,17 @@ class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
|
|||
results.append(embeddings[0])
|
||||
else:
|
||||
# Combine chunk embeddings using weighted averaging
|
||||
weights = [len(chunk) for chunk in self._chunk_tokens(tokens, self.max_tokens)]
|
||||
weights = [
|
||||
len(chunk) for chunk in self._chunk_tokens(tokens, self.max_tokens)
|
||||
]
|
||||
results.append(self._process_embeddings(embeddings, weights))
|
||||
|
||||
# Return a single embedding if the input was a single string; otherwise, return a list
|
||||
return results[0] if single_input else results
|
||||
|
||||
def __call__(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
||||
def __call__(
|
||||
self, input: Union[str, List[str]]
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Allows the instance to be called directly to embed text(s).
|
||||
|
||||
|
|
|
@ -6,19 +6,33 @@ import os
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentenceTransformerEmbedder(EmbedderBase):
|
||||
"""
|
||||
SentenceTransformer-based embedder for generating text embeddings.
|
||||
Supports multi-process encoding for large datasets.
|
||||
"""
|
||||
|
||||
model: str = Field(default="all-MiniLM-L6-v2", description="Name of the SentenceTransformer model to use.")
|
||||
device: Literal["cpu", "cuda", "mps", "npu"] = Field(default="cpu", description="Device for computation.")
|
||||
normalize_embeddings: bool = Field(default=False, description="Whether to normalize embeddings.")
|
||||
multi_process: bool = Field(default=False, description="Whether to use multi-process encoding.")
|
||||
cache_dir: Optional[str] = Field(default=None, description="Directory to cache or load the model.")
|
||||
model: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Name of the SentenceTransformer model to use.",
|
||||
)
|
||||
device: Literal["cpu", "cuda", "mps", "npu"] = Field(
|
||||
default="cpu", description="Device for computation."
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False, description="Whether to normalize embeddings."
|
||||
)
|
||||
multi_process: bool = Field(
|
||||
default=False, description="Whether to use multi-process encoding."
|
||||
)
|
||||
cache_dir: Optional[str] = Field(
|
||||
default=None, description="Directory to cache or load the model."
|
||||
)
|
||||
|
||||
client: Optional[Any] = Field(default=None, init=False, description="Loaded SentenceTransformer model.")
|
||||
client: Optional[Any] = Field(
|
||||
default=None, init=False, description="Loaded SentenceTransformer model."
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
|
@ -35,26 +49,40 @@ class SentenceTransformerEmbedder(EmbedderBase):
|
|||
)
|
||||
|
||||
# Determine whether to load from cache or download
|
||||
model_path = self.cache_dir if self.cache_dir and os.path.exists(self.cache_dir) else self.model
|
||||
model_path = (
|
||||
self.cache_dir
|
||||
if self.cache_dir and os.path.exists(self.cache_dir)
|
||||
else self.model
|
||||
)
|
||||
# Attempt to load the model
|
||||
try:
|
||||
if os.path.exists(model_path):
|
||||
logger.info(f"Loading SentenceTransformer model from local path: {model_path}")
|
||||
logger.info(
|
||||
f"Loading SentenceTransformer model from local path: {model_path}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Downloading SentenceTransformer model: {self.model}")
|
||||
if self.cache_dir:
|
||||
logger.info(f"Model will be cached to: {self.cache_dir}")
|
||||
self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=model_path, device=self.device)
|
||||
self.client: SentenceTransformer = SentenceTransformer(
|
||||
model_name_or_path=model_path, device=self.device
|
||||
)
|
||||
logger.info("Model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SentenceTransformer model: {e}")
|
||||
raise
|
||||
# Save to cache directory if downloaded
|
||||
if model_path == self.model and self.cache_dir and not os.path.exists(self.cache_dir):
|
||||
if (
|
||||
model_path == self.model
|
||||
and self.cache_dir
|
||||
and not os.path.exists(self.cache_dir)
|
||||
):
|
||||
logger.info(f"Saving the downloaded model to: {self.cache_dir}")
|
||||
self.client.save(self.cache_dir)
|
||||
|
||||
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
||||
def embed(
|
||||
self, input: Union[str, List[str]]
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for input text(s).
|
||||
|
||||
|
@ -82,7 +110,7 @@ class SentenceTransformerEmbedder(EmbedderBase):
|
|||
embeddings = self.client.encode_multi_process(
|
||||
input_strings,
|
||||
pool=pool,
|
||||
normalize_embeddings=self.normalize_embeddings
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
finally:
|
||||
logger.info("Stopping multi-process pool.")
|
||||
|
@ -91,14 +119,16 @@ class SentenceTransformerEmbedder(EmbedderBase):
|
|||
embeddings = self.client.encode(
|
||||
input_strings,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize_embeddings
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
|
||||
if single_input:
|
||||
return embeddings[0].tolist()
|
||||
return embeddings.tolist()
|
||||
|
||||
def __call__(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
||||
def __call__(
|
||||
self, input: Union[str, List[str]]
|
||||
) -> Union[List[float], List[List[float]]]:
|
||||
"""
|
||||
Allows the instance to be called directly to embed text(s).
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class ArxivFetcher(FetcherBase):
|
|||
download: bool = False,
|
||||
dirpath: Path = Path("./"),
|
||||
include_summary: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[List[Dict], List["Document"]]:
|
||||
"""
|
||||
Search for papers on arXiv and optionally download them.
|
||||
|
@ -69,7 +69,9 @@ class ArxivFetcher(FetcherBase):
|
|||
|
||||
# Enforce that both from_date and to_date are provided if one is specified
|
||||
if (from_date and not to_date) or (to_date and not from_date):
|
||||
raise ValueError("Both 'from_date' and 'to_date' must be specified if one is provided.")
|
||||
raise ValueError(
|
||||
"Both 'from_date' and 'to_date' must be specified if one is provided."
|
||||
)
|
||||
|
||||
# Add date filter if both from_date and to_date are provided
|
||||
if from_date and to_date:
|
||||
|
@ -94,7 +96,7 @@ class ArxivFetcher(FetcherBase):
|
|||
content_id: str,
|
||||
download: bool = False,
|
||||
dirpath: Path = Path("./"),
|
||||
include_summary: bool = False
|
||||
include_summary: bool = False,
|
||||
) -> Union[Optional[Dict], Optional[Document]]:
|
||||
"""
|
||||
Search for a specific paper by its arXiv ID and optionally download it.
|
||||
|
@ -133,17 +135,15 @@ class ArxivFetcher(FetcherBase):
|
|||
logger.warning(f"No result found for ID: {content_id}")
|
||||
return None
|
||||
|
||||
return self._process_results([result], download, dirpath, include_summary)[0]
|
||||
return self._process_results([result], download, dirpath, include_summary)[
|
||||
0
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching result for ID {content_id}: {e}")
|
||||
return None
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
results: List[Any],
|
||||
download: bool,
|
||||
dirpath: Path,
|
||||
include_summary: bool
|
||||
self, results: List[Any], download: bool, dirpath: Path, include_summary: bool
|
||||
) -> Union[List[Dict], List["Document"]]:
|
||||
"""
|
||||
Process arXiv search results.
|
||||
|
@ -162,12 +162,18 @@ class ArxivFetcher(FetcherBase):
|
|||
metadata_list = []
|
||||
for result in results:
|
||||
file_path = self._download_result(result, dirpath)
|
||||
metadata_list.append(self._format_result_metadata(result, file_path=file_path, include_summary=include_summary))
|
||||
metadata_list.append(
|
||||
self._format_result_metadata(
|
||||
result, file_path=file_path, include_summary=include_summary
|
||||
)
|
||||
)
|
||||
return metadata_list
|
||||
else:
|
||||
documents = []
|
||||
for result in results:
|
||||
metadata = self._format_result_metadata(result, include_summary=include_summary)
|
||||
metadata = self._format_result_metadata(
|
||||
result, include_summary=include_summary
|
||||
)
|
||||
text = result.summary.strip()
|
||||
documents.append(Document(text=text, metadata=metadata))
|
||||
return documents
|
||||
|
@ -194,7 +200,12 @@ class ArxivFetcher(FetcherBase):
|
|||
logger.error(f"Failed to download paper {result.title}: {e}")
|
||||
return None
|
||||
|
||||
def _format_result_metadata(self, result: Any, file_path: Optional[str] = None, include_summary: bool = False) -> Dict:
|
||||
def _format_result_metadata(
|
||||
self,
|
||||
result: Any,
|
||||
file_path: Optional[str] = None,
|
||||
include_summary: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Format metadata from an arXiv result, optionally including file path and summary.
|
||||
|
||||
|
@ -219,12 +230,14 @@ class ArxivFetcher(FetcherBase):
|
|||
}
|
||||
|
||||
if self.include_full_metadata:
|
||||
metadata.update({
|
||||
metadata.update(
|
||||
{
|
||||
"links": result.links,
|
||||
"authors_comment": result.comment,
|
||||
"DOI": result.doi,
|
||||
"journal_reference": result.journal_ref,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
if include_summary:
|
||||
metadata["summary"] = result.summary.strip()
|
||||
|
@ -262,7 +275,9 @@ class ArxivFetcher(FetcherBase):
|
|||
if isinstance(date, str):
|
||||
# Check if the string matches the basic format
|
||||
if not re.fullmatch(r"^\d{8}(\d{4})?$", date):
|
||||
raise ValueError(f"Invalid date format: {date}. Use 'YYYYMMDD' or 'YYYYMMDDHHMM'.")
|
||||
raise ValueError(
|
||||
f"Invalid date format: {date}. Use 'YYYYMMDD' or 'YYYYMMDDHHMM'."
|
||||
)
|
||||
|
||||
# Validate that it is a real date
|
||||
try:
|
||||
|
@ -277,4 +292,6 @@ class ArxivFetcher(FetcherBase):
|
|||
elif isinstance(date, datetime):
|
||||
return date.strftime("%Y%m%d%H%M")
|
||||
else:
|
||||
raise ValueError("Invalid date input. Provide a string in 'YYYYMMDD', 'YYYYMMDDHHMM' format, or a datetime object.")
|
||||
raise ValueError(
|
||||
"Invalid date input. Provide a string in 'YYYYMMDD', 'YYYYMMDDHHMM' format, or a datetime object."
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class FetcherBase(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for fetchers.
|
||||
|
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
class ReaderBase(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for file readers.
|
||||
|
|
|
@ -9,7 +9,9 @@ class PyMuPDFReader(ReaderBase):
|
|||
Reader for PDF documents using PyMuPDF.
|
||||
"""
|
||||
|
||||
def load(self, file_path: Path, additional_metadata: Optional[Dict] = None) -> List[Document]:
|
||||
def load(
|
||||
self, file_path: Path, additional_metadata: Optional[Dict] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load content from a PDF file using PyMuPDF.
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ class PyPDFReader(ReaderBase):
|
|||
Reader for PDF documents using PyPDF.
|
||||
"""
|
||||
|
||||
def load(self, file_path: Path, additional_metadata: Optional[Dict] = None) -> List[Document]:
|
||||
def load(
|
||||
self, file_path: Path, additional_metadata: Optional[Dict] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load content from a PDF file using PyPDF.
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||
from typing import List
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class TextLoader(ReaderBase):
|
||||
"""
|
||||
Loader for plain text files.
|
||||
|
@ -11,6 +12,7 @@ class TextLoader(ReaderBase):
|
|||
Attributes:
|
||||
encoding (str): The text file encoding. Defaults to 'utf-8'.
|
||||
"""
|
||||
|
||||
encoding: str = Field(default="utf-8", description="Encoding of the text file.")
|
||||
|
||||
def load(self, file_path: Path) -> List[Document]:
|
||||
|
|
|
@ -7,6 +7,7 @@ import logging
|
|||
|
||||
try:
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
NLTK_AVAILABLE = True
|
||||
except ImportError:
|
||||
sent_tokenize = None
|
||||
|
@ -14,6 +15,7 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SplitterBase(BaseModel, ABC):
|
||||
"""
|
||||
Base class for defining text splitting strategies.
|
||||
|
@ -21,13 +23,34 @@ class SplitterBase(BaseModel, ABC):
|
|||
based on separators, regex patterns, or sentence-based splitting.
|
||||
"""
|
||||
|
||||
chunk_size: int = Field(default=4000, description="Maximum size of chunks (in characters or tokens).", gt=0)
|
||||
chunk_overlap: int = Field(default=200, description="Overlap size between chunks for context continuity.", ge=0)
|
||||
chunk_size_function: Callable[[str], int] = Field(default=len, description="Function to calculate chunk size (e.g., by characters or tokens).")
|
||||
separator: Optional[str] = Field(default="\n\n", description="Primary separator for splitting text.")
|
||||
fallback_separators: List[str] = Field(default_factory=lambda: ["\n", " "], description="Fallback separators if the primary separator fails.")
|
||||
fallback_regex: str = Field(default=r"[^,.;。?!]+[,.;。?!]", description="Improved regex pattern for fallback splitting.")
|
||||
reserved_metadata_size: int = Field(default=0, description="Tokens reserved for metadata.", ge=0)
|
||||
chunk_size: int = Field(
|
||||
default=4000,
|
||||
description="Maximum size of chunks (in characters or tokens).",
|
||||
gt=0,
|
||||
)
|
||||
chunk_overlap: int = Field(
|
||||
default=200,
|
||||
description="Overlap size between chunks for context continuity.",
|
||||
ge=0,
|
||||
)
|
||||
chunk_size_function: Callable[[str], int] = Field(
|
||||
default=len,
|
||||
description="Function to calculate chunk size (e.g., by characters or tokens).",
|
||||
)
|
||||
separator: Optional[str] = Field(
|
||||
default="\n\n", description="Primary separator for splitting text."
|
||||
)
|
||||
fallback_separators: List[str] = Field(
|
||||
default_factory=lambda: ["\n", " "],
|
||||
description="Fallback separators if the primary separator fails.",
|
||||
)
|
||||
fallback_regex: str = Field(
|
||||
default=r"[^,.;。?!]+[,.;。?!]",
|
||||
description="Improved regex pattern for fallback splitting.",
|
||||
)
|
||||
reserved_metadata_size: int = Field(
|
||||
default=0, description="Tokens reserved for metadata.", ge=0
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
@ -88,7 +111,9 @@ class SplitterBase(BaseModel, ABC):
|
|||
chunks.append(full_chunk)
|
||||
|
||||
# Logging information for overlap and chunk size
|
||||
logger.debug(f"Chunk {len(chunks)} finalized. Size: {current_size}. Overlap size: {self.chunk_overlap}")
|
||||
logger.debug(
|
||||
f"Chunk {len(chunks)} finalized. Size: {current_size}. Overlap size: {self.chunk_overlap}"
|
||||
)
|
||||
|
||||
# Create an overlap using sentences from the current chunk
|
||||
overlap = []
|
||||
|
@ -223,13 +248,15 @@ class SplitterBase(BaseModel, ABC):
|
|||
end_index = start_index + self._get_chunk_size(chunk)
|
||||
|
||||
metadata = doc.metadata.copy() if doc.metadata else {}
|
||||
metadata.update({
|
||||
metadata.update(
|
||||
{
|
||||
"chunk_number": chunk_num + 1,
|
||||
"total_chunks": len(text_chunks),
|
||||
"start_index": start_index,
|
||||
"end_index": end_index,
|
||||
"chunk_length": self._get_chunk_size(chunk),
|
||||
})
|
||||
}
|
||||
)
|
||||
chunked_documents.append(Document(metadata=metadata, text=chunk))
|
||||
previous_end = end_index
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextSplitter(SplitterBase):
|
||||
"""
|
||||
Concrete implementation of the SplitterBase class.
|
||||
|
@ -33,7 +34,9 @@ class TextSplitter(SplitterBase):
|
|||
|
||||
# Step 2: Short-circuit for small texts
|
||||
if self._get_chunk_size(text) <= effective_chunk_size:
|
||||
logger.debug("Text size is smaller than effective chunk size. Returning as a single chunk.")
|
||||
logger.debug(
|
||||
"Text size is smaller than effective chunk size. Returning as a single chunk."
|
||||
)
|
||||
return [text]
|
||||
|
||||
# Step 3: Use adaptive splitting strategy
|
||||
|
|
|
@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
|||
from pydantic import BaseModel
|
||||
from typing import List, ClassVar
|
||||
|
||||
|
||||
class CodeExecutorBase(BaseModel, ABC):
|
||||
"""Abstract base class for executing code in different environments."""
|
||||
|
||||
|
|
|
@ -11,26 +11,57 @@ import os
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DockerCodeExecutor(CodeExecutorBase):
|
||||
"""Executes code securely inside a persistent Docker container with dynamic volume updates."""
|
||||
|
||||
image: Optional[str] = Field("python:3.9", description="Docker image used for execution.")
|
||||
container_name: Optional[str] = Field("dapr_agents_code_executor", description="Name of the Docker container.")
|
||||
disable_network_access: bool = Field(default=True, description="Disable network access inside the container.")
|
||||
execution_timeout: int = Field(default=60, description="Max execution time (seconds).")
|
||||
execution_mode: str = Field("detached", description="Execution mode: 'interactive' or 'detached'.")
|
||||
restart_policy: str = Field("no", description="Container restart policy: 'no', 'on-failure', 'always'.")
|
||||
image: Optional[str] = Field(
|
||||
"python:3.9", description="Docker image used for execution."
|
||||
)
|
||||
container_name: Optional[str] = Field(
|
||||
"dapr_agents_code_executor", description="Name of the Docker container."
|
||||
)
|
||||
disable_network_access: bool = Field(
|
||||
default=True, description="Disable network access inside the container."
|
||||
)
|
||||
execution_timeout: int = Field(
|
||||
default=60, description="Max execution time (seconds)."
|
||||
)
|
||||
execution_mode: str = Field(
|
||||
"detached", description="Execution mode: 'interactive' or 'detached'."
|
||||
)
|
||||
restart_policy: str = Field(
|
||||
"no", description="Container restart policy: 'no', 'on-failure', 'always'."
|
||||
)
|
||||
max_memory: str = Field("500m", description="Max memory for execution.")
|
||||
cpu_quota: int = Field(50000, description="CPU quota limit.")
|
||||
runtime: Optional[str] = Field(default=None, description="Container runtime (e.g., 'nvidia').")
|
||||
auto_remove: bool = Field(default=False, description="Keep container running to reuse it.")
|
||||
auto_cleanup: bool = Field(default=False, description="Automatically clean up the workspace after execution.")
|
||||
volume_access_mode: Literal["ro", "rw"] = Field(default="ro", description="Access mode for the workspace volume.")
|
||||
host_workspace: Optional[str] = Field(default=None, description="Custom workspace on host. If None, defaults to system temp dir.")
|
||||
runtime: Optional[str] = Field(
|
||||
default=None, description="Container runtime (e.g., 'nvidia')."
|
||||
)
|
||||
auto_remove: bool = Field(
|
||||
default=False, description="Keep container running to reuse it."
|
||||
)
|
||||
auto_cleanup: bool = Field(
|
||||
default=False,
|
||||
description="Automatically clean up the workspace after execution.",
|
||||
)
|
||||
volume_access_mode: Literal["ro", "rw"] = Field(
|
||||
default="ro", description="Access mode for the workspace volume."
|
||||
)
|
||||
host_workspace: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Custom workspace on host. If None, defaults to system temp dir.",
|
||||
)
|
||||
|
||||
docker_client: Optional[Any] = Field(default=None, init=False, description="Docker client instance.")
|
||||
execution_container: Optional[Any] = Field(default=None, init=False, description="Persistent Docker container.")
|
||||
container_workspace: Optional[str] = Field(default="/workspace", init=False, description="Mounted workspace in container.")
|
||||
docker_client: Optional[Any] = Field(
|
||||
default=None, init=False, description="Docker client instance."
|
||||
)
|
||||
execution_container: Optional[Any] = Field(
|
||||
default=None, init=False, description="Persistent Docker container."
|
||||
)
|
||||
container_workspace: Optional[str] = Field(
|
||||
default="/workspace", init=False, description="Mounted workspace in container."
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initializes the Docker client and ensures a reusable execution container is ready."""
|
||||
|
@ -38,7 +69,9 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
from docker import DockerClient
|
||||
from docker.errors import DockerException
|
||||
except ImportError as e:
|
||||
raise ImportError("Install 'docker' package with 'pip install docker'.") from e
|
||||
raise ImportError(
|
||||
"Install 'docker' package with 'pip install docker'."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.docker_client: DockerClient = DockerClient.from_env()
|
||||
|
@ -47,9 +80,13 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
|
||||
# Validate or Set the Host Workspace
|
||||
if self.host_workspace:
|
||||
self.host_workspace = os.path.abspath(self.host_workspace) # Ensure absolute path
|
||||
self.host_workspace = os.path.abspath(
|
||||
self.host_workspace
|
||||
) # Ensure absolute path
|
||||
else:
|
||||
self.host_workspace = os.path.join(tempfile.gettempdir(), "dapr_agents_executor_workspace")
|
||||
self.host_workspace = os.path.join(
|
||||
tempfile.gettempdir(), "dapr_agents_executor_workspace"
|
||||
)
|
||||
|
||||
# Ensure the directory exists
|
||||
os.makedirs(self.host_workspace, exist_ok=True)
|
||||
|
@ -66,10 +103,14 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
try:
|
||||
from docker.errors import NotFound
|
||||
except ImportError as e:
|
||||
raise ImportError("Install 'docker' package with 'pip install docker'.") from e
|
||||
raise ImportError(
|
||||
"Install 'docker' package with 'pip install docker'."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.execution_container = self.docker_client.containers.get(self.container_name)
|
||||
self.execution_container = self.docker_client.containers.get(
|
||||
self.container_name
|
||||
)
|
||||
logger.info(f"Reusing existing container: {self.container_name}")
|
||||
except NotFound:
|
||||
logger.info(f"Creating a new container: {self.container_name}")
|
||||
|
@ -82,7 +123,9 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
try:
|
||||
from docker.errors import DockerException, APIError
|
||||
except ImportError as e:
|
||||
raise ImportError("Install 'docker' package with 'pip install docker'.") from e
|
||||
raise ImportError(
|
||||
"Install 'docker' package with 'pip install docker'."
|
||||
) from e
|
||||
try:
|
||||
self.execution_container = self.docker_client.containers.create(
|
||||
self.image,
|
||||
|
@ -99,13 +142,22 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
restart_policy={"Name": self.restart_policy},
|
||||
runtime=self.runtime,
|
||||
working_dir=self.container_workspace,
|
||||
volumes={self.host_workspace: {"bind": self.container_workspace, "mode": self.volume_access_mode}},
|
||||
volumes={
|
||||
self.host_workspace: {
|
||||
"bind": self.container_workspace,
|
||||
"mode": self.volume_access_mode,
|
||||
}
|
||||
},
|
||||
)
|
||||
except (DockerException, APIError) as e:
|
||||
logger.error(f"Failed to create the execution container: {str(e)}")
|
||||
raise RuntimeError(f"Failed to create the execution container: {str(e)}") from e
|
||||
raise RuntimeError(
|
||||
f"Failed to create the execution container: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def execute(self, request: Union[ExecutionRequest, dict]) -> List[ExecutionResult]:
|
||||
async def execute(
|
||||
self, request: Union[ExecutionRequest, dict]
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Executes code inside the persistent Docker container.
|
||||
The code is written to a shared volume instead of stopping & starting the container.
|
||||
|
@ -127,7 +179,9 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
if snippet.language == "python":
|
||||
required_packages = self._extract_imports(snippet.code)
|
||||
if required_packages:
|
||||
logger.info(f"Installing missing dependencies: {required_packages}")
|
||||
logger.info(
|
||||
f"Installing missing dependencies: {required_packages}"
|
||||
)
|
||||
await self._install_missing_packages(required_packages)
|
||||
|
||||
script_filename = f"script.{snippet.language}"
|
||||
|
@ -138,17 +192,24 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
with open(script_path_host, "w", encoding="utf-8") as script_file:
|
||||
script_file.write(snippet.code)
|
||||
|
||||
cmd = f"timeout {self.execution_timeout} python3 {script_path_container}" \
|
||||
if snippet.language == "python" else f"timeout {self.execution_timeout} sh {script_path_container}"
|
||||
cmd = (
|
||||
f"timeout {self.execution_timeout} python3 {script_path_container}"
|
||||
if snippet.language == "python"
|
||||
else f"timeout {self.execution_timeout} sh {script_path_container}"
|
||||
)
|
||||
|
||||
# Run command dynamically inside the running container
|
||||
exec_result = await asyncio.to_thread(self.execution_container.exec_run, cmd)
|
||||
exec_result = await asyncio.to_thread(
|
||||
self.execution_container.exec_run, cmd
|
||||
)
|
||||
|
||||
exit_code = exec_result.exit_code
|
||||
logs = exec_result.output.decode("utf-8", errors="ignore").strip()
|
||||
status = "success" if exit_code == 0 else "error"
|
||||
|
||||
results.append(ExecutionResult(status=status, output=logs, exit_code=exit_code))
|
||||
results.append(
|
||||
ExecutionResult(status=status, output=logs, exit_code=exit_code)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logs = self.get_container_logs()
|
||||
|
@ -159,7 +220,9 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
if self.auto_cleanup:
|
||||
if os.path.exists(self.host_workspace):
|
||||
shutil.rmtree(self.host_workspace, ignore_errors=True)
|
||||
logger.info(f"Temporary workspace {self.host_workspace} cleaned up.")
|
||||
logger.info(
|
||||
f"Temporary workspace {self.host_workspace} cleaned up."
|
||||
)
|
||||
|
||||
if self.auto_remove:
|
||||
self.execution_container.stop()
|
||||
|
@ -191,9 +254,9 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
for node in ast.walk(parsed_code):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
modules.add(alias.name.split('.')[0])
|
||||
modules.add(alias.name.split(".")[0])
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
modules.add(node.module.split('.')[0])
|
||||
modules.add(node.module.split(".")[0])
|
||||
|
||||
return list(modules)
|
||||
|
||||
|
@ -231,7 +294,9 @@ class DockerCodeExecutor(CodeExecutorBase):
|
|||
Exception: If log retrieval fails, an error message is logged.
|
||||
"""
|
||||
try:
|
||||
logs = self.execution_container.logs(stdout=True, stderr=True).decode("utf-8")
|
||||
logs = self.execution_container.logs(stdout=True, stderr=True).decode(
|
||||
"utf-8"
|
||||
)
|
||||
return logs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve container logs: {str(e)}")
|
||||
|
|
|
@ -1,227 +1,332 @@
|
|||
from dapr_agents.executors import CodeExecutorBase
|
||||
from dapr_agents.types.executor import ExecutionRequest, ExecutionResult
|
||||
from typing import List, Union, Any, Callable
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
"""Local executor that runs Python or shell snippets in cached virtual-envs."""
|
||||
|
||||
import asyncio
|
||||
import venv
|
||||
import logging
|
||||
import ast
|
||||
import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
import ast
|
||||
import venv
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Sequence, Union
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from dapr_agents.executors import CodeExecutorBase
|
||||
from dapr_agents.executors.sandbox import detect_backend, wrap_command, SandboxType
|
||||
from dapr_agents.executors.utils.package_manager import (
|
||||
get_install_command,
|
||||
get_project_type,
|
||||
)
|
||||
from dapr_agents.types.executor import ExecutionRequest, ExecutionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LocalCodeExecutor(CodeExecutorBase):
|
||||
"""Executes code locally in an optimized virtual environment with caching,
|
||||
user-defined functions, and enhanced security.
|
||||
|
||||
Supports Python and shell execution with real-time logging,
|
||||
efficient dependency management, and reduced file I/O.
|
||||
class LocalCodeExecutor(CodeExecutorBase):
|
||||
"""
|
||||
Run snippets locally with **optional OS-level sandboxing** and
|
||||
per-snippet virtual-env caching.
|
||||
"""
|
||||
|
||||
cache_dir: Path = Field(default_factory=lambda: Path.cwd() / ".dapr_agents_cached_envs", description="Directory for cached virtual environments and execution artifacts.")
|
||||
user_functions: List[Callable] = Field(default_factory=list, description="List of user-defined functions available during execution.")
|
||||
cleanup_threshold: int = Field(default=604800, description="Time (in seconds) before cached virtual environments are considered for cleanup.")
|
||||
cache_dir: Path = Field(
|
||||
default_factory=lambda: Path.cwd() / ".dapr_agents_cached_envs",
|
||||
description="Directory that stores cached virtual environments.",
|
||||
)
|
||||
user_functions: List[Callable] = Field(
|
||||
default_factory=list,
|
||||
description="Functions whose source is prepended to every Python snippet.",
|
||||
)
|
||||
sandbox: SandboxType = Field(
|
||||
default="auto",
|
||||
description="'seatbelt' | 'firejail' | 'none' | 'auto' (best available)",
|
||||
)
|
||||
writable_paths: List[Path] = Field(
|
||||
default_factory=list,
|
||||
description="Extra paths the sandboxed process may write to.",
|
||||
)
|
||||
cleanup_threshold: int = Field(
|
||||
default=604_800, # one week
|
||||
description="Seconds before a cached venv is considered stale.",
|
||||
)
|
||||
|
||||
_env_lock = asyncio.Lock()
|
||||
_env_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
_bootstrapped_root: Path | None = PrivateAttr(default=None)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Ensures the cache directory is created after model initialization."""
|
||||
def model_post_init(self, __context: Any) -> None: # noqa: D401
|
||||
"""Create ``cache_dir`` after pydantic instantiation."""
|
||||
super().model_post_init(__context)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Cache directory set.")
|
||||
logger.debug(f"{self.cache_dir}")
|
||||
logger.debug("venv cache directory: %s", self.cache_dir)
|
||||
|
||||
async def execute(self, request: Union[ExecutionRequest, dict]) -> List[ExecutionResult]:
|
||||
"""Executes Python or shell code securely in a persistent virtual environment with caching and real-time logging.
|
||||
async def execute(
|
||||
self, request: Union[ExecutionRequest, dict]
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Run the snippets in *request* and return their results.
|
||||
|
||||
Args:
|
||||
request (Union[ExecutionRequest, dict]): The execution request containing code snippets.
|
||||
request: ``ExecutionRequest`` instance or a raw mapping that can
|
||||
be unpacked into one.
|
||||
|
||||
Returns:
|
||||
List[ExecutionResult]: A list of execution results for each snippet.
|
||||
A list with one ``ExecutionResult`` for every snippet in the
|
||||
original request.
|
||||
"""
|
||||
if isinstance(request, dict):
|
||||
request = ExecutionRequest(**request)
|
||||
|
||||
await self._bootstrap_project()
|
||||
self.validate_snippets(request.snippets)
|
||||
results = []
|
||||
|
||||
for snippet in request.snippets:
|
||||
start_time = time.time()
|
||||
|
||||
if snippet.language == "python":
|
||||
required_packages = self._extract_imports(snippet.code)
|
||||
logger.info(f"Packages Required: {required_packages}")
|
||||
venv_path = await self._get_or_create_cached_env(required_packages)
|
||||
|
||||
# Load user-defined functions dynamically in memory
|
||||
function_code = "\n".join(inspect.getsource(f) for f in self.user_functions) if self.user_functions else ""
|
||||
exec_script = f"{function_code}\n{snippet.code}" if function_code else snippet.code
|
||||
|
||||
python_executable = venv_path / "bin" / "python3"
|
||||
command = [str(python_executable), "-c", exec_script]
|
||||
# Resolve sandbox once
|
||||
eff_backend: SandboxType = (
|
||||
detect_backend() if self.sandbox == "auto" else self.sandbox
|
||||
)
|
||||
if eff_backend != "none":
|
||||
logger.info(
|
||||
"Sandbox backend enabled: %s%s",
|
||||
eff_backend,
|
||||
f" (writable: {', '.join(map(str, self.writable_paths))})"
|
||||
if self.writable_paths
|
||||
else "",
|
||||
)
|
||||
else:
|
||||
command = ["sh", "-c", snippet.code]
|
||||
logger.info("Sandbox disabled - running commands directly.")
|
||||
|
||||
logger.info("Executing command")
|
||||
logger.debug(f"{' '.join(command)}")
|
||||
# Main loop
|
||||
results: list[ExecutionResult] = []
|
||||
for snip_idx, snippet in enumerate(request.snippets, start=1):
|
||||
start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Start subprocess execution with explicit timeout
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
close_fds=True
|
||||
# Assemble the *raw* command
|
||||
if snippet.language == "python":
|
||||
env = await self._prepare_python_env(snippet.code)
|
||||
python_bin = env / "bin" / "python3"
|
||||
prelude = "\n".join(inspect.getsource(fn) for fn in self.user_functions)
|
||||
script = f"{prelude}\n{snippet.code}" if prelude else snippet.code
|
||||
raw_cmd: Sequence[str] = [str(python_bin), "-c", script]
|
||||
else:
|
||||
raw_cmd = ["sh", "-c", snippet.code]
|
||||
|
||||
# Wrap for sandbox
|
||||
final_cmd = wrap_command(raw_cmd, eff_backend, self.writable_paths)
|
||||
logger.debug(
|
||||
"Snippet %s - launch command: %s",
|
||||
snip_idx,
|
||||
" ".join(final_cmd),
|
||||
)
|
||||
|
||||
# Wait for completion with timeout enforcement
|
||||
stdout_output, stderr_output = await asyncio.wait_for(process.communicate(), timeout=request.timeout)
|
||||
# Run it
|
||||
snip_timeout = getattr(snippet, "timeout", request.timeout)
|
||||
results.append(await self._run_subprocess(final_cmd, snip_timeout))
|
||||
|
||||
status = "success" if process.returncode == 0 else "error"
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
logger.info(f"Execution completed in {execution_time:.2f} seconds.")
|
||||
if stderr_output:
|
||||
logger.error(f"STDERR: {stderr_output.decode()}")
|
||||
|
||||
results.append(ExecutionResult(
|
||||
status=status,
|
||||
output=stdout_output.decode(),
|
||||
exit_code=process.returncode
|
||||
))
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
process.terminate() # Ensure subprocess is killed if it times out
|
||||
results.append(ExecutionResult(status="error", output="Execution timed out", exit_code=1))
|
||||
except Exception as e:
|
||||
results.append(ExecutionResult(status="error", output=str(e), exit_code=1))
|
||||
logger.info(
|
||||
"Snippet %s finished in %.3fs",
|
||||
snip_idx,
|
||||
time.perf_counter() - start,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _extract_imports(self, code: str) -> List[str]:
|
||||
"""Parses a Python script and extracts top-level module imports.
|
||||
async def _bootstrap_project(self) -> None:
|
||||
"""Install top-level dependencies once per executor instance."""
|
||||
cwd = Path.cwd().resolve()
|
||||
if self._bootstrapped_root == cwd:
|
||||
return
|
||||
|
||||
install_cmd = get_install_command(str(cwd))
|
||||
if install_cmd:
|
||||
logger.info(
|
||||
"bootstrapping %s project with '%s'",
|
||||
get_project_type(str(cwd)).value,
|
||||
install_cmd,
|
||||
)
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
install_cmd,
|
||||
cwd=cwd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, err = await proc.communicate()
|
||||
if proc.returncode:
|
||||
logger.warning(
|
||||
"bootstrap failed (%d): %s", proc.returncode, err.decode().strip()
|
||||
)
|
||||
|
||||
self._bootstrapped_root = cwd
|
||||
|
||||
async def _prepare_python_env(self, code: str) -> Path:
|
||||
"""
|
||||
Ensure a virtual-env exists that satisfies *code* imports.
|
||||
|
||||
Args:
|
||||
code (str): The Python code snippet to analyze.
|
||||
code: User-supplied Python source.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of imported module names found in the code.
|
||||
Path to the virtual-env directory.
|
||||
"""
|
||||
imports = self._extract_imports(code)
|
||||
env = await self._get_or_create_cached_env(imports)
|
||||
missing = await self._get_missing_packages(imports, env)
|
||||
if missing:
|
||||
await self._install_missing_packages(missing, env)
|
||||
return env
|
||||
|
||||
@staticmethod
|
||||
def _extract_imports(code: str) -> List[str]:
|
||||
"""
|
||||
Return all top-level imported module names in *code*.
|
||||
|
||||
Args:
|
||||
code: Python source to scan.
|
||||
|
||||
Returns:
|
||||
Unique list of first-segment module names.
|
||||
|
||||
Raises:
|
||||
SyntaxError: If the code has invalid syntax and cannot be parsed.
|
||||
SyntaxError: If *code* cannot be parsed.
|
||||
"""
|
||||
try:
|
||||
parsed_code = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
logger.error(f"Syntax error while parsing code: {e}")
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError:
|
||||
logger.error("cannot parse user code, assuming no imports")
|
||||
return []
|
||||
|
||||
modules = set()
|
||||
for node in ast.walk(parsed_code):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
modules.add(alias.name.split('.')[0]) # Get the top-level package
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
modules.add(node.module.split('.')[0])
|
||||
names = {
|
||||
alias.name.partition(".")[0]
|
||||
for node in ast.walk(tree)
|
||||
for alias in getattr(node, "names", [])
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom))
|
||||
}
|
||||
if any(
|
||||
isinstance(node, ast.ImportFrom) and node.module for node in ast.walk(tree)
|
||||
):
|
||||
names |= {
|
||||
node.module.partition(".")[0]
|
||||
for node in ast.walk(tree)
|
||||
if isinstance(node, ast.ImportFrom) and node.module
|
||||
}
|
||||
return sorted(names)
|
||||
|
||||
return list(modules)
|
||||
|
||||
async def _get_missing_packages(self, packages: List[str], env_path: Path) -> List[str]:
|
||||
"""Determines which packages are missing inside a given virtual environment.
|
||||
|
||||
Args:
|
||||
packages (List[str]): A list of package names to check.
|
||||
env_path (Path): Path to the virtual environment.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of packages that are missing from the virtual environment.
|
||||
async def _get_missing_packages(
|
||||
self, packages: List[str], env_path: Path
|
||||
) -> List[str]:
|
||||
"""
|
||||
python_bin = env_path / "bin" / "python3"
|
||||
|
||||
async def check_package(pkg):
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
str(python_bin), "-c", f"import {pkg}",
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.DEVNULL
|
||||
)
|
||||
await process.wait()
|
||||
return pkg if process.returncode != 0 else None # Return package name if missing
|
||||
|
||||
tasks = [check_package(pkg) for pkg in packages]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [pkg for pkg in results if pkg] # Filter out installed packages
|
||||
|
||||
|
||||
async def _get_or_create_cached_env(self, dependencies: List[str]) -> Path:
|
||||
"""Creates or retrieves a cached virtual environment based on dependencies.
|
||||
|
||||
This function checks if a suitable cached virtual environment exists.
|
||||
If it does not, it creates a new one and installs missing dependencies.
|
||||
Identify which *packages* are not importable from *env_path*.
|
||||
|
||||
Args:
|
||||
dependencies (List[str]): List of required package names.
|
||||
packages: Candidate import names.
|
||||
env_path: Path to the virtual-env.
|
||||
|
||||
Returns:
|
||||
Path: Path to the virtual environment directory.
|
||||
Subset of *packages* that need installation.
|
||||
"""
|
||||
python = env_path / "bin" / "python3"
|
||||
|
||||
async def probe(pkg: str) -> str | None:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(python),
|
||||
"- <<PY\nimport importlib.util, sys;"
|
||||
f"sys.exit(importlib.util.find_spec('{pkg}') is None)\nPY",
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
)
|
||||
await proc.wait()
|
||||
return pkg if proc.returncode else None
|
||||
|
||||
missing = await asyncio.gather(*(probe(p) for p in packages))
|
||||
return [m for m in missing if m]
|
||||
|
||||
async def _get_or_create_cached_env(self, deps: List[str]) -> Path:
|
||||
"""
|
||||
Return a cached venv path keyed by the sorted list *deps*.
|
||||
|
||||
Args:
|
||||
deps: Import names required by user code.
|
||||
|
||||
Returns:
|
||||
Path to the virtual-env directory.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If virtual environment creation or package installation fails.
|
||||
RuntimeError: If venv creation fails.
|
||||
"""
|
||||
async with self._env_lock:
|
||||
env_hash = hashlib.md5(",".join(sorted(dependencies)).encode()).hexdigest()
|
||||
env_path = self.cache_dir / f"env_{env_hash}"
|
||||
digest = hashlib.sha1(",".join(sorted(deps)).encode()).hexdigest()
|
||||
env_path = self.cache_dir / f"env_{digest}"
|
||||
|
||||
async with self._env_lock:
|
||||
if env_path.exists():
|
||||
logger.info("Reusing cached virtual environment.")
|
||||
else:
|
||||
logger.info("Setting up a new virtual environment.")
|
||||
try:
|
||||
venv.create(str(env_path), with_pip=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create virtual environment: {e}")
|
||||
raise RuntimeError(f"Virtual environment creation failed: {e}")
|
||||
|
||||
# Identify missing packages
|
||||
missing_packages = await self._get_missing_packages(dependencies, env_path)
|
||||
|
||||
if missing_packages:
|
||||
await self._install_missing_packages(missing_packages, env_path)
|
||||
|
||||
venv.create(env_path, with_pip=True)
|
||||
logger.info("Created a new virtual environment")
|
||||
logger.debug("venv %s created", env_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise RuntimeError("virtual-env creation failed") from exc
|
||||
return env_path
|
||||
|
||||
|
||||
async def _install_missing_packages(self, packages: List[str], env_dir: Path):
|
||||
"""Installs missing Python packages inside the virtual environment.
|
||||
async def _install_missing_packages(
|
||||
self, packages: List[str], env_dir: Path
|
||||
) -> None:
|
||||
"""
|
||||
``pip install`` *packages* inside *env_dir*.
|
||||
|
||||
Args:
|
||||
packages (List[str]): A list of package names to install.
|
||||
env_dir (Path): Path to the virtual environment where packages should be installed.
|
||||
packages: Package names to install.
|
||||
env_dir: Target virtual-env directory.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the package installation process fails.
|
||||
RuntimeError: If installation returns non-zero exit code.
|
||||
"""
|
||||
if not packages:
|
||||
return
|
||||
python = env_dir / "bin" / "python3"
|
||||
cmd = [str(python), "-m", "pip", "install", *packages]
|
||||
logger.info("Installing %s", ", ".join(packages))
|
||||
|
||||
python_bin = env_dir / "bin" / "python3"
|
||||
command = [str(python_bin), "-m", "pip", "install", *packages]
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=asyncio.subprocess.DEVNULL, # Suppresses stdout since it's not used
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
close_fds=True
|
||||
)
|
||||
_, stderr = await process.communicate() # Capture only stderr
|
||||
_, err = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
msg = err.decode().strip()
|
||||
logger.error("pip install failed: %s", msg)
|
||||
raise RuntimeError(msg)
|
||||
logger.debug("Installed %d package(s)", len(packages))
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = stderr.decode().strip()
|
||||
logger.error(f"Package installation failed: {error_msg}")
|
||||
raise RuntimeError(f"Package installation failed: {error_msg}")
|
||||
async def _run_subprocess(
|
||||
self, cmd: Sequence[str], timeout: int
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Run *cmd* with *timeout* seconds.
|
||||
|
||||
logger.info(f"Installed dependencies: {', '.join(packages)}")
|
||||
Args:
|
||||
cmd: Command list to execute.
|
||||
timeout: Maximum runtime in seconds.
|
||||
|
||||
Returns:
|
||||
``ExecutionResult`` with captured output.
|
||||
"""
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
out, err = await asyncio.wait_for(proc.communicate(), timeout)
|
||||
status = "success" if proc.returncode == 0 else "error"
|
||||
if err:
|
||||
logger.debug("stderr: %s", err.decode().strip())
|
||||
return ExecutionResult(
|
||||
status=status, output=out.decode(), exit_code=proc.returncode
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
return ExecutionResult(
|
||||
status="error", output="execution timed out", exit_code=1
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
return ExecutionResult(status="error", output=str(exc), exit_code=1)
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
"""Light-weight cross-platform sandbox helpers."""
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
SandboxType = Literal["none", "seatbelt", "firejail", "auto"]
|
||||
|
||||
_READ_ONLY_SEATBELT_POLICY = r"""
|
||||
(version 1)
|
||||
|
||||
; ---------------- default = deny everything -----------------
|
||||
(deny default)
|
||||
|
||||
; ---------------- read-only FS access -----------------------
|
||||
(allow file-read*)
|
||||
|
||||
; ---------------- minimal process mgmt ----------------------
|
||||
(allow process-exec)
|
||||
(allow process-fork)
|
||||
(allow signal (target self))
|
||||
|
||||
; ---------------- write-only to /dev/null -------------------
|
||||
(allow file-write-data
|
||||
(require-all
|
||||
(path "/dev/null")
|
||||
(vnode-type CHARACTER-DEVICE)))
|
||||
|
||||
; ---------------- harmless sysctls --------------------------
|
||||
(allow sysctl-read
|
||||
(sysctl-name "hw.activecpu")
|
||||
(sysctl-name "hw.busfrequency_compat")
|
||||
(sysctl-name "hw.byteorder")
|
||||
(sysctl-name "hw.cacheconfig")
|
||||
(sysctl-name "hw.cachelinesize_compat")
|
||||
(sysctl-name "hw.cpufamily")
|
||||
(sysctl-name "hw.cpufrequency_compat")
|
||||
(sysctl-name "hw.cputype")
|
||||
(sysctl-name "hw.l1dcachesize_compat")
|
||||
(sysctl-name "hw.l1icachesize_compat")
|
||||
(sysctl-name "hw.l2cachesize_compat")
|
||||
(sysctl-name "hw.l3cachesize_compat")
|
||||
(sysctl-name "hw.logicalcpu_max")
|
||||
(sysctl-name "hw.machine")
|
||||
(sysctl-name "hw.ncpu")
|
||||
(sysctl-name "hw.nperflevels")
|
||||
(sysctl-name "hw.memsize")
|
||||
(sysctl-name "hw.pagesize")
|
||||
(sysctl-name "hw.packages")
|
||||
(sysctl-name "hw.physicalcpu_max")
|
||||
(sysctl-name "kern.hostname")
|
||||
(sysctl-name "kern.osrelease")
|
||||
(sysctl-name "kern.ostype")
|
||||
(sysctl-name "kern.osversion")
|
||||
(sysctl-name "kern.version")
|
||||
(sysctl-name-prefix "hw.perflevel")
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def detect_backend() -> SandboxType: # noqa: D401
|
||||
"""Return the best-effort sandbox backend for the current host."""
|
||||
system = platform.system()
|
||||
if system == "Darwin" and shutil.which("sandbox-exec"):
|
||||
return "seatbelt"
|
||||
if system == "Linux" and shutil.which("firejail"):
|
||||
return "firejail"
|
||||
return "none"
|
||||
|
||||
|
||||
def _seatbelt_cmd(cmd: Sequence[str], writable_paths: List[Path]) -> List[str]:
|
||||
"""
|
||||
Construct a **macOS seatbelt** command line.
|
||||
|
||||
The resulting list can be passed directly to `asyncio.create_subprocess_exec`.
|
||||
It launches the target *cmd* under **sandbox-exec** with an
|
||||
*initially-read-only* profile; every directory in *writable_paths* is added
|
||||
as an explicit “write-allowed sub-path”.
|
||||
|
||||
Args:
|
||||
cmd:
|
||||
The *raw* command (program + args) that should run inside the sandbox.
|
||||
writable_paths:
|
||||
Absolute paths that the child process must be able to modify
|
||||
(e.g. a temporary working directory).
|
||||
Each entry becomes a param `-D WR<i>=<path>` and a corresponding
|
||||
``file-write*`` rule in the generated profile.
|
||||
|
||||
Returns:
|
||||
list[str]
|
||||
A fully-assembled ``sandbox-exec`` invocation:
|
||||
``['sandbox-exec', '-p', <profile>, …, '--', *cmd]``.
|
||||
"""
|
||||
policy = _READ_ONLY_SEATBELT_POLICY
|
||||
params: list[str] = []
|
||||
|
||||
if writable_paths:
|
||||
# Build parameter substitutions and the matching `(allow file-write*)` stanza.
|
||||
write_terms: list[str] = []
|
||||
for idx, path in enumerate(writable_paths):
|
||||
param = f"WR{idx}"
|
||||
params.extend(["-D", f"{param}={path}"])
|
||||
write_terms.append(f'(subpath (param "{param}"))')
|
||||
|
||||
policy += f"\n(allow file-write*\n {' '.join(write_terms)}\n)"
|
||||
|
||||
return [
|
||||
"sandbox-exec",
|
||||
"-p",
|
||||
policy,
|
||||
*params,
|
||||
"--",
|
||||
*cmd,
|
||||
]
|
||||
|
||||
|
||||
def _firejail_cmd(cmd: Sequence[str], writable_paths: List[Path]) -> List[str]:
|
||||
"""
|
||||
Build a **Firejail** command line (Linux only).
|
||||
|
||||
The wrapper enables seccomp, disables sound and networking, and whitelists
|
||||
the provided *writable_paths* so the child process can persist data there.
|
||||
|
||||
Args:
|
||||
cmd:
|
||||
The command (program + args) to execute.
|
||||
writable_paths:
|
||||
Directories that must remain writable inside the Firejail sandbox.
|
||||
|
||||
Returns:
|
||||
list[str]
|
||||
A Firejail-prefixed command suitable for
|
||||
``asyncio.create_subprocess_exec``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
If *writable_paths* contains non-absolute paths.
|
||||
"""
|
||||
for p in writable_paths:
|
||||
if not p.is_absolute():
|
||||
raise ValueError(f"Firejail whitelist paths must be absolute: {p}")
|
||||
|
||||
rw_flags = sum([["--whitelist", str(p)] for p in writable_paths], [])
|
||||
return [
|
||||
"firejail",
|
||||
"--quiet", # suppress banner
|
||||
"--seccomp", # enable seccomp filter
|
||||
"--nosound",
|
||||
"--net=none",
|
||||
*rw_flags,
|
||||
"--",
|
||||
*cmd,
|
||||
]
|
||||
|
||||
|
||||
def wrap_command(
|
||||
cmd: Sequence[str],
|
||||
backend: SandboxType,
|
||||
writable_paths: List[Path] | None = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Produce a sandbox-wrapped command according to *backend*.
|
||||
|
||||
This is the single public helper used by the executors: it hides the
|
||||
platform-specific details of **seatbelt** and **Firejail** while providing
|
||||
a graceful fallback to “no sandbox”.
|
||||
|
||||
Args:
|
||||
cmd:
|
||||
The raw command (program + args) to execute.
|
||||
backend:
|
||||
One of ``'seatbelt'``, ``'firejail'``, ``'none'`` or ``'auto'``.
|
||||
When ``'auto'`` is supplied the caller should already have resolved the
|
||||
platform with :func:`detect_backend`; the value is treated as ``'none'``.
|
||||
writable_paths:
|
||||
Extra directories that must remain writable inside the sandbox.
|
||||
Ignored when *backend* is ``'none'`` / ``'auto'``.
|
||||
|
||||
Returns:
|
||||
list[str]
|
||||
The command list ready for ``asyncio.create_subprocess_exec``.
|
||||
If sandboxing is disabled, this is simply ``list(cmd)``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
If an unrecognised *backend* value is given.
|
||||
"""
|
||||
if backend in ("none", "auto"):
|
||||
return list(cmd)
|
||||
|
||||
writable_paths = writable_paths or []
|
||||
|
||||
if backend == "seatbelt":
|
||||
return _seatbelt_cmd(cmd, writable_paths)
|
||||
|
||||
if backend == "firejail":
|
||||
return _firejail_cmd(cmd, writable_paths)
|
||||
|
||||
raise ValueError(f"Unknown sandbox backend: {backend!r}")
|
|
@ -0,0 +1,309 @@
|
|||
from __future__ import annotations
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set
|
||||
from functools import lru_cache
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PackageManagerType(str, Enum):
|
||||
"""Types of package managers that can be detected."""
|
||||
|
||||
PIP = "pip"
|
||||
POETRY = "poetry"
|
||||
PIPENV = "pipenv"
|
||||
CONDA = "conda"
|
||||
NPM = "npm"
|
||||
YARN = "yarn"
|
||||
PNPM = "pnpm"
|
||||
BUN = "bun"
|
||||
CARGO = "cargo"
|
||||
GO = "go"
|
||||
MAVEN = "maven"
|
||||
GRADLE = "gradle"
|
||||
COMPOSER = "composer"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ProjectType(str, Enum):
|
||||
"""Types of projects that can be detected."""
|
||||
|
||||
PYTHON = "python"
|
||||
NODE = "node"
|
||||
RUST = "rust"
|
||||
GO = "go"
|
||||
JAVA = "java"
|
||||
PHP = "php"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class PackageManager:
|
||||
"""Information about a package manager and its commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: PackageManagerType,
|
||||
project_type: ProjectType,
|
||||
install_cmd: str,
|
||||
add_cmd: Optional[str] = None,
|
||||
remove_cmd: Optional[str] = None,
|
||||
update_cmd: Optional[str] = None,
|
||||
markers: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a package manager.
|
||||
|
||||
Args:
|
||||
name: Package manager identifier.
|
||||
project_type: Type of project this manager serves.
|
||||
install_cmd: Command to install project dependencies.
|
||||
add_cmd: Command to add a single package.
|
||||
remove_cmd: Command to remove a package.
|
||||
update_cmd: Command to update packages.
|
||||
markers: Filenames indicating this manager in a project.
|
||||
"""
|
||||
self.name = name
|
||||
self.project_type = project_type
|
||||
self.install_cmd = install_cmd
|
||||
self.add_cmd = add_cmd or f"{name.value} install"
|
||||
self.remove_cmd = remove_cmd or f"{name.value} remove"
|
||||
self.update_cmd = update_cmd or f"{name.value} update"
|
||||
self.markers = markers or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name.value
|
||||
|
||||
|
||||
# Known package managers
|
||||
PACKAGE_MANAGERS: dict[PackageManagerType, PackageManager] = {
|
||||
# Python package managers
|
||||
PackageManagerType.PIP: PackageManager(
|
||||
name=PackageManagerType.PIP,
|
||||
project_type=ProjectType.PYTHON,
|
||||
install_cmd="pip install -r requirements.txt",
|
||||
add_cmd="pip install",
|
||||
remove_cmd="pip uninstall",
|
||||
update_cmd="pip install --upgrade",
|
||||
markers=["requirements.txt", "setup.py", "setup.cfg"],
|
||||
),
|
||||
PackageManagerType.POETRY: PackageManager(
|
||||
name=PackageManagerType.POETRY,
|
||||
project_type=ProjectType.PYTHON,
|
||||
install_cmd="poetry install",
|
||||
add_cmd="poetry add",
|
||||
remove_cmd="poetry remove",
|
||||
update_cmd="poetry update",
|
||||
markers=["pyproject.toml", "poetry.lock"],
|
||||
),
|
||||
PackageManagerType.PIPENV: PackageManager(
|
||||
name=PackageManagerType.PIPENV,
|
||||
project_type=ProjectType.PYTHON,
|
||||
install_cmd="pipenv install",
|
||||
add_cmd="pipenv install",
|
||||
remove_cmd="pipenv uninstall",
|
||||
update_cmd="pipenv update",
|
||||
markers=["Pipfile", "Pipfile.lock"],
|
||||
),
|
||||
PackageManagerType.CONDA: PackageManager(
|
||||
name=PackageManagerType.CONDA,
|
||||
project_type=ProjectType.PYTHON,
|
||||
install_cmd="conda env update -f environment.yml",
|
||||
add_cmd="conda install",
|
||||
remove_cmd="conda remove",
|
||||
update_cmd="conda update",
|
||||
markers=["environment.yml", "environment.yaml"],
|
||||
),
|
||||
# JavaScript package managers
|
||||
PackageManagerType.NPM: PackageManager(
|
||||
name=PackageManagerType.NPM,
|
||||
project_type=ProjectType.NODE,
|
||||
install_cmd="npm install",
|
||||
add_cmd="npm install",
|
||||
remove_cmd="npm uninstall",
|
||||
update_cmd="npm update",
|
||||
markers=["package.json", "package-lock.json"],
|
||||
),
|
||||
PackageManagerType.YARN: PackageManager(
|
||||
name=PackageManagerType.YARN,
|
||||
project_type=ProjectType.NODE,
|
||||
install_cmd="yarn install",
|
||||
add_cmd="yarn add",
|
||||
remove_cmd="yarn remove",
|
||||
update_cmd="yarn upgrade",
|
||||
markers=["package.json", "yarn.lock"],
|
||||
),
|
||||
PackageManagerType.PNPM: PackageManager(
|
||||
name=PackageManagerType.PNPM,
|
||||
project_type=ProjectType.NODE,
|
||||
install_cmd="pnpm install",
|
||||
add_cmd="pnpm add",
|
||||
remove_cmd="pnpm remove",
|
||||
update_cmd="pnpm update",
|
||||
markers=["package.json", "pnpm-lock.yaml"],
|
||||
),
|
||||
PackageManagerType.BUN: PackageManager(
|
||||
name=PackageManagerType.BUN,
|
||||
project_type=ProjectType.NODE,
|
||||
install_cmd="bun install",
|
||||
add_cmd="bun add",
|
||||
remove_cmd="bun remove",
|
||||
update_cmd="bun update",
|
||||
markers=["package.json", "bun.lockb"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_installed(name: str) -> bool:
|
||||
"""
|
||||
Check if a given command exists on PATH.
|
||||
|
||||
Args:
|
||||
name: Command name to check.
|
||||
|
||||
Returns:
|
||||
True if the command is available, False otherwise.
|
||||
"""
|
||||
return shutil.which(name) is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def detect_package_managers(directory: str) -> List[PackageManager]:
|
||||
"""
|
||||
Detect all installed package managers by looking for marker files.
|
||||
|
||||
Args:
|
||||
directory: Path to the project root.
|
||||
|
||||
Returns:
|
||||
A list of PackageManager instances found in the directory.
|
||||
"""
|
||||
dir_path = Path(directory)
|
||||
if not dir_path.is_dir():
|
||||
return []
|
||||
|
||||
found: List[PackageManager] = []
|
||||
for pm in PACKAGE_MANAGERS.values():
|
||||
for marker in pm.markers:
|
||||
if (dir_path / marker).exists() and is_installed(pm.name.value):
|
||||
found.append(pm)
|
||||
break
|
||||
return found
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_primary_package_manager(directory: str) -> Optional[PackageManager]:
|
||||
"""
|
||||
Determine the primary package manager using lockfile heuristics.
|
||||
|
||||
Args:
|
||||
directory: Path to the project root.
|
||||
|
||||
Returns:
|
||||
The chosen PackageManager or None if none detected.
|
||||
"""
|
||||
managers = detect_package_managers(directory)
|
||||
if not managers:
|
||||
return None
|
||||
if len(managers) == 1:
|
||||
return managers[0]
|
||||
|
||||
dir_path = Path(directory)
|
||||
# Prefer lockfiles over others
|
||||
lock_priority = [
|
||||
(PackageManagerType.POETRY, "poetry.lock"),
|
||||
(PackageManagerType.PIPENV, "Pipfile.lock"),
|
||||
(PackageManagerType.PNPM, "pnpm-lock.yaml"),
|
||||
(PackageManagerType.YARN, "yarn.lock"),
|
||||
(PackageManagerType.BUN, "bun.lockb"),
|
||||
(PackageManagerType.NPM, "package-lock.json"),
|
||||
]
|
||||
for pm_type, lock in lock_priority:
|
||||
if (dir_path / lock).exists() and PACKAGE_MANAGERS.get(pm_type) in managers:
|
||||
return PACKAGE_MANAGERS[pm_type]
|
||||
|
||||
return managers[0]
|
||||
|
||||
|
||||
def get_install_command(directory: str) -> Optional[str]:
|
||||
"""
|
||||
Get the shell command to install project dependencies.
|
||||
|
||||
Args:
|
||||
directory: Path to the project root.
|
||||
|
||||
Returns:
|
||||
A shell command string or None if no manager detected.
|
||||
"""
|
||||
pm = get_primary_package_manager(directory)
|
||||
return pm.install_cmd if pm else None
|
||||
|
||||
|
||||
def get_add_command(directory: str, package: str, dev: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Get the shell command to add a package to the project.
|
||||
|
||||
Args:
|
||||
directory: Path to the project root.
|
||||
package: Package name to add.
|
||||
dev: Whether to add as a development dependency.
|
||||
|
||||
Returns:
|
||||
A shell command string or None if no manager detected.
|
||||
"""
|
||||
pm = get_primary_package_manager(directory)
|
||||
if not pm:
|
||||
return None
|
||||
base = pm.add_cmd
|
||||
if dev and pm.name in {
|
||||
PackageManagerType.PIP,
|
||||
PackageManagerType.POETRY,
|
||||
PackageManagerType.NPM,
|
||||
PackageManagerType.YARN,
|
||||
PackageManagerType.PNPM,
|
||||
PackageManagerType.BUN,
|
||||
PackageManagerType.COMPOSER,
|
||||
}:
|
||||
flag = (
|
||||
"--dev"
|
||||
if pm.name in {PackageManagerType.PIP, PackageManagerType.POETRY}
|
||||
else "--save-dev"
|
||||
)
|
||||
return f"{base} {package} {flag}"
|
||||
return f"{base} {package}"
|
||||
|
||||
|
||||
def get_project_type(directory: str) -> ProjectType:
|
||||
"""
|
||||
Infer project type from the primary package manager or file extensions.
|
||||
|
||||
Args:
|
||||
directory: Path to the project root.
|
||||
|
||||
Returns:
|
||||
The detected ProjectType.
|
||||
"""
|
||||
pm = get_primary_package_manager(directory)
|
||||
if pm:
|
||||
return pm.project_type
|
||||
|
||||
# Fallback by extension scanning
|
||||
exts: Set[str] = set()
|
||||
for path in Path(directory).rglob("*"):
|
||||
if path.is_file():
|
||||
exts.add(path.suffix.lower())
|
||||
if len(exts) > 50:
|
||||
break
|
||||
if ".py" in exts:
|
||||
return ProjectType.PYTHON
|
||||
if {".js", ".ts"} & exts:
|
||||
return ProjectType.NODE
|
||||
if ".rs" in exts:
|
||||
return ProjectType.RUST
|
||||
if ".go" in exts:
|
||||
return ProjectType.GO
|
||||
if ".java" in exts:
|
||||
return ProjectType.JAVA
|
||||
if ".php" in exts:
|
||||
return ProjectType.PHP
|
||||
return ProjectType.UNKNOWN
|
|
@ -2,10 +2,12 @@ from pydantic import BaseModel, PrivateAttr
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LLMClientBase(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for LLM models.
|
||||
"""
|
||||
|
||||
# Private attributes for provider and api
|
||||
_provider: str = PrivateAttr()
|
||||
_api: str = PrivateAttr()
|
||||
|
|
|
@ -5,17 +5,27 @@ from pydantic import BaseModel, Field
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ChatClientBase(BaseModel, ABC):
|
||||
"""
|
||||
Base class for chat-specific functionality.
|
||||
Handles Prompty integration and provides abstract methods for chat client configuration.
|
||||
"""
|
||||
prompty: Optional[Prompty] = Field(default=None, description="Instance of the Prompty object (optional).")
|
||||
prompt_template: Optional[PromptTemplateBase] = Field(default=None, description="Prompt template for rendering (optional).")
|
||||
|
||||
prompty: Optional[Prompty] = Field(
|
||||
default=None, description="Instance of the Prompty object (optional)."
|
||||
)
|
||||
prompt_template: Optional[PromptTemplateBase] = Field(
|
||||
default=None, description="Prompt template for rendering (optional)."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'ChatClientBase':
|
||||
def from_prompty(
|
||||
cls,
|
||||
prompty_source: Union[str, Path],
|
||||
timeout: Union[int, float, Dict[str, Any]] = 1500,
|
||||
) -> "ChatClientBase":
|
||||
"""
|
||||
Abstract method to load a Prompty source and configure the chat client.
|
||||
|
||||
|
@ -31,13 +41,15 @@ class ChatClientBase(BaseModel, ABC):
|
|||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
messages: Union[str, Dict[str, Any], BaseModel, Iterable[Union[Dict[str, Any], BaseModel]]] = None,
|
||||
messages: Union[
|
||||
str, Dict[str, Any], BaseModel, Iterable[Union[Dict[str, Any], BaseModel]]
|
||||
] = None,
|
||||
input_data: Optional[Dict[str, Any]] = None,
|
||||
model: Optional[str] = None,
|
||||
tools: Optional[List[Union[Dict[str, Any]]]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
structured_mode: Optional[str] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Abstract method to generate chat completions.
|
||||
|
|
|
@ -5,7 +5,18 @@ from dapr_agents.types.message import BaseMessage
|
|||
from dapr_agents.llm.chat import ChatClientBase
|
||||
from dapr_agents.tool import AgentTool
|
||||
from dapr.clients.grpc._request import ConversationInput
|
||||
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type, Literal, ClassVar
|
||||
from typing import (
|
||||
Union,
|
||||
Optional,
|
||||
Iterable,
|
||||
Dict,
|
||||
Any,
|
||||
List,
|
||||
Iterator,
|
||||
Type,
|
||||
Literal,
|
||||
ClassVar,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
@ -14,6 +25,7 @@ import time
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
||||
"""
|
||||
Concrete class for Dapr's chat completion API using the Inference API.
|
||||
|
@ -28,12 +40,16 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
"""
|
||||
# Set the private provider and api attributes
|
||||
self._api = "chat"
|
||||
self._llm_component = os.environ['DAPR_LLM_COMPONENT_DEFAULT']
|
||||
self._llm_component = os.environ["DAPR_LLM_COMPONENT_DEFAULT"]
|
||||
|
||||
return super().model_post_init(__context)
|
||||
|
||||
@classmethod
|
||||
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'DaprChatClient':
|
||||
def from_prompty(
|
||||
cls,
|
||||
prompty_source: Union[str, Path],
|
||||
timeout: Union[int, float, Dict[str, Any]] = 1500,
|
||||
) -> "DaprChatClient":
|
||||
"""
|
||||
Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content.
|
||||
|
||||
|
@ -52,11 +68,13 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
prompt_template = Prompty.to_prompt_template(prompty_instance)
|
||||
|
||||
# Initialize the DaprChatClient based on the Prompty model configuration
|
||||
return cls.model_validate({
|
||||
'timeout': timeout,
|
||||
'prompty': prompty_instance,
|
||||
'prompt_template': prompt_template,
|
||||
})
|
||||
return cls.model_validate(
|
||||
{
|
||||
"timeout": timeout,
|
||||
"prompty": prompty_instance,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
)
|
||||
|
||||
def translate_response(self, response: dict, model: str) -> dict:
|
||||
"""Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion."""
|
||||
|
@ -64,11 +82,8 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
{
|
||||
"finish_reason": "stop",
|
||||
"index": i,
|
||||
"message": {
|
||||
"content": output["result"],
|
||||
"role": "assistant"
|
||||
},
|
||||
"logprobs": None
|
||||
"message": {"content": output["result"], "role": "assistant"},
|
||||
"logprobs": None,
|
||||
}
|
||||
for i, output in enumerate(response.get("outputs", []))
|
||||
]
|
||||
|
@ -78,22 +93,29 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"object": "chat.completion",
|
||||
"usage": {"total_tokens": "-1"}
|
||||
"usage": {"total_tokens": "-1"},
|
||||
}
|
||||
|
||||
def convert_to_conversation_inputs(self, inputs: List[Dict[str, Any]]) -> List[ConversationInput]:
|
||||
def convert_to_conversation_inputs(
|
||||
self, inputs: List[Dict[str, Any]]
|
||||
) -> List[ConversationInput]:
|
||||
return [
|
||||
ConversationInput(
|
||||
content=item["content"],
|
||||
role=item.get("role"),
|
||||
scrub_pii=item.get("scrubPII") == "true"
|
||||
scrub_pii=item.get("scrubPII") == "true",
|
||||
)
|
||||
for item in inputs
|
||||
]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None,
|
||||
messages: Union[
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseMessage,
|
||||
Iterable[Union[Dict[str, Any], BaseMessage]],
|
||||
] = None,
|
||||
input_data: Optional[Dict[str, Any]] = None,
|
||||
llm_component: Optional[str] = None,
|
||||
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
|
||||
|
@ -101,7 +123,7 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
structured_mode: Literal["function_call"] = "function_call",
|
||||
scrubPII: Optional[bool] = False,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Generate chat completions based on provided messages or input_data for prompt templates.
|
||||
|
@ -120,12 +142,16 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
|
||||
"""
|
||||
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
|
||||
raise ValueError(f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}.")
|
||||
raise ValueError(
|
||||
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
|
||||
)
|
||||
|
||||
# If input_data is provided, check for a prompt_template
|
||||
if input_data:
|
||||
if not self.prompt_template:
|
||||
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.")
|
||||
raise ValueError(
|
||||
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
|
||||
)
|
||||
|
||||
logger.info("Using prompt template to generate messages.")
|
||||
messages = self.prompt_template.format_prompt(**input_data)
|
||||
|
@ -135,7 +161,7 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
raise ValueError("Either 'messages' or 'input_data' must be provided.")
|
||||
|
||||
# Process and normalize the messages
|
||||
params = {'inputs': RequestHandler.normalize_chat_messages(messages)}
|
||||
params = {"inputs": RequestHandler.normalize_chat_messages(messages)}
|
||||
# Merge Prompty parameters if available, then override with any explicit kwargs
|
||||
if self.prompty:
|
||||
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
|
||||
|
@ -148,13 +174,18 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
tools=tools,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode
|
||||
structured_mode=structured_mode,
|
||||
)
|
||||
inputs = self.convert_to_conversation_inputs(params['inputs'])
|
||||
inputs = self.convert_to_conversation_inputs(params["inputs"])
|
||||
|
||||
try:
|
||||
logger.info("Invoking the Dapr Conversation API.")
|
||||
response = self.client.chat_completion(llm=llm_component or self._llm_component, conversation_inputs=inputs, scrub_pii=scrubPII, temperature=temperature)
|
||||
response = self.client.chat_completion(
|
||||
llm=llm_component or self._llm_component,
|
||||
conversation_inputs=inputs,
|
||||
scrub_pii=scrubPII,
|
||||
temperature=temperature,
|
||||
)
|
||||
transposed_response = self.translate_response(response, self._llm_component)
|
||||
logger.info("Chat completion retrieved successfully.")
|
||||
|
||||
|
@ -163,8 +194,10 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode,
|
||||
stream=params.get('stream', False)
|
||||
stream=params.get("stream", False),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during the Dapr Conversation API call: {e}")
|
||||
logger.error(
|
||||
f"An error occurred during the Dapr Conversation API call: {e}"
|
||||
)
|
||||
raise
|
|
@ -10,6 +10,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DaprInferenceClient:
|
||||
def __init__(self):
|
||||
self.dapr_client = DaprClient()
|
||||
|
@ -26,8 +27,19 @@ class DaprInferenceClient:
|
|||
|
||||
return response_dict
|
||||
|
||||
def chat_completion(self, llm: str, conversation_inputs: List[ConversationInput], scrub_pii: bool | None = None, temperature: float | None = None) -> Any:
|
||||
response = self.dapr_client.converse_alpha1(name=llm, inputs=conversation_inputs, scrub_pii=scrub_pii, temperature=temperature)
|
||||
def chat_completion(
|
||||
self,
|
||||
llm: str,
|
||||
conversation_inputs: List[ConversationInput],
|
||||
scrub_pii: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> Any:
|
||||
response = self.dapr_client.converse_alpha1(
|
||||
name=llm,
|
||||
inputs=conversation_inputs,
|
||||
scrub_pii=scrub_pii,
|
||||
temperature=temperature,
|
||||
)
|
||||
output = self.translate_to_json(response)
|
||||
|
||||
return output
|
||||
|
@ -38,6 +50,7 @@ class DaprInferenceClientBase(LLMClientBase):
|
|||
Base class for managing Dapr Inference API clients.
|
||||
Handles client initialization, configuration, and shared logic.
|
||||
"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return values
|
||||
|
@ -66,7 +79,9 @@ class DaprInferenceClientBase(LLMClientBase):
|
|||
return DaprInferenceClient()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, client_options: DaprInferenceClientConfig, timeout: float = 1500):
|
||||
def from_config(
|
||||
cls, client_options: DaprInferenceClientConfig, timeout: float = 1500
|
||||
):
|
||||
"""
|
||||
Initializes the DaprInferenceClientBase using DaprInferenceClientConfig.
|
||||
|
||||
|
|
|
@ -7,14 +7,21 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElevenLabsClientBase(LLMClientBase):
|
||||
"""
|
||||
Base class for managing ElevenLabs LLM clients.
|
||||
Handles client initialization, configuration, and shared logic specific to the ElevenLabs API.
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = Field(default=None, description="API key for authenticating with the ElevenLabs API. Defaults to environment variables 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY'.")
|
||||
base_url: Optional[str] = Field(default="https://api.elevenlabs.io", description="Base URL for the ElevenLabs API endpoints.")
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for authenticating with the ElevenLabs API. Defaults to environment variables 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY'.",
|
||||
)
|
||||
base_url: Optional[str] = Field(
|
||||
default="https://api.elevenlabs.io",
|
||||
description="Base URL for the ElevenLabs API endpoints.",
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
|
@ -26,10 +33,14 @@ class ElevenLabsClientBase(LLMClientBase):
|
|||
|
||||
# Use environment variable if `api_key` is not explicitly provided
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("ELEVENLABS_API_KEY") or os.getenv("ELEVEN_API_KEY")
|
||||
self.api_key = os.getenv("ELEVENLABS_API_KEY") or os.getenv(
|
||||
"ELEVEN_API_KEY"
|
||||
)
|
||||
|
||||
if self.api_key is None:
|
||||
raise ValueError("API key is required. Set it explicitly or in the 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY' environment variable.")
|
||||
raise ValueError(
|
||||
"API key is required. Set it explicitly or in the 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
# Initialize configuration and client
|
||||
self._config = self.get_config()
|
||||
|
@ -42,10 +53,7 @@ class ElevenLabsClientBase(LLMClientBase):
|
|||
"""
|
||||
Returns the configuration object for the ElevenLabs API client.
|
||||
"""
|
||||
return ElevenLabsClientConfig(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
return ElevenLabsClientConfig(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
def get_client(self) -> Any:
|
||||
"""
|
||||
|
@ -63,10 +71,7 @@ class ElevenLabsClientBase(LLMClientBase):
|
|||
config = self.config
|
||||
|
||||
logger.info("Initializing ElevenLabs API client...")
|
||||
return ElevenLabs(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url
|
||||
)
|
||||
return ElevenLabs(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
@property
|
||||
def config(self) -> ElevenLabsClientConfig:
|
||||
|
|
|
@ -5,17 +5,32 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElevenLabsSpeechClient(ElevenLabsClientBase):
|
||||
"""
|
||||
Client for ElevenLabs speech generation functionality.
|
||||
Handles text-to-speech conversions with customizable options.
|
||||
"""
|
||||
|
||||
voice: Optional[Any] = Field(default=None, description="Default voice (ID, name, or object) for speech generation.")
|
||||
model: Optional[str] = Field(default="eleven_multilingual_v2", description="Default model for speech generation.")
|
||||
output_format: Optional[str] = Field(default="mp3_44100_128", description="Default audio output format.")
|
||||
optimize_streaming_latency: Optional[int] = Field(default=0, description="Default latency optimization level (0 means no optimizations).")
|
||||
voice_settings: Optional[Any] = Field(default=None, description="Default voice settings (stability, similarity boost, etc.).")
|
||||
voice: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Default voice (ID, name, or object) for speech generation.",
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
default="eleven_multilingual_v2",
|
||||
description="Default model for speech generation.",
|
||||
)
|
||||
output_format: Optional[str] = Field(
|
||||
default="mp3_44100_128", description="Default audio output format."
|
||||
)
|
||||
optimize_streaming_latency: Optional[int] = Field(
|
||||
default=0,
|
||||
description="Default latency optimization level (0 means no optimizations).",
|
||||
)
|
||||
voice_settings: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Default voice settings (stability, similarity boost, etc.).",
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
|
@ -71,7 +86,9 @@ class ElevenLabsSpeechClient(ElevenLabsClientBase):
|
|||
voice = voice or self.voice
|
||||
model = model or self.model
|
||||
output_format = output_format or self.output_format
|
||||
optimize_streaming_latency = optimize_streaming_latency or self.optimize_streaming_latency
|
||||
optimize_streaming_latency = (
|
||||
optimize_streaming_latency or self.optimize_streaming_latency
|
||||
)
|
||||
voice_settings = voice_settings or self.voice_settings
|
||||
|
||||
logger.info(f"Generating speech with voice '{voice}', model '{model}'.")
|
||||
|
|
|
@ -4,13 +4,25 @@ from dapr_agents.prompt.prompty import Prompty
|
|||
from dapr_agents.types.message import BaseMessage
|
||||
from dapr_agents.llm.chat import ChatClientBase
|
||||
from dapr_agents.tool import AgentTool
|
||||
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type, Literal, ClassVar
|
||||
from typing import (
|
||||
Union,
|
||||
Optional,
|
||||
Iterable,
|
||||
Dict,
|
||||
Any,
|
||||
List,
|
||||
Iterator,
|
||||
Type,
|
||||
Literal,
|
||||
ClassVar,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
||||
"""
|
||||
Concrete class for the Hugging Face Hub's chat completion API using the Inference API.
|
||||
|
@ -28,7 +40,11 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
|||
return super().model_post_init(__context)
|
||||
|
||||
@classmethod
|
||||
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'HFHubChatClient':
|
||||
def from_prompty(
|
||||
cls,
|
||||
prompty_source: Union[str, Path],
|
||||
timeout: Union[int, float, Dict[str, Any]] = 1500,
|
||||
) -> "HFHubChatClient":
|
||||
"""
|
||||
Initializes an HFHubChatClient client using a Prompty source, which can be a file path or inline content.
|
||||
|
||||
|
@ -50,27 +66,34 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
|||
model_config = prompty_instance.model
|
||||
|
||||
# Initialize the HFHubChatClient based on the Prompty model configuration
|
||||
return cls.model_validate({
|
||||
'model': model_config.configuration.name,
|
||||
'api_key': model_config.configuration.api_key,
|
||||
'base_url': model_config.configuration.base_url,
|
||||
'headers': model_config.configuration.headers,
|
||||
'cookies': model_config.configuration.cookies,
|
||||
'proxies': model_config.configuration.proxies,
|
||||
'timeout': timeout,
|
||||
'prompty': prompty_instance,
|
||||
'prompt_template': prompt_template,
|
||||
})
|
||||
return cls.model_validate(
|
||||
{
|
||||
"model": model_config.configuration.name,
|
||||
"api_key": model_config.configuration.api_key,
|
||||
"base_url": model_config.configuration.base_url,
|
||||
"headers": model_config.configuration.headers,
|
||||
"cookies": model_config.configuration.cookies,
|
||||
"proxies": model_config.configuration.proxies,
|
||||
"timeout": timeout,
|
||||
"prompty": prompty_instance,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None,
|
||||
messages: Union[
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseMessage,
|
||||
Iterable[Union[Dict[str, Any], BaseMessage]],
|
||||
] = None,
|
||||
input_data: Optional[Dict[str, Any]] = None,
|
||||
model: Optional[str] = None,
|
||||
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
structured_mode: Literal["function_call"] = "function_call",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Generate chat completions based on provided messages or input_data for prompt templates.
|
||||
|
@ -89,12 +112,16 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
|||
"""
|
||||
|
||||
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
|
||||
raise ValueError(f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}.")
|
||||
raise ValueError(
|
||||
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
|
||||
)
|
||||
|
||||
# If input_data is provided, check for a prompt_template
|
||||
if input_data:
|
||||
if not self.prompt_template:
|
||||
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.")
|
||||
raise ValueError(
|
||||
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
|
||||
)
|
||||
|
||||
logger.info("Using prompt template to generate messages.")
|
||||
messages = self.prompt_template.format_prompt(**input_data)
|
||||
|
@ -104,7 +131,7 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
|||
raise ValueError("Either 'messages' or 'input_data' must be provided.")
|
||||
|
||||
# Process and normalize the messages
|
||||
params = {'messages': RequestHandler.normalize_chat_messages(messages)}
|
||||
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
|
||||
|
||||
# Merge Prompty parameters if available, then override with any explicit kwargs
|
||||
if self.prompty:
|
||||
|
@ -113,7 +140,7 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
|||
params.update(kwargs)
|
||||
|
||||
# If a model is provided, override the default model
|
||||
params['model'] = model or self.model
|
||||
params["model"] = model or self.model
|
||||
|
||||
# Prepare request parameters
|
||||
params = RequestHandler.process_params(
|
||||
|
@ -121,7 +148,7 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
tools=tools,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode
|
||||
structured_mode=structured_mode,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -134,7 +161,7 @@ class HFHubChatClient(HFHubInferenceClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode,
|
||||
stream=params.get('stream', False)
|
||||
stream=params.get("stream", False),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
|
||||
|
|
|
@ -8,19 +8,43 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFHubInferenceClientBase(LLMClientBase):
|
||||
"""
|
||||
Base class for managing Hugging Face Inference API clients.
|
||||
Handles client initialization, configuration, and shared logic.
|
||||
"""
|
||||
model: Optional[str] = Field(default=None, description="Model ID or URL for the Hugging Face API. Cannot be used with `base_url`. If set, the client will infer a model-specific endpoint.")
|
||||
token: Optional[Union[str, bool]] = Field(default=None, description="Hugging Face token. Defaults to the locally saved token if not provided. Pass `False` to disable authentication.")
|
||||
api_key: Optional[Union[str, bool]] = Field(default=None, description="Alias for `token` for compatibility with OpenAI's client. Cannot be used if `token` is set.")
|
||||
base_url: Optional[str] = Field(default=None, description="Base URL to run inference. Alias for `model`. Cannot be used if `model` is set.")
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="Additional headers to send to the server. Overrides the default authorization and user-agent headers.")
|
||||
cookies: Optional[Dict[str, str]] = Field(default=None, description="Additional cookies to send to the server.")
|
||||
proxies: Optional[Any] = Field(default=None, description="Proxies to use for the request.")
|
||||
timeout: Optional[float] = Field(default=None, description="The maximum number of seconds to wait for a response from the server. Loading a new model in Inference. API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.")
|
||||
|
||||
model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model ID or URL for the Hugging Face API. Cannot be used with `base_url`. If set, the client will infer a model-specific endpoint.",
|
||||
)
|
||||
token: Optional[Union[str, bool]] = Field(
|
||||
default=None,
|
||||
description="Hugging Face token. Defaults to the locally saved token if not provided. Pass `False` to disable authentication.",
|
||||
)
|
||||
api_key: Optional[Union[str, bool]] = Field(
|
||||
default=None,
|
||||
description="Alias for `token` for compatibility with OpenAI's client. Cannot be used if `token` is set.",
|
||||
)
|
||||
base_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL to run inference. Alias for `model`. Cannot be used if `model` is set.",
|
||||
)
|
||||
headers: Optional[Dict[str, str]] = Field(
|
||||
default=None,
|
||||
description="Additional headers to send to the server. Overrides the default authorization and user-agent headers.",
|
||||
)
|
||||
cookies: Optional[Dict[str, str]] = Field(
|
||||
default=None, description="Additional cookies to send to the server."
|
||||
)
|
||||
proxies: Optional[Any] = Field(
|
||||
default=None, description="Proxies to use for the request."
|
||||
)
|
||||
timeout: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The maximum number of seconds to wait for a response from the server. Loading a new model in Inference. API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
@ -29,28 +53,32 @@ class HFHubInferenceClientBase(LLMClientBase):
|
|||
- Normalizes 'token' and 'api_key' to a single field.
|
||||
- Validates exclusivity of 'model' and 'base_url'.
|
||||
"""
|
||||
token = values.get('token')
|
||||
api_key = values.get('api_key')
|
||||
model = values.get('model')
|
||||
base_url = values.get('base_url')
|
||||
token = values.get("token")
|
||||
api_key = values.get("api_key")
|
||||
model = values.get("model")
|
||||
base_url = values.get("base_url")
|
||||
|
||||
# Ensure mutual exclusivity of `token` and `api_key`
|
||||
if token is not None and api_key is not None:
|
||||
raise ValueError("Provide only one of 'api_key' or 'token'. They are aliases and cannot coexist.")
|
||||
raise ValueError(
|
||||
"Provide only one of 'api_key' or 'token'. They are aliases and cannot coexist."
|
||||
)
|
||||
|
||||
# Normalize `token` to `api_key`
|
||||
if token is not None:
|
||||
values['api_key'] = token
|
||||
values.pop('token', None) # Remove `token` for consistency
|
||||
values["api_key"] = token
|
||||
values.pop("token", None) # Remove `token` for consistency
|
||||
|
||||
# Use environment variable if `api_key` is not explicitly provided
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("HUGGINGFACE_API_KEY")
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError("API key is required. Set it explicitly or in the 'HUGGINGFACE_API_KEY' environment variable.")
|
||||
raise ValueError(
|
||||
"API key is required. Set it explicitly or in the 'HUGGINGFACE_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
values['api_key'] = api_key
|
||||
values["api_key"] = api_key
|
||||
|
||||
# mutual‑exclusivity
|
||||
if model is not None and base_url is not None:
|
||||
|
@ -92,7 +120,7 @@ class HFHubInferenceClientBase(LLMClientBase):
|
|||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
proxies=self.proxies,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
def get_client(self) -> InferenceClient:
|
||||
|
@ -107,11 +135,13 @@ class HFHubInferenceClientBase(LLMClientBase):
|
|||
headers=config.headers,
|
||||
cookies=config.cookies,
|
||||
proxies=config.proxies,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, client_options: HFInferenceClientConfig, timeout: float = 1500):
|
||||
def from_config(
|
||||
cls, client_options: HFInferenceClientConfig, timeout: float = 1500
|
||||
):
|
||||
"""
|
||||
Initializes the HFHubInferenceClientBase using HFInferenceClientConfig.
|
||||
|
||||
|
|
|
@ -4,7 +4,18 @@ from dapr_agents.types.message import BaseMessage
|
|||
from dapr_agents.llm.chat import ChatClientBase
|
||||
from dapr_agents.prompt.prompty import Prompty
|
||||
from dapr_agents.tool import AgentTool
|
||||
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type, Literal, ClassVar
|
||||
from typing import (
|
||||
Union,
|
||||
Optional,
|
||||
Iterable,
|
||||
Dict,
|
||||
Any,
|
||||
List,
|
||||
Iterator,
|
||||
Type,
|
||||
Literal,
|
||||
ClassVar,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
|
@ -12,14 +23,23 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
||||
"""
|
||||
Chat client for NVIDIA chat models.
|
||||
Combines NVIDIA client management with Prompty-specific functionality for handling chat completions.
|
||||
"""
|
||||
|
||||
model: str = Field(default='meta/llama3-8b-instruct', description="Model name to use. Defaults to 'meta/llama3-8b-instruct'.")
|
||||
max_tokens: Optional[int] = Field(default=1024,description=("The maximum number of tokens to generate in any given call. Must be an integer ≥ 1. Defaults to 1024."))
|
||||
model: str = Field(
|
||||
default="meta/llama3-8b-instruct",
|
||||
description="Model name to use. Defaults to 'meta/llama3-8b-instruct'.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=1024,
|
||||
description=(
|
||||
"The maximum number of tokens to generate in any given call. Must be an integer ≥ 1. Defaults to 1024."
|
||||
),
|
||||
)
|
||||
|
||||
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"function_call"}
|
||||
|
||||
|
@ -34,7 +54,7 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
|||
super().model_post_init(__context)
|
||||
|
||||
@classmethod
|
||||
def from_prompty(cls, prompty_source: Union[str, Path]) -> 'NVIDIAChatClient':
|
||||
def from_prompty(cls, prompty_source: Union[str, Path]) -> "NVIDIAChatClient":
|
||||
"""
|
||||
Initializes an NVIDIAChatClient client using a Prompty source, which can be a file path or inline content.
|
||||
|
||||
|
@ -55,24 +75,31 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
|||
model_config = prompty_instance.model
|
||||
|
||||
# Initialize the NVIDIAChatClient instance using model_validate
|
||||
return cls.model_validate({
|
||||
'model': model_config.configuration.name,
|
||||
'api_key': model_config.configuration.api_key,
|
||||
'base_url': model_config.configuration.base_url,
|
||||
'prompty': prompty_instance,
|
||||
'prompt_template': prompt_template,
|
||||
})
|
||||
return cls.model_validate(
|
||||
{
|
||||
"model": model_config.configuration.name,
|
||||
"api_key": model_config.configuration.api_key,
|
||||
"base_url": model_config.configuration.base_url,
|
||||
"prompty": prompty_instance,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None,
|
||||
messages: Union[
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseMessage,
|
||||
Iterable[Union[Dict[str, Any], BaseMessage]],
|
||||
] = None,
|
||||
input_data: Optional[Dict[str, Any]] = None,
|
||||
model: Optional[str] = None,
|
||||
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
structured_mode: Literal["function_call"] = "function_call",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Generate chat completions based on provided messages or input_data for prompt templates.
|
||||
|
@ -92,12 +119,16 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
|||
"""
|
||||
|
||||
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
|
||||
raise ValueError(f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}.")
|
||||
raise ValueError(
|
||||
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
|
||||
)
|
||||
|
||||
# If input_data is provided, check for a prompt_template
|
||||
if input_data:
|
||||
if not self.prompt_template:
|
||||
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.")
|
||||
raise ValueError(
|
||||
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
|
||||
)
|
||||
|
||||
logger.info("Using prompt template to generate messages.")
|
||||
messages = self.prompt_template.format_prompt(**input_data)
|
||||
|
@ -107,7 +138,7 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
|||
raise ValueError("Either 'messages' or 'input_data' must be provided.")
|
||||
|
||||
# Process and normalize the messages
|
||||
params = {'messages': RequestHandler.normalize_chat_messages(messages)}
|
||||
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
|
||||
|
||||
# Merge prompty parameters if available, then override with any explicit kwargs
|
||||
if self.prompty:
|
||||
|
@ -116,10 +147,10 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
|||
params.update(kwargs)
|
||||
|
||||
# If a model is provided, override the default model
|
||||
params['model'] = model or self.model
|
||||
params["model"] = model or self.model
|
||||
|
||||
# Apply max_tokens if provided
|
||||
params['max_tokens'] = max_tokens or self.max_tokens
|
||||
params["max_tokens"] = max_tokens or self.max_tokens
|
||||
|
||||
# Prepare request parameters
|
||||
params = RequestHandler.process_params(
|
||||
|
@ -127,13 +158,15 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
tools=tools,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode
|
||||
structured_mode=structured_mode,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Invoking ChatCompletion API.")
|
||||
logger.debug(f"ChatCompletion API Parameters:{params}")
|
||||
response: ChatCompletionMessage = self.client.chat.completions.create(**params)
|
||||
response: ChatCompletionMessage = self.client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
logger.info("Chat completion retrieved successfully.")
|
||||
|
||||
return ResponseHandler.process_response(
|
||||
|
@ -141,7 +174,7 @@ class NVIDIAChatClient(NVIDIAClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode,
|
||||
stream=params.get('stream', False)
|
||||
stream=params.get("stream", False),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
|
||||
|
|
|
@ -8,14 +8,21 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIAClientBase(LLMClientBase):
|
||||
"""
|
||||
Base class for managing NVIDIA LLM clients.
|
||||
Handles client initialization, configuration, and shared logic specific to NVIDIA's API.
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = Field(default=None, description="API key for authenticating with the NVIDIA LLM API. If not provided, it will be sourced from the 'NVIDIA_API_KEY' environment variable.")
|
||||
base_url: Optional[str] = Field(default="https://integrate.api.nvidia.com/v1", description="Base URL for the NVIDIA LLM API endpoints.")
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for authenticating with the NVIDIA LLM API. If not provided, it will be sourced from the 'NVIDIA_API_KEY' environment variable.",
|
||||
)
|
||||
base_url: Optional[str] = Field(
|
||||
default="https://integrate.api.nvidia.com/v1",
|
||||
description="Base URL for the NVIDIA LLM API endpoints.",
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
|
@ -33,7 +40,9 @@ class NVIDIAClientBase(LLMClientBase):
|
|||
self.api_key = os.environ.get("NVIDIA_API_KEY")
|
||||
|
||||
if self.api_key is None:
|
||||
raise ValueError("API key is required. Set it explicitly or in the 'NVIDIA_API_KEY' environment variable.")
|
||||
raise ValueError(
|
||||
"API key is required. Set it explicitly or in the 'NVIDIA_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
# Set up the private config and client attributes
|
||||
self._config: NVIDIAClientConfig = self.get_config()
|
||||
|
@ -49,10 +58,7 @@ class NVIDIAClientBase(LLMClientBase):
|
|||
Returns:
|
||||
NVIDIAClientConfig: Configuration object containing API credentials and endpoint details.
|
||||
"""
|
||||
return NVIDIAClientConfig(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
return NVIDIAClientConfig(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
def get_client(self) -> OpenAI:
|
||||
"""
|
||||
|
@ -66,10 +72,7 @@ class NVIDIAClientBase(LLMClientBase):
|
|||
config = self.config
|
||||
|
||||
logger.info("Initializing NVIDIA API client...")
|
||||
return OpenAI(
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url
|
||||
)
|
||||
return OpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
@property
|
||||
def config(self) -> NVIDIAClientConfig:
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIAEmbeddingClient(NVIDIAClientBase):
|
||||
"""
|
||||
Client for handling NVIDIA's embedding functionalities.
|
||||
|
@ -19,11 +20,24 @@ class NVIDIAEmbeddingClient(NVIDIAClientBase):
|
|||
'passage' for generating embeddings during indexing.
|
||||
truncate (Optional[Literal["NONE", "START", "END"]]): Specifies handling for inputs exceeding the model's max token length. Defaults to 'NONE'.
|
||||
"""
|
||||
model: str = Field("nvidia/nv-embedqa-e5-v5", description="ID of the model to use for embedding.")
|
||||
encoding_format: Optional[Literal["float", "base64"]] = Field("float", description="Format for the embeddings. Defaults to 'float'.")
|
||||
dimensions: Optional[int] = Field(None, description="Number of dimensions for the output embeddings. Not supported by all models.")
|
||||
input_type: Optional[Literal["query", "passage"]] = Field("passage", description="Mode of operation: 'query' or 'passage'.")
|
||||
truncate: Optional[Literal["NONE", "START", "END"]] = Field("NONE", description="Handling for inputs exceeding max token length. Defaults to 'NONE'.")
|
||||
|
||||
model: str = Field(
|
||||
"nvidia/nv-embedqa-e5-v5", description="ID of the model to use for embedding."
|
||||
)
|
||||
encoding_format: Optional[Literal["float", "base64"]] = Field(
|
||||
"float", description="Format for the embeddings. Defaults to 'float'."
|
||||
)
|
||||
dimensions: Optional[int] = Field(
|
||||
None,
|
||||
description="Number of dimensions for the output embeddings. Not supported by all models.",
|
||||
)
|
||||
input_type: Optional[Literal["query", "passage"]] = Field(
|
||||
"passage", description="Mode of operation: 'query' or 'passage'."
|
||||
)
|
||||
truncate: Optional[Literal["NONE", "START", "END"]] = Field(
|
||||
"NONE",
|
||||
description="Handling for inputs exceeding max token length. Defaults to 'NONE'.",
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
|
@ -45,7 +59,7 @@ class NVIDIAEmbeddingClient(NVIDIAClientBase):
|
|||
truncate: Optional[Literal["NONE", "START", "END"]] = None,
|
||||
encoding_format: Optional[Literal["float", "base64"]] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> CreateEmbeddingResponse:
|
||||
"""
|
||||
Generate embeddings for the given input text(s).
|
||||
|
@ -77,7 +91,7 @@ class NVIDIAEmbeddingClient(NVIDIAClientBase):
|
|||
"model": model,
|
||||
"input": input,
|
||||
"encoding_format": encoding_format or self.encoding_format,
|
||||
"extra_body": extra_body or {}
|
||||
"extra_body": extra_body or {},
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
from dapr_agents.llm.openai.client.base import OpenAIClientBase
|
||||
from dapr_agents.llm.utils import RequestHandler
|
||||
from dapr_agents.types.llm import (
|
||||
AudioSpeechRequest, AudioTranscriptionRequest,
|
||||
AudioTranslationRequest, AudioTranscriptionResponse, AudioTranslationResponse,
|
||||
AudioSpeechRequest,
|
||||
AudioTranscriptionRequest,
|
||||
AudioTranslationRequest,
|
||||
AudioTranscriptionResponse,
|
||||
AudioTranslationResponse,
|
||||
)
|
||||
from typing import Union, Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAudioClient(OpenAIClientBase):
|
||||
"""
|
||||
Client for handling OpenAI's audio functionalities, including speech generation, transcription, and translation.
|
||||
|
@ -22,7 +26,11 @@ class OpenAIAudioClient(OpenAIClientBase):
|
|||
self._api = "audio"
|
||||
super().model_post_init(__context)
|
||||
|
||||
def create_speech(self, request: Union[AudioSpeechRequest, Dict[str, Any]], file_name: Optional[str] = None) -> Union[bytes, None]:
|
||||
def create_speech(
|
||||
self,
|
||||
request: Union[AudioSpeechRequest, Dict[str, Any]],
|
||||
file_name: Optional[str] = None,
|
||||
) -> Union[bytes, None]:
|
||||
"""
|
||||
Generate speech audio from text and optionally save it to a file.
|
||||
|
||||
|
@ -34,7 +42,9 @@ class OpenAIAudioClient(OpenAIClientBase):
|
|||
Union[bytes, None]: The generated audio content as bytes if no file_name is provided, otherwise None.
|
||||
"""
|
||||
# Transform dictionary to Pydantic object if needed
|
||||
validated_request: AudioSpeechRequest = RequestHandler.validate_request(request, AudioSpeechRequest)
|
||||
validated_request: AudioSpeechRequest = RequestHandler.validate_request(
|
||||
request, AudioSpeechRequest
|
||||
)
|
||||
|
||||
logger.info(f"Using model '{validated_request.model}' for speech generation.")
|
||||
|
||||
|
@ -43,7 +53,9 @@ class OpenAIAudioClient(OpenAIClientBase):
|
|||
max_chunk_size = 4096
|
||||
|
||||
if len(input_text) > max_chunk_size:
|
||||
logger.info(f"Input exceeds {max_chunk_size} characters. Splitting into smaller chunks.")
|
||||
logger.info(
|
||||
f"Input exceeds {max_chunk_size} characters. Splitting into smaller chunks."
|
||||
)
|
||||
|
||||
# Split input text into manageable chunks
|
||||
def split_text(text, max_size):
|
||||
|
@ -62,7 +74,9 @@ class OpenAIAudioClient(OpenAIClientBase):
|
|||
try:
|
||||
for chunk in text_chunks:
|
||||
validated_request.input = chunk
|
||||
with self.client.with_streaming_response.audio.speech.create(**validated_request.model_dump()) as response:
|
||||
with self.client.with_streaming_response.audio.speech.create(
|
||||
**validated_request.model_dump()
|
||||
) as response:
|
||||
if file_name:
|
||||
# Write each chunk incrementally to the file
|
||||
logger.info(f"Saving audio chunk to file: {file_name}")
|
||||
|
@ -83,7 +97,9 @@ class OpenAIAudioClient(OpenAIClientBase):
|
|||
logger.error(f"Failed to create or save speech: {e}")
|
||||
raise ValueError(f"An error occurred during speech generation: {e}")
|
||||
|
||||
def create_transcription(self, request: Union[AudioTranscriptionRequest, Dict[str, Any]]) -> AudioTranscriptionResponse:
|
||||
def create_transcription(
|
||||
self, request: Union[AudioTranscriptionRequest, Dict[str, Any]]
|
||||
) -> AudioTranscriptionResponse:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
|
@ -93,17 +109,21 @@ class OpenAIAudioClient(OpenAIClientBase):
|
|||
Returns:
|
||||
AudioTranscriptionResponse: The transcription result.
|
||||
"""
|
||||
validated_request: AudioTranscriptionRequest = RequestHandler.validate_request(request, AudioTranscriptionRequest)
|
||||
validated_request: AudioTranscriptionRequest = RequestHandler.validate_request(
|
||||
request, AudioTranscriptionRequest
|
||||
)
|
||||
|
||||
logger.info(f"Using model '{validated_request.model}' for transcription.")
|
||||
|
||||
response = self.client.audio.transcriptions.create(
|
||||
file=validated_request.file,
|
||||
**validated_request.model_dump(exclude={"file"})
|
||||
**validated_request.model_dump(exclude={"file"}),
|
||||
)
|
||||
return response
|
||||
|
||||
def create_translation(self, request: Union[AudioTranslationRequest, Dict[str, Any]]) -> AudioTranslationResponse:
|
||||
def create_translation(
|
||||
self, request: Union[AudioTranslationRequest, Dict[str, Any]]
|
||||
) -> AudioTranslationResponse:
|
||||
"""
|
||||
Translate audio to English.
|
||||
|
||||
|
@ -113,12 +133,14 @@ class OpenAIAudioClient(OpenAIClientBase):
|
|||
Returns:
|
||||
AudioTranslationResponse: The translation result.
|
||||
"""
|
||||
validated_request: AudioTranslationRequest = RequestHandler.validate_request(request, AudioTranslationRequest)
|
||||
validated_request: AudioTranslationRequest = RequestHandler.validate_request(
|
||||
request, AudioTranslationRequest
|
||||
)
|
||||
|
||||
logger.info(f"Using model '{validated_request.model}' for translation.")
|
||||
|
||||
response = self.client.audio.translations.create(
|
||||
file=validated_request.file,
|
||||
**validated_request.model_dump(exclude={"file"})
|
||||
**validated_request.model_dump(exclude={"file"}),
|
||||
)
|
||||
return response
|
|
@ -5,7 +5,18 @@ from dapr_agents.types.message import BaseMessage
|
|||
from dapr_agents.llm.chat import ChatClientBase
|
||||
from dapr_agents.prompt.prompty import Prompty
|
||||
from dapr_agents.tool import AgentTool
|
||||
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type, Literal, ClassVar
|
||||
from typing import (
|
||||
Union,
|
||||
Optional,
|
||||
Iterable,
|
||||
Dict,
|
||||
Any,
|
||||
List,
|
||||
Iterator,
|
||||
Type,
|
||||
Literal,
|
||||
ClassVar,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pathlib import Path
|
||||
|
@ -13,11 +24,13 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
||||
"""
|
||||
Chat client for OpenAI models.
|
||||
Combines OpenAI client management with Prompty-specific functionality.
|
||||
"""
|
||||
|
||||
model: str = Field(default=None, description="Model name to use, e.g., 'gpt-4'.")
|
||||
|
||||
SUPPORTED_STRUCTURED_MODES: ClassVar[set] = {"json", "function_call"}
|
||||
|
@ -28,8 +41,8 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
Ensures the 'model' is set during validation.
|
||||
Uses 'azure_deployment' if no model is specified, defaults to 'gpt-4o'.
|
||||
"""
|
||||
if 'model' not in values or values['model'] is None:
|
||||
values['model'] = values.get('azure_deployment', 'gpt-4o')
|
||||
if "model" not in values or values["model"] is None:
|
||||
values["model"] = values.get("azure_deployment", "gpt-4o")
|
||||
return values
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
|
@ -40,7 +53,11 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
super().model_post_init(__context)
|
||||
|
||||
@classmethod
|
||||
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'OpenAIChatClient':
|
||||
def from_prompty(
|
||||
cls,
|
||||
prompty_source: Union[str, Path],
|
||||
timeout: Union[int, float, Dict[str, Any]] = 1500,
|
||||
) -> "OpenAIChatClient":
|
||||
"""
|
||||
Initializes an OpenAIChatClient client using a Prompty source, which can be a file path or inline content.
|
||||
|
||||
|
@ -63,43 +80,54 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
|
||||
# Initialize the OpenAIChatClient instance using model_validate
|
||||
if isinstance(model_config.configuration, OpenAIModelConfig):
|
||||
return cls.model_validate({
|
||||
'model': model_config.configuration.name,
|
||||
'api_key': model_config.configuration.api_key,
|
||||
'base_url': model_config.configuration.base_url,
|
||||
'organization': model_config.configuration.organization,
|
||||
'project': model_config.configuration.project,
|
||||
'timeout': timeout,
|
||||
'prompty': prompty_instance,
|
||||
'prompt_template': prompt_template,
|
||||
})
|
||||
return cls.model_validate(
|
||||
{
|
||||
"model": model_config.configuration.name,
|
||||
"api_key": model_config.configuration.api_key,
|
||||
"base_url": model_config.configuration.base_url,
|
||||
"organization": model_config.configuration.organization,
|
||||
"project": model_config.configuration.project,
|
||||
"timeout": timeout,
|
||||
"prompty": prompty_instance,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
)
|
||||
elif isinstance(model_config.configuration, AzureOpenAIModelConfig):
|
||||
return cls.model_validate({
|
||||
'model': model_config.configuration.azure_deployment,
|
||||
'api_key': model_config.configuration.api_key,
|
||||
'azure_endpoint': model_config.configuration.azure_endpoint,
|
||||
'azure_deployment': model_config.configuration.azure_deployment,
|
||||
'api_version': model_config.configuration.api_version,
|
||||
'organization': model_config.configuration.organization,
|
||||
'project': model_config.configuration.project,
|
||||
'azure_ad_token': model_config.configuration.azure_ad_token,
|
||||
'azure_client_id': model_config.configuration.azure_client_id,
|
||||
'timeout': timeout,
|
||||
'prompty': prompty_instance,
|
||||
'prompt_template': prompt_template,
|
||||
})
|
||||
return cls.model_validate(
|
||||
{
|
||||
"model": model_config.configuration.azure_deployment,
|
||||
"api_key": model_config.configuration.api_key,
|
||||
"azure_endpoint": model_config.configuration.azure_endpoint,
|
||||
"azure_deployment": model_config.configuration.azure_deployment,
|
||||
"api_version": model_config.configuration.api_version,
|
||||
"organization": model_config.configuration.organization,
|
||||
"project": model_config.configuration.project,
|
||||
"azure_ad_token": model_config.configuration.azure_ad_token,
|
||||
"azure_client_id": model_config.configuration.azure_client_id,
|
||||
"timeout": timeout,
|
||||
"prompty": prompty_instance,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model configuration type: {type(model_config.configuration)}")
|
||||
raise ValueError(
|
||||
f"Unsupported model configuration type: {type(model_config.configuration)}"
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None,
|
||||
messages: Union[
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseMessage,
|
||||
Iterable[Union[Dict[str, Any], BaseMessage]],
|
||||
] = None,
|
||||
input_data: Optional[Dict[str, Any]] = None,
|
||||
model: Optional[str] = None,
|
||||
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
structured_mode: Literal["json", "function_call"] = "json",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""
|
||||
Generate chat completions based on provided messages or input_data for prompt templates.
|
||||
|
@ -118,12 +146,16 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
"""
|
||||
|
||||
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
|
||||
raise ValueError(f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}.")
|
||||
raise ValueError(
|
||||
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
|
||||
)
|
||||
|
||||
# If input_data is provided, check for a prompt_template
|
||||
if input_data:
|
||||
if not self.prompt_template:
|
||||
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.")
|
||||
raise ValueError(
|
||||
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
|
||||
)
|
||||
|
||||
logger.info("Using prompt template to generate messages.")
|
||||
messages = self.prompt_template.format_prompt(**input_data)
|
||||
|
@ -133,7 +165,7 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
raise ValueError("Either 'messages' or 'input_data' must be provided.")
|
||||
|
||||
# Process and normalize the messages
|
||||
params = {'messages': RequestHandler.normalize_chat_messages(messages)}
|
||||
params = {"messages": RequestHandler.normalize_chat_messages(messages)}
|
||||
|
||||
# Merge prompty parameters if available, then override with any explicit kwargs
|
||||
if self.prompty:
|
||||
|
@ -142,7 +174,7 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
params.update(kwargs)
|
||||
|
||||
# If a model is provided, override the default model
|
||||
params['model'] = model or self.model
|
||||
params["model"] = model or self.model
|
||||
|
||||
# Prepare request parameters
|
||||
params = RequestHandler.process_params(
|
||||
|
@ -150,13 +182,15 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
tools=tools,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode
|
||||
structured_mode=structured_mode,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Invoking ChatCompletion API.")
|
||||
logger.debug(f"ChatCompletion API Parameters: {params}")
|
||||
response: ChatCompletionMessage = self.client.chat.completions.create(**params, timeout=self.timeout)
|
||||
response: ChatCompletionMessage = self.client.chat.completions.create(
|
||||
**params, timeout=self.timeout
|
||||
)
|
||||
logger.info("Chat completion retrieved successfully.")
|
||||
|
||||
return ResponseHandler.process_response(
|
||||
|
@ -164,7 +198,7 @@ class OpenAIChatClient(OpenAIClientBase, ChatClientBase):
|
|||
llm_provider=self.provider,
|
||||
response_format=response_format,
|
||||
structured_mode=structured_mode,
|
||||
stream=params.get('stream', False)
|
||||
stream=params.get("stream", False),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during the ChatCompletion API call: {e}")
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential, get_bearer_token_provider
|
||||
from azure.identity import (
|
||||
DefaultAzureCredential,
|
||||
ManagedIdentityCredential,
|
||||
get_bearer_token_provider,
|
||||
)
|
||||
from dapr_agents.types.llm import AzureOpenAIClientConfig
|
||||
from dapr_agents.llm.utils import HTTPHelper
|
||||
from openai import AzureOpenAI
|
||||
|
@ -8,6 +12,7 @@ import os
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAIClient:
|
||||
"""
|
||||
Client for Azure OpenAI language models, handling API communication and authentication.
|
||||
|
@ -23,7 +28,7 @@ class AzureOpenAIClient:
|
|||
azure_endpoint: Optional[str] = None,
|
||||
azure_deployment: Optional[str] = None,
|
||||
azure_client_id: Optional[str] = None,
|
||||
timeout: Union[int, float, dict] = 1500
|
||||
timeout: Union[int, float, dict] = 1500,
|
||||
):
|
||||
"""
|
||||
Initializes the client with API key or Azure AD credentials.
|
||||
|
@ -50,7 +55,9 @@ class AzureOpenAIClient:
|
|||
self.azure_client_id = azure_client_id or os.getenv("AZURE_CLIENT_ID")
|
||||
|
||||
if not self.azure_endpoint or not self.azure_deployment:
|
||||
raise ValueError("Azure OpenAI endpoint and deployment must be provided, either via arguments or environment variables.")
|
||||
raise ValueError(
|
||||
"Azure OpenAI endpoint and deployment must be provided, either via arguments or environment variables."
|
||||
)
|
||||
|
||||
self.timeout = HTTPHelper.configure_timeout(timeout)
|
||||
|
||||
|
@ -74,7 +81,9 @@ class AzureOpenAIClient:
|
|||
return self._create_client(azure_ad_token=self.azure_ad_token)
|
||||
|
||||
# Case 3: Use Azure Identity Credentials
|
||||
logger.info("No API key or Azure AD token provided, attempting to use Azure Identity credentials.")
|
||||
logger.info(
|
||||
"No API key or Azure AD token provided, attempting to use Azure Identity credentials."
|
||||
)
|
||||
try:
|
||||
credential = (
|
||||
ManagedIdentityCredential(client_id=self.azure_client_id)
|
||||
|
@ -87,7 +96,9 @@ class AzureOpenAIClient:
|
|||
return self._create_client(azure_ad_token_provider=azure_ad_token_provider)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Azure Identity credentials: {e}")
|
||||
raise ValueError("Unable to authenticate using Azure Identity credentials. Check your setup.") from e
|
||||
raise ValueError(
|
||||
"Unable to authenticate using Azure Identity credentials. Check your setup."
|
||||
) from e
|
||||
|
||||
def _create_client(self, **kwargs) -> AzureOpenAI:
|
||||
"""
|
||||
|
@ -98,7 +109,7 @@ class AzureOpenAIClient:
|
|||
azure_deployment=self.azure_deployment,
|
||||
api_version=self.api_version,
|
||||
timeout=self.timeout,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -106,7 +117,7 @@ class AzureOpenAIClient:
|
|||
cls,
|
||||
client_options: AzureOpenAIClientConfig,
|
||||
azure_client_id: Optional[str] = None,
|
||||
timeout: Union[int, float, dict] = 1500
|
||||
timeout: Union[int, float, dict] = 1500,
|
||||
):
|
||||
"""
|
||||
Initialize AzureOpenAIClient using AzureOpenAIClientOptions.
|
||||
|
@ -128,5 +139,5 @@ class AzureOpenAIClient:
|
|||
azure_endpoint=client_options.azure_endpoint,
|
||||
azure_deployment=client_options.azure_deployment,
|
||||
azure_client_id=azure_client_id,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
|
@ -8,21 +8,43 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIClientBase(LLMClientBase):
|
||||
"""
|
||||
Base class for managing OpenAI and Azure OpenAI clients.
|
||||
Handles client initialization, configuration, and shared logic.
|
||||
"""
|
||||
api_key: Optional[str] = Field(default=None, description="API key for OpenAI or Azure OpenAI.")
|
||||
base_url: Optional[str] = Field(default=None, description="Base URL for OpenAI API (OpenAI-specific).")
|
||||
azure_endpoint: Optional[str] = Field(default=None, description="Azure endpoint URL (Azure OpenAI-specific).")
|
||||
azure_deployment: Optional[str] = Field(default=None, description="Azure deployment name (Azure OpenAI-specific).")
|
||||
api_version: Optional[str] = Field(default=None, description="Azure API version (Azure OpenAI-specific).")
|
||||
organization: Optional[str] = Field(default=None, description="Organization for OpenAI or Azure OpenAI.")
|
||||
project: Optional[str] = Field(default=None, description="Project for OpenAI or Azure OpenAI.")
|
||||
azure_ad_token: Optional[str] = Field(default=None, description="Azure AD token for authentication (Azure-specific).")
|
||||
azure_client_id: Optional[str] = Field(default=None, description="Client ID for Managed Identity (Azure-specific).")
|
||||
timeout: Union[int, float, Dict[str, Any]] = Field(default=1500, description="Timeout for requests in seconds.")
|
||||
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="API key for OpenAI or Azure OpenAI."
|
||||
)
|
||||
base_url: Optional[str] = Field(
|
||||
default=None, description="Base URL for OpenAI API (OpenAI-specific)."
|
||||
)
|
||||
azure_endpoint: Optional[str] = Field(
|
||||
default=None, description="Azure endpoint URL (Azure OpenAI-specific)."
|
||||
)
|
||||
azure_deployment: Optional[str] = Field(
|
||||
default=None, description="Azure deployment name (Azure OpenAI-specific)."
|
||||
)
|
||||
api_version: Optional[str] = Field(
|
||||
default=None, description="Azure API version (Azure OpenAI-specific)."
|
||||
)
|
||||
organization: Optional[str] = Field(
|
||||
default=None, description="Organization for OpenAI or Azure OpenAI."
|
||||
)
|
||||
project: Optional[str] = Field(
|
||||
default=None, description="Project for OpenAI or Azure OpenAI."
|
||||
)
|
||||
azure_ad_token: Optional[str] = Field(
|
||||
default=None, description="Azure AD token for authentication (Azure-specific)."
|
||||
)
|
||||
azure_client_id: Optional[str] = Field(
|
||||
default=None, description="Client ID for Managed Identity (Azure-specific)."
|
||||
)
|
||||
timeout: Union[int, float, Dict[str, Any]] = Field(
|
||||
default=1500, description="Timeout for requests in seconds."
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""
|
||||
|
@ -31,7 +53,9 @@ class OpenAIClientBase(LLMClientBase):
|
|||
self._provider = "openai"
|
||||
|
||||
# Set up the private config and client attributes
|
||||
self._config: Union[AzureOpenAIClientConfig, OpenAIClientConfig] = self.get_config()
|
||||
self._config: Union[
|
||||
AzureOpenAIClientConfig, OpenAIClientConfig
|
||||
] = self.get_config()
|
||||
self._client: Union[AzureOpenAI, OpenAI] = self.get_client()
|
||||
return super().model_post_init(__context)
|
||||
|
||||
|
@ -49,14 +73,14 @@ class OpenAIClientBase(LLMClientBase):
|
|||
azure_ad_token=self.azure_ad_token,
|
||||
azure_endpoint=self.azure_endpoint,
|
||||
azure_deployment=self.azure_deployment,
|
||||
api_version=self.api_version
|
||||
api_version=self.api_version,
|
||||
)
|
||||
else:
|
||||
return OpenAIClientConfig(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
organization=self.organization,
|
||||
project=self.project
|
||||
project=self.project,
|
||||
)
|
||||
|
||||
def get_client(self) -> Union[AzureOpenAI, OpenAI]:
|
||||
|
@ -77,7 +101,7 @@ class OpenAIClientBase(LLMClientBase):
|
|||
organization=config.organization,
|
||||
project=config.project,
|
||||
azure_client_id=self.azure_client_id,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
).get_client()
|
||||
|
||||
logger.info("Initializing OpenAI client...")
|
||||
|
@ -86,7 +110,7 @@ class OpenAIClientBase(LLMClientBase):
|
|||
base_url=config.base_url,
|
||||
organization=config.organization,
|
||||
project=config.project,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
).get_client()
|
||||
|
||||
@property
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient:
|
||||
"""
|
||||
Client for interfacing with OpenAI's language models.
|
||||
|
@ -18,7 +19,7 @@ class OpenAIClient:
|
|||
base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
timeout: Union[int, float, dict] = 1500
|
||||
timeout: Union[int, float, dict] = 1500,
|
||||
):
|
||||
"""
|
||||
Initializes the OpenAI client with API key, base URL, and organization.
|
||||
|
@ -48,11 +49,13 @@ class OpenAIClient:
|
|||
base_url=self.base_url,
|
||||
organization=self.organization,
|
||||
project=self.project,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, client_options: OpenAIClientConfig, timeout: Union[int, float, dict] = 1500):
|
||||
def from_config(
|
||||
cls, client_options: OpenAIClientConfig, timeout: Union[int, float, dict] = 1500
|
||||
):
|
||||
"""
|
||||
Initialize OpenAIBaseClient using OpenAIClientConfig.
|
||||
|
||||
|
@ -68,5 +71,5 @@ class OpenAIClient:
|
|||
base_url=client_options.base_url,
|
||||
organization=client_options.organization,
|
||||
project=client_options.project,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
|
@ -6,6 +6,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIEmbeddingClient(OpenAIClientBase):
|
||||
"""
|
||||
Client for handling OpenAI's embedding functionalities, supporting both OpenAI and Azure OpenAI configurations.
|
||||
|
@ -16,10 +17,20 @@ class OpenAIEmbeddingClient(OpenAIClientBase):
|
|||
dimensions (Optional[int]): Number of dimensions for the output embeddings. Only supported in specific models like `text-embedding-3`.
|
||||
user (Optional[str]): A unique identifier representing the end-user.
|
||||
"""
|
||||
model: str = Field(default=None, description="ID of the model to use for embedding.")
|
||||
encoding_format: Optional[Literal["float", "base64"]] = Field("float", description="Format for the embeddings. Defaults to 'float'.")
|
||||
dimensions: Optional[int] = Field(None, description="Number of dimensions for the output embeddings. Supported in text-embedding-3 and later models.")
|
||||
user: Optional[str] = Field(None, description="Unique identifier representing the end-user.")
|
||||
|
||||
model: str = Field(
|
||||
default=None, description="ID of the model to use for embedding."
|
||||
)
|
||||
encoding_format: Optional[Literal["float", "base64"]] = Field(
|
||||
"float", description="Format for the embeddings. Defaults to 'float'."
|
||||
)
|
||||
dimensions: Optional[int] = Field(
|
||||
None,
|
||||
description="Number of dimensions for the output embeddings. Supported in text-embedding-3 and later models.",
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
None, description="Unique identifier representing the end-user."
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
@ -33,8 +44,8 @@ class OpenAIEmbeddingClient(OpenAIClientBase):
|
|||
Returns:
|
||||
Dict[str, Any]: Updated dictionary of validated attributes.
|
||||
"""
|
||||
if 'model' not in values or values['model'] is None:
|
||||
values['model'] = values.get('azure_deployment', 'text-embedding-ada-002')
|
||||
if "model" not in values or values["model"] is None:
|
||||
values["model"] = values.get("azure_deployment", "text-embedding-ada-002")
|
||||
return values
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
|
@ -49,7 +60,11 @@ class OpenAIEmbeddingClient(OpenAIClientBase):
|
|||
self._api = "embeddings"
|
||||
return super().model_post_init(__context)
|
||||
|
||||
def create_embedding(self, input: Union[str, List[Union[str, List[int]]]], model: Optional[str] = None) -> CreateEmbeddingResponse:
|
||||
def create_embedding(
|
||||
self,
|
||||
input: Union[str, List[Union[str, List[int]]]],
|
||||
model: Optional[str] = None,
|
||||
) -> CreateEmbeddingResponse:
|
||||
"""
|
||||
Generate embeddings for the given input text(s).
|
||||
|
||||
|
@ -75,6 +90,6 @@ class OpenAIEmbeddingClient(OpenAIClientBase):
|
|||
input=input,
|
||||
encoding_format=self.encoding_format,
|
||||
dimensions=self.dimensions,
|
||||
user=self.user
|
||||
user=self.user,
|
||||
)
|
||||
return response
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Union
|
||||
import httpx
|
||||
|
||||
|
||||
class HTTPHelper:
|
||||
"""
|
||||
HTTP operations helper.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def configure_timeout(timeout: Union[int, float, dict]) -> httpx.Timeout:
|
||||
"""
|
||||
|
|
|
@ -9,13 +9,16 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
"""
|
||||
Handles the preparation of requests for language models.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def process_prompty_messages(prompty: Prompty, inputs: Dict[str, Any] = {}) -> List[Dict[str, Any]]:
|
||||
def process_prompty_messages(
|
||||
prompty: Prompty, inputs: Dict[str, Any] = {}
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process and format messages based on Prompty template and provided inputs.
|
||||
|
||||
|
@ -28,13 +31,24 @@ class RequestHandler:
|
|||
"""
|
||||
# Prepare inputs and generate messages from Prompty content
|
||||
api_type = prompty.model.api
|
||||
prepared_inputs = PromptyHelper.prepare_inputs(inputs, prompty.inputs, prompty.sample)
|
||||
messages = PromptyHelper.to_prompt(prompty.content, prepared_inputs, api_type=api_type)
|
||||
prepared_inputs = PromptyHelper.prepare_inputs(
|
||||
inputs, prompty.inputs, prompty.sample
|
||||
)
|
||||
messages = PromptyHelper.to_prompt(
|
||||
prompty.content, prepared_inputs, api_type=api_type
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def normalize_chat_messages(messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]]) -> List[Dict[str, Any]]:
|
||||
def normalize_chat_messages(
|
||||
messages: Union[
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseMessage,
|
||||
Iterable[Union[Dict[str, Any], BaseMessage]],
|
||||
],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Normalize and validate the input messages into a list of dictionaries.
|
||||
|
||||
|
@ -63,7 +77,9 @@ class RequestHandler:
|
|||
elif isinstance(msg, dict):
|
||||
role = msg.get("role")
|
||||
if role not in {"user", "assistant", "tool", "system"}:
|
||||
raise ValueError(f"Unrecognized role '{role}'. Supported roles are 'user', 'assistant', 'tool', or 'system'.")
|
||||
raise ValueError(
|
||||
f"Unrecognized role '{role}'. Supported roles are 'user', 'assistant', 'tool', or 'system'."
|
||||
)
|
||||
normalized_messages.append(msg)
|
||||
elif isinstance(msg, Iterable) and not isinstance(msg, (str, dict)):
|
||||
queue.extend(msg)
|
||||
|
@ -96,7 +112,9 @@ class RequestHandler:
|
|||
"""
|
||||
if tools:
|
||||
logger.info("Tools are available in the request.")
|
||||
params['tools'] = [ToolHelper.format_tool(tool, tool_format=llm_provider) for tool in tools]
|
||||
params["tools"] = [
|
||||
ToolHelper.format_tool(tool, tool_format=llm_provider) for tool in tools
|
||||
]
|
||||
|
||||
if response_format:
|
||||
logger.info(f"Structured Mode Activated! Mode={structured_mode}.")
|
||||
|
@ -104,13 +122,15 @@ class RequestHandler:
|
|||
response_format=response_format,
|
||||
llm_provider=llm_provider,
|
||||
structured_mode=structured_mode,
|
||||
**params
|
||||
**params,
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def validate_request(request: Union[BaseModel, Dict[str, Any]], request_class: Type[BaseModel]) -> BaseModel:
|
||||
def validate_request(
|
||||
request: Union[BaseModel, Dict[str, Any]], request_class: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Validate and transform a dictionary into a Pydantic object.
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from dapr_agents.types import ChatCompletion
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseHandler:
|
||||
"""
|
||||
Handles the processing of responses from language models.
|
||||
|
@ -50,11 +51,15 @@ class ResponseHandler:
|
|||
)
|
||||
|
||||
# Normalize format and resolve actual model class
|
||||
normalized_format = StructureHandler.normalize_iterable_format(response_format)
|
||||
normalized_format = StructureHandler.normalize_iterable_format(
|
||||
response_format
|
||||
)
|
||||
model_cls = StructureHandler.resolve_response_model(normalized_format)
|
||||
|
||||
if not model_cls:
|
||||
raise TypeError(f"Could not resolve a valid Pydantic model from response_format: {response_format}")
|
||||
raise TypeError(
|
||||
f"Could not resolve a valid Pydantic model from response_format: {response_format}"
|
||||
)
|
||||
|
||||
structured_response_instance = StructureHandler.validate_response(
|
||||
structured_response_json, normalized_format
|
||||
|
|
|
@ -1,4 +1,14 @@
|
|||
from typing import Dict, Any, Iterator, Type, TypeVar, Union, Optional, Iterable, get_args
|
||||
from typing import (
|
||||
Dict,
|
||||
Any,
|
||||
Iterator,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
Optional,
|
||||
Iterable,
|
||||
get_args,
|
||||
)
|
||||
from dapr_agents.llm.utils import StructureHandler
|
||||
from dapr_agents.types import ToolCall
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
@ -9,6 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class StreamHandler:
|
||||
"""
|
||||
Handles streaming of chat completion responses, processing tool calls and content responses.
|
||||
|
@ -34,7 +45,7 @@ class StreamHandler:
|
|||
logger.info("Streaming response enabled.")
|
||||
|
||||
try:
|
||||
if llm_provider == 'openai':
|
||||
if llm_provider == "openai":
|
||||
yield from StreamHandler._process_openai_stream(stream, response_format)
|
||||
else:
|
||||
yield from stream
|
||||
|
@ -45,7 +56,7 @@ class StreamHandler:
|
|||
@staticmethod
|
||||
def _process_openai_stream(
|
||||
stream: Iterator[Dict[str, Any]],
|
||||
response_format: Optional[Union[Type[T], Type[Iterable[T]]]] = None
|
||||
response_format: Optional[Union[Type[T], Type[Iterable[T]]]] = None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Process OpenAI stream for chat completion.
|
||||
|
@ -82,34 +93,45 @@ class StreamHandler:
|
|||
tool_calls.setdefault(
|
||||
tool_call_index,
|
||||
{
|
||||
'id': tool_call_id,
|
||||
'type': tool_chunk["type"],
|
||||
'function': {
|
||||
'name': tool_call_function["name"],
|
||||
'arguments': tool_call_arguments
|
||||
}
|
||||
}
|
||||
"id": tool_call_id,
|
||||
"type": tool_chunk["type"],
|
||||
"function": {
|
||||
"name": tool_call_function["name"],
|
||||
"arguments": tool_call_arguments,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Add tool call arguments to current tool calls
|
||||
tool_calls[tool_call_index]["function"]["arguments"] += tool_call_arguments
|
||||
tool_calls[tool_call_index]["function"][
|
||||
"arguments"
|
||||
] += tool_call_arguments
|
||||
|
||||
# Process Iterable model if provided
|
||||
if response_format and isinstance(response_format, Iterable) is True:
|
||||
if (
|
||||
response_format
|
||||
and isinstance(response_format, Iterable) is True
|
||||
):
|
||||
trimmed_character = tool_call_arguments.strip()
|
||||
# Check beginning of List
|
||||
if trimmed_character == "[" and json_extraction_active is False:
|
||||
json_extraction_active = True
|
||||
# Check beginning of a JSON object
|
||||
elif trimmed_character == "{" and json_extraction_active is True:
|
||||
elif (
|
||||
trimmed_character == "{" and json_extraction_active is True
|
||||
):
|
||||
json_brace_level += 1
|
||||
json_string_buffer += trimmed_character
|
||||
# Check the end of a JSON object
|
||||
elif "}" in trimmed_character and json_extraction_active is True:
|
||||
elif (
|
||||
"}" in trimmed_character and json_extraction_active is True
|
||||
):
|
||||
json_brace_level -= 1
|
||||
json_string_buffer += trimmed_character.rstrip(',')
|
||||
json_string_buffer += trimmed_character.rstrip(",")
|
||||
if json_brace_level == 0:
|
||||
yield from StreamHandler._validate_json_object(response_format, json_string_buffer)
|
||||
yield from StreamHandler._validate_json_object(
|
||||
response_format, json_string_buffer
|
||||
)
|
||||
# Reset buffers and counts
|
||||
json_string_buffer = ""
|
||||
elif json_extraction_active is True:
|
||||
|
@ -145,15 +167,27 @@ class StreamHandler:
|
|||
|
||||
# Process tool calls
|
||||
if delta.get("tool_calls"):
|
||||
return {"type": "tool_calls", "data": delta["tool_calls"], "chunk": chunk}
|
||||
return {
|
||||
"type": "tool_calls",
|
||||
"data": delta["tool_calls"],
|
||||
"chunk": chunk,
|
||||
}
|
||||
|
||||
# Process function calls
|
||||
if delta.get("function_call"):
|
||||
return {"type": "function_call", "data": delta["function_call"], "chunk": chunk}
|
||||
return {
|
||||
"type": "function_call",
|
||||
"data": delta["function_call"],
|
||||
"chunk": chunk,
|
||||
}
|
||||
|
||||
# Process finish reason
|
||||
if choice.get("finish_reason"):
|
||||
return {"type": "finish", "data": choice["finish_reason"], "chunk": chunk}
|
||||
return {
|
||||
"type": "finish",
|
||||
"data": choice["finish_reason"],
|
||||
"chunk": chunk,
|
||||
}
|
||||
|
||||
return {}
|
||||
except Exception as e:
|
||||
|
@ -163,22 +197,26 @@ class StreamHandler:
|
|||
@staticmethod
|
||||
def _validate_json_object(
|
||||
response_format: Optional[Union[Type[T], Type[Iterable[T]]]],
|
||||
json_string_buffer: str
|
||||
json_string_buffer: str,
|
||||
):
|
||||
try:
|
||||
model_class = get_args(response_format)[0]
|
||||
# Return current tool call
|
||||
structured_output = StructureHandler.validate_response(json_string_buffer, model_class)
|
||||
structured_output = StructureHandler.validate_response(
|
||||
json_string_buffer, model_class
|
||||
)
|
||||
if isinstance(structured_output, model_class):
|
||||
logger.info("Structured output was successfully validated.")
|
||||
yield {"type": "structured_output", "data": structured_output}
|
||||
except ValidationError as validation_error:
|
||||
logger.error(f"Validation error: {validation_error} with JSON: {json_string_buffer}")
|
||||
logger.error(
|
||||
f"Validation error: {validation_error} with JSON: {json_string_buffer}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_final_tool_calls(
|
||||
tool_calls: Dict[int, Any],
|
||||
response_format: Optional[Union[Type[T], Type[Iterable[T]]]]
|
||||
response_format: Optional[Union[Type[T], Type[Iterable[T]]]],
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Yield final tool calls after processing.
|
||||
|
@ -192,7 +230,9 @@ class StreamHandler:
|
|||
"""
|
||||
for tool in tool_calls.values():
|
||||
if response_format and isinstance(response_format, Iterable) is False:
|
||||
structured_output = StructureHandler.validate_response(tool["function"]["arguments"], response_format)
|
||||
structured_output = StructureHandler.validate_response(
|
||||
tool["function"]["arguments"], response_format
|
||||
)
|
||||
if isinstance(structured_output, response_format):
|
||||
logger.info("Structured output was successfully validated.")
|
||||
yield {"type": "structured_output", "data": structured_output}
|
||||
|
|
|
@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class StructureHandler:
|
||||
@staticmethod
|
||||
def is_json_string(input_string: str) -> bool:
|
||||
|
@ -50,7 +51,9 @@ class StructureHandler:
|
|||
if origin in (list, List, tuple, Iterable) and args:
|
||||
item_type = args[0]
|
||||
if isinstance(item_type, type) and issubclass(item_type, BaseModel):
|
||||
logger.debug("Detected iterable of BaseModel. Wrapping in generated Pydantic model.")
|
||||
logger.debug(
|
||||
"Detected iterable of BaseModel. Wrapping in generated Pydantic model."
|
||||
)
|
||||
return StructureHandler.create_iterable_model(item_type)
|
||||
|
||||
return tp
|
||||
|
@ -60,7 +63,7 @@ class StructureHandler:
|
|||
response_format: Union[Type[T], Dict[str, Any], Iterable[Type[T]]],
|
||||
llm_provider: str,
|
||||
structured_mode: Literal["json", "function_call"] = "json",
|
||||
**params
|
||||
**params,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a structured request that conforms to a specified API format using the given Pydantic model.
|
||||
|
@ -93,11 +96,15 @@ class StructureHandler:
|
|||
if structured_mode == "function_call":
|
||||
model_cls = StructureHandler.resolve_response_model(response_format)
|
||||
if not model_cls:
|
||||
raise TypeError("function_call mode requires a single, unambiguous Pydantic model.")
|
||||
raise TypeError(
|
||||
"function_call mode requires a single, unambiguous Pydantic model."
|
||||
)
|
||||
|
||||
name = model_cls.__name__
|
||||
description = model_cls.__doc__ or ""
|
||||
model_tool_format = to_function_call_definition(name, description, model_cls, llm_provider)
|
||||
model_tool_format = to_function_call_definition(
|
||||
name, description, model_cls, llm_provider
|
||||
)
|
||||
|
||||
params["tools"] = [model_tool_format]
|
||||
params["tool_choice"] = {
|
||||
|
@ -108,13 +115,17 @@ class StructureHandler:
|
|||
|
||||
elif structured_mode == "json":
|
||||
try:
|
||||
logger.debug(f"generate_request called with type={type(response_format)}, mode={structured_mode}, provider={llm_provider}")
|
||||
logger.debug(
|
||||
f"generate_request called with type={type(response_format)}, mode={structured_mode}, provider={llm_provider}"
|
||||
)
|
||||
# If it's a dict, assume it's already a JSON schema; otherwise, try to create from model
|
||||
if isinstance(response_format, dict):
|
||||
raw_schema = response_format
|
||||
name = response_format.get("name", "custom_schema")
|
||||
description = response_format.get("description")
|
||||
elif isinstance(response_format, type) and issubclass(response_format, BaseModel):
|
||||
elif isinstance(response_format, type) and issubclass(
|
||||
response_format, BaseModel
|
||||
):
|
||||
raw_schema = response_format.model_json_schema()
|
||||
name = response_format.__name__
|
||||
description = response_format.__doc__
|
||||
|
@ -130,16 +141,22 @@ class StructureHandler:
|
|||
name=name,
|
||||
description=description,
|
||||
schema_=strict_schema,
|
||||
strict=True
|
||||
strict=True,
|
||||
)
|
||||
|
||||
# Wrap it in the top-level response format object
|
||||
response_format_obj = OAIResponseFormatSchema(json_schema=json_schema_obj)
|
||||
response_format_obj = OAIResponseFormatSchema(
|
||||
json_schema=json_schema_obj
|
||||
)
|
||||
|
||||
logger.debug(f"Generated JSON schema: {response_format_obj.model_dump()}")
|
||||
logger.debug(
|
||||
f"Generated JSON schema: {response_format_obj.model_dump()}"
|
||||
)
|
||||
|
||||
# Use model_dump() to serialize the response format into a dictionary
|
||||
params["response_format"] = response_format_obj.model_dump(by_alias=True)
|
||||
params["response_format"] = response_format_obj.model_dump(
|
||||
by_alias=True
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error in JSON schema: {e}")
|
||||
|
@ -148,13 +165,15 @@ class StructureHandler:
|
|||
return params
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported structured_mode: {structured_mode}. Must be 'json' or 'function_call'.")
|
||||
raise ValueError(
|
||||
f"Unsupported structured_mode: {structured_mode}. Must be 'json' or 'function_call'."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_iterable_model(
|
||||
model: Type[BaseModel],
|
||||
model_name: Optional[str] = None,
|
||||
model_description: Optional[str] = None
|
||||
model_description: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
"""
|
||||
Constructs an iterable Pydantic model for a given Pydantic model.
|
||||
|
@ -172,13 +191,11 @@ class StructureHandler:
|
|||
|
||||
objects_field = (
|
||||
List[model],
|
||||
Field(..., description=f"A list of `{model_name}` objects")
|
||||
Field(..., description=f"A list of `{model_name}` objects"),
|
||||
)
|
||||
|
||||
iterable_model = create_model(
|
||||
iterable_model_name,
|
||||
objects=objects_field,
|
||||
__base__=(BaseModel,)
|
||||
iterable_model_name, objects=objects_field, __base__=(BaseModel,)
|
||||
)
|
||||
|
||||
iterable_model.__doc__ = (
|
||||
|
@ -193,7 +210,7 @@ class StructureHandler:
|
|||
def extract_structured_response(
|
||||
response: Any,
|
||||
llm_provider: str,
|
||||
structured_mode: Literal["json", "function_call"] = "json"
|
||||
structured_mode: Literal["json", "function_call"] = "json",
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""
|
||||
Extracts the structured JSON string or content from the response.
|
||||
|
@ -228,7 +245,9 @@ class StructureHandler:
|
|||
function = getattr(tool_calls[0], "function", None)
|
||||
if function and hasattr(function, "arguments"):
|
||||
extracted_response = function.arguments
|
||||
logger.debug(f"Extracted function-call response: {extracted_response}")
|
||||
logger.debug(
|
||||
f"Extracted function-call response: {extracted_response}"
|
||||
)
|
||||
return extracted_response
|
||||
raise StructureError("No tool_calls found for function_call mode.")
|
||||
|
||||
|
@ -237,7 +256,9 @@ class StructureHandler:
|
|||
refusal = getattr(message, "refusal", None)
|
||||
|
||||
if refusal:
|
||||
logger.warning(f"Model refused to fulfill the request: {refusal}")
|
||||
logger.warning(
|
||||
f"Model refused to fulfill the request: {refusal}"
|
||||
)
|
||||
raise StructureError(f"Request refused by the model: {refusal}")
|
||||
|
||||
if not content:
|
||||
|
@ -247,7 +268,9 @@ class StructureHandler:
|
|||
return content
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported structured_mode: {structured_mode}. Must be 'json' or 'function_call'.")
|
||||
raise ValueError(
|
||||
f"Unsupported structured_mode: {structured_mode}. Must be 'json' or 'function_call'."
|
||||
)
|
||||
else:
|
||||
raise StructureError(f"Unsupported LLM provider: {llm_provider}")
|
||||
except Exception as e:
|
||||
|
@ -309,13 +332,18 @@ class StructureHandler:
|
|||
raise ValueError(f"Reference '{ref_name}' not found in $defs.")
|
||||
|
||||
# Merge the referenced schema with the current part, resolving nested $refs
|
||||
merged = {**defs_section[ref_name], **{k: v for k, v in part.items() if k != "$ref"}}
|
||||
merged = {
|
||||
**defs_section[ref_name],
|
||||
**{k: v for k, v in part.items() if k != "$ref"},
|
||||
}
|
||||
return StructureHandler.expand_local_refs(merged, root)
|
||||
|
||||
# Process objects and their properties
|
||||
if part.get("type") == "object" and "properties" in part:
|
||||
for key, value in part["properties"].items():
|
||||
part["properties"][key] = StructureHandler.expand_local_refs(value, root)
|
||||
part["properties"][key] = StructureHandler.expand_local_refs(
|
||||
value, root
|
||||
)
|
||||
|
||||
# Process arrays and their items
|
||||
if part.get("type") == "array" and "items" in part:
|
||||
|
@ -324,7 +352,10 @@ class StructureHandler:
|
|||
# Process anyOf and allOf schemas
|
||||
for key in ("anyOf", "allOf"):
|
||||
if key in part and isinstance(part[key], list):
|
||||
part[key] = [StructureHandler.expand_local_refs(subschema, root) for subschema in part[key]]
|
||||
part[key] = [
|
||||
StructureHandler.expand_local_refs(subschema, root)
|
||||
for subschema in part[key]
|
||||
]
|
||||
|
||||
return part
|
||||
|
||||
|
@ -357,7 +388,9 @@ class StructureHandler:
|
|||
required_fields = set(schema.get("required", []))
|
||||
|
||||
for key, value in schema.get("properties", {}).items():
|
||||
schema["properties"][key] = StructureHandler.enforce_strict_json_schema(value)
|
||||
schema["properties"][key] = StructureHandler.enforce_strict_json_schema(
|
||||
value
|
||||
)
|
||||
|
||||
# Remove default values (not allowed by OpenAI)
|
||||
schema["properties"][key].pop("default", None)
|
||||
|
@ -366,22 +399,34 @@ class StructureHandler:
|
|||
if key not in required_fields:
|
||||
field_type = schema["properties"][key].get("type")
|
||||
|
||||
if field_type and not isinstance(field_type, list): # Ensure it's not already `anyOf`
|
||||
if field_type and not isinstance(
|
||||
field_type, list
|
||||
): # Ensure it's not already `anyOf`
|
||||
if field_type in ["string", "integer", "number"]:
|
||||
schema["properties"][key]["anyOf"] = [{"type": field_type}, {"type": "null"}]
|
||||
schema["properties"][key].pop("type", None) # Remove direct "type" field
|
||||
schema["properties"][key]["anyOf"] = [
|
||||
{"type": field_type},
|
||||
{"type": "null"},
|
||||
]
|
||||
schema["properties"][key].pop(
|
||||
"type", None
|
||||
) # Remove direct "type" field
|
||||
|
||||
# Ensure field is included in "required" (even if it allows null)
|
||||
required_fields.add(key)
|
||||
|
||||
# Handle optional arrays inside object properties
|
||||
if schema["properties"][key].get("anyOf") and isinstance(schema["properties"][key]["anyOf"], list):
|
||||
if schema["properties"][key].get("anyOf") and isinstance(
|
||||
schema["properties"][key]["anyOf"], list
|
||||
):
|
||||
for subschema in schema["properties"][key]["anyOf"]:
|
||||
if subschema.get("type") == "array":
|
||||
schema["properties"][key] = {
|
||||
"anyOf": [
|
||||
{"type": "array", "items": subschema.get("items", {})},
|
||||
{"type": "null"}
|
||||
{
|
||||
"type": "array",
|
||||
"items": subschema.get("items", {}),
|
||||
},
|
||||
{"type": "null"},
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -394,28 +439,39 @@ class StructureHandler:
|
|||
if "items" not in schema:
|
||||
raise ValueError(f"Array schema missing 'items': {schema}")
|
||||
|
||||
schema["items"] = StructureHandler.enforce_strict_json_schema(schema["items"])
|
||||
schema["items"] = StructureHandler.enforce_strict_json_schema(
|
||||
schema["items"]
|
||||
)
|
||||
|
||||
# Convert optional arrays from `anyOf` to `anyOf: [{"type": "array", "items": T}, {"type": "null"}]`
|
||||
if "anyOf" in schema and isinstance(schema["anyOf"], list):
|
||||
if any(subschema.get("type") == "array" for subschema in schema["anyOf"]):
|
||||
if any(
|
||||
subschema.get("type") == "array" for subschema in schema["anyOf"]
|
||||
):
|
||||
schema["anyOf"] = [
|
||||
{"type": "array", "items": schema["items"]},
|
||||
{"type": "null"}
|
||||
{"type": "null"},
|
||||
]
|
||||
schema.pop("type", None) # Remove direct "type" field
|
||||
schema.pop("minItems", None) # Remove `minItems`, not needed with null
|
||||
schema.pop(
|
||||
"minItems", None
|
||||
) # Remove `minItems`, not needed with null
|
||||
|
||||
# Process $defs and remove after expansion
|
||||
if "$defs" in schema:
|
||||
for def_name, def_schema in schema["$defs"].items():
|
||||
schema["$defs"][def_name] = StructureHandler.enforce_strict_json_schema(def_schema)
|
||||
schema["$defs"][def_name] = StructureHandler.enforce_strict_json_schema(
|
||||
def_schema
|
||||
)
|
||||
schema.pop("$defs", None)
|
||||
|
||||
# Process anyOf and allOf schemas recursively
|
||||
for key in ("anyOf", "allOf"):
|
||||
if key in schema and isinstance(schema[key], list):
|
||||
schema[key] = [StructureHandler.enforce_strict_json_schema(subschema) for subschema in schema[key]]
|
||||
schema[key] = [
|
||||
StructureHandler.enforce_strict_json_schema(subschema)
|
||||
for subschema in schema[key]
|
||||
]
|
||||
|
||||
return schema
|
||||
|
||||
|
@ -453,7 +509,9 @@ class StructureHandler:
|
|||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"[resolve] Skipping non-class inner: {inner} ({type(inner)})")
|
||||
logger.debug(
|
||||
f"[resolve] Skipping non-class inner: {inner} ({type(inner)})"
|
||||
)
|
||||
|
||||
if origin is Union:
|
||||
for arg in args:
|
||||
|
@ -492,7 +550,9 @@ class StructureHandler:
|
|||
elif len(models) == 0:
|
||||
return None # No model = primitive or unsupported type → silently skip
|
||||
else:
|
||||
logger.warning(f"Ambiguous model resolution: found multiple models in {tp}. Returning None.")
|
||||
logger.warning(
|
||||
f"Ambiguous model resolution: found multiple models in {tp}. Returning None."
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
@ -523,7 +583,10 @@ class StructureHandler:
|
|||
for model_cls in models:
|
||||
try:
|
||||
if isinstance(result, list):
|
||||
return [StructureHandler.validate_response(item, model_cls).model_dump() for item in result]
|
||||
return [
|
||||
StructureHandler.validate_response(item, model_cls).model_dump()
|
||||
for item in result
|
||||
]
|
||||
else:
|
||||
validated = StructureHandler.validate_response(result, model_cls)
|
||||
return validated.model_dump()
|
||||
|
|
|
@ -3,6 +3,7 @@ from pydantic import BaseModel, ConfigDict
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class MemoryBase(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for managing message memory. This class defines a standard interface for memory operations,
|
||||
|
@ -38,7 +39,9 @@ class MemoryBase(BaseModel, ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_interaction(self, user_message: BaseMessage, assistant_message: BaseMessage):
|
||||
def add_interaction(
|
||||
self, user_message: BaseMessage, assistant_message: BaseMessage
|
||||
):
|
||||
"""
|
||||
Adds a user-assistant interaction to the memory storage.
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue