Compare commits

...

27 Commits
v0.5.0 ... main

Author SHA1 Message Date
Filinto Duran c2eff2b971
use glob pattern for docs (#134)
* update, remove using library for mkdocs, use glob properly, do not run deploy on PR

Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com>

* simplify if

Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com>

---------

Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com>
2025-06-12 13:45:31 -07:00
Sam c4b1f7c441
fix: some random quickstart issues (#127)
* fix(quickstarts): general fixes for DX on quickstarts

Signed-off-by: Samantha Coyle <sam@diagrid.io>

* style: clean this up for now

Signed-off-by: Samantha Coyle <sam@diagrid.io>

---------

Signed-off-by: Samantha Coyle <sam@diagrid.io>
2025-06-11 11:30:46 -07:00
Marc Duiker f87e27f450
Fix docs_any_changed (#123)
Signed-off-by: Marc Duiker <marcduiker@users.noreply.github.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
2025-05-29 13:22:51 -07:00
Bilgin Ibryam 1e5275834d
Dapr university preparation (#122)
* Created an example with dapr-based conversation history agent

Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>

* Added a new example with AssistantAgent

Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>

* Fix for linting errors

Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>

---------

Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>
Co-authored-by: Marc Duiker <marcduiker@users.noreply.github.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
2025-05-29 13:19:28 -07:00
Bilgin Ibryam b7b4a9891e
Reorganize MCP quickstart examples and add SSE implementation (#120)
Updated references to align with namechange



Fix code formatting with ruff

Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
2025-05-29 13:16:16 -07:00
Yaron Schneider 2fd44b3ecc
fix code of conduct (#125)
Signed-off-by: yaron2 <schneider.yaron@live.com>
2025-05-23 12:35:31 -07:00
Yaron Schneider 2757aab5b6
Revert "Reorganize quickstart examples and added MCP implementation" (#124) 2025-05-23 08:13:19 -07:00
Marc Duiker d86a4c5a70
Merge pull request #112 from dapr/mcp-quickstart
Reorganize quickstart examples and added MCP implementation
2025-05-23 16:25:23 +02:00
Bilgin Ibryam 83fc449e39 Fix code formatting with ruff 2025-05-20 16:55:20 +01:00
Yaron Schneider 94bf5d2a38
Merge branch 'main' into mcp-quickstart 2025-05-12 08:45:41 -07:00
Bilgin Ibryam 8741289e7d
Added brief MCP support explanation (#116)
* Added brief MCP support explanation

Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>

* Update docs/concepts/agents.md

Co-authored-by: Casper Nielsen <scni@novonordisk.com>
Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>

---------

Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>
Co-authored-by: Casper Nielsen <scni@novonordisk.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
2025-05-12 08:44:19 -07:00
Casper Nielsen 41faa4f5b7
Chore: Bump mcp == 1.7.1 (#114)
Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>
2025-05-08 15:27:21 -07:00
Bilgin Ibryam 76ad962b69
Updated references to align with namechange
Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>
2025-05-03 10:02:00 +01:00
Bilgin Ibryam 28ac198055
Reorganize MCP quickstart examples and add SSE implementation
Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>
2025-05-02 21:53:57 +01:00
Casper Nielsen e27f5befb0
Fix/81 dapr http endpoint (#107)
* Fix: #81 by implementing custom tool for wrapping HTTP calling

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Ensure _base_url is passed if FQDN

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Explicit exclude the other tools from type checking so we can check http tool

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct arg-type for url

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ruff formatting

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Update deps

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Feat: First draft impl. of OTel

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F811

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: import

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct import and pass var

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Instantiation instead of model

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Circular ref

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Remove specific tracer as lib should pick it up & add logger

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Attempt to aquire tracer and use logger

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Set the logger

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Allow passing custom endpoint per provider

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Set tmp correct logger port

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct sending json encoded

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Ensure requests client always run http

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct tmp port for logger

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Feat: Ensure /v1/[traces|metrics|logs] always in otlp_endpoint

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Don't capitalize

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Ensure we check for v1 and set if not

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ensure http always

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Transition otel to tools.utils

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct import

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Feat: Clean init of provider with validator func

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: conform to new validator

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Remove not needed endpoint pass to logger client

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F821

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Remove redundant check

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Improve validator func for less code

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Remove unused import

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Feat: Refine checks and ensure resiliency on url creation

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Add parsing of DAPR_AGENTS_OTEL_ENABLED to disable OTel from client

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Convert to reusable function passing the http verb

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct string to upper rather than capitalize

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Bump python version on build agent as 3.9 don't contain switch statement

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Pass version as str

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Move otel into dapr_agents.agent.telemetry

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Include errors from dapr_agents.agent.telemetry

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct import of DaprAgentsOTel

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Include type-check on http

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Proper name of DaprAgentsOTel

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct imports

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Raising ToolError results in workflow breaking rather than reiterating the tool request

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: .json() not guaranteed to hold a value

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Remove unused import

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Decorate the tool

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Feat: Better naming convention for docstring

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Feat: Export tool decoration for consumption

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Revert

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Import & formatting

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

---------

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
2025-05-02 09:04:29 -07:00
Bilgin Ibryam 889b7bf7ef
Fixed docs, and added build and social media links to the readme (#110)
Signed-off-by: Bilgin Ibryam <bibryam@gmail.com>
2025-05-01 09:08:45 -07:00
Yaron Schneider 4dce1c0300
Add quickstart for a knowledge base agent over Postgres + MCP + Chainlit (#103)
* initial commit

Signed-off-by: yaron2 <schneider.yaron@live.com>

* add quickstart for a postgres agent with mcp

Signed-off-by: yaron2 <schneider.yaron@live.com>

* linter

Signed-off-by: yaron2 <schneider.yaron@live.com>

* linter

Signed-off-by: yaron2 <schneider.yaron@live.com>

* review feedback

Signed-off-by: yaron2 <schneider.yaron@live.com>

* changed docker instructions

Signed-off-by: yaron2 <schneider.yaron@live.com>

* Update README.md

Signed-off-by: Yaron Schneider <schneider.yaron@live.com>

---------

Signed-off-by: yaron2 <schneider.yaron@live.com>
Signed-off-by: Yaron Schneider <schneider.yaron@live.com>
2025-04-30 08:14:43 -07:00
Casper Nielsen 53c1c9ffde
Fix: ref to 07 in compose file (#106)
Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>
2025-04-28 15:49:54 -07:00
Yaron Schneider 6f20c0d9a0
Update README.md (#102)
Signed-off-by: Yaron Schneider <schneider.yaron@live.com>
2025-04-26 06:45:47 -07:00
Casper Nielsen 6823cd633d
Feat/k8s deployment (#69)
* Updating requirements

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Add docker compose file for image building

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Add manifests for k8s deployment

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Add dockerfiles - @jgmartinez

Co-authored-by: Juan González <38658722+jgmartinez@users.noreply.github.com>
Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Add client for triggering

Co-authored-by: Juan González <38658722+jgmartinez@users.noreply.github.com>
Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Add install script and readme

Co-authored-by: Juan González <38658722+jgmartinez@users.noreply.github.com>
Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Add prerequisites for kind, docker and helm

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Add step for building

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* chore: Update to latest version of dapr_agents

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Remove f string where not needed

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: ruff formatting

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F841

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Rename quickstarter to 07-

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

---------

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>
Co-authored-by: Juan González <38658722+jgmartinez@users.noreply.github.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
2025-04-25 22:09:17 -07:00
Yaron Schneider a878e76ec1
Fix grpc connection lifecycle with daprclient (#101)
* fix grpc connection lifecycle with daprclient

Signed-off-by: yaron2 <schneider.yaron@live.com>

* fix linter

Signed-off-by: yaron2 <schneider.yaron@live.com>

---------

Signed-off-by: yaron2 <schneider.yaron@live.com>
2025-04-24 13:31:30 -07:00
Roberto Rodriguez 75274ac607
Adding DaprWorkflowContext from dapr.ext.workflow (#99) 2025-04-24 04:05:43 -07:00
Casper Nielsen f129754486
Fix/30 add linter action (#95)
* Fix: Fix Setup lint GitHub action #30

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Remove branch filter on PR and remove on push

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Remove on mergequeue

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Add tox.ini file

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Return on push

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: tox -e ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Ignore .ruff_cache

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Update tox file

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Add mypy.ini

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Ignore if line is too long

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Set the ignore in command instead

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: W503

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: 541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: W503

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ignore F401, unused imports as __init__ files has them

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Return linebreak as tox -e ruff yields that

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Ignore W503 as ruff introduces it

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F841

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: E203

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: W293

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: W291

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: E203

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: E203

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: W291

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F811

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F841

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F811

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F841

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F811

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: W291

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F811

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Ruff want's the space before :

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Ignore space before :

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: E291

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Add dev-requirements.txt

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Correct python version

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Ref dev-requirements.txt

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Add mypy cache dir

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Update mypy version

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Fix: Exclude cookbook and quicstarts

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Remove unused import

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Add specific sub module ignore on error for future smaller fixing

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Reintroduce branches filter on push and pull_request

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* chore: Ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: ruff formatting

* Chore: F541

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: E401

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: F811

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: F841

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: Ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: E711

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

* Chore: ruff

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>

---------

Signed-off-by: Casper Guldbech Nielsen <scni@novonordisk.com>
2025-04-23 22:58:48 -07:00
Roberto Rodriguez c31e985d81
Merge pull request #97 from dapr/cyb3rward0g/update-local-executor
Executors: Sandbox support + per-project bootstrap + full refactor
2025-04-23 23:13:52 -04:00
Roberto Rodriguez f9eb48c02c added cookbook to show example 2025-04-23 18:34:21 -04:00
Roberto Rodriguez 6f0cfc8818 LocalExecutor sandbox support + per-project bootstrap + full refactor 2025-04-23 18:34:02 -04:00
Yaron Schneider fd28b02935
Add document agent+chainlit quickstart (#96)
* add document agent+chainlit quickstart

Signed-off-by: yaron2 <schneider.yaron@live.com>

* add upload response

Signed-off-by: yaron2 <schneider.yaron@live.com>

---------

Signed-off-by: yaron2 <schneider.yaron@live.com>
2025-04-22 21:41:11 -07:00
322 changed files with 10421 additions and 3197 deletions

65
.github/workflows/build.yaml vendored Normal file
View File

@ -0,0 +1,65 @@
name: Lint and Build
on:
push:
branches:
- feature/*
- feat/*
- bugfix/*
- hotfix/*
- fix/*
pull_request:
branches:
- main
- feature/*
- release-*
workflow_dispatch:
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel tox
- name: Run Autoformatter
run: |
tox -e ruff
statusResult=$(git status -u --porcelain)
if [ -z $statusResult ]
then
exit 0
else
echo "Source files are not formatted correctly. Run 'tox -e ruff' to autoformat."
exit 1
fi
- name: Run Linter
run: |
tox -e flake8
build:
needs: lint
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_ver: ["3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python_ver }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_ver }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel tox
- name: Check Typing
run: |
tox -e type

View File

@ -4,12 +4,12 @@ on:
branches: branches:
- main - main
paths: paths:
- docs - docs/**
pull_request: pull_request:
branches: branches:
- main - main
paths: paths:
- docs - docs/**
workflow_dispatch: workflow_dispatch:
permissions: permissions:
contents: write contents: write
@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: Review changed files name: Review changed files
outputs: outputs:
docs_any_changed: NaN docs_any_changed: ${{ steps.changed-files.outputs.docs_any_changed }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Get changed files - name: Get changed files
@ -42,10 +42,16 @@ jobs:
- name: Remove plugins from mkdocs configuration - name: Remove plugins from mkdocs configuration
run: | run: |
sed -i '/^plugins:/,/^[^ ]/d' mkdocs.yml sed -i '/^plugins:/,/^[^ ]/d' mkdocs.yml
- name: Run MkDocs build - name: Install Python dependencies
uses: Kjuly/mkdocs-page-builder@main run: |
pip install mkdocs-material
pip install .[recommended,git,imaging]
pip install mkdocs-jupyter
- name: Validate build
run: mkdocs build
deploy: deploy:
if: github.ref == 'refs/heads/main'
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: documentation_validation needs: documentation_validation
steps: steps:
@ -53,7 +59,7 @@ jobs:
- uses: actions/setup-python@v5 - uses: actions/setup-python@v5
with: with:
python-version: 3.x python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
key: mkdocs-material-${{ env.cache_id }} key: mkdocs-material-${{ env.cache_id }}

4
.gitignore vendored
View File

@ -165,3 +165,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
.idea .idea
.ruff_cache/
quickstarts/05-multi-agent-workflow-dapr-workflows/services/**/*_state.json

View File

@ -1,86 +0,0 @@
# Code of Conduct
We are committed to fostering a welcoming, inclusive, and respectful environment for everyone involved in this project. This Code of Conduct outlines the expected behaviors within our community and the steps for reporting unacceptable actions. By participating, you agree to uphold these standards, helping to create a positive and collaborative space.
---
## Our Pledge
As members, contributors, and leaders of this community, we pledge to:
* Ensure participation in our project is free from harassment, discrimination, or exclusion.
* Treat everyone with respect and empathy, regardless of factors such as age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity or expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual orientation.
* Act in ways that contribute to a safe, welcoming, and supportive environment for all participants.
---
## Our Standards
We strive to create an environment where all members can thrive. Examples of positive behaviors include:
* Showing kindness, empathy, and consideration for others.
* Being respectful of differing opinions, experiences, and perspectives.
* Providing constructive feedback in a supportive manner.
* Taking responsibility for mistakes, apologizing when necessary, and learning from experiences.
* Prioritizing the success and well-being of the entire community over individual gains.
The following behaviors are considered unacceptable:
* Using sexualized language or imagery, or engaging in inappropriate sexual attention or advances.
* Making insulting, derogatory, or inflammatory comments, including trolling or personal attacks.
* Engaging in harassment, whether public or private.
* Publishing private or sensitive information about others without explicit consent.
* Engaging in behavior that disrupts discussions, events, or contributions in a negative way.
* Any conduct that could reasonably be deemed unprofessional or harmful to others.
---
## Scope
This Code of Conduct applies to all areas of interaction within the community, including but not limited to:
* Discussions on forums, repositories, or other official communication channels.
* Contributions made to the project, such as code, documentation, or issues.
* Public representation of the community, such as through official social media accounts or at events.
It also applies to actions outside these spaces if they negatively impact the health, safety, or inclusivity of the community.
---
## Enforcement Responsibilities
Community leaders are responsible for ensuring that this Code of Conduct is upheld. They may take appropriate and fair corrective actions in response to any behavior that violates these standards, including:
* Removing, editing, or rejecting comments, commits, issues, or other contributions not aligned with the Code of Conduct.
* Temporarily or permanently banning individuals for repeated or severe violations.
Leaders will always strive to communicate their decisions clearly and fairly.
---
## Reporting Issues
If you experience or witness unacceptable behavior, please report it to the project's owner [Roberto Rodriguez](https://www.linkedin.com/in/cyb3rward0g/). Your report will be handled with sensitivity, and we will respect your privacy and confidentiality while addressing the issue.
When reporting, please include:
* A description of the incident.
* When and where it occurred.
* Any additional context or supporting evidence, if available.
---
## Enforcement Process
We encourage resolving issues through dialogue when possible, but community leaders will intervene when necessary. Actions may include warnings, temporary bans, or permanent removal from the community, depending on the severity of the behavior.
---
## Attribution
This Code of Conduct is inspired by the [Contributor Covenant, version 2.0](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html) and has drawn inspiration from open source community guidelines by Microsoft, Mozilla, and others.
For further context on best practices for open source codes of conduct, see the [Contributor Covenant FAQ](https://www.contributor-covenant.org/faq).
---
Thank you for helping to create a positive environment! ❤️

View File

@ -1,5 +1,13 @@
# Dapr Agents: A Framework for Agentic AI Systems # Dapr Agents: A Framework for Agentic AI Systems
[![PyPI - Version](https://img.shields.io/pypi/v/dapr-agents?style=flat&logo=pypi&logoColor=white&label=Latest%20version)](https://pypi.org/project/dapr-agents/)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/dapr-agents?style=flat&logo=pypi&logoColor=white&label=Downloads)](https://pypi.org/project/dapr-agents/)
[![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/dapr/dapr-agents/.github%2Fworkflows%2Fbuild.yaml?branch=main&label=Build&logo=github)](https://github.com/dapr/dapr-agents/actions/workflows/build.yaml)
[![GitHub License](https://img.shields.io/github/license/dapr/dapr-agents?style=flat&label=License&logo=github)](https://github.com/dapr/dapr-agents/blob/main/LICENSE)
[![Discord](https://img.shields.io/discord/778680217417809931?label=Discord&style=flat&logo=discord)](http://bit.ly/dapr-discord)
[![YouTube Channel Views](https://img.shields.io/youtube/channel/views/UCtpSQ9BLB_3EXdWAUQYwnRA?style=flat&label=YouTube%20views&logo=youtube)](https://youtube.com/@daprdev)
[![X (formerly Twitter) Follow](https://img.shields.io/twitter/follow/daprdev?logo=x&style=flat)](https://twitter.com/daprdev)
Dapr Agents is a developer framework designed to build production-grade resilient AI agent systems that operate at scale. Built on top of the battle-tested Dapr project, it enables software developers to create AI agents that reason, act, and collaborate using Large Language Models (LLMs), while leveraging built-in observability and stateful workflow execution to guarantee agentic workflows complete successfully, no matter how complex. Dapr Agents is a developer framework designed to build production-grade resilient AI agent systems that operate at scale. Built on top of the battle-tested Dapr project, it enables software developers to create AI agents that reason, act, and collaborate using Large Language Models (LLMs), while leveraging built-in observability and stateful workflow execution to guarantee agentic workflows complete successfully, no matter how complex.
![](./docs/img/logo-workflows.png) ![](./docs/img/logo-workflows.png)
@ -60,8 +68,8 @@ As a part of **CNCF**, Dapr Agents is vendor-neutral, eliminating concerns about
Here are some of the major features we're working on for the current quarter: Here are some of the major features we're working on for the current quarter:
### Q2 2024 ### Q2 2025
- **MCP Support** - Integration with Anthropic's MCP platform ([#50](https://github.com/dapr/dapr-agents/issues/50)) - **MCP Support** - Integration with Anthropic's MCP platform ([#50](https://github.com/dapr/dapr-agents/issues/50))
- **Agent Interaction Tracing** - Enhanced observability of agent interactions with LLMs and tools ([#79](https://github.com/dapr/dapr-agents/issues/79)) - **Agent Interaction Tracing** - Enhanced observability of agent interactions with LLMs and tools ([#79](https://github.com/dapr/dapr-agents/issues/79))
- **Streaming LLM Output** - Real-time streaming capabilities for LLM responses ([#80](https://github.com/dapr/dapr-agents/issues/80)) - **Streaming LLM Output** - Real-time streaming capabilities for LLM responses ([#80](https://github.com/dapr/dapr-agents/issues/80))
- **HTTP Endpoint Tools** - Support for using Dapr's HTTP endpoint capabilities for tool calling ([#81](https://github.com/dapr/dapr-agents/issues/81)) - **HTTP Endpoint Tools** - Support for using Dapr's HTTP endpoint capabilities for tool calling ([#81](https://github.com/dapr/dapr-agents/issues/81))

View File

@ -4,9 +4,10 @@ from datetime import datetime
import requests import requests
import time import time
class WeatherForecast(AgentTool): class WeatherForecast(AgentTool):
name: str = 'WeatherForecast' name: str = "WeatherForecast"
description: str = 'A tool for retrieving the weather/temperature for a given city.' description: str = "A tool for retrieving the weather/temperature for a given city."
# Default user agent # Default user agent
user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15" user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15"
@ -23,24 +24,26 @@ class WeatherForecast(AgentTool):
f"No data found during {stage}. URL: {url}. Response: {response.text}" f"No data found during {stage}. URL: {url}. Response: {response.text}"
) )
def _run(self, city: str, state: Optional[str] = None, country: Optional[str] = "usa") -> dict: def _run(
self, city: str, state: Optional[str] = None, country: Optional[str] = "usa"
) -> dict:
""" """
Retrieves weather data by first fetching geocode data for the city and then fetching weather data. Retrieves weather data by first fetching geocode data for the city and then fetching weather data.
Args: Args:
city (str): The name of the city to get weather for. city (str): The name of the city to get weather for.
state (Optional[str]): The two-letter state abbreviation (optional). state (Optional[str]): The two-letter state abbreviation (optional).
country (Optional[str]): The two-letter country abbreviation. Defaults to 'usa'. country (Optional[str]): The two-letter country abbreviation. Defaults to 'usa'.
Returns: Returns:
dict: A dictionary containing the city, state, country, and current temperature. dict: A dictionary containing the city, state, country, and current temperature.
""" """
headers = { headers = {"User-Agent": self.user_agent}
"User-Agent": self.user_agent
}
# Construct the geocode URL, conditionally including the state if it's provided # Construct the geocode URL, conditionally including the state if it's provided
geocode_url = f"https://nominatim.openstreetmap.org/search?city={city}&country={country}" geocode_url = (
f"https://nominatim.openstreetmap.org/search?city={city}&country={country}"
)
if state: if state:
geocode_url += f"&state={state}" geocode_url += f"&state={state}"
geocode_url += "&limit=1&format=jsonv2" geocode_url += "&limit=1&format=jsonv2"
@ -64,7 +67,7 @@ class WeatherForecast(AgentTool):
# Add delay between requests # Add delay between requests
time.sleep(2) time.sleep(2)
weather_data = weather_response.json() weather_data = weather_response.json()
forecast_url = weather_data["properties"]["forecast"] forecast_url = weather_data["properties"]["forecast"]
@ -81,7 +84,7 @@ class WeatherForecast(AgentTool):
"state": state, "state": state,
"country": country, "country": country,
"temperature": today_forecast["temperature"], "temperature": today_forecast["temperature"],
"unit": "Fahrenheit" "unit": "Fahrenheit",
} }
else: else:
@ -91,8 +94,12 @@ class WeatherForecast(AgentTool):
self.handle_error(weather_response, met_no_url, "Met.no weather lookup") self.handle_error(weather_response, met_no_url, "Met.no weather lookup")
weather_data = weather_response.json() weather_data = weather_response.json()
temperature_unit = weather_data["properties"]["meta"]["units"]["air_temperature"] temperature_unit = weather_data["properties"]["meta"]["units"][
today_forecast = weather_data["properties"]["timeseries"][0]["data"]["instant"]["details"]["air_temperature"] "air_temperature"
]
today_forecast = weather_data["properties"]["timeseries"][0]["data"][
"instant"
]["details"]["air_temperature"]
# Return the weather data along with the city, state, and country # Return the weather data along with the city, state, and country
return { return {
@ -100,12 +107,15 @@ class WeatherForecast(AgentTool):
"state": state, "state": state,
"country": country, "country": country,
"temperature": today_forecast, "temperature": today_forecast,
"unit": temperature_unit "unit": temperature_unit,
} }
class HistoricalWeather(AgentTool): class HistoricalWeather(AgentTool):
name: str = 'HistoricalWeather' name: str = "HistoricalWeather"
description: str = 'A tool for retrieving historical weather data (temperature) for a given city.' description: str = (
"A tool for retrieving historical weather data (temperature) for a given city."
)
# Default user agent # Default user agent
user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15" user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15"
@ -122,34 +132,48 @@ class HistoricalWeather(AgentTool):
f"No data found during {stage}. URL: {url}. Response: {response.text}" f"No data found during {stage}. URL: {url}. Response: {response.text}"
) )
def _run(self, city: str, state: Optional[str] = None, country: Optional[str] = "usa", start_date: Optional[str] = None, end_date: Optional[str] = None) -> dict: def _run(
self,
city: str,
state: Optional[str] = None,
country: Optional[str] = "usa",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> dict:
""" """
Retrieves historical weather data for the city by first fetching geocode data and then historical weather data. Retrieves historical weather data for the city by first fetching geocode data and then historical weather data.
Args: Args:
city (str): The name of the city to get weather for. city (str): The name of the city to get weather for.
state (Optional[str]): The two-letter state abbreviation (optional). state (Optional[str]): The two-letter state abbreviation (optional).
country (Optional[str]): The two-letter country abbreviation. Defaults to 'usa'. country (Optional[str]): The two-letter country abbreviation. Defaults to 'usa'.
start_date (Optional[str]): Start date for historical data (YYYY-MM-DD format). start_date (Optional[str]): Start date for historical data (YYYY-MM-DD format).
end_date (Optional[str]): End date for historical data (YYYY-MM-DD format). end_date (Optional[str]): End date for historical data (YYYY-MM-DD format).
Returns: Returns:
dict: A dictionary containing the city, state, country, and historical temperature data. dict: A dictionary containing the city, state, country, and historical temperature data.
""" """
headers = { headers = {"User-Agent": self.user_agent}
"User-Agent": self.user_agent
}
# Validate dates # Validate dates
current_date = datetime.now().strftime('%Y-%m-%d') current_date = datetime.now().strftime("%Y-%m-%d")
if start_date >= current_date or end_date >= current_date: if start_date >= current_date or end_date >= current_date:
raise ValueError("Both start_date and end_date must be earlier than the current date.") raise ValueError(
"Both start_date and end_date must be earlier than the current date."
if (datetime.strptime(end_date, "%Y-%m-%d") - datetime.strptime(start_date, "%Y-%m-%d")).days > 30: )
raise ValueError("The time span between start_date and end_date cannot exceed 30 days.")
if (
datetime.strptime(end_date, "%Y-%m-%d")
- datetime.strptime(start_date, "%Y-%m-%d")
).days > 30:
raise ValueError(
"The time span between start_date and end_date cannot exceed 30 days."
)
# Construct the geocode URL, conditionally including the state if it's provided # Construct the geocode URL, conditionally including the state if it's provided
geocode_url = f"https://nominatim.openstreetmap.org/search?city={city}&country={country}" geocode_url = (
f"https://nominatim.openstreetmap.org/search?city={city}&country={country}"
)
if state: if state:
geocode_url += f"&state={state}" geocode_url += f"&state={state}"
geocode_url += "&limit=1&format=jsonv2" geocode_url += "&limit=1&format=jsonv2"
@ -167,7 +191,9 @@ class HistoricalWeather(AgentTool):
# Historical weather request # Historical weather request
historical_weather_url = f"https://archive-api.open-meteo.com/v1/archive?latitude={lat}&longitude={lon}&start_date={start_date}&end_date={end_date}&hourly=temperature_2m" historical_weather_url = f"https://archive-api.open-meteo.com/v1/archive?latitude={lat}&longitude={lon}&start_date={start_date}&end_date={end_date}&hourly=temperature_2m"
weather_response = requests.get(historical_weather_url, headers=headers) weather_response = requests.get(historical_weather_url, headers=headers)
self.handle_error(weather_response, historical_weather_url, "historical weather lookup") self.handle_error(
weather_response, historical_weather_url, "historical weather lookup"
)
weather_data = weather_response.json() weather_data = weather_response.json()
@ -177,7 +203,9 @@ class HistoricalWeather(AgentTool):
temperature_unit = weather_data["hourly_units"]["temperature_2m"] temperature_unit = weather_data["hourly_units"]["temperature_2m"]
# Combine timestamps and temperatures into a dictionary # Combine timestamps and temperatures into a dictionary
temperature_data = {timestamps[i]: temperatures[i] for i in range(len(timestamps))} temperature_data = {
timestamps[i]: temperatures[i] for i in range(len(timestamps))
}
# Return the structured weather data along with the city, state, country # Return the structured weather data along with the city, state, country
return { return {
@ -187,5 +215,5 @@ class HistoricalWeather(AgentTool):
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"temperature_data": temperature_data, "temperature_data": temperature_data,
"unit": temperature_unit "unit": temperature_unit,
} }

View File

@ -0,0 +1,501 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "39c2dcc0",
"metadata": {},
"source": [
"# Executor: LocalCodeExecutorBasic Examples\n",
"\n",
"This notebook shows how to execute Python and shell snippets in **isolated, cached virtual environments**"
]
},
{
"cell_type": "markdown",
"id": "c4ff4b2b",
"metadata": {},
"source": [
"## Install Required Libraries\n",
"Before starting, ensure the required libraries are installed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b41a66a",
"metadata": {},
"outputs": [],
"source": [
"!pip install dapr-agents"
]
},
{
"cell_type": "markdown",
"id": "a9c01be3",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "508fd446",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"\n",
"from dapr_agents.executors.local import LocalCodeExecutor\n",
"from dapr_agents.types.executor import CodeSnippet, ExecutionRequest\n",
"from rich.console import Console\n",
"from rich.ansi import AnsiDecoder\n",
"import shutil"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "27594072",
"metadata": {},
"outputs": [],
"source": [
"logging.basicConfig(level=logging.INFO)\n",
"\n",
"executor = LocalCodeExecutor()\n",
"console = Console()\n",
"decoder = AnsiDecoder()"
]
},
{
"cell_type": "markdown",
"id": "4d663475",
"metadata": {},
"source": [
"## Running a basic Python Code Snippet"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ba45ddc8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
"INFO:dapr_agents.executors.local:Created a new virtual environment\n",
"INFO:dapr_agents.executors.local:Installing print, rich\n",
"INFO:dapr_agents.executors.local:Snippet 1 finished in 2.442s\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">Hello executor!</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;32mHello executor!\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"code = \"\"\"\n",
"from rich import print\n",
"print(\"[bold green]Hello executor![/bold green]\")\n",
"\"\"\"\n",
"\n",
"request = ExecutionRequest(snippets=[\n",
" CodeSnippet(language='python', code=code, timeout=10)\n",
"])\n",
"\n",
"results = await executor.execute(request)\n",
"results[0] # raw result\n",
"\n",
"# prettyprint with Rich\n",
"console.print(*decoder.decode(results[0].output))"
]
},
{
"cell_type": "markdown",
"id": "d28c7531",
"metadata": {},
"source": [
"## Run a Shell Snipper"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4ea89b85",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
"INFO:dapr_agents.executors.local:Snippet 1 finished in 0.019s\n"
]
},
{
"data": {
"text/plain": [
"[ExecutionResult(status='success', output='4\\n', exit_code=0)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shell_request = ExecutionRequest(snippets=[\n",
" CodeSnippet(language='sh', code='echo $((2+2))', timeout=5)\n",
"])\n",
"\n",
"await executor.execute(shell_request)"
]
},
{
"cell_type": "markdown",
"id": "da281b6e",
"metadata": {},
"source": [
"## Reuse the cached virtual environment"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3e9e7e9b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
"INFO:dapr_agents.executors.local:Reusing cached virtual environment.\n",
"INFO:dapr_agents.executors.local:Installing print, rich\n",
"INFO:dapr_agents.executors.local:Snippet 1 finished in 0.297s\n"
]
},
{
"data": {
"text/plain": [
"[ExecutionResult(status='success', output='\\x1b[1;32mHello executor!\\x1b[0m\\n', exit_code=0)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Rerunning the same Python request will reuse the cached venv, so it is faster\n",
"await executor.execute(request)"
]
},
{
"cell_type": "markdown",
"id": "14dc3e4c",
"metadata": {},
"source": [
"## Inject Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "82f9a168",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
"INFO:dapr_agents.executors.local:Created a new virtual environment\n",
"INFO:dapr_agents.executors.local:Snippet 1 finished in 1.408s\n"
]
},
{
"data": {
"text/plain": [
"[ExecutionResult(status='success', output='42\\n', exit_code=0)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def fancy_sum(a: int, b: int) -> int:\n",
" return a + b\n",
"\n",
"executor.user_functions.append(fancy_sum)\n",
"\n",
"helper_request = ExecutionRequest(snippets=[\n",
" CodeSnippet(language='python', code='print(fancy_sum(40, 2))', timeout=5)\n",
"])\n",
"\n",
"await executor.execute(helper_request)"
]
},
{
"cell_type": "markdown",
"id": "25f9718c",
"metadata": {},
"source": [
"## Clean Up"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b09059f1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cache directory removed ✅\n"
]
}
],
"source": [
"shutil.rmtree(executor.cache_dir, ignore_errors=True)\n",
"print(\"Cache directory removed ✅\")"
]
},
{
"cell_type": "markdown",
"id": "2c93cdef",
"metadata": {},
"source": [
"## Package-manager detection & automatic bootstrap"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8691f3e3",
"metadata": {},
"outputs": [],
"source": [
"from dapr_agents.executors.utils import package_manager as pm\n",
"import pathlib, tempfile"
]
},
{
"cell_type": "markdown",
"id": "e9e08d81",
"metadata": {},
"source": [
"### Create a throw-away project"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4c7dd9c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tmp project: /var/folders/9z/8xhqw8x1611fcbhzl339yrs40000gn/T/tmpmssk0m2b\n"
]
}
],
"source": [
"tmp_proj = pathlib.Path(tempfile.mkdtemp())\n",
"(tmp_proj / \"requirements.txt\").write_text(\"rich==13.7.0\\n\")\n",
"print(\"tmp project:\", tmp_proj)"
]
},
{
"cell_type": "markdown",
"id": "03558a95",
"metadata": {},
"source": [
"### Show what the helper detects"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3b5acbfb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"detect_package_managers -> [<PackageManagerType.PIP: 'pip'>]\n",
"get_install_command -> pip install -r requirements.txt\n"
]
}
],
"source": [
"print(\"detect_package_managers ->\",\n",
" [m.name for m in pm.detect_package_managers(tmp_proj)])\n",
"print(\"get_install_command ->\",\n",
" pm.get_install_command(tmp_proj))"
]
},
{
"cell_type": "markdown",
"id": "42f1ae7c",
"metadata": {},
"source": [
"### Point the executor at that directory"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "81e53cf4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from contextlib import contextmanager, ExitStack\n",
"\n",
"@contextmanager\n",
"def chdir(path):\n",
" \"\"\"\n",
" Temporarily change the process CWD to *path*.\n",
"\n",
" Works on every CPython ≥ 3.6 (and PyPy) and restores the old directory\n",
" even if an exception is raised inside the block.\n",
" \"\"\"\n",
" old_cwd = os.getcwd()\n",
" os.chdir(path)\n",
" try:\n",
" yield\n",
" finally:\n",
" os.chdir(old_cwd)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fb2f5052",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dapr_agents.executors.local:bootstrapping python project with 'pip install -r requirements.txt'\n",
"INFO:dapr_agents.executors.local:Sandbox backend enabled: seatbelt\n",
"INFO:dapr_agents.executors.local:Created a new virtual environment\n",
"INFO:dapr_agents.executors.local:Snippet 1 finished in 1.433s\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">bootstrap OK\n",
"\n",
"</pre>\n"
],
"text/plain": [
"bootstrap OK\n",
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"with ExitStack() as stack:\n",
" # keep a directory handle open (optional but handy if youll delete tmp_proj later)\n",
" stack.enter_context(os.scandir(tmp_proj))\n",
"\n",
" # <-- our portable replacement for contextlib.chdir()\n",
" stack.enter_context(chdir(tmp_proj))\n",
"\n",
" # run a trivial snippet; executor will bootstrap because it now “sees”\n",
" # requirements.txt in the current working directory\n",
" out = await executor.execute(\n",
" ExecutionRequest(snippets=[\n",
" CodeSnippet(language=\"python\", code=\"print('bootstrap OK')\", timeout=5)\n",
" ])\n",
" )\n",
" console.print(out[0].output)"
]
},
{
"cell_type": "markdown",
"id": "45de2386",
"metadata": {},
"source": [
"### Clean Up the throw-away project "
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0c7aa010",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cache directory removed ✅\n",
"temporary project removed ✅\n"
]
}
],
"source": [
"shutil.rmtree(executor.cache_dir, ignore_errors=True)\n",
"print(\"Cache directory removed ✅\")\n",
"shutil.rmtree(tmp_proj, ignore_errors=True)\n",
"print(\"temporary project removed ✅\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36ea4010",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -12,11 +12,11 @@ from tools import mcp
# Logging Configuration # Logging Configuration
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
format="%(asctime)s - %(levelname)s - %(message)s"
) )
logger = logging.getLogger("mcp-server") logger = logging.getLogger("mcp-server")
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
# Starlette App Factory # Starlette App Factory
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
@ -29,27 +29,44 @@ def create_starlette_app():
async def handle_sse(request: Request) -> None: async def handle_sse(request: Request) -> None:
logger.info("🔌 SSE connection established") logger.info("🔌 SSE connection established")
async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream): async with sse.connect_sse(request.scope, request.receive, request._send) as (
read_stream,
write_stream,
):
logger.debug("Starting MCP server run loop over SSE") logger.debug("Starting MCP server run loop over SSE")
await mcp._mcp_server.run(read_stream, write_stream, mcp._mcp_server.create_initialization_options()) await mcp._mcp_server.run(
read_stream,
write_stream,
mcp._mcp_server.create_initialization_options(),
)
logger.debug("MCP run loop completed") logger.debug("MCP run loop completed")
return Starlette( return Starlette(
debug=False, debug=False,
routes=[ routes=[
Route("/sse", endpoint=handle_sse), Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message) Mount("/messages/", app=sse.handle_post_message),
] ],
) )
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
# CLI Entrypoint # CLI Entrypoint
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
def main(): def main():
parser = argparse.ArgumentParser(description="Run an MCP tool server.") parser = argparse.ArgumentParser(description="Run an MCP tool server.")
parser.add_argument("--server_type", choices=["stdio", "sse"], default="stdio", help="Transport to use") parser.add_argument(
parser.add_argument("--host", default="127.0.0.1", help="Host to bind to (SSE only)") "--server_type",
parser.add_argument("--port", type=int, default=8000, help="Port to bind to (SSE only)") choices=["stdio", "sse"],
default="stdio",
help="Transport to use",
)
parser.add_argument(
"--host", default="127.0.0.1", help="Host to bind to (SSE only)"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to bind to (SSE only)"
)
args = parser.parse_args() args = parser.parse_args()
logger.info(f"🚀 Starting MCP server in {args.server_type.upper()} mode") logger.info(f"🚀 Starting MCP server in {args.server_type.upper()} mode")
@ -61,5 +78,6 @@ def main():
logger.info(f"🌐 Running SSE server on {args.host}:{args.port}") logger.info(f"🌐 Running SSE server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,13 +3,15 @@ import random
mcp = FastMCP("TestServer") mcp = FastMCP("TestServer")
@mcp.tool() @mcp.tool()
async def get_weather(location: str) -> str: async def get_weather(location: str) -> str:
"""Get weather information for a specific location.""" """Get weather information for a specific location."""
temperature = random.randint(60, 80) temperature = random.randint(60, 80)
return f"{location}: {temperature}F." return f"{location}: {temperature}F."
@mcp.tool() @mcp.tool()
async def jump(distance: str) -> str: async def jump(distance: str) -> str:
"""Simulate a jump of a given distance.""" """Simulate a jump of a given distance."""
return f"I jumped the following distance: {distance}" return f"I jumped the following distance: {distance}"

View File

@ -5,6 +5,7 @@ from dotenv import load_dotenv
from dapr_agents import AssistantAgent from dapr_agents import AssistantAgent
from dapr_agents.tool.mcp import MCPClient from dapr_agents.tool.mcp import MCPClient
async def main(): async def main():
try: try:
# Load MCP tools from server (stdio or sse) # Load MCP tools from server (stdio or sse)
@ -34,11 +35,12 @@ async def main():
# Start the FastAPI agent service # Start the FastAPI agent service
await weather_agent.start() await weather_agent.start()
except Exception as e: except Exception as e:
logging.exception("Error starting weather agent service", exc_info=e) logging.exception("Error starting weather agent service", exc_info=e)
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -23,7 +23,7 @@ if __name__ == "__main__":
print(f"Request failed: {e}") print(f"Request failed: {e}")
attempt += 1 attempt += 1
print(f"Waiting 5s seconds before next health checkattempt...") print("Waiting 5s seconds before next health checkattempt...")
time.sleep(5) time.sleep(5)
if not healthy: if not healthy:
@ -48,10 +48,10 @@ if __name__ == "__main__":
print(f"Request failed: {e}") print(f"Request failed: {e}")
attempt += 1 attempt += 1
print(f"Waiting 1s seconds before next attempt...") print("Waiting 1s seconds before next attempt...")
time.sleep(1) time.sleep(1)
print(f"Maximum attempts (10) reached without success.") print("Maximum attempts (10) reached without success.")
print("Failed to get successful response") print("Failed to get successful response")
sys.exit(1) sys.exit(1)

View File

@ -12,11 +12,11 @@ from tools import mcp
# Logging Configuration # Logging Configuration
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
format="%(asctime)s - %(levelname)s - %(message)s"
) )
logger = logging.getLogger("mcp-server") logger = logging.getLogger("mcp-server")
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
# Starlette App Factory # Starlette App Factory
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
@ -29,27 +29,44 @@ def create_starlette_app():
async def handle_sse(request: Request) -> None: async def handle_sse(request: Request) -> None:
logger.info("🔌 SSE connection established") logger.info("🔌 SSE connection established")
async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream): async with sse.connect_sse(request.scope, request.receive, request._send) as (
read_stream,
write_stream,
):
logger.debug("Starting MCP server run loop over SSE") logger.debug("Starting MCP server run loop over SSE")
await mcp._mcp_server.run(read_stream, write_stream, mcp._mcp_server.create_initialization_options()) await mcp._mcp_server.run(
read_stream,
write_stream,
mcp._mcp_server.create_initialization_options(),
)
logger.debug("MCP run loop completed") logger.debug("MCP run loop completed")
return Starlette( return Starlette(
debug=False, debug=False,
routes=[ routes=[
Route("/sse", endpoint=handle_sse), Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message) Mount("/messages/", app=sse.handle_post_message),
] ],
) )
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
# CLI Entrypoint # CLI Entrypoint
# ───────────────────────────────────────────── # ─────────────────────────────────────────────
def main(): def main():
parser = argparse.ArgumentParser(description="Run an MCP tool server.") parser = argparse.ArgumentParser(description="Run an MCP tool server.")
parser.add_argument("--server_type", choices=["stdio", "sse"], default="stdio", help="Transport to use") parser.add_argument(
parser.add_argument("--host", default="127.0.0.1", help="Host to bind to (SSE only)") "--server_type",
parser.add_argument("--port", type=int, default=8000, help="Port to bind to (SSE only)") choices=["stdio", "sse"],
default="stdio",
help="Transport to use",
)
parser.add_argument(
"--host", default="127.0.0.1", help="Host to bind to (SSE only)"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to bind to (SSE only)"
)
args = parser.parse_args() args = parser.parse_args()
logger.info(f"🚀 Starting MCP server in {args.server_type.upper()} mode") logger.info(f"🚀 Starting MCP server in {args.server_type.upper()} mode")
@ -61,5 +78,6 @@ def main():
logger.info(f"🌐 Running SSE server on {args.host}:{args.port}") logger.info(f"🌐 Running SSE server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,13 +3,15 @@ import random
mcp = FastMCP("TestServer") mcp = FastMCP("TestServer")
@mcp.tool() @mcp.tool()
async def get_weather(location: str) -> str: async def get_weather(location: str) -> str:
"""Get weather information for a specific location.""" """Get weather information for a specific location."""
temperature = random.randint(60, 80) temperature = random.randint(60, 80)
return f"{location}: {temperature}F." return f"{location}: {temperature}F."
@mcp.tool() @mcp.tool()
async def jump(distance: str) -> str: async def jump(distance: str) -> str:
"""Simulate a jump of a given distance.""" """Simulate a jump of a given distance."""
return f"I jumped the following distance: {distance}" return f"I jumped the following distance: {distance}"

View File

@ -3,39 +3,46 @@ import dapr.ext.workflow as wf
wfr = wf.WorkflowRuntime() wfr = wf.WorkflowRuntime()
@wfr.workflow(name='random_workflow')
@wfr.workflow(name="random_workflow")
def task_chain_workflow(ctx: wf.DaprWorkflowContext, x: int): def task_chain_workflow(ctx: wf.DaprWorkflowContext, x: int):
result1 = yield ctx.call_activity(step1, input=x) result1 = yield ctx.call_activity(step1, input=x)
result2 = yield ctx.call_activity(step2, input=result1) result2 = yield ctx.call_activity(step2, input=result1)
result3 = yield ctx.call_activity(step3, input=result2) result3 = yield ctx.call_activity(step3, input=result2)
return [result1, result2, result3] return [result1, result2, result3]
@wfr.activity @wfr.activity
def step1(ctx, activity_input): def step1(ctx, activity_input):
print(f'Step 1: Received input: {activity_input}.') print(f"Step 1: Received input: {activity_input}.")
# Do some work # Do some work
return activity_input + 1 return activity_input + 1
@wfr.activity @wfr.activity
def step2(ctx, activity_input): def step2(ctx, activity_input):
print(f'Step 2: Received input: {activity_input}.') print(f"Step 2: Received input: {activity_input}.")
# Do some work # Do some work
return activity_input * 2 return activity_input * 2
@wfr.activity @wfr.activity
def step3(ctx, activity_input): def step3(ctx, activity_input):
print(f'Step 3: Received input: {activity_input}.') print(f"Step 3: Received input: {activity_input}.")
# Do some work # Do some work
return activity_input ^ 2 return activity_input ^ 2
if __name__ == '__main__':
if __name__ == "__main__":
wfr.start() wfr.start()
sleep(5) # wait for workflow runtime to start sleep(5) # wait for workflow runtime to start
wf_client = wf.DaprWorkflowClient() wf_client = wf.DaprWorkflowClient()
instance_id = wf_client.schedule_new_workflow(workflow=task_chain_workflow, input=10) instance_id = wf_client.schedule_new_workflow(
print(f'Workflow started. Instance ID: {instance_id}') workflow=task_chain_workflow, input=10
)
print(f"Workflow started. Instance ID: {instance_id}")
state = wf_client.wait_for_workflow_completion(instance_id) state = wf_client.wait_for_workflow_completion(instance_id)
print(f'Workflow completed! Status: {state.runtime_status}') print(f"Workflow completed! Status: {state.runtime_status}")
wfr.shutdown() wfr.shutdown()

View File

@ -1,38 +1,43 @@
import logging import logging
from dapr_agents.workflow import WorkflowApp, workflow, task from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
@workflow(name='random_workflow')
def task_chain_workflow(ctx:DaprWorkflowContext, input: int): @workflow(name="random_workflow")
def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
result1 = yield ctx.call_activity(step1, input=input) result1 = yield ctx.call_activity(step1, input=input)
result2 = yield ctx.call_activity(step2, input=result1) result2 = yield ctx.call_activity(step2, input=result1)
result3 = yield ctx.call_activity(step3, input=result2) result3 = yield ctx.call_activity(step3, input=result2)
return [result1, result2, result3] return [result1, result2, result3]
@task @task
def step1(activity_input): def step1(activity_input):
print(f'Step 1: Received input: {activity_input}.') print(f"Step 1: Received input: {activity_input}.")
# Do some work # Do some work
return activity_input + 1 return activity_input + 1
@task @task
def step2(activity_input): def step2(activity_input):
print(f'Step 2: Received input: {activity_input}.') print(f"Step 2: Received input: {activity_input}.")
# Do some work # Do some work
return activity_input * 2 return activity_input * 2
@task @task
def step3(activity_input): def step3(activity_input):
print(f'Step 3: Received input: {activity_input}.') print(f"Step 3: Received input: {activity_input}.")
# Do some work # Do some work
return activity_input ^ 2 return activity_input ^ 2
if __name__ == '__main__':
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
wfapp = WorkflowApp() wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow, input=10) results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow, input=10)
print(f"Results: {results}") print(f"Results: {results}")

View File

@ -2,7 +2,8 @@ import asyncio
import logging import logging
from dapr_agents.workflow import WorkflowApp, workflow, task from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
@workflow(name="random_workflow") @workflow(name="random_workflow")
def task_chain_workflow(ctx: DaprWorkflowContext, input: int): def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
@ -11,30 +12,32 @@ def task_chain_workflow(ctx: DaprWorkflowContext, input: int):
result3 = yield ctx.call_activity(step3, input=result2) result3 = yield ctx.call_activity(step3, input=result2)
return [result1, result2, result3] return [result1, result2, result3]
@task @task
def step1(activity_input: int) -> int: def step1(activity_input: int) -> int:
print(f"Step 1: Received input: {activity_input}.") print(f"Step 1: Received input: {activity_input}.")
return activity_input + 1 return activity_input + 1
@task @task
def step2(activity_input: int) -> int: def step2(activity_input: int) -> int:
print(f"Step 2: Received input: {activity_input}.") print(f"Step 2: Received input: {activity_input}.")
return activity_input * 2 return activity_input * 2
@task @task
def step3(activity_input: int) -> int: def step3(activity_input: int) -> int:
print(f"Step 3: Received input: {activity_input}.") print(f"Step 3: Received input: {activity_input}.")
return activity_input ^ 2 return activity_input ^ 2
async def main(): async def main():
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
wfapp = WorkflowApp() wfapp = WorkflowApp()
result = await wfapp.run_and_monitor_workflow_async( result = await wfapp.run_and_monitor_workflow_async(task_chain_workflow, input=10)
task_chain_workflow,
input=10
)
print(f"Results: {result}") print(f"Results: {result}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -1,27 +1,35 @@
from dapr_agents.workflow import WorkflowApp, workflow, task from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
from dotenv import load_dotenv from dotenv import load_dotenv
import logging import logging
# Define Workflow logic # Define Workflow logic
@workflow(name='lotr_workflow') @workflow(name="lotr_workflow")
def task_chain_workflow(ctx: DaprWorkflowContext): def task_chain_workflow(ctx: DaprWorkflowContext):
result1 = yield ctx.call_activity(get_character) result1 = yield ctx.call_activity(get_character)
result2 = yield ctx.call_activity(get_line, input={"character": result1}) result2 = yield ctx.call_activity(get_line, input={"character": result1})
return result2 return result2
@task(description="""
@task(
description="""
Pick a random character from The Lord of the Rings\n Pick a random character from The Lord of the Rings\n
and respond with the character's name ONLY and respond with the character's name ONLY
""") """
)
def get_character() -> str: def get_character() -> str:
pass pass
@task(description="What is a famous line by {character}",)
@task(
description="What is a famous line by {character}",
)
def get_line(character: str) -> str: def get_line(character: str) -> str:
pass pass
if __name__ == '__main__':
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# Load environment variables # Load environment variables
@ -32,4 +40,4 @@ if __name__ == '__main__':
# Run workflow # Run workflow
results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow) results = wfapp.run_and_monitor_workflow_sync(task_chain_workflow)
print(results) print(results)

View File

@ -2,27 +2,35 @@ import asyncio
import logging import logging
from dapr_agents.workflow import WorkflowApp, workflow, task from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
from dotenv import load_dotenv from dotenv import load_dotenv
# Define Workflow logic # Define Workflow logic
@workflow(name='lotr_workflow') @workflow(name="lotr_workflow")
def task_chain_workflow(ctx: DaprWorkflowContext): def task_chain_workflow(ctx: DaprWorkflowContext):
result1 = yield ctx.call_activity(get_character) result1 = yield ctx.call_activity(get_character)
result2 = yield ctx.call_activity(get_line, input={"character": result1}) result2 = yield ctx.call_activity(get_line, input={"character": result1})
return result2 return result2
@task(description="""
@task(
description="""
Pick a random character from The Lord of the Rings\n Pick a random character from The Lord of the Rings\n
and respond with the character's name ONLY and respond with the character's name ONLY
""") """
)
def get_character() -> str: def get_character() -> str:
pass pass
@task(description="What is a famous line by {character}",)
@task(
description="What is a famous line by {character}",
)
def get_line(character: str) -> str: def get_line(character: str) -> str:
pass pass
async def main(): async def main():
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -31,10 +39,11 @@ async def main():
# Initialize the WorkflowApp # Initialize the WorkflowApp
wfapp = WorkflowApp() wfapp = WorkflowApp()
# Run workflow # Run workflow
result = await wfapp.run_and_monitor_workflow_async(task_chain_workflow) result = await wfapp.run_and_monitor_workflow_async(task_chain_workflow)
print(f"Results: {result}") print(f"Results: {result}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -1,33 +1,34 @@
import logging import logging
from dapr_agents.workflow import WorkflowApp, workflow, task from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
from pydantic import BaseModel from pydantic import BaseModel
from dotenv import load_dotenv from dotenv import load_dotenv
@workflow @workflow
def question(ctx:DaprWorkflowContext, input:int): def question(ctx: DaprWorkflowContext, input: int):
step1 = yield ctx.call_activity(ask, input=input) step1 = yield ctx.call_activity(ask, input=input)
return step1 return step1
class Dog(BaseModel): class Dog(BaseModel):
name: str name: str
bio: str bio: str
breed: str breed: str
@task("Who was {name}?") @task("Who was {name}?")
def ask(name:str) -> Dog: def ask(name: str) -> Dog:
pass pass
if __name__ == '__main__':
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
load_dotenv() load_dotenv()
wfapp = WorkflowApp() wfapp = WorkflowApp()
results = wfapp.run_and_monitor_workflow_sync( results = wfapp.run_and_monitor_workflow_sync(workflow=question, input="Scooby Doo")
workflow=question,
input="Scooby Doo"
)
print(results) print(results)

View File

@ -2,24 +2,28 @@ import asyncio
import logging import logging
from dapr_agents.workflow import WorkflowApp, workflow, task from dapr_agents.workflow import WorkflowApp, workflow, task
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
from pydantic import BaseModel from pydantic import BaseModel
from dotenv import load_dotenv from dotenv import load_dotenv
@workflow @workflow
def question(ctx:DaprWorkflowContext, input:int): def question(ctx: DaprWorkflowContext, input: int):
step1 = yield ctx.call_activity(ask, input=input) step1 = yield ctx.call_activity(ask, input=input)
return step1 return step1
class Dog(BaseModel): class Dog(BaseModel):
name: str name: str
bio: str bio: str
breed: str breed: str
@task("Who was {name}?") @task("Who was {name}?")
def ask(name:str) -> Dog: def ask(name: str) -> Dog:
pass pass
async def main(): async def main():
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -28,13 +32,13 @@ async def main():
# Initialize the WorkflowApp # Initialize the WorkflowApp
wfapp = WorkflowApp() wfapp = WorkflowApp()
# Run workflow # Run workflow
result = await wfapp.run_and_monitor_workflow_async( result = await wfapp.run_and_monitor_workflow_async(
workflow=question, workflow=question, input="Scooby Doo"
input="Scooby Doo"
) )
print(f"Results: {result}") print(f"Results: {result}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -9,55 +9,54 @@ load_dotenv()
# Initialize Workflow Instance # Initialize Workflow Instance
wfr = wf.WorkflowRuntime() wfr = wf.WorkflowRuntime()
# Define Workflow logic # Define Workflow logic
@wfr.workflow(name='lotr_workflow') @wfr.workflow(name="lotr_workflow")
def task_chain_workflow(ctx: wf.DaprWorkflowContext): def task_chain_workflow(ctx: wf.DaprWorkflowContext):
result1 = yield ctx.call_activity(get_character) result1 = yield ctx.call_activity(get_character)
result2 = yield ctx.call_activity(get_line, input=result1) result2 = yield ctx.call_activity(get_line, input=result1)
return result2 return result2
# Activity 1 # Activity 1
@wfr.activity(name='step1') @wfr.activity(name="step1")
def get_character(ctx): def get_character(ctx):
client = OpenAI() client = OpenAI()
response = client.chat.completions.create( response = client.chat.completions.create(
messages = [ messages=[
{ {
"role": "user", "role": "user",
"content": "Pick a random character from The Lord of the Rings and respond with the character name only" "content": "Pick a random character from The Lord of the Rings and respond with the character name only",
} }
], ],
model = 'gpt-4o' model="gpt-4o",
) )
character = response.choices[0].message.content character = response.choices[0].message.content
print(f"Character: {character}") print(f"Character: {character}")
return character return character
# Activity 2 # Activity 2
@wfr.activity(name='step2') @wfr.activity(name="step2")
def get_line(ctx, character: str): def get_line(ctx, character: str):
client = OpenAI() client = OpenAI()
response = client.chat.completions.create( response = client.chat.completions.create(
messages = [ messages=[{"role": "user", "content": f"What is a famous line by {character}"}],
{ model="gpt-4o",
"role": "user",
"content": f"What is a famous line by {character}"
}
],
model = 'gpt-4o'
) )
line = response.choices[0].message.content line = response.choices[0].message.content
print(f"Line: {line}") print(f"Line: {line}")
return line return line
if __name__ == '__main__':
if __name__ == "__main__":
wfr.start() wfr.start()
sleep(5) # wait for workflow runtime to start sleep(5) # wait for workflow runtime to start
wf_client = wf.DaprWorkflowClient() wf_client = wf.DaprWorkflowClient()
instance_id = wf_client.schedule_new_workflow(workflow=task_chain_workflow) instance_id = wf_client.schedule_new_workflow(workflow=task_chain_workflow)
print(f'Workflow started. Instance ID: {instance_id}') print(f"Workflow started. Instance ID: {instance_id}")
state = wf_client.wait_for_workflow_completion(instance_id) state = wf_client.wait_for_workflow_completion(instance_id)
print(f'Workflow completed! Status: {state.runtime_status}') print(f"Workflow completed! Status: {state.runtime_status}")
wfr.shutdown() wfr.shutdown()

View File

@ -1,5 +1,5 @@
from dapr_agents.document.reader.pdf.pypdf import PyPDFReader from dapr_agents.document.reader.pdf.pypdf import PyPDFReader
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents import WorkflowApp from dapr_agents import WorkflowApp
from urllib.parse import urlparse, unquote from urllib.parse import urlparse, unquote
from dotenv import load_dotenv from dotenv import load_dotenv
@ -22,16 +22,19 @@ load_dotenv()
# Initialize the WorkflowApp # Initialize the WorkflowApp
wfapp = WorkflowApp() wfapp = WorkflowApp()
# Define structured output models # Define structured output models
class SpeakerEntry(BaseModel): class SpeakerEntry(BaseModel):
name: str name: str
text: str text: str
class PodcastDialogue(BaseModel): class PodcastDialogue(BaseModel):
participants: List[SpeakerEntry] participants: List[SpeakerEntry]
# Define Workflow logic # Define Workflow logic
@wfapp.workflow(name='doc2podcast') @wfapp.workflow(name="doc2podcast")
def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]): def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
# Extract pre-validated input # Extract pre-validated input
podcast_name = input["podcast_name"] podcast_name = input["podcast_name"]
@ -44,10 +47,13 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
audio_model = input["audio_model"] audio_model = input["audio_model"]
# Step 1: Assign voices to the team # Step 1: Assign voices to the team
team_config = yield ctx.call_activity(assign_podcast_voices, input={ team_config = yield ctx.call_activity(
"host_config": host_config, assign_podcast_voices,
"participant_configs": participant_configs, input={
}) "host_config": host_config,
"participant_configs": participant_configs,
},
)
# Step 2: Read PDF and get documents # Step 2: Read PDF and get documents
file_path = yield ctx.call_activity(download_pdf, input=file_input) file_path = yield ctx.call_activity(download_pdf, input=file_input)
@ -67,7 +73,9 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
"context": accumulated_context, "context": accumulated_context,
"participants": [p["name"] for p in team_config["participants"]], "participants": [p["name"] for p in team_config["participants"]],
} }
generated_prompt = yield ctx.call_activity(generate_prompt, input=document_with_context) generated_prompt = yield ctx.call_activity(
generate_prompt, input=document_with_context
)
# Use the prompt to generate the structured dialogue # Use the prompt to generate the structured dialogue
prompt_parameters = { prompt_parameters = {
@ -76,7 +84,9 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
"prompt": generated_prompt, "prompt": generated_prompt,
"max_rounds": max_rounds, "max_rounds": max_rounds,
} }
dialogue_entry = yield ctx.call_activity(generate_transcript, input=prompt_parameters) dialogue_entry = yield ctx.call_activity(
generate_transcript, input=prompt_parameters
)
# Update context and transcript parts # Update context and transcript parts
conversations = dialogue_entry["participants"] conversations = dialogue_entry["participants"]
@ -85,18 +95,30 @@ def doc2podcast(ctx: DaprWorkflowContext, input: Dict[str, Any]):
transcript_parts.append(participant) transcript_parts.append(participant)
# Step 4: Write the final transcript to a file # Step 4: Write the final transcript to a file
yield ctx.call_activity(write_transcript_to_file, input={"podcast_dialogue": transcript_parts, "output_path": output_transcript_path}) yield ctx.call_activity(
write_transcript_to_file,
input={
"podcast_dialogue": transcript_parts,
"output_path": output_transcript_path,
},
)
# Step 5: Convert transcript to audio using team_config # Step 5: Convert transcript to audio using team_config
yield ctx.call_activity(convert_transcript_to_audio, input={ yield ctx.call_activity(
"transcript_parts": transcript_parts, convert_transcript_to_audio,
"output_path": output_audio_path, input={
"voices": team_config, "transcript_parts": transcript_parts,
"model": audio_model, "output_path": output_audio_path,
}) "voices": team_config,
"model": audio_model,
},
)
@wfapp.task @wfapp.task
def assign_podcast_voices(host_config: Dict[str, Any], participant_configs: List[Dict[str, Any]]) -> Dict[str, Any]: def assign_podcast_voices(
host_config: Dict[str, Any], participant_configs: List[Dict[str, Any]]
) -> Dict[str, Any]:
""" """
Assign voices to the podcast host and participants. Assign voices to the podcast host and participants.
@ -112,7 +134,9 @@ def assign_podcast_voices(host_config: Dict[str, Any], participant_configs: List
# Assign voice to the host if not already specified # Assign voice to the host if not already specified
if "voice" not in host_config: if "voice" not in host_config:
host_config["voice"] = next(voice for voice in allowed_voices if voice not in assigned_voices) host_config["voice"] = next(
voice for voice in allowed_voices if voice not in assigned_voices
)
assigned_voices.add(host_config["voice"]) assigned_voices.add(host_config["voice"])
# Assign voices to participants, ensuring no duplicates # Assign voices to participants, ensuring no duplicates
@ -131,6 +155,7 @@ def assign_podcast_voices(host_config: Dict[str, Any], participant_configs: List
"participants": updated_participants, "participants": updated_participants,
} }
@wfapp.task @wfapp.task
def download_pdf(pdf_url: str, local_directory: str = ".") -> str: def download_pdf(pdf_url: str, local_directory: str = ".") -> str:
""" """
@ -142,7 +167,7 @@ def download_pdf(pdf_url: str, local_directory: str = ".") -> str:
if not filename: if not filename:
raise ValueError("Invalid URL: Cannot determine filename from the URL.") raise ValueError("Invalid URL: Cannot determine filename from the URL.")
filename = filename.replace(" ", "_") filename = filename.replace(" ", "_")
local_directory_path = Path(local_directory).resolve() local_directory_path = Path(local_directory).resolve()
local_directory_path.mkdir(parents=True, exist_ok=True) local_directory_path.mkdir(parents=True, exist_ok=True)
@ -163,6 +188,7 @@ def download_pdf(pdf_url: str, local_directory: str = ".") -> str:
logger.error(f"Error downloading PDF: {e}") logger.error(f"Error downloading PDF: {e}")
raise raise
@wfapp.task @wfapp.task
def read_pdf(file_path: str) -> List[dict]: def read_pdf(file_path: str) -> List[dict]:
""" """
@ -176,8 +202,15 @@ def read_pdf(file_path: str) -> List[dict]:
logger.error(f"Error reading document: {e}") logger.error(f"Error reading document: {e}")
raise raise
@wfapp.task @wfapp.task
def generate_prompt(text: str, iteration_index: int, total_iterations: int, context: str, participants: List[str]) -> str: def generate_prompt(
text: str,
iteration_index: int,
total_iterations: int,
context: str,
participants: List[str],
) -> str:
""" """
Generate a prompt dynamically for the chunk. Generate a prompt dynamically for the chunk.
""" """
@ -189,7 +222,7 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
""" """
if participants: if participants:
participant_names = ', '.join(participants) participant_names = ", ".join(participants)
instructions += f"\nPARTICIPANTS: {participant_names}" instructions += f"\nPARTICIPANTS: {participant_names}"
else: else:
instructions += "\nPARTICIPANTS: None (Host-only conversation)" instructions += "\nPARTICIPANTS: None (Host-only conversation)"
@ -214,7 +247,7 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
- Follow up on the previous discussion points and introduce the next topic naturally. - Follow up on the previous discussion points and introduce the next topic naturally.
""" """
instructions += f""" instructions += """
TASK: TASK:
- Use the provided TEXT to guide this part of the conversation. - Use the provided TEXT to guide this part of the conversation.
- Alternate between speakers, ensuring a natural conversational flow. - Alternate between speakers, ensuring a natural conversational flow.
@ -222,7 +255,9 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
""" """
return f"{instructions}\nTEXT:\n{text.strip()}" return f"{instructions}\nTEXT:\n{text.strip()}"
@wfapp.task("""
@wfapp.task(
"""
Generate a structured podcast dialogue based on the context and text provided. Generate a structured podcast dialogue based on the context and text provided.
The podcast is titled '{podcast_name}' and is hosted by {host_name}. The podcast is titled '{podcast_name}' and is hosted by {host_name}.
If participants are available, each speaker is limited to a maximum of {max_rounds} turns per iteration. If participants are available, each speaker is limited to a maximum of {max_rounds} turns per iteration.
@ -231,26 +266,39 @@ def generate_prompt(text: str, iteration_index: int, total_iterations: int, cont
If participants are not available, the host drives the conversation alone. If participants are not available, the host drives the conversation alone.
Keep the dialogue concise and ensure a natural conversational flow. Keep the dialogue concise and ensure a natural conversational flow.
{prompt} {prompt}
""") """
def generate_transcript(podcast_name: str, host_name: str, prompt: str, max_rounds: int) -> PodcastDialogue: )
def generate_transcript(
podcast_name: str, host_name: str, prompt: str, max_rounds: int
) -> PodcastDialogue:
pass pass
@wfapp.task @wfapp.task
def write_transcript_to_file(podcast_dialogue: List[Dict[str, Any]], output_path: str) -> None: def write_transcript_to_file(
podcast_dialogue: List[Dict[str, Any]], output_path: str
) -> None:
""" """
Write the final structured transcript to a file. Write the final structured transcript to a file.
""" """
try: try:
with open(output_path, "w", encoding="utf-8") as file: with open(output_path, "w", encoding="utf-8") as file:
import json import json
json.dump(podcast_dialogue, file, ensure_ascii=False, indent=4) json.dump(podcast_dialogue, file, ensure_ascii=False, indent=4)
logger.info(f"Podcast dialogue successfully written to {output_path}") logger.info(f"Podcast dialogue successfully written to {output_path}")
except Exception as e: except Exception as e:
logger.error(f"Error writing podcast dialogue to file: {e}") logger.error(f"Error writing podcast dialogue to file: {e}")
raise raise
@wfapp.task @wfapp.task
def convert_transcript_to_audio(transcript_parts: List[Dict[str, Any]], output_path: str, voices: Dict[str, Any], model: str = "tts-1") -> None: def convert_transcript_to_audio(
transcript_parts: List[Dict[str, Any]],
output_path: str,
voices: Dict[str, Any],
model: str = "tts-1",
) -> None:
""" """
Converts a transcript into a single audio file using the OpenAI Audio Client and pydub for concatenation. Converts a transcript into a single audio file using the OpenAI Audio Client and pydub for concatenation.
@ -271,24 +319,30 @@ def convert_transcript_to_audio(transcript_parts: List[Dict[str, Any]], output_p
for part in transcript_parts: for part in transcript_parts:
speaker_name = part["name"] speaker_name = part["name"]
speaker_text = part["text"] speaker_text = part["text"]
assigned_voice = voice_mapping.get(speaker_name, "alloy") # Default to "alloy" if not found assigned_voice = voice_mapping.get(
speaker_name, "alloy"
) # Default to "alloy" if not found
# Log assigned voice for debugging # Log assigned voice for debugging
logger.info(f"Generating audio for {speaker_name} using voice '{assigned_voice}'.") logger.info(
f"Generating audio for {speaker_name} using voice '{assigned_voice}'."
)
# Create TTS request # Create TTS request
tts_request = AudioSpeechRequest( tts_request = AudioSpeechRequest(
model=model, model=model,
input=speaker_text, input=speaker_text,
voice=assigned_voice, voice=assigned_voice,
response_format="mp3" response_format="mp3",
) )
# Generate the audio # Generate the audio
audio_bytes = client.create_speech(request=tts_request) audio_bytes = client.create_speech(request=tts_request)
# Create an AudioSegment from the audio bytes # Create an AudioSegment from the audio bytes
audio_chunk = AudioSegment.from_file(io.BytesIO(audio_bytes), format=tts_request.response_format) audio_chunk = AudioSegment.from_file(
io.BytesIO(audio_bytes), format=tts_request.response_format
)
# Append the audio to the combined segment # Append the audio to the combined segment
combined_audio += audio_chunk + AudioSegment.silent(duration=300) combined_audio += audio_chunk + AudioSegment.silent(duration=300)
@ -301,17 +355,18 @@ def convert_transcript_to_audio(transcript_parts: List[Dict[str, Any]], output_p
logger.error(f"Error during audio generation: {e}") logger.error(f"Error during audio generation: {e}")
raise raise
if __name__ == '__main__':
if __name__ == "__main__":
import argparse import argparse
import json import json
import yaml import yaml
def load_config(file_path: str) -> dict: def load_config(file_path: str) -> dict:
"""Load configuration from a JSON or YAML file.""" """Load configuration from a JSON or YAML file."""
with open(file_path, 'r') as file: with open(file_path, "r") as file:
if file_path.endswith('.yaml') or file_path.endswith('.yml'): if file_path.endswith(".yaml") or file_path.endswith(".yml"):
return yaml.safe_load(file) return yaml.safe_load(file)
elif file_path.endswith('.json'): elif file_path.endswith(".json"):
return json.load(file) return json.load(file)
else: else:
raise ValueError("Unsupported file format. Use JSON or YAML.") raise ValueError("Unsupported file format. Use JSON or YAML.")
@ -323,11 +378,21 @@ if __name__ == '__main__':
parser.add_argument("--podcast_name", type=str, help="Name of the podcast.") parser.add_argument("--podcast_name", type=str, help="Name of the podcast.")
parser.add_argument("--host_name", type=str, help="Name of the host.") parser.add_argument("--host_name", type=str, help="Name of the host.")
parser.add_argument("--host_voice", type=str, help="Voice for the host.") parser.add_argument("--host_voice", type=str, help="Voice for the host.")
parser.add_argument("--participants", type=str, nargs='+', help="List of participant names.") parser.add_argument(
parser.add_argument("--max_rounds", type=int, default=4, help="Number of turns per round.") "--participants", type=str, nargs="+", help="List of participant names."
parser.add_argument("--output_transcript_path", type=str, help="Path to save the output transcript.") )
parser.add_argument("--output_audio_path", type=str, help="Path to save the final audio file.") parser.add_argument(
parser.add_argument("--audio_model", type=str, default="tts-1", help="Audio model for TTS.") "--max_rounds", type=int, default=4, help="Number of turns per round."
)
parser.add_argument(
"--output_transcript_path", type=str, help="Path to save the output transcript."
)
parser.add_argument(
"--output_audio_path", type=str, help="Path to save the final audio file."
)
parser.add_argument(
"--audio_model", type=str, default="tts-1", help="Audio model for TTS."
)
args = parser.parse_args() args = parser.parse_args()
@ -337,15 +402,18 @@ if __name__ == '__main__':
# Merge CLI and Config inputs # Merge CLI and Config inputs
user_input = { user_input = {
"pdf_url": args.pdf_url or config.get("pdf_url"), "pdf_url": args.pdf_url or config.get("pdf_url"),
"podcast_name": args.podcast_name or config.get("podcast_name", "Default Podcast"), "podcast_name": args.podcast_name
or config.get("podcast_name", "Default Podcast"),
"host": { "host": {
"name": args.host_name or config.get("host", {}).get("name", "Host"), "name": args.host_name or config.get("host", {}).get("name", "Host"),
"voice": args.host_voice or config.get("host", {}).get("voice", "alloy"), "voice": args.host_voice or config.get("host", {}).get("voice", "alloy"),
}, },
"participants": config.get("participants", []), "participants": config.get("participants", []),
"max_rounds": args.max_rounds or config.get("max_rounds", 4), "max_rounds": args.max_rounds or config.get("max_rounds", 4),
"output_transcript_path": args.output_transcript_path or config.get("output_transcript_path", "podcast_dialogue.json"), "output_transcript_path": args.output_transcript_path
"output_audio_path": args.output_audio_path or config.get("output_audio_path", "final_podcast.mp3"), or config.get("output_transcript_path", "podcast_dialogue.json"),
"output_audio_path": args.output_audio_path
or config.get("output_audio_path", "final_podcast.mp3"),
"audio_model": args.audio_model or config.get("audio_model", "tts-1"), "audio_model": args.audio_model or config.get("audio_model", "tts-1"),
} }
@ -356,6 +424,6 @@ if __name__ == '__main__':
# Validate inputs # Validate inputs
if not user_input["pdf_url"]: if not user_input["pdf_url"]:
raise ValueError("PDF URL must be provided via CLI or config file.") raise ValueError("PDF URL must be provided via CLI or config file.")
# Run the workflow # Run the workflow
wfapp.run_and_monitor_workflow_sync(workflow=doc2podcast, input=user_input) wfapp.run_and_monitor_workflow_sync(workflow=doc2podcast, input=user_input)

View File

@ -1,6 +1,6 @@
from dapr_agents import OpenAIChatClient, NVIDIAChatClient from dapr_agents import OpenAIChatClient, NVIDIAChatClient
from dapr_agents.types import DaprWorkflowContext from dapr.ext.workflow import DaprWorkflowContext
from dapr_agents. workflow import WorkflowApp, task, workflow from dapr_agents.workflow import WorkflowApp, task, workflow
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
import logging import logging
@ -8,8 +8,7 @@ import logging
load_dotenv() load_dotenv()
nvidia_llm = NVIDIAChatClient( nvidia_llm = NVIDIAChatClient(
model="meta/llama-3.1-8b-instruct", model="meta/llama-3.1-8b-instruct", api_key=os.getenv("NVIDIA_API_KEY")
api_key=os.getenv("NVIDIA_API_KEY")
) )
oai_llm = OpenAIChatClient( oai_llm = OpenAIChatClient(
@ -22,7 +21,7 @@ azoai_llm = OpenAIChatClient(
api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_key=os.getenv("AZURE_OPENAI_API_KEY"),
azure_deployment="gpt-4o-mini", azure_deployment="gpt-4o-mini",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
azure_api_version="2024-12-01-preview" azure_api_version="2024-12-01-preview",
) )
@ -36,20 +35,29 @@ def test_workflow(ctx: DaprWorkflowContext):
nvidia_results = yield ctx.call_activity(invoke_nvidia, input=azoai_results) nvidia_results = yield ctx.call_activity(invoke_nvidia, input=azoai_results)
return nvidia_results return nvidia_results
@task(description="What is the name of the capital of {country}?. Reply with just the name.", llm=oai_llm)
@task(
description="What is the name of the capital of {country}?. Reply with just the name.",
llm=oai_llm,
)
def invoke_oai(country: str) -> str: def invoke_oai(country: str) -> str:
pass pass
@task(description="What is a famous thing about {capital}?", llm=azoai_llm) @task(description="What is a famous thing about {capital}?", llm=azoai_llm)
def invoke_azoai(capital: str) -> str: def invoke_azoai(capital: str) -> str:
pass pass
@task(description="Context: {context}. From the previous context. Pick one thing to do.", llm=nvidia_llm)
@task(
description="Context: {context}. From the previous context. Pick one thing to do.",
llm=nvidia_llm,
)
def invoke_nvidia(context: str) -> str: def invoke_nvidia(context: str) -> str:
pass pass
if __name__ == '__main__':
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
wfapp = WorkflowApp() wfapp = WorkflowApp()
@ -57,4 +65,4 @@ if __name__ == '__main__':
results = wfapp.run_and_monitor_workflow_sync(workflow=test_workflow) results = wfapp.run_and_monitor_workflow_sync(workflow=test_workflow)
logging.info("Workflow results: %s", results) logging.info("Workflow results: %s", results)
logging.info("Workflow completed successfully.") logging.info("Workflow completed successfully.")

View File

@ -31,12 +31,9 @@ def sub(a: float, b: float) -> float:
async def main(): async def main():
calculator_agent = Agent( calculator_agent = Agent(
name="MathematicsAgent", name="MathematicsAgent",
role="Calculator Assistant", role="Calculator Assistant",
goal="Assist Humans with calculation tasks.", goal="Assist Humans with calculation tasks.",
instructions=[ instructions=[
"Get accurate calculation results", "Get accurate calculation results",
@ -59,4 +56,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -7,6 +7,7 @@ from dapr.clients import DaprClient
# Default Pub/Sub component # Default Pub/Sub component
PUBSUB_NAME = "pubsub" PUBSUB_NAME = "pubsub"
def main(orchestrator_topic, max_attempts=10, retry_delay=1): def main(orchestrator_topic, max_attempts=10, retry_delay=1):
""" """
Publishes a task to a specified Dapr Pub/Sub topic with retries. Publishes a task to a specified Dapr Pub/Sub topic with retries.
@ -26,8 +27,10 @@ def main(orchestrator_topic, max_attempts=10, retry_delay=1):
while attempt <= max_attempts: while attempt <= max_attempts:
try: try:
print(f"📢 Attempt {attempt}: Publishing to topic '{orchestrator_topic}'...") print(
f"📢 Attempt {attempt}: Publishing to topic '{orchestrator_topic}'..."
)
with DaprClient() as client: with DaprClient() as client:
client.publish_event( client.publish_event(
pubsub_name=PUBSUB_NAME, pubsub_name=PUBSUB_NAME,
@ -36,7 +39,7 @@ def main(orchestrator_topic, max_attempts=10, retry_delay=1):
data_content_type="application/json", data_content_type="application/json",
publish_metadata={ publish_metadata={
"cloudevent.type": "TriggerAction", "cloudevent.type": "TriggerAction",
} },
) )
print(f"✅ Successfully published request to '{orchestrator_topic}'") print(f"✅ Successfully published request to '{orchestrator_topic}'")
@ -44,7 +47,7 @@ def main(orchestrator_topic, max_attempts=10, retry_delay=1):
except Exception as e: except Exception as e:
print(f"❌ Request failed: {e}") print(f"❌ Request failed: {e}")
attempt += 1 attempt += 1
print(f"⏳ Waiting {retry_delay}s before next attempt...") print(f"⏳ Waiting {retry_delay}s before next attempt...")
time.sleep(retry_delay) time.sleep(retry_delay)
@ -52,8 +55,8 @@ def main(orchestrator_topic, max_attempts=10, retry_delay=1):
print(f"❌ Maximum attempts ({max_attempts}) reached without success.") print(f"❌ Maximum attempts ({max_attempts}) reached without success.")
sys.exit(1) sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
orchestrator_topic = "LLMOrchestrator"
orchestrator_topic = 'LLMOrchestrator' main(orchestrator_topic)
main(orchestrator_topic)

View File

@ -6,10 +6,8 @@ import logging
async def main(): async def main():
try: try:
workflow_service = LLMOrchestrator( workflow_service = LLMOrchestrator(
name="LLMOrchestrator", name="LLMOrchestrator",
message_bus_name="pubsub", message_bus_name="pubsub",
state_store_name="workflowstatestore", state_store_name="workflowstatestore",
state_key="workflow_state", state_key="workflow_state",
@ -28,4 +26,4 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,8 +16,8 @@ async def main():
"Be swift, silent, and precise, moving effortlessly across any terrain.", "Be swift, silent, and precise, moving effortlessly across any terrain.",
"Use superior vision and heightened senses to scout ahead and detect threats.", "Use superior vision and heightened senses to scout ahead and detect threats.",
"Excel in ranged combat, delivering pinpoint arrow strikes from great distances.", "Excel in ranged combat, delivering pinpoint arrow strikes from great distances.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
] ],
) )
# Expose Agent as an Actor over a Service # Expose Agent as an Actor over a Service
@ -32,9 +33,10 @@ async def main():
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,26 +16,27 @@ async def main():
"Endure hardships and temptations, staying true to the mission even when faced with doubt.", "Endure hardships and temptations, staying true to the mission even when faced with doubt.",
"Seek guidance and trust allies, but bear the ultimate burden alone when necessary.", "Seek guidance and trust allies, but bear the ultimate burden alone when necessary.",
"Move carefully through enemy-infested lands, avoiding unnecessary risks.", "Move carefully through enemy-infested lands, avoiding unnecessary risks.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
] ],
) )
# Expose Agent as an Actor over a Service # Expose Agent as an Actor over a Service
hobbit_service = AgentActor( hobbit_service = AgentActor(
agent=hobbit_agent, agent=hobbit_agent,
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
service_port=8001 service_port=8001,
) )
await hobbit_service.start() await hobbit_service.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,8 +16,8 @@ async def main():
"Provide strategic counsel, always considering the long-term consequences of actions.", "Provide strategic counsel, always considering the long-term consequences of actions.",
"Use magic sparingly, applying it when necessary to guide or protect.", "Use magic sparingly, applying it when necessary to guide or protect.",
"Encourage allies to find strength within themselves rather than relying solely on your power.", "Encourage allies to find strength within themselves rather than relying solely on your power.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
] ],
) )
# Expose Agent as an Actor over a Service # Expose Agent as an Actor over a Service
@ -25,16 +26,17 @@ async def main():
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
service_port=8002 service_port=8002,
) )
await wizard_service.start() await wizard_service.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
agentic_orchestrator = LLMOrchestrator( agentic_orchestrator = LLMOrchestrator(
@ -12,16 +13,17 @@ async def main():
state_key="workflow_state", state_key="workflow_state",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
max_iterations=25 max_iterations=25,
).as_service(port=8004) ).as_service(port=8004)
await agentic_orchestrator.start() await agentic_orchestrator.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
random_workflow_service = RandomOrchestrator( random_workflow_service = RandomOrchestrator(
@ -12,16 +13,17 @@ async def main():
state_key="workflow_state", state_key="workflow_state",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
max_iterations=3 max_iterations=3,
).as_service(port=8004) ).as_service(port=8004)
await random_workflow_service.start() await random_workflow_service.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
roundrobin_workflow_service = RoundRobinOrchestrator( roundrobin_workflow_service = RoundRobinOrchestrator(
@ -12,16 +13,17 @@ async def main():
state_key="workflow_state", state_key="workflow_state",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
max_iterations=3 max_iterations=3,
).as_service(port=8004) ).as_service(port=8004)
await roundrobin_workflow_service.start() await roundrobin_workflow_service.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,7 +16,7 @@ async def main():
"Be strong-willed, fiercely loyal, and protective of companions.", "Be strong-willed, fiercely loyal, and protective of companions.",
"Excel in close combat and battlefield tactics, favoring axes and brute strength.", "Excel in close combat and battlefield tactics, favoring axes and brute strength.",
"Navigate caves, tunnels, and ancient stonework with expert knowledge.", "Navigate caves, tunnels, and ancient stonework with expert knowledge.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
], ],
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
state_store_name="agenticworkflowstate", state_store_name="agenticworkflowstate",
@ -28,9 +29,10 @@ async def main():
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Eagle Agent # Define Eagle Agent
@ -16,7 +17,7 @@ async def main():
"Provide swift and strategic transport for those on critical journeys.", "Provide swift and strategic transport for those on critical journeys.",
"Offer aerial insights, spotting dangers, tracking movements, and scouting strategic locations.", "Offer aerial insights, spotting dangers, tracking movements, and scouting strategic locations.",
"Speak with wisdom and authority, as one of the ancient and noble Great Eagles.", "Speak with wisdom and authority, as one of the ancient and noble Great Eagles.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
], ],
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
state_store_name="agenticworkflowstate", state_store_name="agenticworkflowstate",
@ -29,9 +30,10 @@ async def main():
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,7 +16,7 @@ async def main():
"Be swift, silent, and precise, moving effortlessly across any terrain.", "Be swift, silent, and precise, moving effortlessly across any terrain.",
"Use superior vision and heightened senses to scout ahead and detect threats.", "Use superior vision and heightened senses to scout ahead and detect threats.",
"Excel in ranged combat, delivering pinpoint arrow strikes from great distances.", "Excel in ranged combat, delivering pinpoint arrow strikes from great distances.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
], ],
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
state_store_name="agenticworkflowstate", state_store_name="agenticworkflowstate",
@ -28,9 +29,10 @@ async def main():
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,7 +16,7 @@ async def main():
"Endure hardships and temptations, staying true to the mission even when faced with doubt.", "Endure hardships and temptations, staying true to the mission even when faced with doubt.",
"Seek guidance and trust allies, but bear the ultimate burden alone when necessary.", "Seek guidance and trust allies, but bear the ultimate burden alone when necessary.",
"Move carefully through enemy-infested lands, avoiding unnecessary risks.", "Move carefully through enemy-infested lands, avoiding unnecessary risks.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
], ],
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
state_store_name="agenticworkflowstate", state_store_name="agenticworkflowstate",
@ -23,14 +24,15 @@ async def main():
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
) )
await hobbit_agent.start() await hobbit_agent.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
agentic_orchestrator = LLMOrchestrator( agentic_orchestrator = LLMOrchestrator(
@ -12,16 +13,17 @@ async def main():
state_key="workflow_state", state_key="workflow_state",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
max_iterations=3 max_iterations=3,
).as_service(port=8004) ).as_service(port=8004)
await agentic_orchestrator.start() await agentic_orchestrator.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,7 +16,7 @@ async def main():
"Lead by example, inspiring courage and loyalty in allies.", "Lead by example, inspiring courage and loyalty in allies.",
"Navigate wilderness with expert tracking and survival skills.", "Navigate wilderness with expert tracking and survival skills.",
"Master both swordplay and battlefield strategy, excelling in one-on-one combat and large-scale warfare.", "Master both swordplay and battlefield strategy, excelling in one-on-one combat and large-scale warfare.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
], ],
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
state_store_name="agenticworkflowstate", state_store_name="agenticworkflowstate",
@ -28,9 +29,10 @@ async def main():
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Define Agent # Define Agent
@ -15,7 +16,7 @@ async def main():
"Provide strategic counsel, always considering the long-term consequences of actions.", "Provide strategic counsel, always considering the long-term consequences of actions.",
"Use magic sparingly, applying it when necessary to guide or protect.", "Use magic sparingly, applying it when necessary to guide or protect.",
"Encourage allies to find strength within themselves rather than relying solely on your power.", "Encourage allies to find strength within themselves rather than relying solely on your power.",
"Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task." "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.",
], ],
message_bus_name="messagepubsub", message_bus_name="messagepubsub",
state_store_name="agenticworkflowstate", state_store_name="agenticworkflowstate",
@ -28,9 +29,10 @@ async def main():
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
random_workflow_service = RandomOrchestrator( random_workflow_service = RandomOrchestrator(
@ -12,16 +13,17 @@ async def main():
state_key="workflow_state", state_key="workflow_state",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
max_iterations=3 max_iterations=3,
).as_service(port=8004) ).as_service(port=8004)
await random_workflow_service.start() await random_workflow_service.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
roundrobin_workflow_service = RoundRobinOrchestrator( roundrobin_workflow_service = RoundRobinOrchestrator(
@ -12,16 +13,17 @@ async def main():
state_key="workflow_state", state_key="workflow_state",
agents_registry_store_name="agentsregistrystore", agents_registry_store_name="agentsregistrystore",
agents_registry_key="agents_registry", agents_registry_key="agents_registry",
max_iterations=3 max_iterations=3,
).as_service(port=8004) ).as_service(port=8004)
await roundrobin_workflow_service.start() await roundrobin_workflow_service.start()
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -3,6 +3,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
import logging import logging
async def main(): async def main():
try: try:
# Create the Weather Agent using those tools # Create the Weather Agent using those tools
@ -27,9 +28,10 @@ async def main():
except Exception as e: except Exception as e:
print(f"Error starting service: {e}") print(f"Error starting service: {e}")
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
asyncio.run(main()) asyncio.run(main())

View File

@ -23,7 +23,7 @@ if __name__ == "__main__":
print(f"Request failed: {e}") print(f"Request failed: {e}")
attempt += 1 attempt += 1
print(f"Waiting 5s seconds before next health checkattempt...") print("Waiting 5s seconds before next health checkattempt...")
time.sleep(5) time.sleep(5)
if not healthy: if not healthy:
@ -48,10 +48,10 @@ if __name__ == "__main__":
print(f"Request failed: {e}") print(f"Request failed: {e}")
attempt += 1 attempt += 1
print(f"Waiting 1s seconds before next attempt...") print("Waiting 1s seconds before next attempt...")
time.sleep(1) time.sleep(1)
print(f"Maximum attempts (10) reached without success.") print("Maximum attempts (10) reached without success.")
print("Failed to get successful response") print("Failed to get successful response")
sys.exit(1) sys.exit(1)

View File

@ -1,12 +1,25 @@
from dapr_agents.agent import Agent, AgentActor, ReActAgent, ToolCallAgent, OpenAPIReActAgent from dapr_agents.agent import (
from dapr_agents.llm.openai import OpenAIChatClient, OpenAIAudioClient, OpenAIEmbeddingClient Agent,
AgentActor,
ReActAgent,
ToolCallAgent,
OpenAPIReActAgent,
)
from dapr_agents.llm.openai import (
OpenAIChatClient,
OpenAIAudioClient,
OpenAIEmbeddingClient,
)
from dapr_agents.llm.huggingface import HFHubChatClient from dapr_agents.llm.huggingface import HFHubChatClient
from dapr_agents.llm.nvidia import NVIDIAChatClient, NVIDIAEmbeddingClient from dapr_agents.llm.nvidia import NVIDIAChatClient, NVIDIAEmbeddingClient
from dapr_agents.llm.elevenlabs import ElevenLabsSpeechClient from dapr_agents.llm.elevenlabs import ElevenLabsSpeechClient
from dapr_agents.tool import AgentTool, tool from dapr_agents.tool import AgentTool, tool
from dapr_agents.workflow import ( from dapr_agents.workflow import (
WorkflowApp, AgenticWorkflow, WorkflowApp,
LLMOrchestrator, RandomOrchestrator, RoundRobinOrchestrator, AgenticWorkflow,
AssistantAgent LLMOrchestrator,
RandomOrchestrator,
RoundRobinOrchestrator,
AssistantAgent,
) )
from dapr_agents.executors import LocalCodeExecutor, DockerCodeExecutor from dapr_agents.executors import LocalCodeExecutor, DockerCodeExecutor

View File

@ -1,4 +1,4 @@
from .base import AgentBase from .base import AgentBase
from .utils.factory import Agent from .utils.factory import Agent
from .actor import AgentActor from .actor import AgentActor
from .patterns import ReActAgent, ToolCallAgent, OpenAPIReActAgent from .patterns import ReActAgent, ToolCallAgent, OpenAPIReActAgent

View File

@ -1,4 +1,4 @@
from .base import AgentActorBase from .base import AgentActorBase
from .interface import AgentActorInterface from .interface import AgentActorInterface
from .service import AgentActorService from .service import AgentActorService
from .agent import AgentActor from .agent import AgentActor

View File

@ -1,16 +1,21 @@
import logging import logging
from dapr_agents.agent.actor.schemas import AgentTaskResponse, TriggerAction, BroadcastMessage from dapr_agents.agent.actor.schemas import (
AgentTaskResponse,
TriggerAction,
BroadcastMessage,
)
from dapr_agents.agent.actor.service import AgentActorService from dapr_agents.agent.actor.service import AgentActorService
from dapr_agents.types.agent import AgentActorMessage from dapr_agents.types.agent import AgentActorMessage
from dapr_agents.workflow.messaging.decorator import message_router from dapr_agents.workflow.messaging.decorator import message_router
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AgentActor(AgentActorService): class AgentActor(AgentActorService):
""" """
A Pydantic-based class for managing services and exposing FastAPI routes with Dapr pub/sub and actor support. A Pydantic-based class for managing services and exposing FastAPI routes with Dapr pub/sub and actor support.
""" """
@message_router @message_router
async def process_trigger_action(self, message: TriggerAction): async def process_trigger_action(self, message: TriggerAction):
""" """
@ -35,17 +40,23 @@ class AgentActor(AgentActorService):
response = await self.invoke_task(task) response = await self.invoke_task(task)
# Check if the response exists # Check if the response exists
content = response.body.decode() if response and response.body else "Task completed but no response generated." content = (
response.body.decode()
if response and response.body
else "Task completed but no response generated."
)
# Broadcast result # Broadcast result
response_message = BroadcastMessage(name=self.agent.name, role="user", content=content) response_message = BroadcastMessage(
name=self.agent.name, role="user", content=content
)
await self.broadcast_message(message=response_message) await self.broadcast_message(message=response_message)
# Update response # Update response
response_message = response_message.model_dump() response_message = response_message.model_dump()
response_message["workflow_instance_id"] = workflow_instance_id response_message["workflow_instance_id"] = workflow_instance_id
agent_response = AgentTaskResponse(**response_message) agent_response = AgentTaskResponse(**response_message)
# Send the message to the target agent # Send the message to the target agent
await self.send_message_to_agent(name=source, message=agent_response) await self.send_message_to_agent(name=source, message=agent_response)
except Exception as e: except Exception as e:
@ -60,22 +71,30 @@ class AgentActor(AgentActorService):
metadata = message.pop("_message_metadata", {}) metadata = message.pop("_message_metadata", {})
if not isinstance(metadata, dict): if not isinstance(metadata, dict):
logger.warning(f"{getattr(self, 'name', 'agent')} received a broadcast with invalid metadata. Ignoring.") logger.warning(
f"{getattr(self, 'name', 'agent')} received a broadcast with invalid metadata. Ignoring."
)
return return
source = metadata.get("source", "unknown_source") source = metadata.get("source", "unknown_source")
message_type = metadata.get("type", "unknown_type") message_type = metadata.get("type", "unknown_type")
message_content = message.get("content", "No content") message_content = message.get("content", "No content")
logger.info(f"{self.agent.name} received broadcast message of type '{message_type}' from '{source}'.") logger.info(
f"{self.agent.name} received broadcast message of type '{message_type}' from '{source}'."
)
# Ignore messages sent by this agent # Ignore messages sent by this agent
if source == self.agent.name: if source == self.agent.name:
logger.info(f"{self.agent.name} ignored its own broadcast message of type '{message_type}'.") logger.info(
f"{self.agent.name} ignored its own broadcast message of type '{message_type}'."
)
return return
# Log and process the valid broadcast message # Log and process the valid broadcast message
logger.debug(f"{self.agent.name} is processing broadcast message of type '{message_type}' from '{source}'.") logger.debug(
f"{self.agent.name} is processing broadcast message of type '{message_type}' from '{source}'."
)
logger.debug(f"Message content: {message_content}") logger.debug(f"Message content: {message_content}")
# Add the message to the agent's memory # Add the message to the agent's memory
@ -86,4 +105,4 @@ class AgentActor(AgentActorService):
await self.add_message(actor_message) await self.add_message(actor_message)
except Exception as e: except Exception as e:
logger.error(f"Error processing broadcast message: {e}", exc_info=True) logger.error(f"Error processing broadcast message: {e}", exc_info=True)

View File

@ -16,6 +16,7 @@ from pydantic import ValidationError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AgentActorBase(Actor, AgentActorInterface): class AgentActorBase(Actor, AgentActorInterface):
"""Base class for all agent actors, including task execution and agent state management.""" """Base class for all agent actors, including task execution and agent state management."""
@ -24,19 +25,23 @@ class AgentActorBase(Actor, AgentActorInterface):
self.actor_id = actor_id self.actor_id = actor_id
self.agent: AgentBase self.agent: AgentBase
self.agent_state_key = "agent_state" self.agent_state_key = "agent_state"
async def _on_activate(self) -> None: async def _on_activate(self) -> None:
""" """
Called when the actor is activated. Initializes the agent's state if not present. Called when the actor is activated. Initializes the agent's state if not present.
""" """
logger.info(f"Activating actor with ID: {self.actor_id}") logger.info(f"Activating actor with ID: {self.actor_id}")
has_state, state_data = await self._state_manager.try_get_state(self.agent_state_key) has_state, state_data = await self._state_manager.try_get_state(
self.agent_state_key
)
if not has_state: if not has_state:
# Initialize state with default values if it doesn't exist # Initialize state with default values if it doesn't exist
logger.info(f"Initializing state for {self.actor_id}") logger.info(f"Initializing state for {self.actor_id}")
self.state = AgentActorState(overall_status=AgentStatus.IDLE) self.state = AgentActorState(overall_status=AgentStatus.IDLE)
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump()) await self._state_manager.set_state(
self.agent_state_key, self.state.model_dump()
)
await self._state_manager.save_state() await self._state_manager.save_state()
else: else:
# Load existing state # Load existing state
@ -48,16 +53,20 @@ class AgentActorBase(Actor, AgentActorInterface):
""" """
Called when the actor is deactivated. Called when the actor is deactivated.
""" """
logger.info(f"Deactivate {self.__class__.__name__} actor with ID: {self.actor_id}.") logger.info(
f"Deactivate {self.__class__.__name__} actor with ID: {self.actor_id}."
)
async def set_status(self, status: AgentStatus) -> None: async def set_status(self, status: AgentStatus) -> None:
""" """
Sets the current operational status of the agent and saves the state. Sets the current operational status of the agent and saves the state.
""" """
self.state.overall_status = status self.state.overall_status = status
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump()) await self._state_manager.set_state(
self.agent_state_key, self.state.model_dump()
)
await self._state_manager.save_state() await self._state_manager.save_state()
async def invoke_task(self, task: Optional[str] = None) -> str: async def invoke_task(self, task: Optional[str] = None) -> str:
""" """
Execute the agent's main task, log the input/output in the task history, Execute the agent's main task, log the input/output in the task history,
@ -76,7 +85,9 @@ class AgentActorBase(Actor, AgentActorInterface):
# Look for the last message in the conversation history # Look for the last message in the conversation history
last_message = messages[-1] last_message = messages[-1]
default_task = last_message.get("content") default_task = last_message.get("content")
logger.debug(f"Default task entry input derived from last message: {default_task}") logger.debug(
f"Default task entry input derived from last message: {default_task}"
)
# Prepare the input for task entry # Prepare the input for task entry
task_entry_input = task or default_task or "Triggered without a specific task" task_entry_input = task or default_task or "Triggered without a specific task"
@ -93,7 +104,9 @@ class AgentActorBase(Actor, AgentActorInterface):
self.state.task_history.append(task_entry) self.state.task_history.append(task_entry)
# Save initial task state with IN_PROGRESS status # Save initial task state with IN_PROGRESS status
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump()) await self._state_manager.set_state(
self.agent_state_key, self.state.model_dump()
)
await self._state_manager.save_state() await self._state_manager.save_state()
try: try:
@ -120,11 +133,13 @@ class AgentActorBase(Actor, AgentActorInterface):
finally: finally:
# Ensure the final state of the task is saved # Ensure the final state of the task is saved
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump()) await self._state_manager.set_state(
self.agent_state_key, self.state.model_dump()
)
await self._state_manager.save_state() await self._state_manager.save_state()
# Revert the agent's status to idle # Revert the agent's status to idle
await self.set_status(AgentStatus.IDLE) await self.set_status(AgentStatus.IDLE)
async def add_message(self, message: Union[AgentActorMessage, dict]) -> None: async def add_message(self, message: Union[AgentActorMessage, dict]) -> None:
""" """
Adds a message to the conversation history in the actor's state. Adds a message to the conversation history in the actor's state.
@ -135,21 +150,25 @@ class AgentActorBase(Actor, AgentActorInterface):
# Convert dictionary to AgentActorMessage if necessary # Convert dictionary to AgentActorMessage if necessary
if isinstance(message, dict): if isinstance(message, dict):
message = AgentActorMessage(**message) message = AgentActorMessage(**message)
# Add the new message to the state # Add the new message to the state
self.state.messages.append(message) self.state.messages.append(message)
self.state.message_count += 1 self.state.message_count += 1
# Save state back to Dapr # Save state back to Dapr
await self._state_manager.set_state(self.agent_state_key, self.state.model_dump()) await self._state_manager.set_state(
self.agent_state_key, self.state.model_dump()
)
await self._state_manager.save_state() await self._state_manager.save_state()
async def get_messages(self) -> List[dict]: async def get_messages(self) -> List[dict]:
""" """
Retrieves the messages from the actor's state, validates it using Pydantic, Retrieves the messages from the actor's state, validates it using Pydantic,
and returns a list of dictionaries if valid. and returns a list of dictionaries if valid.
""" """
has_state, state_data = await self._state_manager.try_get_state(self.agent_state_key) has_state, state_data = await self._state_manager.try_get_state(
self.agent_state_key
)
if has_state: if has_state:
try: try:
@ -162,4 +181,4 @@ class AgentActorBase(Actor, AgentActorInterface):
# Handle validation errors # Handle validation errors
print(f"Validation error: {e}") print(f"Validation error: {e}")
return [] return []
return [] return []

View File

@ -3,9 +3,10 @@ from typing import List, Optional, Union
from dapr.actor import ActorInterface, actormethod from dapr.actor import ActorInterface, actormethod
from dapr_agents.types.agent import AgentActorMessage, AgentStatus from dapr_agents.types.agent import AgentActorMessage, AgentStatus
class AgentActorInterface(ActorInterface): class AgentActorInterface(ActorInterface):
@abstractmethod @abstractmethod
@actormethod(name='InvokeTask') @actormethod(name="InvokeTask")
async def invoke_task(self, task: Optional[str] = None) -> str: async def invoke_task(self, task: Optional[str] = None) -> str:
""" """
Invoke a task and returns the result as a string. Invoke a task and returns the result as a string.
@ -13,7 +14,7 @@ class AgentActorInterface(ActorInterface):
pass pass
@abstractmethod @abstractmethod
@actormethod(name='AddMessage') @actormethod(name="AddMessage")
async def add_message(self, message: Union[AgentActorMessage, dict]) -> None: async def add_message(self, message: Union[AgentActorMessage, dict]) -> None:
""" """
Adds a message to the conversation history in the actor's state. Adds a message to the conversation history in the actor's state.
@ -21,7 +22,7 @@ class AgentActorInterface(ActorInterface):
pass pass
@abstractmethod @abstractmethod
@actormethod(name='GetMessages') @actormethod(name="GetMessages")
async def get_messages(self) -> List[dict]: async def get_messages(self) -> List[dict]:
""" """
Retrieves the conversation history from the actor's state. Retrieves the conversation history from the actor's state.
@ -29,9 +30,9 @@ class AgentActorInterface(ActorInterface):
pass pass
@abstractmethod @abstractmethod
@actormethod(name='SetStatus') @actormethod(name="SetStatus")
async def set_status(self, status: AgentStatus) -> None: async def set_status(self, status: AgentStatus) -> None:
""" """
Sets the current operational status of the agent. Sets the current operational status of the agent.
""" """
pass pass

View File

@ -2,21 +2,33 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from dapr_agents.types.message import BaseMessage from dapr_agents.types.message import BaseMessage
class AgentTaskResponse(BaseMessage): class AgentTaskResponse(BaseMessage):
""" """
Represents a response message from an agent after completing a task. Represents a response message from an agent after completing a task.
""" """
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available")
workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)
class TriggerAction(BaseModel): class TriggerAction(BaseModel):
""" """
Represents a message used to trigger an agent's activity within the workflow. Represents a message used to trigger an agent's activity within the workflow.
""" """
task: Optional[str] = Field(None, description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.")
task: Optional[str] = Field(
None,
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
)
iteration: Optional[int] = Field(0, description="") iteration: Optional[int] = Field(0, description="")
workflow_instance_id: Optional[str] = Field(default=None, description="Dapr workflow instance id from source if available") workflow_instance_id: Optional[str] = Field(
default=None, description="Dapr workflow instance id from source if available"
)
class BroadcastMessage(BaseMessage): class BroadcastMessage(BaseMessage):
""" """
Represents a broadcast message from an agent Represents a broadcast message from an agent
""" """

View File

@ -18,7 +18,10 @@ from dapr.actor.runtime.config import (
) )
from dapr.actor.runtime.runtime import ActorRuntime from dapr.actor.runtime.runtime import ActorRuntime
from dapr.clients import DaprClient from dapr.clients import DaprClient
from dapr.clients.grpc._request import TransactionOperationType, TransactionalStateOperation from dapr.clients.grpc._request import (
TransactionOperationType,
TransactionalStateOperation,
)
from dapr.clients.grpc._response import StateResponse from dapr.clients.grpc._response import StateResponse
from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
from dapr.ext.fastapi import DaprActor from dapr.ext.fastapi import DaprActor
@ -34,22 +37,56 @@ from dapr_agents.workflow.messaging.routing import MessageRoutingMixin
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AgentActorService(DaprPubSub, MessageRoutingMixin): class AgentActorService(DaprPubSub, MessageRoutingMixin):
agent: AgentBase agent: AgentBase
name: Optional[str] = Field(default=None, description="Name of the agent actor, derived from the agent if not provided.") name: Optional[str] = Field(
agent_topic_name: Optional[str] = Field(None, description="The topic name dedicated to this specific agent, derived from the agent's name if not provided.") default=None,
broadcast_topic_name: str = Field("beacon_channel", description="The default topic used for broadcasting messages to all agents.") description="Name of the agent actor, derived from the agent if not provided.",
agents_registry_store_name: str = Field(..., description="The name of the Dapr state store component used to store and share agent metadata centrally.") )
agents_registry_key: str = Field(default="agents_registry", description="Dapr state store key for agentic workflow state.") agent_topic_name: Optional[str] = Field(
service_port: Optional[int] = Field(default=None, description="The port number to run the API server on.") None,
service_host: Optional[str] = Field(default="0.0.0.0", description="Host address for the API server.") description="The topic name dedicated to this specific agent, derived from the agent's name if not provided.",
)
broadcast_topic_name: str = Field(
"beacon_channel",
description="The default topic used for broadcasting messages to all agents.",
)
agents_registry_store_name: str = Field(
...,
description="The name of the Dapr state store component used to store and share agent metadata centrally.",
)
agents_registry_key: str = Field(
default="agents_registry",
description="Dapr state store key for agentic workflow state.",
)
service_port: Optional[int] = Field(
default=None, description="The port number to run the API server on."
)
service_host: Optional[str] = Field(
default="0.0.0.0", description="Host address for the API server."
)
# Fields initialized in model_post_init # Fields initialized in model_post_init
actor: Optional[DaprActor] = Field(default=None, init=False, description="DaprActor for actor lifecycle support.") actor: Optional[DaprActor] = Field(
actor_name: Optional[str] = Field(default=None, init=False, description="Actor name") default=None, init=False, description="DaprActor for actor lifecycle support."
actor_proxy: Optional[ActorProxy] = Field(default=None, init=False, description="Proxy for invoking methods on the agent's actor.") )
actor_class: Optional[type] = Field(default=None, init=False, description="Dynamically created actor class for the agent") actor_name: Optional[str] = Field(
agent_metadata: Optional[dict] = Field(default=None, init=False, description="Agent's metadata") default=None, init=False, description="Actor name"
)
actor_proxy: Optional[ActorProxy] = Field(
default=None,
init=False,
description="Proxy for invoking methods on the agent's actor.",
)
actor_class: Optional[type] = Field(
default=None,
init=False,
description="Dynamically created actor class for the agent",
)
agent_metadata: Optional[dict] = Field(
default=None, init=False, description="Agent's metadata"
)
# Private internal attributes (not schema/validated) # Private internal attributes (not schema/validated)
_http_server: Optional[Any] = PrivateAttr(default=None) _http_server: Optional[Any] = PrivateAttr(default=None)
@ -57,7 +94,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
_dapr_client: Optional[DaprClient] = PrivateAttr(default=None) _dapr_client: Optional[DaprClient] = PrivateAttr(default=None)
_is_running: bool = PrivateAttr(default=False) _is_running: bool = PrivateAttr(default=False)
_subscriptions: Dict[str, Callable] = PrivateAttr(default_factory=dict) _subscriptions: Dict[str, Callable] = PrivateAttr(default_factory=dict)
_topic_handlers: Dict[Tuple[str, str], Dict[Type[BaseModel], Callable]] = PrivateAttr(default_factory=dict) _topic_handlers: Dict[
Tuple[str, str], Dict[Type[BaseModel], Callable]
] = PrivateAttr(default_factory=dict)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -71,19 +110,25 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
if not values.get("name") and agent: if not values.get("name") and agent:
values["name"] = agent.name or agent.role values["name"] = agent.name or agent.role
return values return values
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
# Proceed with base model setup # Proceed with base model setup
super().model_post_init(__context) super().model_post_init(__context)
# Dynamically create the actor class based on the agent's name # Dynamically create the actor class based on the agent's name
actor_class_name = f"{self.agent.name}Actor" actor_class_name = f"{self.agent.name}Actor"
# Create the actor class dynamically using the 'type' function # Create the actor class dynamically using the 'type' function
self.actor_class = type(actor_class_name, (AgentActorBase,), { self.actor_class = type(
'__init__': lambda self, ctx, actor_id: AgentActorBase.__init__(self, ctx, actor_id), actor_class_name,
'agent': self.agent (AgentActorBase,),
}) {
"__init__": lambda self, ctx, actor_id: AgentActorBase.__init__(
self, ctx, actor_id
),
"agent": self.agent,
},
)
# Prepare agent metadata # Prepare agent metadata
self.agent_metadata = { self.agent_metadata = {
@ -92,12 +137,14 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
"goal": self.agent.goal, "goal": self.agent.goal,
"topic_name": self.agent_topic_name, "topic_name": self.agent_topic_name,
"pubsub_name": self.message_bus_name, "pubsub_name": self.message_bus_name,
"orchestrator": False "orchestrator": False,
} }
# Proxy for actor methods # Proxy for actor methods
self.actor_name = self.actor_class.__name__ self.actor_name = self.actor_class.__name__
self.actor_proxy = ActorProxy.create(self.actor_name, ActorId(self.agent.name), AgentActorInterface) self.actor_proxy = ActorProxy.create(
self.actor_name, ActorId(self.agent.name), AgentActorInterface
)
# Initialize Sync Dapr Client # Initialize Sync Dapr Client
self._dapr_client = DaprClient() self._dapr_client = DaprClient()
@ -106,13 +153,13 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
self._http_server: FastAPIServerBase = FastAPIServerBase( self._http_server: FastAPIServerBase = FastAPIServerBase(
service_name=self.agent.name, service_name=self.agent.name,
service_port=self.service_port, service_port=self.service_port,
service_host=self.service_host service_host=self.service_host,
) )
self._http_server.app.router.lifespan_context = self.lifespan self._http_server.app.router.lifespan_context = self.lifespan
# Create DaprActor using FastAPI app # Create DaprActor using FastAPI app
self.actor = DaprActor(self.app) self.actor = DaprActor(self.app)
self.app.add_api_route("/GetMessages", self.get_messages, methods=["GET"]) self.app.add_api_route("/GetMessages", self.get_messages, methods=["GET"])
logger.info(f"Dapr Actor class {self.actor_class.__name__} initialized.") logger.info(f"Dapr Actor class {self.actor_class.__name__} initialized.")
@ -128,20 +175,23 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
if self._http_server: if self._http_server:
return self._http_server.app return self._http_server.app
raise RuntimeError("FastAPI server not initialized.") raise RuntimeError("FastAPI server not initialized.")
@asynccontextmanager @asynccontextmanager
async def lifespan(self, app: FastAPI): async def lifespan(self, app: FastAPI):
# Register actor # Register actor
actor_runtime_config = ActorRuntimeConfig() actor_runtime_config = ActorRuntimeConfig()
actor_runtime_config.update_actor_type_configs([ actor_runtime_config.update_actor_type_configs(
ActorTypeConfig( [
actor_type=self.actor_class.__name__, ActorTypeConfig(
actor_idle_timeout=timedelta(hours=1), actor_type=self.actor_class.__name__,
actor_scan_interval=timedelta(seconds=30), actor_idle_timeout=timedelta(hours=1),
drain_ongoing_call_timeout=timedelta(minutes=1), actor_scan_interval=timedelta(seconds=30),
drain_rebalanced_actors=True, drain_ongoing_call_timeout=timedelta(minutes=1),
reentrancy=ActorReentrancyConfig(enabled=True)) drain_rebalanced_actors=True,
]) reentrancy=ActorReentrancyConfig(enabled=True),
)
]
)
ActorRuntime.set_actor_config(actor_runtime_config) ActorRuntime.set_actor_config(actor_runtime_config)
await self.actor.register_actor(self.actor_class) await self.actor.register_actor(self.actor_class)
@ -158,7 +208,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
async def start(self): async def start(self):
if self._is_running: if self._is_running:
logger.warning("Service is already running. Ignoring duplicate start request.") logger.warning(
"Service is already running. Ignoring duplicate start request."
)
return return
logger.info("Starting Agent Actor Service...") logger.info("Starting Agent Actor Service...")
@ -176,7 +228,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
for (pubsub_name, topic_name), close_fn in self._subscriptions.items(): for (pubsub_name, topic_name), close_fn in self._subscriptions.items():
try: try:
logger.info(f"Unsubscribing from pubsub '{pubsub_name}' topic '{topic_name}'") logger.info(
f"Unsubscribing from pubsub '{pubsub_name}' topic '{topic_name}'"
)
close_fn() close_fn()
except Exception as e: except Exception as e:
logger.error(f"Failed to unsubscribe from topic '{topic_name}': {e}") logger.error(f"Failed to unsubscribe from topic '{topic_name}': {e}")
@ -184,7 +238,7 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
self._subscriptions.clear() self._subscriptions.clear()
self._is_running = False self._is_running = False
logger.info("Agent Actor Service stopped.") logger.info("Agent Actor Service stopped.")
def get_data_from_store(self, store_name: str, key: str) -> Optional[dict]: def get_data_from_store(self, store_name: str, key: str) -> Optional[dict]:
""" """
Retrieve data from a specified Dapr state store using a provided key. Retrieve data from a specified Dapr state store using a provided key.
@ -197,15 +251,21 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
Optional[dict]: The data stored under the specified key if found; otherwise, None. Optional[dict]: The data stored under the specified key if found; otherwise, None.
""" """
try: try:
response: StateResponse = self._dapr_client.get_state(store_name=store_name, key=key) response: StateResponse = self._dapr_client.get_state(
store_name=store_name, key=key
)
data = response.data data = response.data
return json.loads(data) if data else None return json.loads(data) if data else None
except Exception as e: except Exception:
logger.warning(f"Error retrieving data for key '{key}' from store '{store_name}'") logger.warning(
f"Error retrieving data for key '{key}' from store '{store_name}'"
)
return None return None
def get_agents_metadata(self, exclude_self: bool = True, exclude_orchestrator: bool = False) -> dict: def get_agents_metadata(
self, exclude_self: bool = True, exclude_orchestrator: bool = False
) -> dict:
""" """
Retrieves metadata for all registered agents while ensuring orchestrators do not interact with other orchestrators. Retrieves metadata for all registered agents while ensuring orchestrators do not interact with other orchestrators.
@ -221,17 +281,28 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
""" """
try: try:
# Fetch agent metadata from the registry # Fetch agent metadata from the registry
agents_metadata = self.get_data_from_store(self.agents_registry_store_name, self.agents_registry_key) or {} agents_metadata = (
self.get_data_from_store(
self.agents_registry_store_name, self.agents_registry_key
)
or {}
)
if agents_metadata: if agents_metadata:
logger.info(f"Agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'.") logger.info(
f"Agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'."
)
# Filter based on exclusion rules # Filter based on exclusion rules
filtered_metadata = { filtered_metadata = {
name: metadata name: metadata
for name, metadata in agents_metadata.items() for name, metadata in agents_metadata.items()
if not (exclude_self and name == self.agent.name) # Exclude self if requested if not (
and not (exclude_orchestrator and metadata.get("orchestrator", False)) # Exclude all orchestrators if exclude_orchestrator=True exclude_self and name == self.agent.name
) # Exclude self if requested
and not (
exclude_orchestrator and metadata.get("orchestrator", False)
) # Exclude all orchestrators if exclude_orchestrator=True
} }
if not filtered_metadata: if not filtered_metadata:
@ -239,12 +310,14 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
return filtered_metadata return filtered_metadata
logger.info(f"No agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'.") logger.info(
f"No agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'."
)
return {} return {}
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve agents metadata: {e}", exc_info=True) logger.error(f"Failed to retrieve agents metadata: {e}", exc_info=True)
return {} return {}
def register_agent_metadata(self) -> None: def register_agent_metadata(self) -> None:
""" """
Registers the agent's metadata in the Dapr state store under 'agents_metadata'. Registers the agent's metadata in the Dapr state store under 'agents_metadata'.
@ -255,14 +328,20 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
store_name=self.agents_registry_store_name, store_name=self.agents_registry_store_name,
store_key=self.agents_registry_key, store_key=self.agents_registry_key,
agent_name=self.name, agent_name=self.name,
agent_metadata=self.agent_metadata agent_metadata=self.agent_metadata,
)
logger.info(
f"{self.name} registered its metadata under key '{self.agents_registry_key}'"
) )
logger.info(f"{self.name} registered its metadata under key '{self.agents_registry_key}'")
except Exception as e: except Exception as e:
logger.error(f"Failed to register metadata for agent {self.agent.name}: {e}") logger.error(
f"Failed to register metadata for agent {self.agent.name}: {e}"
)
raise e raise e
def register_agent(self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict) -> None: def register_agent(
self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict
) -> None:
""" """
Merges the existing data with the new data and updates the store. Merges the existing data with the new data and updates the store.
@ -274,7 +353,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
# retry the entire operation up to ten times sleeping 1 second between each attempt # retry the entire operation up to ten times sleeping 1 second between each attempt
for attempt in range(1, 11): for attempt in range(1, 11):
try: try:
response: StateResponse = self._dapr_client.get_state(store_name=store_name, key=store_key) response: StateResponse = self._dapr_client.get_state(
store_name=store_name, key=store_key
)
if not response.etag: if not response.etag:
# if there is no etag the following transaction won't work as expected # if there is no etag the following transaction won't work as expected
# so we need to save an empty object with a strong consistency to force the etag to be created # so we need to save an empty object with a strong consistency to force the etag to be created
@ -283,7 +364,10 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
key=store_key, key=store_key,
value=json.dumps({}), value=json.dumps({}),
state_metadata={"contentType": "application/json"}, state_metadata={"contentType": "application/json"},
options=StateOptions(concurrency=Concurrency.first_write, consistency=Consistency.strong) options=StateOptions(
concurrency=Concurrency.first_write,
consistency=Consistency.strong,
),
) )
# raise an exception to retry the entire operation # raise an exception to retry the entire operation
raise Exception(f"No etag found for key: {store_key}") raise Exception(f"No etag found for key: {store_key}")
@ -303,20 +387,22 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
key=store_key, key=store_key,
data=json.dumps(merged_data), data=json.dumps(merged_data),
etag=response.etag, etag=response.etag,
operation_type=TransactionOperationType.upsert operation_type=TransactionOperationType.upsert,
) )
], ],
transactional_metadata={"contentType": "application/json"} transactional_metadata={"contentType": "application/json"},
) )
except Exception as e: except Exception as e:
raise e raise e
return None return None
except Exception as e: except Exception as e:
logger.debug(f"Error on transaction attempt: {attempt}: {e}") logger.debug(f"Error on transaction attempt: {attempt}: {e}")
logger.debug(f"Sleeping for 1 second before retrying transaction...") logger.debug("Sleeping for 1 second before retrying transaction...")
time.sleep(1) time.sleep(1)
raise Exception(f"Failed to update state store key: {store_key} after 10 attempts.") raise Exception(
f"Failed to update state store key: {store_key} after 10 attempts."
)
async def invoke_task(self, task: Optional[str]) -> Response: async def invoke_task(self, task: Optional[str]) -> Response:
""" """
Use the actor to invoke a task by running the InvokeTask method through ActorProxy. Use the actor to invoke a task by running the InvokeTask method through ActorProxy.
@ -332,8 +418,10 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
return Response(content=response, status_code=status.HTTP_200_OK) return Response(content=response, status_code=status.HTTP_200_OK)
except Exception as e: except Exception as e:
logger.error(f"Failed to run task for {self.actor_name}: {e}") logger.error(f"Failed to run task for {self.actor_name}: {e}")
raise HTTPException(status_code=500, detail=f"Error invoking task: {str(e)}") raise HTTPException(
status_code=500, detail=f"Error invoking task: {str(e)}"
)
async def add_message(self, message: AgentActorMessage) -> None: async def add_message(self, message: AgentActorMessage) -> None:
""" """
Adds a message to the conversation history in the actor's state. Adds a message to the conversation history in the actor's state.
@ -342,19 +430,28 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
await self.actor_proxy.AddMessage(message.model_dump()) await self.actor_proxy.AddMessage(message.model_dump())
except Exception as e: except Exception as e:
logger.error(f"Failed to add message to {self.actor_name}: {e}") logger.error(f"Failed to add message to {self.actor_name}: {e}")
async def get_messages(self) -> Response: async def get_messages(self) -> Response:
""" """
Retrieve the conversation history from the actor. Retrieve the conversation history from the actor.
""" """
try: try:
messages = await self.actor_proxy.GetMessages() messages = await self.actor_proxy.GetMessages()
return JSONResponse(content=jsonable_encoder(messages), status_code=status.HTTP_200_OK) return JSONResponse(
content=jsonable_encoder(messages), status_code=status.HTTP_200_OK
)
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve messages for {self.actor_name}: {e}") logger.error(f"Failed to retrieve messages for {self.actor_name}: {e}")
raise HTTPException(status_code=500, detail=f"Error retrieving messages: {str(e)}") raise HTTPException(
status_code=500, detail=f"Error retrieving messages: {str(e)}"
async def broadcast_message(self, message: Union[BaseModel, dict], exclude_orchestrator: bool = False, **kwargs) -> None: )
async def broadcast_message(
self,
message: Union[BaseModel, dict],
exclude_orchestrator: bool = False,
**kwargs,
) -> None:
""" """
Sends a message to all agents (or only to non-orchestrator agents if exclude_orchestrator=True). Sends a message to all agents (or only to non-orchestrator agents if exclude_orchestrator=True).
@ -365,7 +462,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
""" """
try: try:
# Retrieve agents metadata while respecting the exclude_orchestrator flag # Retrieve agents metadata while respecting the exclude_orchestrator flag
agents_metadata = self.get_agents_metadata(exclude_orchestrator=exclude_orchestrator) agents_metadata = self.get_agents_metadata(
exclude_orchestrator=exclude_orchestrator
)
if not agents_metadata: if not agents_metadata:
logger.warning("No agents available for broadcast.") logger.warning("No agents available for broadcast.")
@ -385,7 +484,9 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
except Exception as e: except Exception as e:
logger.error(f"Failed to broadcast message: {e}", exc_info=True) logger.error(f"Failed to broadcast message: {e}", exc_info=True)
async def send_message_to_agent(self, name: str, message: Union[BaseModel, dict], **kwargs) -> None: async def send_message_to_agent(
self, name: str, message: Union[BaseModel, dict], **kwargs
) -> None:
""" """
Sends a message to a specific agent. Sends a message to a specific agent.
@ -396,9 +497,11 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
""" """
try: try:
agents_metadata = self.get_agents_metadata() agents_metadata = self.get_agents_metadata()
if name not in agents_metadata: if name not in agents_metadata:
logger.warning(f"Target '{name}' is not registered as an agent. Skipping message send.") logger.warning(
f"Target '{name}' is not registered as an agent. Skipping message send."
)
return # Do not raise an error—just warn and move on. return # Do not raise an error—just warn and move on.
agent_metadata = agents_metadata[name] agent_metadata = agents_metadata[name]
@ -414,4 +517,6 @@ class AgentActorService(DaprPubSub, MessageRoutingMixin):
logger.debug(f"{self.name} sent message to agent '{name}'.") logger.debug(f"{self.name} sent message to agent '{name}'.")
except Exception as e: except Exception as e:
logger.error(f"Failed to send message to agent '{name}': {e}", exc_info=True) logger.error(
f"Failed to send message to agent '{name}': {e}", exc_info=True
)

View File

@ -1,4 +1,8 @@
from dapr_agents.memory import MemoryBase, ConversationListMemory, ConversationVectorMemory from dapr_agents.memory import (
MemoryBase,
ConversationListMemory,
ConversationVectorMemory,
)
from dapr_agents.agent.utils.text_printer import ColorTextFormatter from dapr_agents.agent.utils.text_printer import ColorTextFormatter
from dapr_agents.types import MessageContent, MessagePlaceHolder from dapr_agents.types import MessageContent, MessagePlaceHolder
from dapr_agents.tool.executor import AgentToolExecutor from dapr_agents.tool.executor import AgentToolExecutor
@ -14,46 +18,79 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AgentBase(BaseModel, ABC): class AgentBase(BaseModel, ABC):
""" """
Base class for agents that interact with language models and manage tools for task execution. Base class for agents that interact with language models and manage tools for task execution.
""" """
name: Optional[str] = Field(default=None, description="The agent's name, defaulting to the role if not provided.") name: Optional[str] = Field(
role: Optional[str] = Field(default="Assistant", description="The agent's role in the interaction (e.g., 'Weather Expert').") default=None,
goal: Optional[str] = Field(default="Help humans", description="The agent's main objective (e.g., 'Provide Weather information').") description="The agent's name, defaulting to the role if not provided.",
instructions: Optional[List[str]] = Field(default=None, description="Instructions guiding the agent's tasks.") )
system_prompt: Optional[str] = Field(default=None, description="A custom system prompt, overriding name, role, goal, and instructions.") role: Optional[str] = Field(
llm: LLMClientBase = Field(default_factory=OpenAIChatClient, description="Language model client for generating responses.") default="Assistant",
prompt_template: Optional[PromptTemplateBase] = Field(default=None, description="The prompt template for the agent.") description="The agent's role in the interaction (e.g., 'Weather Expert').",
tools: List[Union[AgentTool, Callable]] = Field(default_factory=list, description="Tools available for the agent to assist with tasks.") )
max_iterations: int = Field(default=10, description="Max iterations for conversation cycles.") goal: Optional[str] = Field(
memory: MemoryBase = Field(default_factory=ConversationListMemory, description="Handles conversation history and context storage.") default="Help humans",
template_format: Literal["f-string", "jinja2"] = Field(default="jinja2", description="The format used for rendering the prompt template.") description="The agent's main objective (e.g., 'Provide Weather information').",
)
instructions: Optional[List[str]] = Field(
default=None, description="Instructions guiding the agent's tasks."
)
system_prompt: Optional[str] = Field(
default=None,
description="A custom system prompt, overriding name, role, goal, and instructions.",
)
llm: LLMClientBase = Field(
default_factory=OpenAIChatClient,
description="Language model client for generating responses.",
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="The prompt template for the agent."
)
tools: List[Union[AgentTool, Callable]] = Field(
default_factory=list,
description="Tools available for the agent to assist with tasks.",
)
max_iterations: int = Field(
default=10, description="Max iterations for conversation cycles."
)
memory: MemoryBase = Field(
default_factory=ConversationListMemory,
description="Handles conversation history and context storage.",
)
template_format: Literal["f-string", "jinja2"] = Field(
default="jinja2",
description="The format used for rendering the prompt template.",
)
# Private attributes # Private attributes
_tool_executor: AgentToolExecutor = PrivateAttr() _tool_executor: AgentToolExecutor = PrivateAttr()
_text_formatter: ColorTextFormatter = PrivateAttr(default_factory=ColorTextFormatter) _text_formatter: ColorTextFormatter = PrivateAttr(
default_factory=ColorTextFormatter
)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="before") @model_validator(mode="before")
def set_name_from_role(cls, values: dict): def set_name_from_role(cls, values: dict):
# Set name to role if name is not provided # Set name to role if name is not provided
if not values.get("name") and values.get("role"): if not values.get("name") and values.get("role"):
values["name"] = values["role"] values["name"] = values["role"]
return values return values
@property @property
def tool_executor(self) -> AgentToolExecutor: def tool_executor(self) -> AgentToolExecutor:
"""Returns the tool executor, ensuring it's accessible but read-only.""" """Returns the tool executor, ensuring it's accessible but read-only."""
return self._tool_executor return self._tool_executor
@property @property
def text_formatter(self) -> ColorTextFormatter: def text_formatter(self) -> ColorTextFormatter:
"""Returns the text formatter for the agent.""" """Returns the text formatter for the agent."""
return self._text_formatter return self._text_formatter
@property @property
def chat_history(self, task: str = None) -> List[MessageContent]: def chat_history(self, task: str = None) -> List[MessageContent]:
""" """
@ -69,7 +106,7 @@ class AgentBase(BaseModel, ABC):
query_embeddings = self.memory.vector_store.embed_documents([task]) query_embeddings = self.memory.vector_store.embed_documents([task])
return self.memory.get_messages(query_embeddings=query_embeddings) return self.memory.get_messages(query_embeddings=query_embeddings)
return self.memory.get_messages() return self.memory.get_messages()
@abstractmethod @abstractmethod
def run(self, input_data: Union[str, Dict[str, Any]]) -> Any: def run(self, input_data: Union[str, Dict[str, Any]]) -> Any:
""" """
@ -97,7 +134,9 @@ class AgentBase(BaseModel, ABC):
# If the agent's prompt_template is provided, use it and skip further configuration # If the agent's prompt_template is provided, use it and skip further configuration
if self.prompt_template: if self.prompt_template:
logger.info("Using the provided agent prompt_template. Skipping system prompt construction.") logger.info(
"Using the provided agent prompt_template. Skipping system prompt construction."
)
self.llm.prompt_template = self.prompt_template self.llm.prompt_template = self.prompt_template
# If the LLM client already has a prompt template, sync it and prefill/validate as needed # If the LLM client already has a prompt template, sync it and prefill/validate as needed
@ -112,7 +151,7 @@ class AgentBase(BaseModel, ABC):
logger.info("Using system_prompt to create the prompt template.") logger.info("Using system_prompt to create the prompt template.")
self.prompt_template = self.construct_prompt_template() self.prompt_template = self.construct_prompt_template()
# Pre-fill Agent Attributes if needed # Pre-fill Agent Attributes if needed
self.prefill_agent_attributes() self.prefill_agent_attributes()
@ -145,28 +184,44 @@ class AgentBase(BaseModel, ABC):
prefill_data["instructions"] = "\n".join(self.instructions) prefill_data["instructions"] = "\n".join(self.instructions)
# Collect attributes set but not in input_variables for informational logging # Collect attributes set but not in input_variables for informational logging
set_attributes = {"name": self.name, "role": self.role, "goal": self.goal, "instructions": self.instructions} set_attributes = {
"name": self.name,
"role": self.role,
"goal": self.goal,
"instructions": self.instructions,
}
# Use Pydantic's model_fields_set to detect if attributes were user-set # Use Pydantic's model_fields_set to detect if attributes were user-set
user_set_attributes = {attr for attr in set_attributes if attr in self.model_fields_set} user_set_attributes = {
attr for attr in set_attributes if attr in self.model_fields_set
}
ignored_attributes = [ ignored_attributes = [
attr for attr in set_attributes attr
if attr not in self.prompt_template.input_variables and set_attributes[attr] is not None and attr in user_set_attributes for attr in set_attributes
if attr not in self.prompt_template.input_variables
and set_attributes[attr] is not None
and attr in user_set_attributes
] ]
# Apply pre-filled data only for attributes that are in input_variables # Apply pre-filled data only for attributes that are in input_variables
if prefill_data: if prefill_data:
self.prompt_template = self.prompt_template.pre_fill_variables(**prefill_data) self.prompt_template = self.prompt_template.pre_fill_variables(
logger.info(f"Pre-filled prompt template with attributes: {list(prefill_data.keys())}") **prefill_data
)
logger.info(
f"Pre-filled prompt template with attributes: {list(prefill_data.keys())}"
)
elif ignored_attributes: elif ignored_attributes:
raise ValueError( raise ValueError(
f"The following agent attributes were explicitly set by the user but are not considered by the prompt template: {', '.join(ignored_attributes)}. " f"The following agent attributes were explicitly set by the user but are not considered by the prompt template: {', '.join(ignored_attributes)}. "
"Please ensure that these attributes are included in the prompt template's input variables if they are needed." "Please ensure that these attributes are included in the prompt template's input variables if they are needed."
) )
else: else:
logger.info("No agent attributes were pre-filled, as the template did not require any.") logger.info(
"No agent attributes were pre-filled, as the template did not require any."
)
def construct_system_prompt(self) -> str: def construct_system_prompt(self) -> str:
""" """
Constructs a system prompt with agent attributes like `name`, `role`, `goal`, and `instructions`. Constructs a system prompt with agent attributes like `name`, `role`, `goal`, and `instructions`.
@ -191,7 +246,7 @@ class AgentBase(BaseModel, ABC):
prompt_parts.append("## Instructions\n{{instructions}}") prompt_parts.append("## Instructions\n{{instructions}}")
return "\n\n".join(prompt_parts) return "\n\n".join(prompt_parts)
def construct_prompt_template(self) -> ChatPromptTemplate: def construct_prompt_template(self) -> ChatPromptTemplate:
""" """
Constructs a ChatPromptTemplate that includes the system prompt and a placeholder for chat history. Constructs a ChatPromptTemplate that includes the system prompt and a placeholder for chat history.
@ -206,19 +261,21 @@ class AgentBase(BaseModel, ABC):
# Create the template with placeholders for system message and chat history # Create the template with placeholders for system message and chat history
return ChatPromptTemplate.from_messages( return ChatPromptTemplate.from_messages(
messages=[ messages=[
('system', system_prompt), ("system", system_prompt),
MessagePlaceHolder(variable_name="chat_history") MessagePlaceHolder(variable_name="chat_history"),
], ],
template_format=self.template_format template_format=self.template_format,
) )
def construct_messages(self, input_data: Union[str, Dict[str, Any]]) -> List[Dict[str, Any]]: def construct_messages(
self, input_data: Union[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
""" """
Constructs and formats initial messages based on input type, pre-filling chat history as needed. Constructs and formats initial messages based on input type, pre-filling chat history as needed.
Args: Args:
input_data (Union[str, Dict[str, Any]]): User input, either as a string or dictionary. input_data (Union[str, Dict[str, Any]]): User input, either as a string or dictionary.
Returns: Returns:
List[Dict[str, Any]]: List of formatted messages, including the user message if input_data is a string. List[Dict[str, Any]]: List of formatted messages, including the user message if input_data is a string.
""" """
@ -244,7 +301,7 @@ class AgentBase(BaseModel, ABC):
def reset_memory(self): def reset_memory(self):
"""Clears all messages stored in the agent's memory.""" """Clears all messages stored in the agent's memory."""
self.memory.reset_memory() self.memory.reset_memory()
def get_last_message(self) -> Optional[MessageContent]: def get_last_message(self) -> Optional[MessageContent]:
""" """
Retrieves the last message from the chat history. Retrieves the last message from the chat history.
@ -254,8 +311,10 @@ class AgentBase(BaseModel, ABC):
""" """
chat_history = self.chat_history chat_history = self.chat_history
return chat_history[-1] if chat_history else None return chat_history[-1] if chat_history else None
def get_last_user_message(self, messages: List[Dict[str, Any]]) -> Optional[MessageContent]: def get_last_user_message(
self, messages: List[Dict[str, Any]]
) -> Optional[MessageContent]:
""" """
Retrieves the last user message in a list of messages. Retrieves the last user message in a list of messages.
@ -272,13 +331,13 @@ class AgentBase(BaseModel, ABC):
message["content"] = message["content"].strip() message["content"] = message["content"].strip()
return message return message
return None return None
def pre_fill_prompt_template(self, **kwargs: Union[str, Callable[[], str]]) -> None: def pre_fill_prompt_template(self, **kwargs: Union[str, Callable[[], str]]) -> None:
""" """
Pre-fills the prompt template with specified variables, updating input variables if applicable. Pre-fills the prompt template with specified variables, updating input variables if applicable.
Args: Args:
**kwargs: Variables to pre-fill in the prompt template. These can be strings or callables **kwargs: Variables to pre-fill in the prompt template. These can be strings or callables
that return strings. that return strings.
Notes: Notes:
@ -286,7 +345,9 @@ class AgentBase(BaseModel, ABC):
- This method does not affect the `chat_history` which is dynamically updated. - This method does not affect the `chat_history` which is dynamically updated.
""" """
if not self.prompt_template: if not self.prompt_template:
raise ValueError("Prompt template must be initialized before pre-filling variables.") raise ValueError(
"Prompt template must be initialized before pre-filling variables."
)
self.prompt_template = self.prompt_template.pre_fill_variables(**kwargs) self.prompt_template = self.prompt_template.pre_fill_variables(**kwargs)
logger.debug(f"Pre-filled prompt template with variables: {kwargs.keys()}") logger.debug(f"Pre-filled prompt template with variables: {kwargs.keys()}")

View File

@ -1,3 +1,3 @@
from .react import ReActAgent from .react import ReActAgent
from .toolcall import ToolCallAgent from .toolcall import ToolCallAgent
from .openapi import OpenAPIReActAgent from .openapi import OpenAPIReActAgent

View File

@ -1 +1 @@
from .react import OpenAPIReActAgent from .react import OpenAPIReActAgent

View File

@ -8,15 +8,18 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OpenAPIReActAgent(ReActAgent): class OpenAPIReActAgent(ReActAgent):
""" """
Extends ReActAgent with OpenAPI handling capabilities, including tools for managing API calls. Extends ReActAgent with OpenAPI handling capabilities, including tools for managing API calls.
""" """
role: str = Field(default="OpenAPI Expert", description="The agent's role in the interaction.") role: str = Field(
default="OpenAPI Expert", description="The agent's role in the interaction."
)
goal: str = Field( goal: str = Field(
default="Help users work with OpenAPI specifications and API integrations.", default="Help users work with OpenAPI specifications and API integrations.",
description="The main objective of the agent." description="The main objective of the agent.",
) )
instructions: List[str] = Field( instructions: List[str] = Field(
default=[ default=[
@ -25,15 +28,23 @@ class OpenAPIReActAgent(ReActAgent):
"You must first help users explore potential APIs by analyzing OpenAPI definitions, then assist in making authenticated API requests.", "You must first help users explore potential APIs by analyzing OpenAPI definitions, then assist in making authenticated API requests.",
"Ensure that all API calls are executed with the correct parameters, authentication, and methods.", "Ensure that all API calls are executed with the correct parameters, authentication, and methods.",
"Your responses should be concise, clear, and focus on guiding the user through the steps of working with APIs, including retrieving API definitions, understanding endpoint parameters, and handling errors.", "Your responses should be concise, clear, and focus on guiding the user through the steps of working with APIs, including retrieving API definitions, understanding endpoint parameters, and handling errors.",
"You only respond to questions directly related to your role." "You only respond to questions directly related to your role.",
], ],
description="Instructions to guide the agent's behavior." description="Instructions to guide the agent's behavior.",
)
spec_parser: OpenAPISpecParser = Field(
..., description="Parser for handling OpenAPI specifications."
)
api_vector_store: VectorStoreBase = Field(
..., description="Vector store for storing API definitions."
)
auth_header: Optional[Dict] = Field(
None, description="Authentication headers for executing API calls."
)
tool_vector_store: Optional[VectorToolStore] = Field(
default=None, init=False, description="Internal vector store for OpenAPI tools."
) )
spec_parser: OpenAPISpecParser = Field(..., description="Parser for handling OpenAPI specifications.")
api_vector_store: VectorStoreBase = Field(..., description="Vector store for storing API definitions.")
auth_header: Optional[Dict] = Field(None, description="Authentication headers for executing API calls.")
tool_vector_store: Optional[VectorToolStore] = Field(default=None, init=False, description="Internal vector store for OpenAPI tools.")
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -52,13 +63,14 @@ class OpenAPIReActAgent(ReActAgent):
# Generate OpenAPI-specific tools # Generate OpenAPI-specific tools
from .tools import generate_api_call_executor, generate_get_openapi_definition from .tools import generate_api_call_executor, generate_get_openapi_definition
openapi_tools = [ openapi_tools = [
generate_get_openapi_definition(self.tool_vector_store), generate_get_openapi_definition(self.tool_vector_store),
generate_api_call_executor(self.spec_parser, self.auth_header) generate_api_call_executor(self.spec_parser, self.auth_header),
] ]
# Extend tools with OpenAPI tools # Extend tools with OpenAPI tools
self.tools.extend(openapi_tools) self.tools.extend(openapi_tools)
# Call parent model_post_init for additional setup # Call parent model_post_init for additional setup
super().model_post_init(__context) super().model_post_init(__context)

View File

@ -1,4 +1,6 @@
import json, logging, requests import json
import logging
import requests
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Any, Dict, Optional, List from typing import Any, Dict, Optional, List
@ -42,7 +44,10 @@ def _fmt_candidate(doc: str, meta: Dict[str, Any]) -> str:
class GetDefinitionInput(BaseModel): class GetDefinitionInput(BaseModel):
"""Free-form query describing *one* desired operation (e.g. "multiply two numbers").""" """Free-form query describing *one* desired operation (e.g. "multiply two numbers")."""
user_input: str = Field(..., description="Natural-language description of ONE desired API operation.")
user_input: str = Field(
..., description="Natural-language description of ONE desired API operation."
)
def generate_get_openapi_definition(store: VectorToolStore): def generate_get_openapi_definition(store: VectorToolStore):
@ -65,17 +70,29 @@ def generate_get_openapi_definition(store: VectorToolStore):
class OpenAPIExecutorInput(BaseModel): class OpenAPIExecutorInput(BaseModel):
path_template: str = Field(..., description="Path template, may contain `{placeholder}` segments.") path_template: str = Field(
..., description="Path template, may contain `{placeholder}` segments."
)
method: str = Field(..., description="HTTP verb, uppercase.") method: str = Field(..., description="HTTP verb, uppercase.")
path_params: Dict[str, Any] = Field(default_factory=dict, description="Replacements for path placeholders.") path_params: Dict[str, Any] = Field(
data: Dict[str, Any] = Field(default_factory=dict, description="JSON body for POST/PUT/PATCH.") default_factory=dict, description="Replacements for path placeholders."
headers: Optional[Dict[str, Any]] = Field(default=None, description="Extra request headers.") )
params: Optional[Dict[str, Any]] = Field(default=None, description="Query params (?key=value).") data: Dict[str, Any] = Field(
default_factory=dict, description="JSON body for POST/PUT/PATCH."
)
headers: Optional[Dict[str, Any]] = Field(
default=None, description="Extra request headers."
)
params: Optional[Dict[str, Any]] = Field(
default=None, description="Query params (?key=value)."
)
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Dict[str, str]] = None): def generate_api_call_executor(
spec: OpenAPISpecParser, auth_header: Optional[Dict[str, str]] = None
):
base_url = spec.spec.servers[0].url # assumes at least one server entry base_url = spec.spec.servers[0].url # assumes at least one server entry
@tool(args_model=OpenAPIExecutorInput) @tool(args_model=OpenAPIExecutorInput)
@ -106,29 +123,37 @@ def generate_api_call_executor(spec: OpenAPISpecParser, auth_header: Optional[Di
final_headers.update(headers) final_headers.update(headers)
# redact auth key in debug logs # redact auth key in debug logs
safe_hdrs = {k: ("***" if "auth" in k.lower() or "key" in k.lower() else v) safe_hdrs = {
for k, v in final_headers.items()} k: ("***" if "auth" in k.lower() or "key" in k.lower() else v)
for k, v in final_headers.items()
}
# Only convert data to JSON if we're doing a request that requires a body # Only convert data to JSON if we're doing a request that requires a body
# and there's actually data to send # and there's actually data to send
body = None body = None
if method.upper() in ["POST", "PUT", "PATCH"] and data: if method.upper() in ["POST", "PUT", "PATCH"] and data:
body = json.dumps(data) body = json.dumps(data)
# Add more detailed logging similar to old implementation # Add more detailed logging similar to old implementation
logger.debug("%s %s | headers=%s params=%s data=%s", logger.debug(
method, url, safe_hdrs, params, "%s %s | headers=%s params=%s data=%s",
"***" if body else None) method,
url,
safe_hdrs,
params,
"***" if body else None,
)
# For debugging purposes, similar to the old implementation # For debugging purposes, similar to the old implementation
print(f"Base Url: {base_url}") print(f"Base Url: {base_url}")
print(f"Requested Url: {url}") print(f"Requested Url: {url}")
print(f"Requested Method: {method}") print(f"Requested Method: {method}")
print(f"Requested Parameters: {params}") print(f"Requested Parameters: {params}")
resp = requests.request(method, url, headers=final_headers, resp = requests.request(
params=params, data=body, **req_kwargs) method, url, headers=final_headers, params=params, data=body, **req_kwargs
)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
return open_api_call_executor return open_api_call_executor

View File

@ -1 +1 @@
from .base import ReActAgent from .base import ReActAgent

View File

@ -14,22 +14,32 @@ from dapr_agents.types import AgentError, AssistantMessage, ChatCompletion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReActAgent(AgentBase): class ReActAgent(AgentBase):
""" """
Agent implementing the ReAct (Reasoning-Action) framework for dynamic, few-shot problem-solving by leveraging Agent implementing the ReAct (Reasoning-Action) framework for dynamic, few-shot problem-solving by leveraging
contextual reasoning, actions, and observations in a conversation flow. contextual reasoning, actions, and observations in a conversation flow.
""" """
stop_at_token: List[str] = Field(default=["\nObservation:"], description="Token(s) signaling the LLM to stop generation.") stop_at_token: List[str] = Field(
tools: List[Union[AgentTool, Callable]] = Field(default_factory=list, description="Tools available for the agent, including final_answer.") default=["\nObservation:"],
template_format: Literal["f-string", "jinja2"] = Field(default="jinja2", description="The format used for rendering the prompt template.") description="Token(s) signaling the LLM to stop generation.",
)
tools: List[Union[AgentTool, Callable]] = Field(
default_factory=list,
description="Tools available for the agent, including final_answer.",
)
template_format: Literal["f-string", "jinja2"] = Field(
default="jinja2",
description="The format used for rendering the prompt template.",
)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
def construct_system_prompt(self) -> str: def construct_system_prompt(self) -> str:
""" """
Constructs a system prompt in the ReAct reasoning-action format based on the agent's attributes and tools. Constructs a system prompt in the ReAct reasoning-action format based on the agent's attributes and tools.
Returns: Returns:
str: The structured system message content. str: The structured system message content.
""" """
@ -51,11 +61,16 @@ class ReActAgent(AgentBase):
# Tools section with schema details # Tools section with schema details
tools_section = "## Tools\nYou have access ONLY to the following tools:\n" tools_section = "## Tools\nYou have access ONLY to the following tools:\n"
for tool in self.tools: for tool in self.tools:
tools_section += f"{tool.name}: {tool.description}. Args schema: {tool.args_schema}\n" tools_section += (
prompt_parts.append(tools_section.rstrip()) # Trim any trailing newlines from tools_section f"{tool.name}: {tool.description}. Args schema: {tool.args_schema}\n"
)
prompt_parts.append(
tools_section.rstrip()
) # Trim any trailing newlines from tools_section
# Additional Guidelines # Additional Guidelines
additional_guidelines = textwrap.dedent(""" additional_guidelines = textwrap.dedent(
"""
If you think about using tool, it must use the correct tool JSON blob format as shown below: If you think about using tool, it must use the correct tool JSON blob format as shown below:
``` ```
{ {
@ -63,11 +78,13 @@ class ReActAgent(AgentBase):
"arguments": $INPUT "arguments": $INPUT
} }
``` ```
""").strip() """
).strip()
prompt_parts.append(additional_guidelines) prompt_parts.append(additional_guidelines)
# ReAct specific guidelines # ReAct specific guidelines
react_guidelines = textwrap.dedent(""" react_guidelines = textwrap.dedent(
"""
## ReAct Format ## ReAct Format
Thought: Reflect on the current state of the conversation or task. If additional information is needed, determine if using a tool is necessary. When a tool is required, briefly explain why it is needed for the specific step at hand, and immediately follow this with an `Action:` statement to address that specific requirement. Avoid combining multiple tool requests in a single `Thought`. If no tools are needed, proceed directly to an `Answer:` statement. Thought: Reflect on the current state of the conversation or task. If additional information is needed, determine if using a tool is necessary. When a tool is required, briefly explain why it is needed for the specific step at hand, and immediately follow this with an `Action:` statement to address that specific requirement. Avoid combining multiple tool requests in a single `Thought`. If no tools are needed, proceed directly to an `Answer:` statement.
Action: Action:
@ -81,19 +98,19 @@ class ReActAgent(AgentBase):
... (repeat Thought/Action/Observation as needed, but **ALWAYS proceed to a final `Answer:` statement when you have enough information**) ... (repeat Thought/Action/Observation as needed, but **ALWAYS proceed to a final `Answer:` statement when you have enough information**)
Thought: I now have sufficient information to answer the initial question. Thought: I now have sufficient information to answer the initial question.
Answer: ALWAYS proceed to a final `Answer:` statement once enough information is gathered or if the tools do not provide the necessary data. Answer: ALWAYS proceed to a final `Answer:` statement once enough information is gathered or if the tools do not provide the necessary data.
### Providing a Final Answer ### Providing a Final Answer
Once you have enough information to answer the question OR if tools cannot provide the necessary data, respond using one of the following formats: Once you have enough information to answer the question OR if tools cannot provide the necessary data, respond using one of the following formats:
1. **Direct Answer without Tools**: 1. **Direct Answer without Tools**:
Thought: I can answer directly without using any tools. Answer: Direct answer based on previous interactions or current knowledge. Thought: I can answer directly without using any tools. Answer: Direct answer based on previous interactions or current knowledge.
2. **When All Needed Information is Gathered**: 2. **When All Needed Information is Gathered**:
Thought: I now have sufficient information to answer the question. Answer: Complete final answer here. Thought: I now have sufficient information to answer the question. Answer: Complete final answer here.
3. **If Tools Cannot Provide the Needed Information**: 3. **If Tools Cannot Provide the Needed Information**:
Thought: The available tools do not provide the necessary information. Answer: Explanation of limitation and relevant information if possible. Thought: The available tools do not provide the necessary information. Answer: Explanation of limitation and relevant information if possible.
### Key Guidelines ### Key Guidelines
- Always Conclude with an `Answer:` statement. - Always Conclude with an `Answer:` statement.
- Ensure every response ends with an `Answer:` statement that summarizes the most recent findings or relevant information, avoiding incomplete thoughts. - Ensure every response ends with an `Answer:` statement that summarizes the most recent findings or relevant information, avoiding incomplete thoughts.
@ -104,12 +121,13 @@ class ReActAgent(AgentBase):
- Progressively Move Towards Finality: Reflect on the current step and avoid re-evaluating the entire user request each time. Aim to advance towards the final Answer in each cycle. - Progressively Move Towards Finality: Reflect on the current step and avoid re-evaluating the entire user request each time. Aim to advance towards the final Answer in each cycle.
## Chat History ## Chat History
The chat history is provided to avoid repeating information and to ensure accurate references when summarizing past interactions. The chat history is provided to avoid repeating information and to ensure accurate references when summarizing past interactions.
""").strip() """
).strip()
prompt_parts.append(react_guidelines) prompt_parts.append(react_guidelines)
return "\n\n".join(prompt_parts) return "\n\n".join(prompt_parts)
async def run(self, input_data: Optional[Union[str, Dict[str, Any]]] = None) -> Any: async def run(self, input_data: Optional[Union[str, Dict[str, Any]]] = None) -> Any:
""" """
Runs the agent in a ReAct-style loop until it generates a final answer or reaches max iterations. Runs the agent in a ReAct-style loop until it generates a final answer or reaches max iterations.
@ -123,7 +141,9 @@ class ReActAgent(AgentBase):
Raises: Raises:
AgentError: If LLM fails or tool execution encounters issues. AgentError: If LLM fails or tool execution encounters issues.
""" """
logger.debug(f"Agent run started with input: {input_data or 'Using memory context'}") logger.debug(
f"Agent run started with input: {input_data or 'Using memory context'}"
)
# Format messages; construct_messages already includes chat history. # Format messages; construct_messages already includes chat history.
messages = self.construct_messages(input_data or {}) messages = self.construct_messages(input_data or {})
@ -132,7 +152,7 @@ class ReActAgent(AgentBase):
# Add the new user message to memory only if input_data is provided and user message exists. # Add the new user message to memory only if input_data is provided and user message exists.
if input_data and user_message: if input_data and user_message:
self.memory.add_message(user_message) self.memory.add_message(user_message)
# Always print the last user message for context, even if no input_data is provided # Always print the last user message for context, even if no input_data is provided
if user_message: if user_message:
self.text_formatter.print_message(user_message) self.text_formatter.print_message(user_message)
@ -161,8 +181,12 @@ class ReActAgent(AgentBase):
break break
else: else:
# Append react_loop to the last message if no user message is found # Append react_loop to the last message if no user message is found
logger.warning("No user message found in the current messages; appending react_loop to the last message.") logger.warning(
iteration_messages[-1]["content"] += f"\n{react_loop}" # Append react_loop to the last message "No user message found in the current messages; appending react_loop to the last message."
)
iteration_messages[-1][
"content"
] += f"\n{react_loop}" # Append react_loop to the last message
try: try:
response: ChatCompletion = self.llm.generate( response: ChatCompletion = self.llm.generate(
@ -179,13 +203,17 @@ class ReActAgent(AgentBase):
assistant_final = AssistantMessage(final_answer) assistant_final = AssistantMessage(final_answer)
self.memory.add_message(assistant_final) self.memory.add_message(assistant_final)
self.text_formatter.print_separator() self.text_formatter.print_separator()
self.text_formatter.print_message(assistant_final, include_separator=False) self.text_formatter.print_message(
assistant_final, include_separator=False
)
logger.info("Agent provided a direct final answer.") logger.info("Agent provided a direct final answer.")
return final_answer return final_answer
# If there's no action, update the loop and continue reasoning # If there's no action, update the loop and continue reasoning
if not action: if not action:
logger.info("No action specified; continuing with further reasoning.") logger.info(
"No action specified; continuing with further reasoning."
)
react_loop += f"Thought:{thought_action}\n" react_loop += f"Thought:{thought_action}\n"
continue # Proceed to the next iteration continue # Proceed to the next iteration
@ -210,9 +238,10 @@ class ReActAgent(AgentBase):
raise AgentError(f"ReActAgent failed: {e}") from e raise AgentError(f"ReActAgent failed: {e}") from e
logger.info("Max iterations reached. Agent has stopped.") logger.info("Max iterations reached. Agent has stopped.")
def parse_response(
def parse_response(self, response: ChatCompletion) -> Tuple[str, Optional[dict], Optional[str]]: self, response: ChatCompletion
) -> Tuple[str, Optional[dict], Optional[str]]:
""" """
Parses a ReAct-style LLM response into a Thought, optional Action (JSON blob), and optional Final Answer. Parses a ReAct-style LLM response into a Thought, optional Action (JSON blob), and optional Final Answer.
@ -225,16 +254,18 @@ class ReActAgent(AgentBase):
- Parsed Action dictionary, if present. - Parsed Action dictionary, if present.
- Final Answer string, if present. - Final Answer string, if present.
""" """
pattern = r'\{(?:[^{}]|(?R))*\}' # Recursive pattern to match nested JSON blobs pattern = r"\{(?:[^{}]|(?R))*\}" # Recursive pattern to match nested JSON blobs
content = response.get_content() content = response.get_content()
# Compile reusable regex patterns # Compile reusable regex patterns
action_split_regex = regex.compile(r'action:\s*', flags=regex.IGNORECASE) action_split_regex = regex.compile(r"action:\s*", flags=regex.IGNORECASE)
final_answer_regex = regex.compile(r'answer:\s*(.*)', flags=regex.IGNORECASE | regex.DOTALL) final_answer_regex = regex.compile(
thought_label_regex = regex.compile(r'thought:\s*', flags=regex.IGNORECASE) r"answer:\s*(.*)", flags=regex.IGNORECASE | regex.DOTALL
)
thought_label_regex = regex.compile(r"thought:\s*", flags=regex.IGNORECASE)
# Strip leading "Thought:" labels (they get repeated a lot) # Strip leading "Thought:" labels (they get repeated a lot)
content = thought_label_regex.sub('', content).strip() content = thought_label_regex.sub("", content).strip()
# Check if there's a final answer present # Check if there's a final answer present
if final_match := final_answer_regex.search(content): if final_match := final_answer_regex.search(content):
@ -247,24 +278,34 @@ class ReActAgent(AgentBase):
thought_part, action_block = action_split_regex.split(content, 1) thought_part, action_block = action_split_regex.split(content, 1)
thought_part = thought_part.strip() thought_part = thought_part.strip()
logger.debug(f"[parse_response] Thought extracted: {thought_part}") logger.debug(f"[parse_response] Thought extracted: {thought_part}")
logger.debug(f"[parse_response] Action block to parse: {action_block.strip()}") logger.debug(
f"[parse_response] Action block to parse: {action_block.strip()}"
)
else: else:
logger.debug(f"[parse_response] No action or answer found. Returning content as Thought: {content}") logger.debug(
f"[parse_response] No action or answer found. Returning content as Thought: {content}"
)
return content, None, None return content, None, None
# Attempt to extract the first valid JSON blob from the action block # Attempt to extract the first valid JSON blob from the action block
for match in regex.finditer(pattern, action_block, flags=regex.DOTALL): for match in regex.finditer(pattern, action_block, flags=regex.DOTALL):
try: try:
action_dict = json.loads(match.group()) action_dict = json.loads(match.group())
logger.debug(f"[parse_response] Successfully parsed action: {action_dict}") logger.debug(
f"[parse_response] Successfully parsed action: {action_dict}"
)
return thought_part, action_dict, None return thought_part, action_dict, None
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.debug(f"[parse_response] Failed to parse action JSON blob: {match.group()}{e}") logger.debug(
f"[parse_response] Failed to parse action JSON blob: {match.group()}{e}"
)
continue continue
logger.debug(f"[parse_response] No valid action JSON found. Returning Thought only.") logger.debug(
"[parse_response] No valid action JSON found. Returning Thought only."
)
return thought_part, None, None return thought_part, None, None
async def run_tool(self, tool_name: str, *args, **kwargs) -> Any: async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
""" """
Executes a tool by name, resolving async or sync tools automatically. Executes a tool by name, resolving async or sync tools automatically.
@ -284,4 +325,4 @@ class ReActAgent(AgentBase):
return await self.tool_executor.run_tool(tool_name, *args, **kwargs) return await self.tool_executor.run_tool(tool_name, *args, **kwargs)
except Exception as e: except Exception as e:
logger.error(f"Failed to run tool '{tool_name}' via ReActAgent: {e}") logger.error(f"Failed to run tool '{tool_name}' via ReActAgent: {e}")
raise AgentError(f"Error running tool '{tool_name}': {e}") from e raise AgentError(f"Error running tool '{tool_name}': {e}") from e

View File

@ -1 +1 @@
from .base import ToolCallAgent from .base import ToolCallAgent

View File

@ -6,14 +6,20 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolCallAgent(AgentBase): class ToolCallAgent(AgentBase):
""" """
Agent that manages tool calls and conversations using a language model. Agent that manages tool calls and conversations using a language model.
It integrates tools and processes them based on user inputs and task orchestration. It integrates tools and processes them based on user inputs and task orchestration.
""" """
tool_history: List[ToolMessage] = Field(default_factory=list, description="Executed tool calls during the conversation.") tool_history: List[ToolMessage] = Field(
tool_choice: Optional[str] = Field(default=None, description="Strategy for selecting tools ('auto', 'required', 'none'). Defaults to 'auto' if tools are provided.") default_factory=list, description="Executed tool calls during the conversation."
)
tool_choice: Optional[str] = Field(
default=None,
description="Strategy for selecting tools ('auto', 'required', 'none'). Defaults to 'auto' if tools are provided.",
)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -22,8 +28,8 @@ class ToolCallAgent(AgentBase):
Initialize the agent's settings, such as tool choice and parent setup. Initialize the agent's settings, such as tool choice and parent setup.
Sets the tool choice strategy based on provided tools. Sets the tool choice strategy based on provided tools.
""" """
self.tool_choice = self.tool_choice or ('auto' if self.tools else None) self.tool_choice = self.tool_choice or ("auto" if self.tools else None)
# Proceed with base model setup # Proceed with base model setup
super().model_post_init(__context) super().model_post_init(__context)
@ -40,12 +46,14 @@ class ToolCallAgent(AgentBase):
Raises: Raises:
AgentError: If user input is invalid or tool execution fails. AgentError: If user input is invalid or tool execution fails.
""" """
logger.debug(f"Agent run started with input: {input_data if input_data else 'Using memory context'}") logger.debug(
f"Agent run started with input: {input_data if input_data else 'Using memory context'}"
)
# Format messages; construct_messages already includes chat history. # Format messages; construct_messages already includes chat history.
messages = self.construct_messages(input_data or {}) messages = self.construct_messages(input_data or {})
user_message = self.get_last_user_message(messages) user_message = self.get_last_user_message(messages)
if input_data and user_message: if input_data and user_message:
# Add the new user message to memory only if input_data is provided and user message exists # Add the new user message to memory only if input_data is provided and user message exists
self.memory.add_message(user_message) self.memory.add_message(user_message)
@ -56,7 +64,7 @@ class ToolCallAgent(AgentBase):
# Process conversation iterations # Process conversation iterations
return await self.process_iterations(messages) return await self.process_iterations(messages)
async def process_response(self, tool_calls: List[dict]) -> None: async def process_response(self, tool_calls: List[dict]) -> None:
""" """
Asynchronously executes tool calls and appends tool results to memory. Asynchronously executes tool calls and appends tool results to memory.
@ -70,15 +78,21 @@ class ToolCallAgent(AgentBase):
for tool in tool_calls: for tool in tool_calls:
function_name = tool.function.name function_name = tool.function.name
try: try:
logger.info(f"Executing {function_name} with arguments {tool.function.arguments}") logger.info(
result = await self.tool_executor.run_tool(function_name, **tool.function.arguments_dict) f"Executing {function_name} with arguments {tool.function.arguments}"
tool_message = ToolMessage(tool_call_id=tool.id, name=function_name, content=str(result)) )
result = await self.tool_executor.run_tool(
function_name, **tool.function.arguments_dict
)
tool_message = ToolMessage(
tool_call_id=tool.id, name=function_name, content=str(result)
)
self.text_formatter.print_message(tool_message) self.text_formatter.print_message(tool_message)
self.tool_history.append(tool_message) self.tool_history.append(tool_message)
except Exception as e: except Exception as e:
logger.error(f"Error executing tool {function_name}: {e}") logger.error(f"Error executing tool {function_name}: {e}")
raise AgentError(f"Error executing tool '{function_name}': {e}") from e raise AgentError(f"Error executing tool '{function_name}': {e}") from e
async def process_iterations(self, messages: List[Dict[str, Any]]) -> Any: async def process_iterations(self, messages: List[Dict[str, Any]]) -> Any:
""" """
Iteratively drives the agent conversation until a final answer or max iterations. Iteratively drives the agent conversation until a final answer or max iterations.
@ -118,7 +132,7 @@ class ToolCallAgent(AgentBase):
raise AgentError(f"Failed during chat generation: {e}") from e raise AgentError(f"Failed during chat generation: {e}") from e
logger.info("Max iterations reached. Agent has stopped.") logger.info("Max iterations reached. Agent has stopped.")
async def run_tool(self, tool_name: str, *args, **kwargs) -> Any: async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
""" """
Executes a registered tool by name, automatically handling sync or async tools. Executes a registered tool by name, automatically handling sync or async tools.
@ -138,4 +152,4 @@ class ToolCallAgent(AgentBase):
return await self.tool_executor.run_tool(tool_name, *args, **kwargs) return await self.tool_executor.run_tool(tool_name, *args, **kwargs)
except Exception as e: except Exception as e:
logger.error(f"Agent failed to run tool '{tool_name}': {e}") logger.error(f"Agent failed to run tool '{tool_name}': {e}")
raise AgentError(f"Failed to run tool '{tool_name}': {e}") from e raise AgentError(f"Failed to run tool '{tool_name}': {e}") from e

View File

@ -0,0 +1 @@
from .otel import DaprAgentsOTel

View File

@ -0,0 +1,144 @@
from logging import Logger
from typing import Union
from opentelemetry._logs import set_logger_provider
from opentelemetry.metrics import set_meter_provider
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource, SERVICE_NAME
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import set_tracer_provider
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
class DaprAgentsOTel:
"""
OpenTelemetry configuration for Dapr agents.
"""
def __init__(self, service_name: str = "", otlp_endpoint: str = ""):
# Configure OpenTelemetry
self.service_name = service_name
self.otlp_endpoint = otlp_endpoint
self.setup_resources()
def setup_resources(self):
"""
Set up the resource for OpenTelemetry.
"""
self._resource = Resource.create(
attributes={
SERVICE_NAME: str(self.service_name),
}
)
def create_and_instrument_meter_provider(
self,
otlp_endpoint: str = "",
) -> MeterProvider:
"""
Returns a `MeterProvider` that is configured to export metrics using the `PeriodicExportingMetricReader`
which means that metrics are exported periodically in the background. The interval can be set by
the environment variable `OTEL_METRIC_EXPORT_INTERVAL`. The default value is 60000ms (1 minute).
Also sets the global OpenTelemetry meter provider to the returned meter provider.
"""
# Ensure the endpoint is set correctly
endpoint = self._endpoint_validator(
endpoint=self.otlp_endpoint if otlp_endpoint == "" else otlp_endpoint,
telemetry_type="metrics",
)
metric_exporter = OTLPMetricExporter(endpoint=str(endpoint))
metric_reader = PeriodicExportingMetricReader(metric_exporter)
meter_provider = MeterProvider(
resource=self._resource, metric_readers=[metric_reader]
)
set_meter_provider(meter_provider)
return meter_provider
def create_and_instrument_tracer_provider(
self,
otlp_endpoint: str = "",
) -> TracerProvider:
"""
Returns a `TracerProvider` that is configured to export traces using the `BatchSpanProcessor`
which means that traces are exported in batches. The batch size can be set by
the environment variable `OTEL_TRACES_EXPORT_BATCH_SIZE`. The default value is 512.
Also sets the global OpenTelemetry tracer provider to the returned tracer provider.
"""
# Ensure the endpoint is set correctly
endpoint = self._endpoint_validator(
endpoint=self.otlp_endpoint if otlp_endpoint == "" else otlp_endpoint,
telemetry_type="traces",
)
trace_exporter = OTLPSpanExporter(endpoint=str(endpoint))
tracer_processor = BatchSpanProcessor(trace_exporter)
tracer_provider = TracerProvider(resource=self._resource)
tracer_provider.add_span_processor(tracer_processor)
set_tracer_provider(tracer_provider)
return tracer_provider
def create_and_instrument_logging_provider(
self,
logger: Logger,
otlp_endpoint: str = "",
) -> LoggerProvider:
"""
Returns a `LoggingProvider` that is configured to export logs using the `BatchLogProcessor`
which means that logs are exported in batches. The batch size can be set by
the environment variable `OTEL_LOGS_EXPORT_BATCH_SIZE`. The default value is 512.
Also sets the global OpenTelemetry logging provider to the returned logging provider.
"""
# Ensure the endpoint is set correctly
endpoint = self._endpoint_validator(
endpoint=self.otlp_endpoint if otlp_endpoint == "" else otlp_endpoint,
telemetry_type="logs",
)
log_exporter = OTLPLogExporter(endpoint=str(endpoint))
logging_provider = LoggerProvider(resource=self._resource)
logging_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
set_logger_provider(logging_provider)
handler = LoggingHandler(logger_provider=logging_provider)
logger.addHandler(handler)
return logging_provider
def _endpoint_validator(
self,
endpoint: str,
telemetry_type: str,
) -> Union[str | Exception]:
"""
Validates the endpoint and method.
"""
if endpoint == "":
raise ValueError(
"OTLP endpoint must be set either in the environment variable OTEL_EXPORTER_OTLP_ENDPOINT or in the constructor."
)
if endpoint.startswith("https://"):
raise NotImplementedError(
"OTLP over HTTPS is not supported. Please use HTTP."
)
endpoint = (
endpoint
if endpoint.endswith(f"/v1/{telemetry_type}")
else f"{endpoint}/v1/{telemetry_type}"
)
endpoint = endpoint if endpoint.startswith("http://") else f"http://{endpoint}"
return endpoint

View File

@ -1,28 +1,33 @@
import requests import requests
import os import os
def construct_auth_headers(auth_url, grant_type='client_credentials', **kwargs):
def construct_auth_headers(auth_url, grant_type="client_credentials", **kwargs):
""" """
Construct authorization headers for API requests. Construct authorization headers for API requests.
:param auth_url: The authorization URL. :param auth_url: The authorization URL.
:param grant_type: The type of OAuth grant (default is 'client_credentials'). :param grant_type: The type of OAuth grant (default is 'client_credentials').
:param kwargs: Additional parameters for the POST request body. :param kwargs: Additional parameters for the POST request body.
:return: A dictionary containing the Authorization header. :return: A dictionary containing the Authorization header.
""" """
# Define default parameters based on the grant_type # Define default parameters based on the grant_type
data = { data = {
'grant_type': grant_type, "grant_type": grant_type,
} }
# Defaults for client_credentials grant type # Defaults for client_credentials grant type
if grant_type == 'client_credentials': if grant_type == "client_credentials":
data.update({ data.update(
'client_id': kwargs.get('client_id', os.getenv('CLIENT_ID')), {
'client_secret': kwargs.get('client_secret', os.getenv('CLIENT_SECRET')), "client_id": kwargs.get("client_id", os.getenv("CLIENT_ID")),
}) "client_secret": kwargs.get(
"client_secret", os.getenv("CLIENT_SECRET")
),
}
)
# Add any additional data passed in kwargs # Add any additional data passed in kwargs
data.update(kwargs) data.update(kwargs)
@ -37,9 +42,9 @@ def construct_auth_headers(auth_url, grant_type='client_credentials', **kwargs):
auth_response_data = auth_response.json() auth_response_data = auth_response.json()
# Extract the access token # Extract the access token
access_token = auth_response_data.get('access_token') access_token = auth_response_data.get("access_token")
if not access_token: if not access_token:
raise ValueError("No access token found in the response") raise ValueError("No access token found in the response")
return {"Authorization": f"Bearer {access_token}"} return {"Authorization": f"Bearer {access_token}"}

View File

@ -8,16 +8,18 @@ from dapr_agents.memory import MemoryBase
from dapr_agents.tool import AgentTool from dapr_agents.tool import AgentTool
from typing import Optional, List, Union, Type, TypeVar from typing import Optional, List, Union, Type, TypeVar
T = TypeVar('T', ToolCallAgent, ReActAgent, OpenAPIReActAgent) T = TypeVar("T", ToolCallAgent, ReActAgent, OpenAPIReActAgent)
class AgentFactory: class AgentFactory:
""" """
Returns agent classes based on the provided pattern. Returns agent classes based on the provided pattern.
""" """
AGENT_PATTERNS = { AGENT_PATTERNS = {
"react": ReActAgent, "react": ReActAgent,
"toolcalling": ToolCallAgent, "toolcalling": ToolCallAgent,
"openapireact": OpenAPIReActAgent "openapireact": OpenAPIReActAgent,
} }
@staticmethod @staticmethod
@ -54,7 +56,7 @@ class Agent(AgentBase):
llm: Optional[LLMClientBase] = None, llm: Optional[LLMClientBase] = None,
memory: Optional[MemoryBase] = None, memory: Optional[MemoryBase] = None,
tools: Optional[List[AgentTool]] = [], tools: Optional[List[AgentTool]] = [],
**kwargs **kwargs,
) -> Union[ToolCallAgent, ReActAgent, OpenAPIReActAgent]: ) -> Union[ToolCallAgent, ReActAgent, OpenAPIReActAgent]:
""" """
Creates and returns an instance of the selected agent class. Creates and returns an instance of the selected agent class.
@ -77,11 +79,21 @@ class Agent(AgentBase):
memory = memory or ConversationListMemory() memory = memory or ConversationListMemory()
if pattern == "openapireact": if pattern == "openapireact":
kwargs.update({ kwargs.update(
"spec_parser": kwargs.get('spec_parser', OpenAPISpecParser()), {
"auth_header": kwargs.get('auth_header', {}) "spec_parser": kwargs.get("spec_parser", OpenAPISpecParser()),
}) "auth_header": kwargs.get("auth_header", {}),
}
)
instance = super().__new__(agent_class) instance = super().__new__(agent_class)
agent_class.__init__(instance, role=role, name=name, llm=llm, memory=memory, tools=tools, **kwargs) agent_class.__init__(
return instance instance,
role=role,
name=name,
llm=llm,
memory=memory,
tools=tools,
**kwargs,
)
return instance

View File

@ -2,6 +2,7 @@ from dapr_agents.types import BaseMessage
from typing import List from typing import List
from pydantic import ValidationError from pydantic import ValidationError
def messages_to_string(messages: List[BaseMessage]) -> str: def messages_to_string(messages: List[BaseMessage]) -> str:
""" """
Converts messages into a single string with roles and content. Converts messages into a single string with roles and content.
@ -36,4 +37,4 @@ def messages_to_string(messages: List[BaseMessage]) -> str:
except (ValidationError, ValueError) as e: except (ValidationError, ValueError) as e:
raise ValueError(f"Invalid message in chat history. Error: {e}") raise ValueError(f"Invalid message in chat history. Error: {e}")
return "\n".join(formatted_history) return "\n".join(formatted_history)

View File

@ -4,14 +4,15 @@ from colorama import Style
# Define your custom colors as a dictionary # Define your custom colors as a dictionary
COLORS = { COLORS = {
"dapr_agents_teal": '\033[38;2;147;191;183m', "dapr_agents_teal": "\033[38;2;147;191;183m",
"dapr_agents_mustard": '\033[38;2;242;182;128m', "dapr_agents_mustard": "\033[38;2;242;182;128m",
"dapr_agents_red": '\033[38;2;217;95;118m', "dapr_agents_red": "\033[38;2;217;95;118m",
"dapr_agents_pink": '\033[38;2;191;69;126m', "dapr_agents_pink": "\033[38;2;191;69;126m",
"dapr_agents_purple": '\033[38;2;146;94;130m', "dapr_agents_purple": "\033[38;2;146;94;130m",
"reset": Style.RESET_ALL "reset": Style.RESET_ALL,
} }
class ColorTextFormatter: class ColorTextFormatter:
""" """
A flexible text formatter class to print colored text dynamically. A flexible text formatter class to print colored text dynamically.
@ -40,7 +41,7 @@ class ColorTextFormatter:
""" """
color_code = COLORS.get(color, self.default_color) color_code = COLORS.get(color, self.default_color)
return f"{color_code}{text}{COLORS['reset']}" return f"{color_code}{text}{COLORS['reset']}"
def print_colored_text(self, text_blocks: list[tuple[str, Optional[str]]]): def print_colored_text(self, text_blocks: list[tuple[str, Optional[str]]]):
""" """
Print multiple blocks of text in specified colors dynamically, ensuring that newlines Print multiple blocks of text in specified colors dynamically, ensuring that newlines
@ -55,9 +56,9 @@ class ColorTextFormatter:
for i, line in enumerate(lines): for i, line in enumerate(lines):
formatted_line = self.format_text(line, color) formatted_line = self.format_text(line, color)
print(formatted_line, end="\n" if i < len(lines) - 1 else "") print(formatted_line, end="\n" if i < len(lines) - 1 else "")
print(COLORS['reset']) # Ensure terminal color is reset at the end print(COLORS["reset"]) # Ensure terminal color is reset at the end
def print_separator(self): def print_separator(self):
""" """
Prints a separator line. Prints a separator line.
@ -65,13 +66,17 @@ class ColorTextFormatter:
separator = "-" * 80 separator = "-" * 80
self.print_colored_text([(f"\n{separator}\n", "reset")]) self.print_colored_text([(f"\n{separator}\n", "reset")])
def print_message(self, message: Union[BaseMessage, Dict[str, Any]], include_separator: bool = True): def print_message(
self,
message: Union[BaseMessage, Dict[str, Any]],
include_separator: bool = True,
):
""" """
Prints messages with colored formatting based on the role and message content. Prints messages with colored formatting based on the role and message content.
Args: Args:
message (Union[BaseMessage, Dict[str, Any]]): The message content, either as a BaseMessage object or message (Union[BaseMessage, Dict[str, Any]]): The message content, either as a BaseMessage object or
a dictionary. If a BaseMessage is provided, it will be a dictionary. If a BaseMessage is provided, it will be
converted to a dictionary using its `model_dump` method. converted to a dictionary using its `model_dump` method.
include_separator (bool): Whether to include a separator line after the message. Defaults to True. include_separator (bool): Whether to include a separator line after the message. Defaults to True.
""" """
@ -86,14 +91,14 @@ class ColorTextFormatter:
formatted_role = f"{name}({role})" if name else role formatted_role = f"{name}({role})" if name else role
content = message.get("content", "") content = message.get("content", "")
color_map = { color_map = {
"user": "dapr_agents_mustard", "user": "dapr_agents_mustard",
"assistant": "dapr_agents_teal", "assistant": "dapr_agents_teal",
"tool_calls": "dapr_agents_red", "tool_calls": "dapr_agents_red",
"tool": "dapr_agents_pink" "tool": "dapr_agents_pink",
} }
# Handle tool calls # Handle tool calls
if "tool_calls" in message and message["tool_calls"]: if "tool_calls" in message and message["tool_calls"]:
tool_calls = message["tool_calls"] tool_calls = message["tool_calls"]
@ -103,13 +108,16 @@ class ColorTextFormatter:
tool_id = tool_call["id"] tool_id = tool_call["id"]
tool_call_text = [ tool_call_text = [
(f"{formatted_role}:\n", color_map["tool_calls"]), (f"{formatted_role}:\n", color_map["tool_calls"]),
(f"Function name: {function_name} (Call Id: {tool_id})\n", color_map["tool_calls"]), (
f"Function name: {function_name} (Call Id: {tool_id})\n",
color_map["tool_calls"],
),
(f"Arguments: {arguments}", color_map["tool_calls"]), (f"Arguments: {arguments}", color_map["tool_calls"]),
] ]
self.print_colored_text(tool_call_text) self.print_colored_text(tool_call_text)
if include_separator: if include_separator:
self.print_separator() self.print_separator()
elif role == "tool": elif role == "tool":
# Handle tool messages # Handle tool messages
tool_call_id = message.get("tool_call_id", "Unknown") tool_call_id = message.get("tool_call_id", "Unknown")
@ -130,7 +138,7 @@ class ColorTextFormatter:
self.print_colored_text(regular_message_text) self.print_colored_text(regular_message_text)
if include_separator: if include_separator:
self.print_separator() self.print_separator()
def print_react_part(self, part_type: str, content: str): def print_react_part(self, part_type: str, content: str):
""" """
Prints a part of the ReAct loop (Thought, Action, Observation) with the corresponding color. Prints a part of the ReAct loop (Thought, Action, Observation) with the corresponding color.
@ -142,11 +150,11 @@ class ColorTextFormatter:
color_map = { color_map = {
"Thought": "dapr_agents_red", "Thought": "dapr_agents_red",
"Action": "dapr_agents_pink", "Action": "dapr_agents_pink",
"Observation": "dapr_agents_purple" "Observation": "dapr_agents_purple",
} }
# Get the color for the part type, defaulting to reset if not found # Get the color for the part type, defaulting to reset if not found
color = color_map.get(part_type, "reset") color = color_map.get(part_type, "reset")
# Print the part with the specified color # Print the part with the specified color
self.print_colored_text([(f"{part_type}: {content}", color)]) self.print_colored_text([(f"{part_type}: {content}", color)])

View File

@ -1,4 +1,4 @@
from .fetcher import ArxivFetcher from .fetcher import ArxivFetcher
from .reader import PyMuPDFReader, PyPDFReader from .reader import PyMuPDFReader, PyPDFReader
from .splitter import TextSplitter from .splitter import TextSplitter
from .embedder import OpenAIEmbedder, SentenceTransformerEmbedder, NVIDIAEmbedder from .embedder import OpenAIEmbedder, SentenceTransformerEmbedder, NVIDIAEmbedder

View File

@ -1,3 +1,3 @@
from .openai import OpenAIEmbedder from .openai import OpenAIEmbedder
from .sentence import SentenceTransformerEmbedder from .sentence import SentenceTransformerEmbedder
from .nvidia import NVIDIAEmbedder from .nvidia import NVIDIAEmbedder

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Any from typing import List, Any
class EmbedderBase(BaseModel, ABC): class EmbedderBase(BaseModel, ABC):
""" """
Abstract base class for Embedders. Abstract base class for Embedders.
@ -19,4 +20,4 @@ class EmbedderBase(BaseModel, ABC):
Returns: Returns:
List[Any]: A list of results. List[Any]: A list of results.
""" """
pass pass

View File

@ -7,6 +7,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase): class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
""" """
NVIDIA-based embedder for generating text embeddings with support for indexing (passage) and querying. NVIDIA-based embedder for generating text embeddings with support for indexing (passage) and querying.
@ -17,10 +18,16 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
normalize (bool): Whether to normalize embeddings. Defaults to True. normalize (bool): Whether to normalize embeddings. Defaults to True.
""" """
chunk_size: int = Field(default=1000, description="Batch size for embedding requests.") chunk_size: int = Field(
normalize: bool = Field(default=True, description="Whether to normalize embeddings.") default=1000, description="Batch size for embedding requests."
)
normalize: bool = Field(
default=True, description="Whether to normalize embeddings."
)
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: def embed(
self, input: Union[str, List[str]]
) -> Union[List[float], List[List[float]]]:
""" """
Embeds input text(s) for indexing with default input_type set to 'passage'. Embeds input text(s) for indexing with default input_type set to 'passage'.
@ -37,7 +44,9 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
""" """
return self._generate_embeddings(input, input_type="passage") return self._generate_embeddings(input, input_type="passage")
def embed_query(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: def embed_query(
self, input: Union[str, List[str]]
) -> Union[List[float], List[List[float]]]:
""" """
Embeds input text(s) for querying with input_type set to 'query'. Embeds input text(s) for querying with input_type set to 'query'.
@ -54,7 +63,9 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
""" """
return self._generate_embeddings(input, input_type="query") return self._generate_embeddings(input, input_type="query")
def _generate_embeddings(self, input: Union[str, List[str]], input_type: str) -> Union[List[float], List[List[float]]]: def _generate_embeddings(
self, input: Union[str, List[str]], input_type: str
) -> Union[List[float], List[List[float]]]:
""" """
Helper function to generate embeddings for given input text(s) with specified input_type. Helper function to generate embeddings for given input text(s) with specified input_type.
@ -75,14 +86,15 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
# Process input in chunks for efficiency # Process input in chunks for efficiency
chunk_embeddings = [] chunk_embeddings = []
for i in range(0, len(input_list), self.chunk_size): for i in range(0, len(input_list), self.chunk_size):
batch = input_list[i:i + self.chunk_size] batch = input_list[i : i + self.chunk_size]
response = self.create_embedding(input=batch, input_type=input_type) response = self.create_embedding(input=batch, input_type=input_type)
chunk_embeddings.extend(r.embedding for r in response.data) chunk_embeddings.extend(r.embedding for r in response.data)
# Normalize embeddings if required # Normalize embeddings if required
if self.normalize: if self.normalize:
normalized_embeddings = [ normalized_embeddings = [
(embedding / np.linalg.norm(embedding)).tolist() for embedding in chunk_embeddings (embedding / np.linalg.norm(embedding)).tolist()
for embedding in chunk_embeddings
] ]
else: else:
normalized_embeddings = chunk_embeddings normalized_embeddings = chunk_embeddings
@ -90,7 +102,9 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
# Return a single embedding if the input was a single string; otherwise, return a list # Return a single embedding if the input was a single string; otherwise, return a list
return normalized_embeddings[0] if single_input else normalized_embeddings return normalized_embeddings[0] if single_input else normalized_embeddings
def __call__(self, input: Union[str, List[str]], query: bool = False) -> Union[List[float], List[List[float]]]: def __call__(
self, input: Union[str, List[str]], query: bool = False
) -> Union[List[float], List[List[float]]]:
""" """
Allows the instance to be called directly to embed text(s). Allows the instance to be called directly to embed text(s).
@ -103,4 +117,4 @@ class NVIDIAEmbedder(NVIDIAEmbeddingClient, EmbedderBase):
""" """
if query: if query:
return self.embed_query(input) return self.embed_query(input)
return self.embed(input) return self.embed(input)

View File

@ -7,20 +7,31 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase): class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
""" """
OpenAI-based embedder for generating text embeddings with handling for long inputs. OpenAI-based embedder for generating text embeddings with handling for long inputs.
Inherits functionality from OpenAIEmbeddingClient for API interactions. Inherits functionality from OpenAIEmbeddingClient for API interactions.
""" """
max_tokens: int = Field(default=8191, description="Maximum tokens allowed per input.") max_tokens: int = Field(
chunk_size: int = Field(default=1000, description="Batch size for embedding requests.") default=8191, description="Maximum tokens allowed per input."
normalize: bool = Field(default=True, description="Whether to normalize embeddings.") )
encoding_name: Optional[str] = Field(default=None, description="Token encoding name (if provided).") chunk_size: int = Field(
encoder: Optional[Any] = Field(default=None, init=False, description="TikToken Encoder") default=1000, description="Batch size for embedding requests."
)
normalize: bool = Field(
default=True, description="Whether to normalize embeddings."
)
encoding_name: Optional[str] = Field(
default=None, description="Token encoding name (if provided)."
)
encoder: Optional[Any] = Field(
default=None, init=False, description="TikToken Encoder"
)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
""" """
Initialize attributes after model validation. Initialize attributes after model validation.
@ -59,9 +70,13 @@ class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
def _chunk_tokens(self, tokens: List[int], chunk_length: int) -> List[List[int]]: def _chunk_tokens(self, tokens: List[int], chunk_length: int) -> List[List[int]]:
"""Splits tokens into chunks of the specified length.""" """Splits tokens into chunks of the specified length."""
return [tokens[i:i + chunk_length] for i in range(0, len(tokens), chunk_length)] return [
tokens[i : i + chunk_length] for i in range(0, len(tokens), chunk_length)
]
def _process_embeddings(self, embeddings: List[List[float]], weights: List[int]) -> List[float]: def _process_embeddings(
self, embeddings: List[List[float]], weights: List[int]
) -> List[float]:
"""Combines embeddings using weighted averaging.""" """Combines embeddings using weighted averaging."""
weighted_avg = np.average(embeddings, axis=0, weights=weights) weighted_avg = np.average(embeddings, axis=0, weights=weights)
if self.normalize: if self.normalize:
@ -69,7 +84,9 @@ class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
return (weighted_avg / norm).tolist() return (weighted_avg / norm).tolist()
return weighted_avg.tolist() return weighted_avg.tolist()
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: def embed(
self, input: Union[str, List[str]]
) -> Union[List[float], List[List[float]]]:
""" """
Embeds input text(s) with support for both single and multiple inputs, handling long texts via chunking and batching. Embeds input text(s) with support for both single and multiple inputs, handling long texts via chunking and batching.
@ -116,7 +133,7 @@ class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
chunk_embeddings = [] # Holds embeddings for all chunks chunk_embeddings = [] # Holds embeddings for all chunks
for i in range(0, len(chunks), batch_size): for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size] batch = chunks[i : i + batch_size]
response = self.create_embedding(input=batch) # Batch API call response = self.create_embedding(input=batch) # Batch API call
chunk_embeddings.extend(r.embedding for r in response.data) chunk_embeddings.extend(r.embedding for r in response.data)
@ -133,19 +150,23 @@ class OpenAIEmbedder(OpenAIEmbeddingClient, EmbedderBase):
results.append(embeddings[0]) results.append(embeddings[0])
else: else:
# Combine chunk embeddings using weighted averaging # Combine chunk embeddings using weighted averaging
weights = [len(chunk) for chunk in self._chunk_tokens(tokens, self.max_tokens)] weights = [
len(chunk) for chunk in self._chunk_tokens(tokens, self.max_tokens)
]
results.append(self._process_embeddings(embeddings, weights)) results.append(self._process_embeddings(embeddings, weights))
# Return a single embedding if the input was a single string; otherwise, return a list # Return a single embedding if the input was a single string; otherwise, return a list
return results[0] if single_input else results return results[0] if single_input else results
def __call__(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: def __call__(
self, input: Union[str, List[str]]
) -> Union[List[float], List[List[float]]]:
""" """
Allows the instance to be called directly to embed text(s). Allows the instance to be called directly to embed text(s).
Args: Args:
input (Union[str, List[str]]): The input text(s) to embed. input (Union[str, List[str]]): The input text(s) to embed.
Returns: Returns:
Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s). Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s).
""" """

View File

@ -6,19 +6,33 @@ import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SentenceTransformerEmbedder(EmbedderBase): class SentenceTransformerEmbedder(EmbedderBase):
""" """
SentenceTransformer-based embedder for generating text embeddings. SentenceTransformer-based embedder for generating text embeddings.
Supports multi-process encoding for large datasets. Supports multi-process encoding for large datasets.
""" """
model: str = Field(default="all-MiniLM-L6-v2", description="Name of the SentenceTransformer model to use.") model: str = Field(
device: Literal["cpu", "cuda", "mps", "npu"] = Field(default="cpu", description="Device for computation.") default="all-MiniLM-L6-v2",
normalize_embeddings: bool = Field(default=False, description="Whether to normalize embeddings.") description="Name of the SentenceTransformer model to use.",
multi_process: bool = Field(default=False, description="Whether to use multi-process encoding.") )
cache_dir: Optional[str] = Field(default=None, description="Directory to cache or load the model.") device: Literal["cpu", "cuda", "mps", "npu"] = Field(
default="cpu", description="Device for computation."
client: Optional[Any] = Field(default=None, init=False, description="Loaded SentenceTransformer model.") )
normalize_embeddings: bool = Field(
default=False, description="Whether to normalize embeddings."
)
multi_process: bool = Field(
default=False, description="Whether to use multi-process encoding."
)
cache_dir: Optional[str] = Field(
default=None, description="Directory to cache or load the model."
)
client: Optional[Any] = Field(
default=None, init=False, description="Loaded SentenceTransformer model."
)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
""" """
@ -35,26 +49,40 @@ class SentenceTransformerEmbedder(EmbedderBase):
) )
# Determine whether to load from cache or download # Determine whether to load from cache or download
model_path = self.cache_dir if self.cache_dir and os.path.exists(self.cache_dir) else self.model model_path = (
self.cache_dir
if self.cache_dir and os.path.exists(self.cache_dir)
else self.model
)
# Attempt to load the model # Attempt to load the model
try: try:
if os.path.exists(model_path): if os.path.exists(model_path):
logger.info(f"Loading SentenceTransformer model from local path: {model_path}") logger.info(
f"Loading SentenceTransformer model from local path: {model_path}"
)
else: else:
logger.info(f"Downloading SentenceTransformer model: {self.model}") logger.info(f"Downloading SentenceTransformer model: {self.model}")
if self.cache_dir: if self.cache_dir:
logger.info(f"Model will be cached to: {self.cache_dir}") logger.info(f"Model will be cached to: {self.cache_dir}")
self.client: SentenceTransformer = SentenceTransformer(model_name_or_path=model_path, device=self.device) self.client: SentenceTransformer = SentenceTransformer(
model_name_or_path=model_path, device=self.device
)
logger.info("Model loaded successfully.") logger.info("Model loaded successfully.")
except Exception as e: except Exception as e:
logger.error(f"Failed to load SentenceTransformer model: {e}") logger.error(f"Failed to load SentenceTransformer model: {e}")
raise raise
# Save to cache directory if downloaded # Save to cache directory if downloaded
if model_path == self.model and self.cache_dir and not os.path.exists(self.cache_dir): if (
model_path == self.model
and self.cache_dir
and not os.path.exists(self.cache_dir)
):
logger.info(f"Saving the downloaded model to: {self.cache_dir}") logger.info(f"Saving the downloaded model to: {self.cache_dir}")
self.client.save(self.cache_dir) self.client.save(self.cache_dir)
def embed(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: def embed(
self, input: Union[str, List[str]]
) -> Union[List[float], List[List[float]]]:
""" """
Generate embeddings for input text(s). Generate embeddings for input text(s).
@ -82,7 +110,7 @@ class SentenceTransformerEmbedder(EmbedderBase):
embeddings = self.client.encode_multi_process( embeddings = self.client.encode_multi_process(
input_strings, input_strings,
pool=pool, pool=pool,
normalize_embeddings=self.normalize_embeddings normalize_embeddings=self.normalize_embeddings,
) )
finally: finally:
logger.info("Stopping multi-process pool.") logger.info("Stopping multi-process pool.")
@ -91,14 +119,16 @@ class SentenceTransformerEmbedder(EmbedderBase):
embeddings = self.client.encode( embeddings = self.client.encode(
input_strings, input_strings,
convert_to_numpy=True, convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings normalize_embeddings=self.normalize_embeddings,
) )
if single_input: if single_input:
return embeddings[0].tolist() return embeddings[0].tolist()
return embeddings.tolist() return embeddings.tolist()
def __call__(self, input: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: def __call__(
self, input: Union[str, List[str]]
) -> Union[List[float], List[List[float]]]:
""" """
Allows the instance to be called directly to embed text(s). Allows the instance to be called directly to embed text(s).
@ -108,4 +138,4 @@ class SentenceTransformerEmbedder(EmbedderBase):
Returns: Returns:
Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s). Union[List[float], List[List[float]]]: Embedding vector(s) for the input(s).
""" """
return self.embed(input) return self.embed(input)

View File

@ -1 +1 @@
from .arxiv import ArxivFetcher from .arxiv import ArxivFetcher

View File

@ -16,7 +16,7 @@ class ArxivFetcher(FetcherBase):
max_results: int = 10 max_results: int = 10
include_full_metadata: bool = False include_full_metadata: bool = False
def search( def search(
self, self,
query: str, query: str,
@ -25,7 +25,7 @@ class ArxivFetcher(FetcherBase):
download: bool = False, download: bool = False,
dirpath: Path = Path("./"), dirpath: Path = Path("./"),
include_summary: bool = False, include_summary: bool = False,
**kwargs **kwargs,
) -> Union[List[Dict], List["Document"]]: ) -> Union[List[Dict], List["Document"]]:
""" """
Search for papers on arXiv and optionally download them. Search for papers on arXiv and optionally download them.
@ -64,12 +64,14 @@ class ArxivFetcher(FetcherBase):
"The `arxiv` library is required to use the ArxivFetcher. " "The `arxiv` library is required to use the ArxivFetcher. "
"Install it with `pip install arxiv`." "Install it with `pip install arxiv`."
) )
logger.info(f"Searching for query: {query}") logger.info(f"Searching for query: {query}")
# Enforce that both from_date and to_date are provided if one is specified # Enforce that both from_date and to_date are provided if one is specified
if (from_date and not to_date) or (to_date and not from_date): if (from_date and not to_date) or (to_date and not from_date):
raise ValueError("Both 'from_date' and 'to_date' must be specified if one is provided.") raise ValueError(
"Both 'from_date' and 'to_date' must be specified if one is provided."
)
# Add date filter if both from_date and to_date are provided # Add date filter if both from_date and to_date are provided
if from_date and to_date: if from_date and to_date:
@ -94,7 +96,7 @@ class ArxivFetcher(FetcherBase):
content_id: str, content_id: str,
download: bool = False, download: bool = False,
dirpath: Path = Path("./"), dirpath: Path = Path("./"),
include_summary: bool = False include_summary: bool = False,
) -> Union[Optional[Dict], Optional[Document]]: ) -> Union[Optional[Dict], Optional[Document]]:
""" """
Search for a specific paper by its arXiv ID and optionally download it. Search for a specific paper by its arXiv ID and optionally download it.
@ -124,7 +126,7 @@ class ArxivFetcher(FetcherBase):
"The `arxiv` library is required to use the ArxivFetcher. " "The `arxiv` library is required to use the ArxivFetcher. "
"Install it with `pip install arxiv`." "Install it with `pip install arxiv`."
) )
logger.info(f"Searching for paper by ID: {content_id}") logger.info(f"Searching for paper by ID: {content_id}")
try: try:
search = arxiv.Search(id_list=[content_id]) search = arxiv.Search(id_list=[content_id])
@ -133,17 +135,15 @@ class ArxivFetcher(FetcherBase):
logger.warning(f"No result found for ID: {content_id}") logger.warning(f"No result found for ID: {content_id}")
return None return None
return self._process_results([result], download, dirpath, include_summary)[0] return self._process_results([result], download, dirpath, include_summary)[
0
]
except Exception as e: except Exception as e:
logger.error(f"Error fetching result for ID {content_id}: {e}") logger.error(f"Error fetching result for ID {content_id}: {e}")
return None return None
def _process_results( def _process_results(
self, self, results: List[Any], download: bool, dirpath: Path, include_summary: bool
results: List[Any],
download: bool,
dirpath: Path,
include_summary: bool
) -> Union[List[Dict], List["Document"]]: ) -> Union[List[Dict], List["Document"]]:
""" """
Process arXiv search results. Process arXiv search results.
@ -162,16 +162,22 @@ class ArxivFetcher(FetcherBase):
metadata_list = [] metadata_list = []
for result in results: for result in results:
file_path = self._download_result(result, dirpath) file_path = self._download_result(result, dirpath)
metadata_list.append(self._format_result_metadata(result, file_path=file_path, include_summary=include_summary)) metadata_list.append(
self._format_result_metadata(
result, file_path=file_path, include_summary=include_summary
)
)
return metadata_list return metadata_list
else: else:
documents = [] documents = []
for result in results: for result in results:
metadata = self._format_result_metadata(result, include_summary=include_summary) metadata = self._format_result_metadata(
result, include_summary=include_summary
)
text = result.summary.strip() text = result.summary.strip()
documents.append(Document(text=text, metadata=metadata)) documents.append(Document(text=text, metadata=metadata))
return documents return documents
def _download_result(self, result: Any, dirpath: Path) -> Optional[str]: def _download_result(self, result: Any, dirpath: Path) -> Optional[str]:
""" """
Download a paper from an arXiv result object. Download a paper from an arXiv result object.
@ -194,7 +200,12 @@ class ArxivFetcher(FetcherBase):
logger.error(f"Failed to download paper {result.title}: {e}") logger.error(f"Failed to download paper {result.title}: {e}")
return None return None
def _format_result_metadata(self, result: Any, file_path: Optional[str] = None, include_summary: bool = False) -> Dict: def _format_result_metadata(
self,
result: Any,
file_path: Optional[str] = None,
include_summary: bool = False,
) -> Dict:
""" """
Format metadata from an arXiv result, optionally including file path and summary. Format metadata from an arXiv result, optionally including file path and summary.
@ -219,24 +230,26 @@ class ArxivFetcher(FetcherBase):
} }
if self.include_full_metadata: if self.include_full_metadata:
metadata.update({ metadata.update(
"links": result.links, {
"authors_comment": result.comment, "links": result.links,
"DOI": result.doi, "authors_comment": result.comment,
"journal_reference": result.journal_ref, "DOI": result.doi,
}) "journal_reference": result.journal_ref,
}
)
if include_summary: if include_summary:
metadata["summary"] = result.summary.strip() metadata["summary"] = result.summary.strip()
return {key: value for key, value in metadata.items() if value is not None} return {key: value for key, value in metadata.items() if value is not None}
def _format_date(self, date: Union[str, datetime]) -> str: def _format_date(self, date: Union[str, datetime]) -> str:
""" """
Format a date into the 'YYYYMMDDHHMM' format required by the arXiv API. Format a date into the 'YYYYMMDDHHMM' format required by the arXiv API.
Args: Args:
date (Union[str, datetime]): The date to format. Can be a string in 'YYYYMMDD' or date (Union[str, datetime]): The date to format. Can be a string in 'YYYYMMDD' or
'YYYYMMDDHHMM' format, or a datetime object. 'YYYYMMDDHHMM' format, or a datetime object.
Returns: Returns:
@ -262,7 +275,9 @@ class ArxivFetcher(FetcherBase):
if isinstance(date, str): if isinstance(date, str):
# Check if the string matches the basic format # Check if the string matches the basic format
if not re.fullmatch(r"^\d{8}(\d{4})?$", date): if not re.fullmatch(r"^\d{8}(\d{4})?$", date):
raise ValueError(f"Invalid date format: {date}. Use 'YYYYMMDD' or 'YYYYMMDDHHMM'.") raise ValueError(
f"Invalid date format: {date}. Use 'YYYYMMDD' or 'YYYYMMDDHHMM'."
)
# Validate that it is a real date # Validate that it is a real date
try: try:
@ -277,4 +292,6 @@ class ArxivFetcher(FetcherBase):
elif isinstance(date, datetime): elif isinstance(date, datetime):
return date.strftime("%Y%m%d%H%M") return date.strftime("%Y%m%d%H%M")
else: else:
raise ValueError("Invalid date input. Provide a string in 'YYYYMMDD', 'YYYYMMDDHHMM' format, or a datetime object.") raise ValueError(
"Invalid date input. Provide a string in 'YYYYMMDD', 'YYYYMMDDHHMM' format, or a datetime object."
)

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Any from typing import List, Any
class FetcherBase(BaseModel, ABC): class FetcherBase(BaseModel, ABC):
""" """
Abstract base class for fetchers. Abstract base class for fetchers.
@ -19,4 +20,4 @@ class FetcherBase(BaseModel, ABC):
Returns: Returns:
List[Any]: A list of results. List[Any]: A list of results.
""" """
pass pass

View File

@ -1 +1 @@
from .pdf import PyMuPDFReader, PyPDFReader from .pdf import PyMuPDFReader, PyPDFReader

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel
from pathlib import Path from pathlib import Path
from typing import List from typing import List
class ReaderBase(BaseModel, ABC): class ReaderBase(BaseModel, ABC):
""" """
Abstract base class for file readers. Abstract base class for file readers.
@ -20,4 +21,4 @@ class ReaderBase(BaseModel, ABC):
Returns: Returns:
List[Document]: A list of Document objects. List[Document]: A list of Document objects.
""" """
pass pass

View File

@ -1,2 +1,2 @@
from .pymupdf import PyMuPDFReader from .pymupdf import PyMuPDFReader
from .pypdf import PyPDFReader from .pypdf import PyPDFReader

View File

@ -9,7 +9,9 @@ class PyMuPDFReader(ReaderBase):
Reader for PDF documents using PyMuPDF. Reader for PDF documents using PyMuPDF.
""" """
def load(self, file_path: Path, additional_metadata: Optional[Dict] = None) -> List[Document]: def load(
self, file_path: Path, additional_metadata: Optional[Dict] = None
) -> List[Document]:
""" """
Load content from a PDF file using PyMuPDF. Load content from a PDF file using PyMuPDF.
@ -45,4 +47,4 @@ class PyMuPDFReader(ReaderBase):
documents.append(Document(text=text.strip(), metadata=metadata)) documents.append(Document(text=text.strip(), metadata=metadata))
doc.close() doc.close()
return documents return documents

View File

@ -9,7 +9,9 @@ class PyPDFReader(ReaderBase):
Reader for PDF documents using PyPDF. Reader for PDF documents using PyPDF.
""" """
def load(self, file_path: Path, additional_metadata: Optional[Dict] = None) -> List[Document]: def load(
self, file_path: Path, additional_metadata: Optional[Dict] = None
) -> List[Document]:
""" """
Load content from a PDF file using PyPDF. Load content from a PDF file using PyPDF.
@ -43,4 +45,4 @@ class PyPDFReader(ReaderBase):
documents.append(Document(text=text.strip(), metadata=metadata)) documents.append(Document(text=text.strip(), metadata=metadata))
return documents return documents

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from pydantic import Field from pydantic import Field
class TextLoader(ReaderBase): class TextLoader(ReaderBase):
""" """
Loader for plain text files. Loader for plain text files.
@ -11,6 +12,7 @@ class TextLoader(ReaderBase):
Attributes: Attributes:
encoding (str): The text file encoding. Defaults to 'utf-8'. encoding (str): The text file encoding. Defaults to 'utf-8'.
""" """
encoding: str = Field(default="utf-8", description="Encoding of the text file.") encoding: str = Field(default="utf-8", description="Encoding of the text file.")
def load(self, file_path: Path) -> List[Document]: def load(self, file_path: Path) -> List[Document]:
@ -31,4 +33,4 @@ class TextLoader(ReaderBase):
"file_path": str(file_path), "file_path": str(file_path),
"file_type": "text", "file_type": "text",
} }
return [Document(text=content, metadata=metadata)] return [Document(text=content, metadata=metadata)]

View File

@ -1,2 +1,2 @@
from .base import SplitterBase from .base import SplitterBase
from .text import TextSplitter from .text import TextSplitter

View File

@ -7,6 +7,7 @@ import logging
try: try:
from nltk.tokenize import sent_tokenize from nltk.tokenize import sent_tokenize
NLTK_AVAILABLE = True NLTK_AVAILABLE = True
except ImportError: except ImportError:
sent_tokenize = None sent_tokenize = None
@ -14,20 +15,42 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SplitterBase(BaseModel, ABC): class SplitterBase(BaseModel, ABC):
""" """
Base class for defining text splitting strategies. Base class for defining text splitting strategies.
Provides common utilities for breaking text into smaller chunks Provides common utilities for breaking text into smaller chunks
based on separators, regex patterns, or sentence-based splitting. based on separators, regex patterns, or sentence-based splitting.
""" """
chunk_size: int = Field(default=4000, description="Maximum size of chunks (in characters or tokens).", gt=0) chunk_size: int = Field(
chunk_overlap: int = Field(default=200, description="Overlap size between chunks for context continuity.", ge=0) default=4000,
chunk_size_function: Callable[[str], int] = Field(default=len, description="Function to calculate chunk size (e.g., by characters or tokens).") description="Maximum size of chunks (in characters or tokens).",
separator: Optional[str] = Field(default="\n\n", description="Primary separator for splitting text.") gt=0,
fallback_separators: List[str] = Field(default_factory=lambda: ["\n", " "], description="Fallback separators if the primary separator fails.") )
fallback_regex: str = Field(default=r"[^,.;。?!]+[,.;。?!]", description="Improved regex pattern for fallback splitting.") chunk_overlap: int = Field(
reserved_metadata_size: int = Field(default=0, description="Tokens reserved for metadata.", ge=0) default=200,
description="Overlap size between chunks for context continuity.",
ge=0,
)
chunk_size_function: Callable[[str], int] = Field(
default=len,
description="Function to calculate chunk size (e.g., by characters or tokens).",
)
separator: Optional[str] = Field(
default="\n\n", description="Primary separator for splitting text."
)
fallback_separators: List[str] = Field(
default_factory=lambda: ["\n", " "],
description="Fallback separators if the primary separator fails.",
)
fallback_regex: str = Field(
default=r"[^,.;。?!]+[,.;。?!]",
description="Improved regex pattern for fallback splitting.",
)
reserved_metadata_size: int = Field(
default=0, description="Tokens reserved for metadata.", ge=0
)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -55,12 +78,12 @@ class SplitterBase(BaseModel, ABC):
int: The size of the text chunk. int: The size of the text chunk.
""" """
return self.chunk_size_function(text) return self.chunk_size_function(text)
def _merge_splits(self, splits: List[str], max_size: int) -> List[str]: def _merge_splits(self, splits: List[str], max_size: int) -> List[str]:
""" """
Merge splits into chunks while ensuring size constraints and meaningful overlaps. Merge splits into chunks while ensuring size constraints and meaningful overlaps.
Unlike other implementations, this method prioritizes sentence boundaries Unlike other implementations, this method prioritizes sentence boundaries
when creating overlaps, ensuring that each chunk remains contextually meaningful. when creating overlaps, ensuring that each chunk remains contextually meaningful.
Args: Args:
@ -88,7 +111,9 @@ class SplitterBase(BaseModel, ABC):
chunks.append(full_chunk) chunks.append(full_chunk)
# Logging information for overlap and chunk size # Logging information for overlap and chunk size
logger.debug(f"Chunk {len(chunks)} finalized. Size: {current_size}. Overlap size: {self.chunk_overlap}") logger.debug(
f"Chunk {len(chunks)} finalized. Size: {current_size}. Overlap size: {self.chunk_overlap}"
)
# Create an overlap using sentences from the current chunk # Create an overlap using sentences from the current chunk
overlap = [] overlap = []
@ -122,12 +147,12 @@ class SplitterBase(BaseModel, ABC):
logger.debug(f"Chunk {len(chunks)} finalized. Size: {current_size}.") logger.debug(f"Chunk {len(chunks)} finalized. Size: {current_size}.")
return chunks return chunks
def _split_by_separators(self, text: str, separators: List[str]) -> List[str]: def _split_by_separators(self, text: str, separators: List[str]) -> List[str]:
""" """
Split text using a prioritized list of separators while keeping separators in chunks. Split text using a prioritized list of separators while keeping separators in chunks.
For each separator in the provided list, attempt to split the text. The separator For each separator in the provided list, attempt to split the text. The separator
is appended to each split except the last one to preserve structure. is appended to each split except the last one to preserve structure.
Args: Args:
@ -164,7 +189,7 @@ class SplitterBase(BaseModel, ABC):
if NLTK_AVAILABLE: if NLTK_AVAILABLE:
return sent_tokenize(text) return sent_tokenize(text)
return self._regex_split(text) return self._regex_split(text)
def _regex_split(self, text: str) -> List[str]: def _regex_split(self, text: str) -> List[str]:
""" """
Split text using the fallback regex, retaining separators. Split text using the fallback regex, retaining separators.
@ -200,7 +225,7 @@ class SplitterBase(BaseModel, ABC):
chunks = self._split_by_sentences(text) chunks = self._split_by_sentences(text)
return chunks return chunks
def split_documents(self, documents: List[Document]) -> List[Document]: def split_documents(self, documents: List[Document]) -> List[Document]:
""" """
Split documents into smaller chunks while retaining metadata. Split documents into smaller chunks while retaining metadata.
@ -223,14 +248,16 @@ class SplitterBase(BaseModel, ABC):
end_index = start_index + self._get_chunk_size(chunk) end_index = start_index + self._get_chunk_size(chunk)
metadata = doc.metadata.copy() if doc.metadata else {} metadata = doc.metadata.copy() if doc.metadata else {}
metadata.update({ metadata.update(
"chunk_number": chunk_num + 1, {
"total_chunks": len(text_chunks), "chunk_number": chunk_num + 1,
"start_index": start_index, "total_chunks": len(text_chunks),
"end_index": end_index, "start_index": start_index,
"chunk_length": self._get_chunk_size(chunk), "end_index": end_index,
}) "chunk_length": self._get_chunk_size(chunk),
}
)
chunked_documents.append(Document(metadata=metadata, text=chunk)) chunked_documents.append(Document(metadata=metadata, text=chunk))
previous_end = end_index previous_end = end_index
return chunked_documents return chunked_documents

View File

@ -4,6 +4,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextSplitter(SplitterBase): class TextSplitter(SplitterBase):
""" """
Concrete implementation of the SplitterBase class. Concrete implementation of the SplitterBase class.
@ -33,7 +34,9 @@ class TextSplitter(SplitterBase):
# Step 2: Short-circuit for small texts # Step 2: Short-circuit for small texts
if self._get_chunk_size(text) <= effective_chunk_size: if self._get_chunk_size(text) <= effective_chunk_size:
logger.debug("Text size is smaller than effective chunk size. Returning as a single chunk.") logger.debug(
"Text size is smaller than effective chunk size. Returning as a single chunk."
)
return [text] return [text]
# Step 3: Use adaptive splitting strategy # Step 3: Use adaptive splitting strategy
@ -44,4 +47,4 @@ class TextSplitter(SplitterBase):
merged_chunks = self._merge_splits(chunks, effective_chunk_size) merged_chunks = self._merge_splits(chunks, effective_chunk_size)
logger.debug(f"Merged into {len(merged_chunks)} chunks with overlap.") logger.debug(f"Merged into {len(merged_chunks)} chunks with overlap.")
return merged_chunks return merged_chunks

View File

@ -1,3 +1,3 @@
from .base import CodeExecutorBase from .base import CodeExecutorBase
from .local import LocalCodeExecutor from .local import LocalCodeExecutor
from .docker import DockerCodeExecutor from .docker import DockerCodeExecutor

View File

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, ClassVar from typing import List, ClassVar
class CodeExecutorBase(BaseModel, ABC): class CodeExecutorBase(BaseModel, ABC):
"""Abstract base class for executing code in different environments.""" """Abstract base class for executing code in different environments."""
@ -18,4 +19,4 @@ class CodeExecutorBase(BaseModel, ABC):
for snippet in snippets: for snippet in snippets:
if snippet.language not in self.SUPPORTED_LANGUAGES: if snippet.language not in self.SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {snippet.language}") raise ValueError(f"Unsupported language: {snippet.language}")
return True return True

View File

@ -11,26 +11,57 @@ import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DockerCodeExecutor(CodeExecutorBase): class DockerCodeExecutor(CodeExecutorBase):
"""Executes code securely inside a persistent Docker container with dynamic volume updates.""" """Executes code securely inside a persistent Docker container with dynamic volume updates."""
image: Optional[str] = Field("python:3.9", description="Docker image used for execution.") image: Optional[str] = Field(
container_name: Optional[str] = Field("dapr_agents_code_executor", description="Name of the Docker container.") "python:3.9", description="Docker image used for execution."
disable_network_access: bool = Field(default=True, description="Disable network access inside the container.") )
execution_timeout: int = Field(default=60, description="Max execution time (seconds).") container_name: Optional[str] = Field(
execution_mode: str = Field("detached", description="Execution mode: 'interactive' or 'detached'.") "dapr_agents_code_executor", description="Name of the Docker container."
restart_policy: str = Field("no", description="Container restart policy: 'no', 'on-failure', 'always'.") )
disable_network_access: bool = Field(
default=True, description="Disable network access inside the container."
)
execution_timeout: int = Field(
default=60, description="Max execution time (seconds)."
)
execution_mode: str = Field(
"detached", description="Execution mode: 'interactive' or 'detached'."
)
restart_policy: str = Field(
"no", description="Container restart policy: 'no', 'on-failure', 'always'."
)
max_memory: str = Field("500m", description="Max memory for execution.") max_memory: str = Field("500m", description="Max memory for execution.")
cpu_quota: int = Field(50000, description="CPU quota limit.") cpu_quota: int = Field(50000, description="CPU quota limit.")
runtime: Optional[str] = Field(default=None, description="Container runtime (e.g., 'nvidia').") runtime: Optional[str] = Field(
auto_remove: bool = Field(default=False, description="Keep container running to reuse it.") default=None, description="Container runtime (e.g., 'nvidia')."
auto_cleanup: bool = Field(default=False, description="Automatically clean up the workspace after execution.") )
volume_access_mode: Literal["ro", "rw"] = Field(default="ro", description="Access mode for the workspace volume.") auto_remove: bool = Field(
host_workspace: Optional[str] = Field(default=None, description="Custom workspace on host. If None, defaults to system temp dir.") default=False, description="Keep container running to reuse it."
)
auto_cleanup: bool = Field(
default=False,
description="Automatically clean up the workspace after execution.",
)
volume_access_mode: Literal["ro", "rw"] = Field(
default="ro", description="Access mode for the workspace volume."
)
host_workspace: Optional[str] = Field(
default=None,
description="Custom workspace on host. If None, defaults to system temp dir.",
)
docker_client: Optional[Any] = Field(default=None, init=False, description="Docker client instance.") docker_client: Optional[Any] = Field(
execution_container: Optional[Any] = Field(default=None, init=False, description="Persistent Docker container.") default=None, init=False, description="Docker client instance."
container_workspace: Optional[str] = Field(default="/workspace", init=False, description="Mounted workspace in container.") )
execution_container: Optional[Any] = Field(
default=None, init=False, description="Persistent Docker container."
)
container_workspace: Optional[str] = Field(
default="/workspace", init=False, description="Mounted workspace in container."
)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
"""Initializes the Docker client and ensures a reusable execution container is ready.""" """Initializes the Docker client and ensures a reusable execution container is ready."""
@ -38,7 +69,9 @@ class DockerCodeExecutor(CodeExecutorBase):
from docker import DockerClient from docker import DockerClient
from docker.errors import DockerException from docker.errors import DockerException
except ImportError as e: except ImportError as e:
raise ImportError("Install 'docker' package with 'pip install docker'.") from e raise ImportError(
"Install 'docker' package with 'pip install docker'."
) from e
try: try:
self.docker_client: DockerClient = DockerClient.from_env() self.docker_client: DockerClient = DockerClient.from_env()
@ -47,9 +80,13 @@ class DockerCodeExecutor(CodeExecutorBase):
# Validate or Set the Host Workspace # Validate or Set the Host Workspace
if self.host_workspace: if self.host_workspace:
self.host_workspace = os.path.abspath(self.host_workspace) # Ensure absolute path self.host_workspace = os.path.abspath(
self.host_workspace
) # Ensure absolute path
else: else:
self.host_workspace = os.path.join(tempfile.gettempdir(), "dapr_agents_executor_workspace") self.host_workspace = os.path.join(
tempfile.gettempdir(), "dapr_agents_executor_workspace"
)
# Ensure the directory exists # Ensure the directory exists
os.makedirs(self.host_workspace, exist_ok=True) os.makedirs(self.host_workspace, exist_ok=True)
@ -66,10 +103,14 @@ class DockerCodeExecutor(CodeExecutorBase):
try: try:
from docker.errors import NotFound from docker.errors import NotFound
except ImportError as e: except ImportError as e:
raise ImportError("Install 'docker' package with 'pip install docker'.") from e raise ImportError(
"Install 'docker' package with 'pip install docker'."
) from e
try: try:
self.execution_container = self.docker_client.containers.get(self.container_name) self.execution_container = self.docker_client.containers.get(
self.container_name
)
logger.info(f"Reusing existing container: {self.container_name}") logger.info(f"Reusing existing container: {self.container_name}")
except NotFound: except NotFound:
logger.info(f"Creating a new container: {self.container_name}") logger.info(f"Creating a new container: {self.container_name}")
@ -82,7 +123,9 @@ class DockerCodeExecutor(CodeExecutorBase):
try: try:
from docker.errors import DockerException, APIError from docker.errors import DockerException, APIError
except ImportError as e: except ImportError as e:
raise ImportError("Install 'docker' package with 'pip install docker'.") from e raise ImportError(
"Install 'docker' package with 'pip install docker'."
) from e
try: try:
self.execution_container = self.docker_client.containers.create( self.execution_container = self.docker_client.containers.create(
self.image, self.image,
@ -99,13 +142,22 @@ class DockerCodeExecutor(CodeExecutorBase):
restart_policy={"Name": self.restart_policy}, restart_policy={"Name": self.restart_policy},
runtime=self.runtime, runtime=self.runtime,
working_dir=self.container_workspace, working_dir=self.container_workspace,
volumes={self.host_workspace: {"bind": self.container_workspace, "mode": self.volume_access_mode}}, volumes={
self.host_workspace: {
"bind": self.container_workspace,
"mode": self.volume_access_mode,
}
},
) )
except (DockerException, APIError) as e: except (DockerException, APIError) as e:
logger.error(f"Failed to create the execution container: {str(e)}") logger.error(f"Failed to create the execution container: {str(e)}")
raise RuntimeError(f"Failed to create the execution container: {str(e)}") from e raise RuntimeError(
f"Failed to create the execution container: {str(e)}"
) from e
async def execute(self, request: Union[ExecutionRequest, dict]) -> List[ExecutionResult]: async def execute(
self, request: Union[ExecutionRequest, dict]
) -> List[ExecutionResult]:
""" """
Executes code inside the persistent Docker container. Executes code inside the persistent Docker container.
The code is written to a shared volume instead of stopping & starting the container. The code is written to a shared volume instead of stopping & starting the container.
@ -127,7 +179,9 @@ class DockerCodeExecutor(CodeExecutorBase):
if snippet.language == "python": if snippet.language == "python":
required_packages = self._extract_imports(snippet.code) required_packages = self._extract_imports(snippet.code)
if required_packages: if required_packages:
logger.info(f"Installing missing dependencies: {required_packages}") logger.info(
f"Installing missing dependencies: {required_packages}"
)
await self._install_missing_packages(required_packages) await self._install_missing_packages(required_packages)
script_filename = f"script.{snippet.language}" script_filename = f"script.{snippet.language}"
@ -138,17 +192,24 @@ class DockerCodeExecutor(CodeExecutorBase):
with open(script_path_host, "w", encoding="utf-8") as script_file: with open(script_path_host, "w", encoding="utf-8") as script_file:
script_file.write(snippet.code) script_file.write(snippet.code)
cmd = f"timeout {self.execution_timeout} python3 {script_path_container}" \ cmd = (
if snippet.language == "python" else f"timeout {self.execution_timeout} sh {script_path_container}" f"timeout {self.execution_timeout} python3 {script_path_container}"
if snippet.language == "python"
else f"timeout {self.execution_timeout} sh {script_path_container}"
)
# Run command dynamically inside the running container # Run command dynamically inside the running container
exec_result = await asyncio.to_thread(self.execution_container.exec_run, cmd) exec_result = await asyncio.to_thread(
self.execution_container.exec_run, cmd
)
exit_code = exec_result.exit_code exit_code = exec_result.exit_code
logs = exec_result.output.decode("utf-8", errors="ignore").strip() logs = exec_result.output.decode("utf-8", errors="ignore").strip()
status = "success" if exit_code == 0 else "error" status = "success" if exit_code == 0 else "error"
results.append(ExecutionResult(status=status, output=logs, exit_code=exit_code)) results.append(
ExecutionResult(status=status, output=logs, exit_code=exit_code)
)
except Exception as e: except Exception as e:
logs = self.get_container_logs() logs = self.get_container_logs()
@ -159,8 +220,10 @@ class DockerCodeExecutor(CodeExecutorBase):
if self.auto_cleanup: if self.auto_cleanup:
if os.path.exists(self.host_workspace): if os.path.exists(self.host_workspace):
shutil.rmtree(self.host_workspace, ignore_errors=True) shutil.rmtree(self.host_workspace, ignore_errors=True)
logger.info(f"Temporary workspace {self.host_workspace} cleaned up.") logger.info(
f"Temporary workspace {self.host_workspace} cleaned up."
)
if self.auto_remove: if self.auto_remove:
self.execution_container.stop() self.execution_container.stop()
logger.info(f"Container {self.execution_container.id} stopped.") logger.info(f"Container {self.execution_container.id} stopped.")
@ -178,7 +241,7 @@ class DockerCodeExecutor(CodeExecutorBase):
List[str]: A list of unique top-level module names imported in the script. List[str]: A list of unique top-level module names imported in the script.
Raises: Raises:
SyntaxError: If the provided code is not valid Python, an error is logged, SyntaxError: If the provided code is not valid Python, an error is logged,
and an empty list is returned. and an empty list is returned.
""" """
try: try:
@ -191,9 +254,9 @@ class DockerCodeExecutor(CodeExecutorBase):
for node in ast.walk(parsed_code): for node in ast.walk(parsed_code):
if isinstance(node, ast.Import): if isinstance(node, ast.Import):
for alias in node.names: for alias in node.names:
modules.add(alias.name.split('.')[0]) modules.add(alias.name.split(".")[0])
elif isinstance(node, ast.ImportFrom) and node.module: elif isinstance(node, ast.ImportFrom) and node.module:
modules.add(node.module.split('.')[0]) modules.add(node.module.split(".")[0])
return list(modules) return list(modules)
@ -231,8 +294,10 @@ class DockerCodeExecutor(CodeExecutorBase):
Exception: If log retrieval fails, an error message is logged. Exception: If log retrieval fails, an error message is logged.
""" """
try: try:
logs = self.execution_container.logs(stdout=True, stderr=True).decode("utf-8") logs = self.execution_container.logs(stdout=True, stderr=True).decode(
"utf-8"
)
return logs return logs
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve container logs: {str(e)}") logger.error(f"Failed to retrieve container logs: {str(e)}")
return "" return ""

View File

@ -1,227 +1,332 @@
from dapr_agents.executors import CodeExecutorBase """Local executor that runs Python or shell snippets in cached virtual-envs."""
from dapr_agents.types.executor import ExecutionRequest, ExecutionResult
from typing import List, Union, Any, Callable
from pydantic import Field
from pathlib import Path
import asyncio import asyncio
import venv import ast
import logging
import hashlib import hashlib
import inspect import inspect
import logging
import time import time
import ast import venv
from pathlib import Path
from typing import Any, Callable, List, Sequence, Union
from pydantic import Field, PrivateAttr
from dapr_agents.executors import CodeExecutorBase
from dapr_agents.executors.sandbox import detect_backend, wrap_command, SandboxType
from dapr_agents.executors.utils.package_manager import (
get_install_command,
get_project_type,
)
from dapr_agents.types.executor import ExecutionRequest, ExecutionResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LocalCodeExecutor(CodeExecutorBase):
"""Executes code locally in an optimized virtual environment with caching,
user-defined functions, and enhanced security.
Supports Python and shell execution with real-time logging, class LocalCodeExecutor(CodeExecutorBase):
efficient dependency management, and reduced file I/O. """
Run snippets locally with **optional OS-level sandboxing** and
per-snippet virtual-env caching.
""" """
cache_dir: Path = Field(default_factory=lambda: Path.cwd() / ".dapr_agents_cached_envs", description="Directory for cached virtual environments and execution artifacts.") cache_dir: Path = Field(
user_functions: List[Callable] = Field(default_factory=list, description="List of user-defined functions available during execution.") default_factory=lambda: Path.cwd() / ".dapr_agents_cached_envs",
cleanup_threshold: int = Field(default=604800, description="Time (in seconds) before cached virtual environments are considered for cleanup.") description="Directory that stores cached virtual environments.",
)
user_functions: List[Callable] = Field(
default_factory=list,
description="Functions whose source is prepended to every Python snippet.",
)
sandbox: SandboxType = Field(
default="auto",
description="'seatbelt' | 'firejail' | 'none' | 'auto' (best available)",
)
writable_paths: List[Path] = Field(
default_factory=list,
description="Extra paths the sandboxed process may write to.",
)
cleanup_threshold: int = Field(
default=604_800, # one week
description="Seconds before a cached venv is considered stale.",
)
_env_lock = asyncio.Lock() _env_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
_bootstrapped_root: Path | None = PrivateAttr(default=None)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None: # noqa: D401
"""Ensures the cache directory is created after model initialization.""" """Create ``cache_dir`` after pydantic instantiation."""
super().model_post_init(__context) super().model_post_init(__context)
self.cache_dir.mkdir(parents=True, exist_ok=True) self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info("Cache directory set.") logger.debug("venv cache directory: %s", self.cache_dir)
logger.debug(f"{self.cache_dir}")
async def execute(self, request: Union[ExecutionRequest, dict]) -> List[ExecutionResult]: async def execute(
"""Executes Python or shell code securely in a persistent virtual environment with caching and real-time logging. self, request: Union[ExecutionRequest, dict]
) -> List[ExecutionResult]:
"""
Run the snippets in *request* and return their results.
Args: Args:
request (Union[ExecutionRequest, dict]): The execution request containing code snippets. request: ``ExecutionRequest`` instance or a raw mapping that can
be unpacked into one.
Returns: Returns:
List[ExecutionResult]: A list of execution results for each snippet. A list with one ``ExecutionResult`` for every snippet in the
original request.
""" """
if isinstance(request, dict): if isinstance(request, dict):
request = ExecutionRequest(**request) request = ExecutionRequest(**request)
await self._bootstrap_project()
self.validate_snippets(request.snippets) self.validate_snippets(request.snippets)
results = []
for snippet in request.snippets:
start_time = time.time()
# Resolve sandbox once
eff_backend: SandboxType = (
detect_backend() if self.sandbox == "auto" else self.sandbox
)
if eff_backend != "none":
logger.info(
"Sandbox backend enabled: %s%s",
eff_backend,
f" (writable: {', '.join(map(str, self.writable_paths))})"
if self.writable_paths
else "",
)
else:
logger.info("Sandbox disabled - running commands directly.")
# Main loop
results: list[ExecutionResult] = []
for snip_idx, snippet in enumerate(request.snippets, start=1):
start = time.perf_counter()
# Assemble the *raw* command
if snippet.language == "python": if snippet.language == "python":
required_packages = self._extract_imports(snippet.code) env = await self._prepare_python_env(snippet.code)
logger.info(f"Packages Required: {required_packages}") python_bin = env / "bin" / "python3"
venv_path = await self._get_or_create_cached_env(required_packages) prelude = "\n".join(inspect.getsource(fn) for fn in self.user_functions)
script = f"{prelude}\n{snippet.code}" if prelude else snippet.code
# Load user-defined functions dynamically in memory raw_cmd: Sequence[str] = [str(python_bin), "-c", script]
function_code = "\n".join(inspect.getsource(f) for f in self.user_functions) if self.user_functions else ""
exec_script = f"{function_code}\n{snippet.code}" if function_code else snippet.code
python_executable = venv_path / "bin" / "python3"
command = [str(python_executable), "-c", exec_script]
else: else:
command = ["sh", "-c", snippet.code] raw_cmd = ["sh", "-c", snippet.code]
logger.info("Executing command") # Wrap for sandbox
logger.debug(f"{' '.join(command)}") final_cmd = wrap_command(raw_cmd, eff_backend, self.writable_paths)
logger.debug(
"Snippet %s - launch command: %s",
snip_idx,
" ".join(final_cmd),
)
try: # Run it
# Start subprocess execution with explicit timeout snip_timeout = getattr(snippet, "timeout", request.timeout)
process = await asyncio.create_subprocess_exec( results.append(await self._run_subprocess(final_cmd, snip_timeout))
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
close_fds=True
)
# Wait for completion with timeout enforcement logger.info(
stdout_output, stderr_output = await asyncio.wait_for(process.communicate(), timeout=request.timeout) "Snippet %s finished in %.3fs",
snip_idx,
status = "success" if process.returncode == 0 else "error" time.perf_counter() - start,
execution_time = time.time() - start_time )
logger.info(f"Execution completed in {execution_time:.2f} seconds.")
if stderr_output:
logger.error(f"STDERR: {stderr_output.decode()}")
results.append(ExecutionResult(
status=status,
output=stdout_output.decode(),
exit_code=process.returncode
))
except asyncio.TimeoutError:
process.terminate() # Ensure subprocess is killed if it times out
results.append(ExecutionResult(status="error", output="Execution timed out", exit_code=1))
except Exception as e:
results.append(ExecutionResult(status="error", output=str(e), exit_code=1))
return results return results
def _extract_imports(self, code: str) -> List[str]: async def _bootstrap_project(self) -> None:
"""Parses a Python script and extracts top-level module imports. """Install top-level dependencies once per executor instance."""
cwd = Path.cwd().resolve()
if self._bootstrapped_root == cwd:
return
install_cmd = get_install_command(str(cwd))
if install_cmd:
logger.info(
"bootstrapping %s project with '%s'",
get_project_type(str(cwd)).value,
install_cmd,
)
proc = await asyncio.create_subprocess_shell(
install_cmd,
cwd=cwd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, err = await proc.communicate()
if proc.returncode:
logger.warning(
"bootstrap failed (%d): %s", proc.returncode, err.decode().strip()
)
self._bootstrapped_root = cwd
async def _prepare_python_env(self, code: str) -> Path:
"""
Ensure a virtual-env exists that satisfies *code* imports.
Args: Args:
code (str): The Python code snippet to analyze. code: User-supplied Python source.
Returns: Returns:
List[str]: A list of imported module names found in the code. Path to the virtual-env directory.
"""
imports = self._extract_imports(code)
env = await self._get_or_create_cached_env(imports)
missing = await self._get_missing_packages(imports, env)
if missing:
await self._install_missing_packages(missing, env)
return env
@staticmethod
def _extract_imports(code: str) -> List[str]:
"""
Return all top-level imported module names in *code*.
Args:
code: Python source to scan.
Returns:
Unique list of first-segment module names.
Raises: Raises:
SyntaxError: If the code has invalid syntax and cannot be parsed. SyntaxError: If *code* cannot be parsed.
""" """
try: try:
parsed_code = ast.parse(code) tree = ast.parse(code)
except SyntaxError as e: except SyntaxError:
logger.error(f"Syntax error while parsing code: {e}") logger.error("cannot parse user code, assuming no imports")
return [] return []
modules = set() names = {
for node in ast.walk(parsed_code): alias.name.partition(".")[0]
if isinstance(node, ast.Import): for node in ast.walk(tree)
for alias in node.names: for alias in getattr(node, "names", [])
modules.add(alias.name.split('.')[0]) # Get the top-level package if isinstance(node, (ast.Import, ast.ImportFrom))
elif isinstance(node, ast.ImportFrom) and node.module: }
modules.add(node.module.split('.')[0]) if any(
isinstance(node, ast.ImportFrom) and node.module for node in ast.walk(tree)
):
names |= {
node.module.partition(".")[0]
for node in ast.walk(tree)
if isinstance(node, ast.ImportFrom) and node.module
}
return sorted(names)
return list(modules) async def _get_missing_packages(
self, packages: List[str], env_path: Path
async def _get_missing_packages(self, packages: List[str], env_path: Path) -> List[str]: ) -> List[str]:
"""Determines which packages are missing inside a given virtual environment.
Args:
packages (List[str]): A list of package names to check.
env_path (Path): Path to the virtual environment.
Returns:
List[str]: A list of packages that are missing from the virtual environment.
""" """
python_bin = env_path / "bin" / "python3" Identify which *packages* are not importable from *env_path*.
async def check_package(pkg):
process = await asyncio.create_subprocess_exec(
str(python_bin), "-c", f"import {pkg}",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL
)
await process.wait()
return pkg if process.returncode != 0 else None # Return package name if missing
tasks = [check_package(pkg) for pkg in packages]
results = await asyncio.gather(*tasks)
return [pkg for pkg in results if pkg] # Filter out installed packages
async def _get_or_create_cached_env(self, dependencies: List[str]) -> Path:
"""Creates or retrieves a cached virtual environment based on dependencies.
This function checks if a suitable cached virtual environment exists.
If it does not, it creates a new one and installs missing dependencies.
Args: Args:
dependencies (List[str]): List of required package names. packages: Candidate import names.
env_path: Path to the virtual-env.
Returns: Returns:
Path: Path to the virtual environment directory. Subset of *packages* that need installation.
"""
python = env_path / "bin" / "python3"
async def probe(pkg: str) -> str | None:
proc = await asyncio.create_subprocess_exec(
str(python),
"- <<PY\nimport importlib.util, sys;"
f"sys.exit(importlib.util.find_spec('{pkg}') is None)\nPY",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
return pkg if proc.returncode else None
missing = await asyncio.gather(*(probe(p) for p in packages))
return [m for m in missing if m]
async def _get_or_create_cached_env(self, deps: List[str]) -> Path:
"""
Return a cached venv path keyed by the sorted list *deps*.
Args:
deps: Import names required by user code.
Returns:
Path to the virtual-env directory.
Raises: Raises:
RuntimeError: If virtual environment creation or package installation fails. RuntimeError: If venv creation fails.
""" """
async with self._env_lock: digest = hashlib.sha1(",".join(sorted(deps)).encode()).hexdigest()
env_hash = hashlib.md5(",".join(sorted(dependencies)).encode()).hexdigest() env_path = self.cache_dir / f"env_{digest}"
env_path = self.cache_dir / f"env_{env_hash}"
async with self._env_lock:
if env_path.exists(): if env_path.exists():
logger.info("Reusing cached virtual environment.") logger.info("Reusing cached virtual environment.")
else: else:
logger.info("Setting up a new virtual environment.")
try: try:
venv.create(str(env_path), with_pip=True) venv.create(env_path, with_pip=True)
except Exception as e: logger.info("Created a new virtual environment")
logger.error(f"Failed to create virtual environment: {e}") logger.debug("venv %s created", env_path)
raise RuntimeError(f"Virtual environment creation failed: {e}") except Exception as exc: # noqa: BLE001
raise RuntimeError("virtual-env creation failed") from exc
return env_path
# Identify missing packages async def _install_missing_packages(
missing_packages = await self._get_missing_packages(dependencies, env_path) self, packages: List[str], env_dir: Path
) -> None:
if missing_packages: """
await self._install_missing_packages(missing_packages, env_path) ``pip install`` *packages* inside *env_dir*.
return env_path
async def _install_missing_packages(self, packages: List[str], env_dir: Path):
"""Installs missing Python packages inside the virtual environment.
Args: Args:
packages (List[str]): A list of package names to install. packages: Package names to install.
env_dir (Path): Path to the virtual environment where packages should be installed. env_dir: Target virtual-env directory.
Raises: Raises:
RuntimeError: If the package installation process fails. RuntimeError: If installation returns non-zero exit code.
""" """
if not packages: python = env_dir / "bin" / "python3"
return cmd = [str(python), "-m", "pip", "install", *packages]
logger.info("Installing %s", ", ".join(packages))
python_bin = env_dir / "bin" / "python3" proc = await asyncio.create_subprocess_exec(
command = [str(python_bin), "-m", "pip", "install", *packages] *cmd,
stdout=asyncio.subprocess.PIPE,
process = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.DEVNULL, # Suppresses stdout since it's not used
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
close_fds=True
) )
_, stderr = await process.communicate() # Capture only stderr _, err = await proc.communicate()
if proc.returncode != 0:
msg = err.decode().strip()
logger.error("pip install failed: %s", msg)
raise RuntimeError(msg)
logger.debug("Installed %d package(s)", len(packages))
if process.returncode != 0: async def _run_subprocess(
error_msg = stderr.decode().strip() self, cmd: Sequence[str], timeout: int
logger.error(f"Package installation failed: {error_msg}") ) -> ExecutionResult:
raise RuntimeError(f"Package installation failed: {error_msg}") """
Run *cmd* with *timeout* seconds.
logger.info(f"Installed dependencies: {', '.join(packages)}") Args:
cmd: Command list to execute.
timeout: Maximum runtime in seconds.
Returns:
``ExecutionResult`` with captured output.
"""
try:
proc = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
out, err = await asyncio.wait_for(proc.communicate(), timeout)
status = "success" if proc.returncode == 0 else "error"
if err:
logger.debug("stderr: %s", err.decode().strip())
return ExecutionResult(
status=status, output=out.decode(), exit_code=proc.returncode
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
return ExecutionResult(
status="error", output="execution timed out", exit_code=1
)
except Exception as exc:
return ExecutionResult(status="error", output=str(exc), exit_code=1)

View File

@ -0,0 +1,200 @@
"""Light-weight cross-platform sandbox helpers."""
import platform
import shutil
from pathlib import Path
from typing import List, Literal, Sequence
SandboxType = Literal["none", "seatbelt", "firejail", "auto"]
_READ_ONLY_SEATBELT_POLICY = r"""
(version 1)
; ---------------- default = deny everything -----------------
(deny default)
; ---------------- read-only FS access -----------------------
(allow file-read*)
; ---------------- minimal process mgmt ----------------------
(allow process-exec)
(allow process-fork)
(allow signal (target self))
; ---------------- write-only to /dev/null -------------------
(allow file-write-data
(require-all
(path "/dev/null")
(vnode-type CHARACTER-DEVICE)))
; ---------------- harmless sysctls --------------------------
(allow sysctl-read
(sysctl-name "hw.activecpu")
(sysctl-name "hw.busfrequency_compat")
(sysctl-name "hw.byteorder")
(sysctl-name "hw.cacheconfig")
(sysctl-name "hw.cachelinesize_compat")
(sysctl-name "hw.cpufamily")
(sysctl-name "hw.cpufrequency_compat")
(sysctl-name "hw.cputype")
(sysctl-name "hw.l1dcachesize_compat")
(sysctl-name "hw.l1icachesize_compat")
(sysctl-name "hw.l2cachesize_compat")
(sysctl-name "hw.l3cachesize_compat")
(sysctl-name "hw.logicalcpu_max")
(sysctl-name "hw.machine")
(sysctl-name "hw.ncpu")
(sysctl-name "hw.nperflevels")
(sysctl-name "hw.memsize")
(sysctl-name "hw.pagesize")
(sysctl-name "hw.packages")
(sysctl-name "hw.physicalcpu_max")
(sysctl-name "kern.hostname")
(sysctl-name "kern.osrelease")
(sysctl-name "kern.ostype")
(sysctl-name "kern.osversion")
(sysctl-name "kern.version")
(sysctl-name-prefix "hw.perflevel")
)
"""
def detect_backend() -> SandboxType: # noqa: D401
"""Return the best-effort sandbox backend for the current host."""
system = platform.system()
if system == "Darwin" and shutil.which("sandbox-exec"):
return "seatbelt"
if system == "Linux" and shutil.which("firejail"):
return "firejail"
return "none"
def _seatbelt_cmd(cmd: Sequence[str], writable_paths: List[Path]) -> List[str]:
"""
Construct a **macOS seatbelt** command line.
The resulting list can be passed directly to `asyncio.create_subprocess_exec`.
It launches the target *cmd* under **sandbox-exec** with an
*initially-read-only* profile; every directory in *writable_paths* is added
as an explicit write-allowed sub-path.
Args:
cmd:
The *raw* command (program + args) that should run inside the sandbox.
writable_paths:
Absolute paths that the child process must be able to modify
(e.g. a temporary working directory).
Each entry becomes a param `-D WR<i>=<path>` and a corresponding
``file-write*`` rule in the generated profile.
Returns:
list[str]
A fully-assembled ``sandbox-exec`` invocation:
``['sandbox-exec', '-p', <profile>, , '--', *cmd]``.
"""
policy = _READ_ONLY_SEATBELT_POLICY
params: list[str] = []
if writable_paths:
# Build parameter substitutions and the matching `(allow file-write*)` stanza.
write_terms: list[str] = []
for idx, path in enumerate(writable_paths):
param = f"WR{idx}"
params.extend(["-D", f"{param}={path}"])
write_terms.append(f'(subpath (param "{param}"))')
policy += f"\n(allow file-write*\n {' '.join(write_terms)}\n)"
return [
"sandbox-exec",
"-p",
policy,
*params,
"--",
*cmd,
]
def _firejail_cmd(cmd: Sequence[str], writable_paths: List[Path]) -> List[str]:
"""
Build a **Firejail** command line (Linux only).
The wrapper enables seccomp, disables sound and networking, and whitelists
the provided *writable_paths* so the child process can persist data there.
Args:
cmd:
The command (program + args) to execute.
writable_paths:
Directories that must remain writable inside the Firejail sandbox.
Returns:
list[str]
A Firejail-prefixed command suitable for
``asyncio.create_subprocess_exec``.
Raises:
ValueError
If *writable_paths* contains non-absolute paths.
"""
for p in writable_paths:
if not p.is_absolute():
raise ValueError(f"Firejail whitelist paths must be absolute: {p}")
rw_flags = sum([["--whitelist", str(p)] for p in writable_paths], [])
return [
"firejail",
"--quiet", # suppress banner
"--seccomp", # enable seccomp filter
"--nosound",
"--net=none",
*rw_flags,
"--",
*cmd,
]
def wrap_command(
cmd: Sequence[str],
backend: SandboxType,
writable_paths: List[Path] | None = None,
) -> List[str]:
"""
Produce a sandbox-wrapped command according to *backend*.
This is the single public helper used by the executors: it hides the
platform-specific details of **seatbelt** and **Firejail** while providing
a graceful fallback to no sandbox.
Args:
cmd:
The raw command (program + args) to execute.
backend:
One of ``'seatbelt'``, ``'firejail'``, ``'none'`` or ``'auto'``.
When ``'auto'`` is supplied the caller should already have resolved the
platform with :func:`detect_backend`; the value is treated as ``'none'``.
writable_paths:
Extra directories that must remain writable inside the sandbox.
Ignored when *backend* is ``'none'`` / ``'auto'``.
Returns:
list[str]
The command list ready for ``asyncio.create_subprocess_exec``.
If sandboxing is disabled, this is simply ``list(cmd)``.
Raises:
ValueError
If an unrecognised *backend* value is given.
"""
if backend in ("none", "auto"):
return list(cmd)
writable_paths = writable_paths or []
if backend == "seatbelt":
return _seatbelt_cmd(cmd, writable_paths)
if backend == "firejail":
return _firejail_cmd(cmd, writable_paths)
raise ValueError(f"Unknown sandbox backend: {backend!r}")

View File

@ -0,0 +1,309 @@
from __future__ import annotations
import shutil
from pathlib import Path
from typing import List, Optional, Set
from functools import lru_cache
from enum import Enum
class PackageManagerType(str, Enum):
"""Types of package managers that can be detected."""
PIP = "pip"
POETRY = "poetry"
PIPENV = "pipenv"
CONDA = "conda"
NPM = "npm"
YARN = "yarn"
PNPM = "pnpm"
BUN = "bun"
CARGO = "cargo"
GO = "go"
MAVEN = "maven"
GRADLE = "gradle"
COMPOSER = "composer"
UNKNOWN = "unknown"
class ProjectType(str, Enum):
"""Types of projects that can be detected."""
PYTHON = "python"
NODE = "node"
RUST = "rust"
GO = "go"
JAVA = "java"
PHP = "php"
UNKNOWN = "unknown"
class PackageManager:
"""Information about a package manager and its commands."""
def __init__(
self,
name: PackageManagerType,
project_type: ProjectType,
install_cmd: str,
add_cmd: Optional[str] = None,
remove_cmd: Optional[str] = None,
update_cmd: Optional[str] = None,
markers: Optional[List[str]] = None,
) -> None:
"""
Initialize a package manager.
Args:
name: Package manager identifier.
project_type: Type of project this manager serves.
install_cmd: Command to install project dependencies.
add_cmd: Command to add a single package.
remove_cmd: Command to remove a package.
update_cmd: Command to update packages.
markers: Filenames indicating this manager in a project.
"""
self.name = name
self.project_type = project_type
self.install_cmd = install_cmd
self.add_cmd = add_cmd or f"{name.value} install"
self.remove_cmd = remove_cmd or f"{name.value} remove"
self.update_cmd = update_cmd or f"{name.value} update"
self.markers = markers or []
def __str__(self) -> str:
return self.name.value
# Known package managers
PACKAGE_MANAGERS: dict[PackageManagerType, PackageManager] = {
# Python package managers
PackageManagerType.PIP: PackageManager(
name=PackageManagerType.PIP,
project_type=ProjectType.PYTHON,
install_cmd="pip install -r requirements.txt",
add_cmd="pip install",
remove_cmd="pip uninstall",
update_cmd="pip install --upgrade",
markers=["requirements.txt", "setup.py", "setup.cfg"],
),
PackageManagerType.POETRY: PackageManager(
name=PackageManagerType.POETRY,
project_type=ProjectType.PYTHON,
install_cmd="poetry install",
add_cmd="poetry add",
remove_cmd="poetry remove",
update_cmd="poetry update",
markers=["pyproject.toml", "poetry.lock"],
),
PackageManagerType.PIPENV: PackageManager(
name=PackageManagerType.PIPENV,
project_type=ProjectType.PYTHON,
install_cmd="pipenv install",
add_cmd="pipenv install",
remove_cmd="pipenv uninstall",
update_cmd="pipenv update",
markers=["Pipfile", "Pipfile.lock"],
),
PackageManagerType.CONDA: PackageManager(
name=PackageManagerType.CONDA,
project_type=ProjectType.PYTHON,
install_cmd="conda env update -f environment.yml",
add_cmd="conda install",
remove_cmd="conda remove",
update_cmd="conda update",
markers=["environment.yml", "environment.yaml"],
),
# JavaScript package managers
PackageManagerType.NPM: PackageManager(
name=PackageManagerType.NPM,
project_type=ProjectType.NODE,
install_cmd="npm install",
add_cmd="npm install",
remove_cmd="npm uninstall",
update_cmd="npm update",
markers=["package.json", "package-lock.json"],
),
PackageManagerType.YARN: PackageManager(
name=PackageManagerType.YARN,
project_type=ProjectType.NODE,
install_cmd="yarn install",
add_cmd="yarn add",
remove_cmd="yarn remove",
update_cmd="yarn upgrade",
markers=["package.json", "yarn.lock"],
),
PackageManagerType.PNPM: PackageManager(
name=PackageManagerType.PNPM,
project_type=ProjectType.NODE,
install_cmd="pnpm install",
add_cmd="pnpm add",
remove_cmd="pnpm remove",
update_cmd="pnpm update",
markers=["package.json", "pnpm-lock.yaml"],
),
PackageManagerType.BUN: PackageManager(
name=PackageManagerType.BUN,
project_type=ProjectType.NODE,
install_cmd="bun install",
add_cmd="bun add",
remove_cmd="bun remove",
update_cmd="bun update",
markers=["package.json", "bun.lockb"],
),
}
@lru_cache(maxsize=None)
def is_installed(name: str) -> bool:
"""
Check if a given command exists on PATH.
Args:
name: Command name to check.
Returns:
True if the command is available, False otherwise.
"""
return shutil.which(name) is not None
@lru_cache(maxsize=None)
def detect_package_managers(directory: str) -> List[PackageManager]:
"""
Detect all installed package managers by looking for marker files.
Args:
directory: Path to the project root.
Returns:
A list of PackageManager instances found in the directory.
"""
dir_path = Path(directory)
if not dir_path.is_dir():
return []
found: List[PackageManager] = []
for pm in PACKAGE_MANAGERS.values():
for marker in pm.markers:
if (dir_path / marker).exists() and is_installed(pm.name.value):
found.append(pm)
break
return found
@lru_cache(maxsize=None)
def get_primary_package_manager(directory: str) -> Optional[PackageManager]:
"""
Determine the primary package manager using lockfile heuristics.
Args:
directory: Path to the project root.
Returns:
The chosen PackageManager or None if none detected.
"""
managers = detect_package_managers(directory)
if not managers:
return None
if len(managers) == 1:
return managers[0]
dir_path = Path(directory)
# Prefer lockfiles over others
lock_priority = [
(PackageManagerType.POETRY, "poetry.lock"),
(PackageManagerType.PIPENV, "Pipfile.lock"),
(PackageManagerType.PNPM, "pnpm-lock.yaml"),
(PackageManagerType.YARN, "yarn.lock"),
(PackageManagerType.BUN, "bun.lockb"),
(PackageManagerType.NPM, "package-lock.json"),
]
for pm_type, lock in lock_priority:
if (dir_path / lock).exists() and PACKAGE_MANAGERS.get(pm_type) in managers:
return PACKAGE_MANAGERS[pm_type]
return managers[0]
def get_install_command(directory: str) -> Optional[str]:
"""
Get the shell command to install project dependencies.
Args:
directory: Path to the project root.
Returns:
A shell command string or None if no manager detected.
"""
pm = get_primary_package_manager(directory)
return pm.install_cmd if pm else None
def get_add_command(directory: str, package: str, dev: bool = False) -> Optional[str]:
"""
Get the shell command to add a package to the project.
Args:
directory: Path to the project root.
package: Package name to add.
dev: Whether to add as a development dependency.
Returns:
A shell command string or None if no manager detected.
"""
pm = get_primary_package_manager(directory)
if not pm:
return None
base = pm.add_cmd
if dev and pm.name in {
PackageManagerType.PIP,
PackageManagerType.POETRY,
PackageManagerType.NPM,
PackageManagerType.YARN,
PackageManagerType.PNPM,
PackageManagerType.BUN,
PackageManagerType.COMPOSER,
}:
flag = (
"--dev"
if pm.name in {PackageManagerType.PIP, PackageManagerType.POETRY}
else "--save-dev"
)
return f"{base} {package} {flag}"
return f"{base} {package}"
def get_project_type(directory: str) -> ProjectType:
"""
Infer project type from the primary package manager or file extensions.
Args:
directory: Path to the project root.
Returns:
The detected ProjectType.
"""
pm = get_primary_package_manager(directory)
if pm:
return pm.project_type
# Fallback by extension scanning
exts: Set[str] = set()
for path in Path(directory).rglob("*"):
if path.is_file():
exts.add(path.suffix.lower())
if len(exts) > 50:
break
if ".py" in exts:
return ProjectType.PYTHON
if {".js", ".ts"} & exts:
return ProjectType.NODE
if ".rs" in exts:
return ProjectType.RUST
if ".go" in exts:
return ProjectType.GO
if ".java" in exts:
return ProjectType.JAVA
if ".php" in exts:
return ProjectType.PHP
return ProjectType.UNKNOWN

View File

@ -10,4 +10,4 @@ from .nvidia.client import NVIDIAClientBase
from .nvidia.chat import NVIDIAChatClient from .nvidia.chat import NVIDIAChatClient
from .nvidia.embeddings import NVIDIAEmbeddingClient from .nvidia.embeddings import NVIDIAEmbeddingClient
from .elevenlabs import ElevenLabsSpeechClient from .elevenlabs import ElevenLabsSpeechClient
from .dapr import DaprChatClient from .dapr import DaprChatClient

View File

@ -2,10 +2,12 @@ from pydantic import BaseModel, PrivateAttr
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
class LLMClientBase(BaseModel, ABC): class LLMClientBase(BaseModel, ABC):
""" """
Abstract base class for LLM models. Abstract base class for LLM models.
""" """
# Private attributes for provider and api # Private attributes for provider and api
_provider: str = PrivateAttr() _provider: str = PrivateAttr()
_api: str = PrivateAttr() _api: str = PrivateAttr()
@ -13,7 +15,7 @@ class LLMClientBase(BaseModel, ABC):
# Private attributes for config and client # Private attributes for config and client
_config: Any = PrivateAttr() _config: Any = PrivateAttr()
_client: Any = PrivateAttr() _client: Any = PrivateAttr()
@property @property
def provider(self) -> str: def provider(self) -> str:
return self._provider return self._provider
@ -21,7 +23,7 @@ class LLMClientBase(BaseModel, ABC):
@property @property
def api(self) -> str: def api(self) -> str:
return self._api return self._api
@property @property
def config(self) -> Any: def config(self) -> Any:
return self._config return self._config
@ -46,4 +48,4 @@ class LLMClientBase(BaseModel, ABC):
""" """
# Refresh config and client using the current state # Refresh config and client using the current state
self._config = self.get_config() self._config = self.get_config()
self._client = self.get_client() self._client = self.get_client()

View File

@ -5,17 +5,27 @@ from pydantic import BaseModel, Field
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
class ChatClientBase(BaseModel, ABC): class ChatClientBase(BaseModel, ABC):
""" """
Base class for chat-specific functionality. Base class for chat-specific functionality.
Handles Prompty integration and provides abstract methods for chat client configuration. Handles Prompty integration and provides abstract methods for chat client configuration.
""" """
prompty: Optional[Prompty] = Field(default=None, description="Instance of the Prompty object (optional).")
prompt_template: Optional[PromptTemplateBase] = Field(default=None, description="Prompt template for rendering (optional).") prompty: Optional[Prompty] = Field(
default=None, description="Instance of the Prompty object (optional)."
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="Prompt template for rendering (optional)."
)
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'ChatClientBase': def from_prompty(
cls,
prompty_source: Union[str, Path],
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "ChatClientBase":
""" """
Abstract method to load a Prompty source and configure the chat client. Abstract method to load a Prompty source and configure the chat client.
@ -31,13 +41,15 @@ class ChatClientBase(BaseModel, ABC):
@abstractmethod @abstractmethod
def generate( def generate(
self, self,
messages: Union[str, Dict[str, Any], BaseModel, Iterable[Union[Dict[str, Any], BaseModel]]] = None, messages: Union[
str, Dict[str, Any], BaseModel, Iterable[Union[Dict[str, Any], BaseModel]]
] = None,
input_data: Optional[Dict[str, Any]] = None, input_data: Optional[Dict[str, Any]] = None,
model: Optional[str] = None, model: Optional[str] = None,
tools: Optional[List[Union[Dict[str, Any]]]] = None, tools: Optional[List[Union[Dict[str, Any]]]] = None,
response_format: Optional[Type[BaseModel]] = None, response_format: Optional[Type[BaseModel]] = None,
structured_mode: Optional[str] = None, structured_mode: Optional[str] = None,
**kwargs **kwargs,
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]: ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
""" """
Abstract method to generate chat completions. Abstract method to generate chat completions.
@ -54,4 +66,4 @@ class ChatClientBase(BaseModel, ABC):
Returns: Returns:
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s). Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
""" """
pass pass

View File

@ -1,2 +1,2 @@
from .chat import DaprChatClient from .chat import DaprChatClient
from .client import DaprInferenceClientBase from .client import DaprInferenceClientBase

View File

@ -5,7 +5,18 @@ from dapr_agents.types.message import BaseMessage
from dapr_agents.llm.chat import ChatClientBase from dapr_agents.llm.chat import ChatClientBase
from dapr_agents.tool import AgentTool from dapr_agents.tool import AgentTool
from dapr.clients.grpc._request import ConversationInput from dapr.clients.grpc._request import ConversationInput
from typing import Union, Optional, Iterable, Dict, Any, List, Iterator, Type, Literal, ClassVar from typing import (
Union,
Optional,
Iterable,
Dict,
Any,
List,
Iterator,
Type,
Literal,
ClassVar,
)
from pydantic import BaseModel from pydantic import BaseModel
from pathlib import Path from pathlib import Path
import logging import logging
@ -14,6 +25,7 @@ import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DaprChatClient(DaprInferenceClientBase, ChatClientBase): class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
""" """
Concrete class for Dapr's chat completion API using the Inference API. Concrete class for Dapr's chat completion API using the Inference API.
@ -28,17 +40,21 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
""" """
# Set the private provider and api attributes # Set the private provider and api attributes
self._api = "chat" self._api = "chat"
self._llm_component = os.environ['DAPR_LLM_COMPONENT_DEFAULT'] self._llm_component = os.environ["DAPR_LLM_COMPONENT_DEFAULT"]
return super().model_post_init(__context) return super().model_post_init(__context)
@classmethod @classmethod
def from_prompty(cls, prompty_source: Union[str, Path], timeout: Union[int, float, Dict[str, Any]] = 1500) -> 'DaprChatClient': def from_prompty(
cls,
prompty_source: Union[str, Path],
timeout: Union[int, float, Dict[str, Any]] = 1500,
) -> "DaprChatClient":
""" """
Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content. Initializes an DaprChatClient client using a Prompty source, which can be a file path or inline content.
Args: Args:
prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file prompty_source (Union[str, Path]): The source of the Prompty file, which can be a path to a file
or inline Prompty content as a string. or inline Prompty content as a string.
timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds. timeout (Union[int, float, Dict[str, Any]], optional): Timeout for requests, defaults to 1500 seconds.
@ -52,11 +68,13 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
prompt_template = Prompty.to_prompt_template(prompty_instance) prompt_template = Prompty.to_prompt_template(prompty_instance)
# Initialize the DaprChatClient based on the Prompty model configuration # Initialize the DaprChatClient based on the Prompty model configuration
return cls.model_validate({ return cls.model_validate(
'timeout': timeout, {
'prompty': prompty_instance, "timeout": timeout,
'prompt_template': prompt_template, "prompty": prompty_instance,
}) "prompt_template": prompt_template,
}
)
def translate_response(self, response: dict, model: str) -> dict: def translate_response(self, response: dict, model: str) -> dict:
"""Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion.""" """Converts a Dapr response dict into a structure compatible with Choice and ChatCompletion."""
@ -64,36 +82,40 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
{ {
"finish_reason": "stop", "finish_reason": "stop",
"index": i, "index": i,
"message": { "message": {"content": output["result"], "role": "assistant"},
"content": output["result"], "logprobs": None,
"role": "assistant"
},
"logprobs": None
} }
for i, output in enumerate(response.get("outputs", [])) for i, output in enumerate(response.get("outputs", []))
] ]
return { return {
"choices": choices, "choices": choices,
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"object": "chat.completion", "object": "chat.completion",
"usage": {"total_tokens": "-1"} "usage": {"total_tokens": "-1"},
} }
def convert_to_conversation_inputs(self, inputs: List[Dict[str, Any]]) -> List[ConversationInput]: def convert_to_conversation_inputs(
self, inputs: List[Dict[str, Any]]
) -> List[ConversationInput]:
return [ return [
ConversationInput( ConversationInput(
content=item["content"], content=item["content"],
role=item.get("role"), role=item.get("role"),
scrub_pii=item.get("scrubPII") == "true" scrub_pii=item.get("scrubPII") == "true",
) )
for item in inputs for item in inputs
] ]
def generate( def generate(
self, self,
messages: Union[str, Dict[str, Any], BaseMessage, Iterable[Union[Dict[str, Any], BaseMessage]]] = None, messages: Union[
str,
Dict[str, Any],
BaseMessage,
Iterable[Union[Dict[str, Any], BaseMessage]],
] = None,
input_data: Optional[Dict[str, Any]] = None, input_data: Optional[Dict[str, Any]] = None,
llm_component: Optional[str] = None, llm_component: Optional[str] = None,
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None, tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
@ -101,7 +123,7 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
structured_mode: Literal["function_call"] = "function_call", structured_mode: Literal["function_call"] = "function_call",
scrubPII: Optional[bool] = False, scrubPII: Optional[bool] = False,
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs **kwargs,
) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]: ) -> Union[Iterator[Dict[str, Any]], Dict[str, Any]]:
""" """
Generate chat completions based on provided messages or input_data for prompt templates. Generate chat completions based on provided messages or input_data for prompt templates.
@ -120,13 +142,17 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s). Union[Iterator[Dict[str, Any]], Dict[str, Any]]: The chat completion response(s).
""" """
if structured_mode not in self.SUPPORTED_STRUCTURED_MODES: if structured_mode not in self.SUPPORTED_STRUCTURED_MODES:
raise ValueError(f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}.") raise ValueError(
f"Invalid structured_mode '{structured_mode}'. Must be one of {self.SUPPORTED_STRUCTURED_MODES}."
)
# If input_data is provided, check for a prompt_template # If input_data is provided, check for a prompt_template
if input_data: if input_data:
if not self.prompt_template: if not self.prompt_template:
raise ValueError("Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data.") raise ValueError(
"Inputs are provided but no 'prompt_template' is set. Please set a 'prompt_template' to use the input_data."
)
logger.info("Using prompt template to generate messages.") logger.info("Using prompt template to generate messages.")
messages = self.prompt_template.format_prompt(**input_data) messages = self.prompt_template.format_prompt(**input_data)
@ -135,36 +161,43 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
raise ValueError("Either 'messages' or 'input_data' must be provided.") raise ValueError("Either 'messages' or 'input_data' must be provided.")
# Process and normalize the messages # Process and normalize the messages
params = {'inputs': RequestHandler.normalize_chat_messages(messages)} params = {"inputs": RequestHandler.normalize_chat_messages(messages)}
# Merge Prompty parameters if available, then override with any explicit kwargs # Merge Prompty parameters if available, then override with any explicit kwargs
if self.prompty: if self.prompty:
params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs} params = {**self.prompty.model.parameters.model_dump(), **params, **kwargs}
else: else:
params.update(kwargs) params.update(kwargs)
# Prepare request parameters # Prepare request parameters
params = RequestHandler.process_params( params = RequestHandler.process_params(
params, params,
llm_provider=self.provider, llm_provider=self.provider,
tools=tools, tools=tools,
response_format=response_format, response_format=response_format,
structured_mode=structured_mode structured_mode=structured_mode,
) )
inputs = self.convert_to_conversation_inputs(params['inputs']) inputs = self.convert_to_conversation_inputs(params["inputs"])
try: try:
logger.info("Invoking the Dapr Conversation API.") logger.info("Invoking the Dapr Conversation API.")
response = self.client.chat_completion(llm=llm_component or self._llm_component, conversation_inputs=inputs, scrub_pii=scrubPII, temperature=temperature) response = self.client.chat_completion(
llm=llm_component or self._llm_component,
conversation_inputs=inputs,
scrub_pii=scrubPII,
temperature=temperature,
)
transposed_response = self.translate_response(response, self._llm_component) transposed_response = self.translate_response(response, self._llm_component)
logger.info("Chat completion retrieved successfully.") logger.info("Chat completion retrieved successfully.")
return ResponseHandler.process_response( return ResponseHandler.process_response(
transposed_response, transposed_response,
llm_provider=self.provider, llm_provider=self.provider,
response_format=response_format, response_format=response_format,
structured_mode=structured_mode, structured_mode=structured_mode,
stream=params.get('stream', False) stream=params.get("stream", False),
) )
except Exception as e: except Exception as e:
logger.error(f"An error occurred during the Dapr Conversation API call: {e}") logger.error(
raise f"An error occurred during the Dapr Conversation API call: {e}"
)
raise

View File

@ -10,6 +10,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DaprInferenceClient: class DaprInferenceClient:
def __init__(self): def __init__(self):
self.dapr_client = DaprClient() self.dapr_client = DaprClient()
@ -25,9 +26,20 @@ class DaprInferenceClient:
} }
return response_dict return response_dict
def chat_completion(self, llm: str, conversation_inputs: List[ConversationInput], scrub_pii: bool | None = None, temperature: float | None = None) -> Any: def chat_completion(
response = self.dapr_client.converse_alpha1(name=llm, inputs=conversation_inputs, scrub_pii=scrub_pii, temperature=temperature) self,
llm: str,
conversation_inputs: List[ConversationInput],
scrub_pii: bool | None = None,
temperature: float | None = None,
) -> Any:
response = self.dapr_client.converse_alpha1(
name=llm,
inputs=conversation_inputs,
scrub_pii=scrub_pii,
temperature=temperature,
)
output = self.translate_to_json(response) output = self.translate_to_json(response)
return output return output
@ -38,6 +50,7 @@ class DaprInferenceClientBase(LLMClientBase):
Base class for managing Dapr Inference API clients. Base class for managing Dapr Inference API clients.
Handles client initialization, configuration, and shared logic. Handles client initialization, configuration, and shared logic.
""" """
@model_validator(mode="before") @model_validator(mode="before")
def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_and_initialize(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values return values
@ -52,7 +65,7 @@ class DaprInferenceClientBase(LLMClientBase):
self._config = self.get_config() self._config = self.get_config()
self._client = self.get_client() self._client = self.get_client()
return super().model_post_init(__context) return super().model_post_init(__context)
def get_config(self) -> DaprInferenceClientConfig: def get_config(self) -> DaprInferenceClientConfig:
""" """
Returns the appropriate configuration for the Dapr Conversation API. Returns the appropriate configuration for the Dapr Conversation API.
@ -64,9 +77,11 @@ class DaprInferenceClientBase(LLMClientBase):
Initializes and returns the Dapr Inference client. Initializes and returns the Dapr Inference client.
""" """
return DaprInferenceClient() return DaprInferenceClient()
@classmethod @classmethod
def from_config(cls, client_options: DaprInferenceClientConfig, timeout: float = 1500): def from_config(
cls, client_options: DaprInferenceClientConfig, timeout: float = 1500
):
""" """
Initializes the DaprInferenceClientBase using DaprInferenceClientConfig. Initializes the DaprInferenceClientBase using DaprInferenceClientConfig.

View File

@ -1 +1 @@
from .speech import ElevenLabsSpeechClient from .speech import ElevenLabsSpeechClient

View File

@ -7,14 +7,21 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ElevenLabsClientBase(LLMClientBase): class ElevenLabsClientBase(LLMClientBase):
""" """
Base class for managing ElevenLabs LLM clients. Base class for managing ElevenLabs LLM clients.
Handles client initialization, configuration, and shared logic specific to the ElevenLabs API. Handles client initialization, configuration, and shared logic specific to the ElevenLabs API.
""" """
api_key: Optional[str] = Field(default=None, description="API key for authenticating with the ElevenLabs API. Defaults to environment variables 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY'.") api_key: Optional[str] = Field(
base_url: Optional[str] = Field(default="https://api.elevenlabs.io", description="Base URL for the ElevenLabs API endpoints.") default=None,
description="API key for authenticating with the ElevenLabs API. Defaults to environment variables 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY'.",
)
base_url: Optional[str] = Field(
default="https://api.elevenlabs.io",
description="Base URL for the ElevenLabs API endpoints.",
)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
""" """
@ -26,26 +33,27 @@ class ElevenLabsClientBase(LLMClientBase):
# Use environment variable if `api_key` is not explicitly provided # Use environment variable if `api_key` is not explicitly provided
if self.api_key is None: if self.api_key is None:
self.api_key = os.getenv("ELEVENLABS_API_KEY") or os.getenv("ELEVEN_API_KEY") self.api_key = os.getenv("ELEVENLABS_API_KEY") or os.getenv(
"ELEVEN_API_KEY"
)
if self.api_key is None: if self.api_key is None:
raise ValueError("API key is required. Set it explicitly or in the 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY' environment variable.") raise ValueError(
"API key is required. Set it explicitly or in the 'ELEVENLABS_API_KEY' or 'ELEVEN_API_KEY' environment variable."
)
# Initialize configuration and client # Initialize configuration and client
self._config = self.get_config() self._config = self.get_config()
self._client = self.get_client() self._client = self.get_client()
logger.info("ElevenLabs client initialized successfully.") logger.info("ElevenLabs client initialized successfully.")
return super().model_post_init(__context) return super().model_post_init(__context)
def get_config(self) -> ElevenLabsClientConfig: def get_config(self) -> ElevenLabsClientConfig:
""" """
Returns the configuration object for the ElevenLabs API client. Returns the configuration object for the ElevenLabs API client.
""" """
return ElevenLabsClientConfig( return ElevenLabsClientConfig(api_key=self.api_key, base_url=self.base_url)
api_key=self.api_key,
base_url=self.base_url
)
def get_client(self) -> Any: def get_client(self) -> Any:
""" """
@ -59,14 +67,11 @@ class ElevenLabsClientBase(LLMClientBase):
raise ImportError( raise ImportError(
"The 'elevenlabs' package is required but not installed. Install it with 'pip install elevenlabs'." "The 'elevenlabs' package is required but not installed. Install it with 'pip install elevenlabs'."
) from e ) from e
config = self.config config = self.config
logger.info("Initializing ElevenLabs API client...") logger.info("Initializing ElevenLabs API client...")
return ElevenLabs( return ElevenLabs(api_key=config.api_key, base_url=config.base_url)
api_key=config.api_key,
base_url=config.base_url
)
@property @property
def config(self) -> ElevenLabsClientConfig: def config(self) -> ElevenLabsClientConfig:
@ -80,4 +85,4 @@ class ElevenLabsClientBase(LLMClientBase):
""" """
Provides access to the ElevenLabs API client instance. Provides access to the ElevenLabs API client instance.
""" """
return self._client return self._client

View File

@ -5,17 +5,32 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ElevenLabsSpeechClient(ElevenLabsClientBase): class ElevenLabsSpeechClient(ElevenLabsClientBase):
""" """
Client for ElevenLabs speech generation functionality. Client for ElevenLabs speech generation functionality.
Handles text-to-speech conversions with customizable options. Handles text-to-speech conversions with customizable options.
""" """
voice: Optional[Any] = Field(default=None, description="Default voice (ID, name, or object) for speech generation.") voice: Optional[Any] = Field(
model: Optional[str] = Field(default="eleven_multilingual_v2", description="Default model for speech generation.") default=None,
output_format: Optional[str] = Field(default="mp3_44100_128", description="Default audio output format.") description="Default voice (ID, name, or object) for speech generation.",
optimize_streaming_latency: Optional[int] = Field(default=0, description="Default latency optimization level (0 means no optimizations).") )
voice_settings: Optional[Any] = Field(default=None, description="Default voice settings (stability, similarity boost, etc.).") model: Optional[str] = Field(
default="eleven_multilingual_v2",
description="Default model for speech generation.",
)
output_format: Optional[str] = Field(
default="mp3_44100_128", description="Default audio output format."
)
optimize_streaming_latency: Optional[int] = Field(
default=0,
description="Default latency optimization level (0 means no optimizations).",
)
voice_settings: Optional[Any] = Field(
default=None,
description="Default voice settings (stability, similarity boost, etc.).",
)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
""" """
@ -71,7 +86,9 @@ class ElevenLabsSpeechClient(ElevenLabsClientBase):
voice = voice or self.voice voice = voice or self.voice
model = model or self.model model = model or self.model
output_format = output_format or self.output_format output_format = output_format or self.output_format
optimize_streaming_latency = optimize_streaming_latency or self.optimize_streaming_latency optimize_streaming_latency = (
optimize_streaming_latency or self.optimize_streaming_latency
)
voice_settings = voice_settings or self.voice_settings voice_settings = voice_settings or self.voice_settings
logger.info(f"Generating speech with voice '{voice}', model '{model}'.") logger.info(f"Generating speech with voice '{voice}', model '{model}'.")
@ -100,4 +117,4 @@ class ElevenLabsSpeechClient(ElevenLabsClientBase):
except Exception as e: except Exception as e:
logger.error(f"Failed to generate speech: {e}") logger.error(f"Failed to generate speech: {e}")
raise ValueError(f"An error occurred during speech generation: {e}") raise ValueError(f"An error occurred during speech generation: {e}")

View File

@ -1,2 +1,2 @@
from .chat import HFHubChatClient from .chat import HFHubChatClient
from .client import HFHubInferenceClientBase from .client import HFHubInferenceClientBase

Some files were not shown because too many files have changed in this diff Show More