mirror of https://github.com/dapr/java-sdk.git
Add Conversation AI to Java SDK (#1235)
* Conversation first commit Signed-off-by: Siri Varma Vegiraju <svegiraju@microsoft.com> Signed-off-by: sirivarma <siri.varma@outlook.com> Signed-off-by: siri-varma <siri.varma@outlook.com> * Add unit tests Signed-off-by: sirivarma <siri.varma@outlook.com> Signed-off-by: siri-varma <siri.varma@outlook.com> * change ai to conv Signed-off-by: sirivarma <siri.varma@outlook.com> Signed-off-by: siri-varma <siri.varma@outlook.com> * Move to single module Signed-off-by: sirivarma <siri.varma@outlook.com> Signed-off-by: siri-varma <siri.varma@outlook.com> * Remove module Signed-off-by: sirivarma <siri.varma@outlook.com> Signed-off-by: siri-varma <siri.varma@outlook.com> * Add Integration tests Signed-off-by: siri-varma <siri.varma@outlook.com> * Update sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com> Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com> Signed-off-by: siri-varma <siri.varma@outlook.com> * Fix things Signed-off-by: siri-varma <siri.varma@outlook.com> * Address comments Signed-off-by: siri-varma <siri.varma@outlook.com> * Import tag Signed-off-by: siri-varma <siri.varma@outlook.com> * Address comments Signed-off-by: siri-varma <siri.varma@outlook.com> * Make common config Signed-off-by: siri-varma <siri.varma@outlook.com> * Address comments Signed-off-by: siri-varma <siri.varma@outlook.com> * fix constant Signed-off-by: siri-varma <siri.varma@outlook.com> * fix constant Signed-off-by: siri-varma <siri.varma@outlook.com> * fix constant Signed-off-by: siri-varma <siri.varma@outlook.com> * fix s Signed-off-by: siri-varma <siri.varma@outlook.com> * Fix things Signed-off-by: siri-varma <siri.varma@outlook.com> * Fix things Signed-off-by: siri-varma <siri.varma@outlook.com> * Fix things Signed-off-by: siri-varma <siri.varma@outlook.com> * Make common config Signed-off-by: siri-varma <siri.varma@outlook.com> * Update README.md Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com> * Update README.md Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com> --------- Signed-off-by: Siri Varma Vegiraju <svegiraju@microsoft.com> Signed-off-by: sirivarma <siri.varma@outlook.com> Signed-off-by: siri-varma <siri.varma@outlook.com> Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com> Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com> Co-authored-by: Cassie Coyle <cassie@diagrid.io>
This commit is contained in:
parent
c53f000ef2
commit
99e21db465
|
@ -121,6 +121,12 @@ jobs:
|
|||
mm.py ./src/main/java/io/dapr/examples/jobs/README.md
|
||||
env:
|
||||
DOCKER_HOST: ${{steps.setup_docker.outputs.sock}}
|
||||
- name: Validate conversation ai example
|
||||
working-directory: ./examples
|
||||
run: |
|
||||
mm.py ./src/main/java/io/dapr/examples/conversation/README.md
|
||||
env:
|
||||
DOCKER_HOST: ${{steps.setup_docker.outputs.sock}}
|
||||
- name: Validate invoke http example
|
||||
working-directory: ./examples
|
||||
run: |
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
apiVersion: dapr.io/v1alpha1
|
||||
kind: Component
|
||||
metadata:
|
||||
name: echo
|
||||
spec:
|
||||
type: conversation.echo
|
||||
version: v1
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Copyright 2021 The Dapr Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package io.dapr.examples.conversation;
|
||||
|
||||
import io.dapr.client.DaprClientBuilder;
|
||||
import io.dapr.client.DaprPreviewClient;
|
||||
import io.dapr.client.domain.ConversationInput;
|
||||
import io.dapr.client.domain.ConversationRequest;
|
||||
import io.dapr.client.domain.ConversationResponse;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class DemoConversationAI {
|
||||
/**
|
||||
* The main method to start the client.
|
||||
*
|
||||
* @param args Input arguments (unused).
|
||||
*/
|
||||
public static void main(String[] args) {
|
||||
try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) {
|
||||
System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567");
|
||||
|
||||
ConversationInput daprConversationInput = new ConversationInput("Hello How are you? "
|
||||
+ "This is the my number 672-123-4567");
|
||||
|
||||
// Component name is the name provided in the metadata block of the conversation.yaml file.
|
||||
Mono<ConversationResponse> responseMono = client.converse(new ConversationRequest("echo",
|
||||
List.of(daprConversationInput))
|
||||
.setContextId("contextId")
|
||||
.setScrubPii(true).setTemperature(1.1d));
|
||||
ConversationResponse response = responseMono.block();
|
||||
System.out.printf("Conversation output: %s", response.getConversationOutputs().get(0).getResult());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
## Manage Dapr via the Conversation API
|
||||
|
||||
This example provides the different capabilities provided by Dapr Java SDK for Conversation. For further information about Conversation APIs please refer to [this link](https://docs.dapr.io/developing-applications/building-blocks/conversation/conversation-overview/)
|
||||
|
||||
### Using the Conversation API
|
||||
|
||||
The Java SDK exposes several methods for this -
|
||||
* `client.converse(...)` for conversing with an LLM through Dapr.
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
* [Dapr CLI](https://docs.dapr.io/getting-started/install-dapr-cli/).
|
||||
* Java JDK 11 (or greater):
|
||||
* [Microsoft JDK 11](https://docs.microsoft.com/en-us/java/openjdk/download#openjdk-11)
|
||||
* [Oracle JDK 11](https://www.oracle.com/technetwork/java/javase/downloads/index.html#JDK11)
|
||||
* [OpenJDK 11](https://jdk.java.net/11/)
|
||||
* [Apache Maven](https://maven.apache.org/install.html) version 3.x.
|
||||
|
||||
### Checking out the code
|
||||
|
||||
Clone this repository:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/dapr/java-sdk.git
|
||||
cd java-sdk
|
||||
```
|
||||
|
||||
Then build the Maven project:
|
||||
|
||||
```sh
|
||||
# make sure you are in the `java-sdk` directory
|
||||
mvn install
|
||||
```
|
||||
|
||||
Then get into the examples directory:
|
||||
|
||||
```sh
|
||||
cd examples
|
||||
```
|
||||
|
||||
### Initialize Dapr
|
||||
|
||||
Run `dapr init` to initialize Dapr in Self-Hosted Mode if it's not already initialized.
|
||||
|
||||
### Running the example
|
||||
|
||||
This example uses the Java SDK Dapr client in order to **Converse** with an LLM.
|
||||
`DemoConversationAI.java` is the example class demonstrating these features.
|
||||
Kindly check [DaprPreviewClient.java](https://github.com/dapr/java-sdk/blob/master/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java) for a detailed description of the supported APIs.
|
||||
|
||||
```java
|
||||
public class DemoConversationAI {
|
||||
/**
|
||||
* The main method to start the client.
|
||||
*
|
||||
* @param args Input arguments (unused).
|
||||
*/
|
||||
public static void main(String[] args) {
|
||||
try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) {
|
||||
System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567");
|
||||
|
||||
ConversationInput daprConversationInput = new ConversationInput("Hello How are you? "
|
||||
+ "This is the my number 672-123-4567");
|
||||
|
||||
// Component name is the name provided in the metadata block of the conversation.yaml file.
|
||||
Mono<ConversationResponse> responseMono = client.converse(new ConversationRequest("echo",
|
||||
List.of(daprConversationInput))
|
||||
.setContextId("contextId")
|
||||
.setScrubPii(true).setTemperature(1.1d));
|
||||
ConversationResponse response = responseMono.block();
|
||||
System.out.printf("Conversation output: %s", response.getConversationOutpus().get(0).getResult());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Use the following command to run this example-
|
||||
|
||||
<!-- STEP
|
||||
name: Run Demo Conversation Client example
|
||||
expected_stdout_lines:
|
||||
- "== APP == Conversation output: Hello How are you? This is the my number <ISBN>"
|
||||
background: true
|
||||
output_match_mode: substring
|
||||
sleep: 10
|
||||
-->
|
||||
|
||||
```bash
|
||||
dapr run --resources-path ./components/conversation --app-id myapp --app-port 8080 --dapr-http-port 3500 --dapr-grpc-port 51439 --log-level debug -- java -jar target/dapr-java-sdk-examples-exec.jar io.dapr.examples.conversation.DemoConversationAI
|
||||
```
|
||||
|
||||
<!-- END_STEP -->
|
||||
|
||||
### Sample output
|
||||
```
|
||||
== APP == Conversation output: Hello How are you? This is the my number <ISBN>
|
||||
```
|
||||
### Cleanup
|
||||
|
||||
To stop the app, run (or press CTRL+C):
|
||||
|
||||
<!-- STEP
|
||||
|
||||
name: Cleanup
|
||||
-->
|
||||
|
||||
```bash
|
||||
dapr stop --app-id myapp
|
||||
```
|
||||
|
||||
<!-- END_STEP -->
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
/*
|
||||
* Copyright 2021 The Dapr Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package io.dapr.it.testcontainers;
|
||||
|
||||
import io.dapr.client.DaprPreviewClient;
|
||||
import io.dapr.client.domain.ConversationInput;
|
||||
import io.dapr.client.domain.ConversationRequest;
|
||||
import io.dapr.client.domain.ConversationResponse;
|
||||
import io.dapr.testcontainers.Component;
|
||||
import io.dapr.testcontainers.DaprContainer;
|
||||
import io.dapr.testcontainers.DaprLogLevel;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
|
||||
import org.springframework.test.context.DynamicPropertyRegistry;
|
||||
import org.springframework.test.context.DynamicPropertySource;
|
||||
import org.testcontainers.containers.Network;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static io.dapr.it.testcontainers.ContainerConstants.DAPR_RUNTIME_IMAGE_TAG;
|
||||
|
||||
@SpringBootTest(
|
||||
webEnvironment = WebEnvironment.RANDOM_PORT,
|
||||
classes = {
|
||||
DaprPreviewClientConfiguration.class,
|
||||
TestConversationApplication.class
|
||||
}
|
||||
)
|
||||
@Testcontainers
|
||||
@Tag("testcontainers")
|
||||
public class DaprConversationIT {
|
||||
|
||||
private static final Network DAPR_NETWORK = Network.newNetwork();
|
||||
private static final Random RANDOM = new Random();
|
||||
private static final int PORT = RANDOM.nextInt(1000) + 8000;
|
||||
|
||||
@Container
|
||||
private static final DaprContainer DAPR_CONTAINER = new DaprContainer(DAPR_RUNTIME_IMAGE_TAG)
|
||||
.withAppName("conversation-dapr-app")
|
||||
.withComponent(new Component("echo", "conversation.echo", "v1", new HashMap<>()))
|
||||
.withNetwork(DAPR_NETWORK)
|
||||
.withDaprLogLevel(DaprLogLevel.DEBUG)
|
||||
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
|
||||
.withAppChannelAddress("host.testcontainers.internal")
|
||||
.withAppPort(PORT);
|
||||
|
||||
/**
|
||||
* Expose the Dapr port to the host.
|
||||
*
|
||||
* @param registry the dynamic property registry
|
||||
*/
|
||||
@DynamicPropertySource
|
||||
static void daprProperties(DynamicPropertyRegistry registry) {
|
||||
registry.add("dapr.http.endpoint", DAPR_CONTAINER::getHttpEndpoint);
|
||||
registry.add("dapr.grpc.endpoint", DAPR_CONTAINER::getGrpcEndpoint);
|
||||
registry.add("server.port", () -> PORT);
|
||||
}
|
||||
|
||||
@Autowired
|
||||
private DaprPreviewClient daprPreviewClient;
|
||||
|
||||
@BeforeEach
|
||||
public void setUp(){
|
||||
org.testcontainers.Testcontainers.exposeHostPorts(PORT);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConversationSDKShouldHaveSameOutputAndInput() {
|
||||
ConversationInput conversationInput = new ConversationInput("input this");
|
||||
List<ConversationInput> conversationInputList = new ArrayList<>();
|
||||
conversationInputList.add(conversationInput);
|
||||
|
||||
ConversationResponse response =
|
||||
this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block();
|
||||
|
||||
Assertions.assertEquals("", response.getContextId());
|
||||
Assertions.assertEquals("input this", response.getConversationOutputs().get(0).getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConversationSDKShouldScrubPIIWhenScrubPIIIsSetInRequestBody() {
|
||||
List<ConversationInput> conversationInputList = new ArrayList<>();
|
||||
conversationInputList.add(new ConversationInput("input this abcd@gmail.com"));
|
||||
conversationInputList.add(new ConversationInput("input this +12341567890"));
|
||||
|
||||
ConversationResponse response =
|
||||
this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)
|
||||
.setScrubPii(true)).block();
|
||||
|
||||
Assertions.assertEquals("", response.getContextId());
|
||||
Assertions.assertEquals("input this <EMAIL_ADDRESS>",
|
||||
response.getConversationOutputs().get(0).getResult());
|
||||
Assertions.assertEquals("input this <PHONE_NUMBER>",
|
||||
response.getConversationOutputs().get(1).getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConversationSDKShouldScrubPIIOnlyForTheInputWhereScrubPIIIsSet() {
|
||||
List<ConversationInput> conversationInputList = new ArrayList<>();
|
||||
conversationInputList.add(new ConversationInput("input this abcd@gmail.com"));
|
||||
conversationInputList.add(new ConversationInput("input this +12341567890").setScrubPii(true));
|
||||
|
||||
ConversationResponse response =
|
||||
this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block();
|
||||
|
||||
Assertions.assertEquals("", response.getContextId());
|
||||
Assertions.assertEquals("input this abcd@gmail.com",
|
||||
response.getConversationOutputs().get(0).getResult());
|
||||
Assertions.assertEquals("input this <PHONE_NUMBER>",
|
||||
response.getConversationOutputs().get(1).getResult());
|
||||
}
|
||||
}
|
|
@ -45,7 +45,7 @@ import static org.junit.Assert.assertEquals;
|
|||
@SpringBootTest(
|
||||
webEnvironment = WebEnvironment.RANDOM_PORT,
|
||||
classes = {
|
||||
TestDaprJobsConfiguration.class,
|
||||
DaprPreviewClientConfiguration.class,
|
||||
TestJobsApplication.class
|
||||
}
|
||||
)
|
||||
|
|
|
@ -14,11 +14,9 @@ limitations under the License.
|
|||
package io.dapr.it.testcontainers;
|
||||
|
||||
import io.dapr.client.DaprClientBuilder;
|
||||
import io.dapr.client.DaprClientImpl;
|
||||
import io.dapr.client.DaprPreviewClient;
|
||||
import io.dapr.config.Properties;
|
||||
import io.dapr.config.Property;
|
||||
import io.dapr.serializer.DefaultObjectSerializer;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
@ -26,7 +24,7 @@ import org.springframework.context.annotation.Configuration;
|
|||
import java.util.Map;
|
||||
|
||||
@Configuration
|
||||
public class TestDaprJobsConfiguration {
|
||||
public class DaprPreviewClientConfiguration {
|
||||
@Bean
|
||||
public DaprPreviewClient daprPreviewClient(
|
||||
@Value("${dapr.http.endpoint}") String daprHttpEndpoint,
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright 2024 The Dapr Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package io.dapr.it.testcontainers;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
|
||||
@SpringBootApplication
|
||||
public class TestConversationApplication {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(TestConversationApplication.class, args);
|
||||
}
|
||||
}
|
|
@ -27,6 +27,10 @@ import io.dapr.client.domain.BulkPublishResponseFailedEntry;
|
|||
import io.dapr.client.domain.CloudEvent;
|
||||
import io.dapr.client.domain.ComponentMetadata;
|
||||
import io.dapr.client.domain.ConfigurationItem;
|
||||
import io.dapr.client.domain.ConversationInput;
|
||||
import io.dapr.client.domain.ConversationOutput;
|
||||
import io.dapr.client.domain.ConversationRequest;
|
||||
import io.dapr.client.domain.ConversationResponse;
|
||||
import io.dapr.client.domain.DaprMetadata;
|
||||
import io.dapr.client.domain.DeleteJobRequest;
|
||||
import io.dapr.client.domain.DeleteStateRequest;
|
||||
|
@ -99,7 +103,6 @@ import javax.annotation.Nonnull;
|
|||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.time.OffsetDateTime;
|
||||
import java.time.ZoneOffset;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
|
@ -1552,6 +1555,79 @@ public class DaprClientImpl extends AbstractDaprClient {
|
|||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public Mono<ConversationResponse> converse(ConversationRequest conversationRequest) {
|
||||
|
||||
try {
|
||||
validateConversationRequest(conversationRequest);
|
||||
|
||||
DaprProtos.ConversationRequest.Builder protosConversationRequestBuilder = DaprProtos.ConversationRequest
|
||||
.newBuilder().setTemperature(conversationRequest.getTemperature())
|
||||
.setScrubPII(conversationRequest.isScrubPii())
|
||||
.setName(conversationRequest.getName());
|
||||
|
||||
if (conversationRequest.getContextId() != null) {
|
||||
protosConversationRequestBuilder.setContextID(conversationRequest.getContextId());
|
||||
}
|
||||
|
||||
for (ConversationInput input : conversationRequest.getInputs()) {
|
||||
if (input.getContent() == null || input.getContent().isEmpty()) {
|
||||
throw new IllegalArgumentException("Conversation input content cannot be null or empty.");
|
||||
}
|
||||
|
||||
DaprProtos.ConversationInput.Builder conversationInputOrBuilder = DaprProtos.ConversationInput.newBuilder()
|
||||
.setContent(input.getContent())
|
||||
.setScrubPII(input.isScrubPii());
|
||||
|
||||
if (input.getRole() != null) {
|
||||
conversationInputOrBuilder.setRole(input.getRole().toString());
|
||||
}
|
||||
|
||||
protosConversationRequestBuilder.addInputs(conversationInputOrBuilder.build());
|
||||
}
|
||||
|
||||
Mono<DaprProtos.ConversationResponse> conversationResponseMono = Mono.deferContextual(
|
||||
context -> this.createMono(
|
||||
it -> intercept(context, asyncStub)
|
||||
.converseAlpha1(protosConversationRequestBuilder.build(), it)
|
||||
)
|
||||
);
|
||||
|
||||
return conversationResponseMono.map(conversationResponse -> {
|
||||
|
||||
List<ConversationOutput> conversationOutputs = new ArrayList<>();
|
||||
for (DaprProtos.ConversationResult conversationResult : conversationResponse.getOutputsList()) {
|
||||
Map<String, byte[]> parameters = new HashMap<>();
|
||||
for (Map.Entry<String, Any> entrySet : conversationResult.getParametersMap().entrySet()) {
|
||||
parameters.put(entrySet.getKey(), entrySet.getValue().toByteArray());
|
||||
}
|
||||
|
||||
ConversationOutput conversationOutput =
|
||||
new ConversationOutput(conversationResult.getResult(), parameters);
|
||||
conversationOutputs.add(conversationOutput);
|
||||
}
|
||||
|
||||
return new ConversationResponse(conversationResponse.getContextID(), conversationOutputs);
|
||||
});
|
||||
} catch (Exception ex) {
|
||||
return DaprException.wrapMono(ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void validateConversationRequest(ConversationRequest conversationRequest) {
|
||||
if ((conversationRequest.getName() == null) || (conversationRequest.getName().trim().isEmpty())) {
|
||||
throw new IllegalArgumentException("LLM name cannot be null or empty.");
|
||||
}
|
||||
|
||||
if ((conversationRequest.getInputs() == null) || (conversationRequest
|
||||
.getInputs().isEmpty())) {
|
||||
throw new IllegalArgumentException("Conversation inputs cannot be null or empty.");
|
||||
}
|
||||
}
|
||||
|
||||
private DaprMetadata buildDaprMetadata(DaprProtos.GetMetadataResponse response) throws IOException {
|
||||
String id = response.getId();
|
||||
String runtimeVersion = response.getRuntimeVersion();
|
||||
|
|
|
@ -17,6 +17,8 @@ import io.dapr.client.domain.BulkPublishEntry;
|
|||
import io.dapr.client.domain.BulkPublishRequest;
|
||||
import io.dapr.client.domain.BulkPublishResponse;
|
||||
import io.dapr.client.domain.BulkPublishResponseFailedEntry;
|
||||
import io.dapr.client.domain.ConversationRequest;
|
||||
import io.dapr.client.domain.ConversationResponse;
|
||||
import io.dapr.client.domain.DeleteJobRequest;
|
||||
import io.dapr.client.domain.GetJobRequest;
|
||||
import io.dapr.client.domain.GetJobResponse;
|
||||
|
@ -304,4 +306,12 @@ public interface DaprPreviewClient extends AutoCloseable {
|
|||
* @throws IllegalArgumentException If the request or its required fields like name are null or empty.
|
||||
*/
|
||||
public Mono<Void> deleteJob(DeleteJobRequest deleteJobRequest);
|
||||
|
||||
/*
|
||||
* Converse with an LLM.
|
||||
*
|
||||
* @param conversationRequest request to be passed to the LLM.
|
||||
* @return {@link ConversationResponse}.
|
||||
*/
|
||||
public Mono<ConversationResponse> converse(ConversationRequest conversationRequest);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright 2021 The Dapr Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package io.dapr.client.domain;
|
||||
|
||||
/**
|
||||
* Represents an input message for a conversation with an LLM.
|
||||
*/
|
||||
public class ConversationInput {
|
||||
|
||||
private final String content;
|
||||
|
||||
private String role;
|
||||
|
||||
private boolean scrubPii;
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*
|
||||
* @param content for the llm.
|
||||
*/
|
||||
public ConversationInput(String content) {
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
/**
|
||||
* The message content to send to the LLM. Required
|
||||
*
|
||||
* @return The content to be sent to the LLM.
|
||||
*/
|
||||
public String getContent() {
|
||||
return content;
|
||||
}
|
||||
|
||||
/**
|
||||
* The role for the LLM to assume.
|
||||
*
|
||||
* @return this.
|
||||
*/
|
||||
public String getRole() {
|
||||
return role;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the role for LLM to assume.
|
||||
*
|
||||
* @param role The role to assign to the message.
|
||||
* @return this.
|
||||
*/
|
||||
public ConversationInput setRole(String role) {
|
||||
this.role = role;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if Personally Identifiable Information (PII) should be scrubbed before sending to the LLM.
|
||||
*
|
||||
* @return {@code true} if PII should be scrubbed, {@code false} otherwise.
|
||||
*/
|
||||
public boolean isScrubPii() {
|
||||
return scrubPii;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable obfuscation of sensitive information present in the content field. Optional
|
||||
*
|
||||
* @param scrubPii A boolean indicating whether to remove PII.
|
||||
* @return this.
|
||||
*/
|
||||
public ConversationInput setScrubPii(boolean scrubPii) {
|
||||
this.scrubPii = scrubPii;
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright 2021 The Dapr Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package io.dapr.client.domain;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Returns the conversation output.
|
||||
*/
|
||||
public class ConversationOutput {
|
||||
|
||||
private final String result;
|
||||
|
||||
private final Map<String, byte[]> parameters;
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*
|
||||
* @param result result for one of the conversation input.
|
||||
* @param parameters all custom fields.
|
||||
*/
|
||||
public ConversationOutput(String result, Map<String, byte[]> parameters) {
|
||||
this.result = result;
|
||||
this.parameters = Map.copyOf(parameters);
|
||||
}
|
||||
|
||||
/**
|
||||
* Result for the one conversation input.
|
||||
*
|
||||
* @return result output from the LLM.
|
||||
*/
|
||||
public String getResult() {
|
||||
return this.result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameters for all custom fields.
|
||||
*
|
||||
* @return parameters.
|
||||
*/
|
||||
public Map<String, byte[]> getParameters() {
|
||||
return this.parameters;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Copyright 2021 The Dapr Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package io.dapr.client.domain;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents a conversation configuration with details about component name,
|
||||
* conversation inputs, context identifier, PII scrubbing, and temperature control.
|
||||
*/
|
||||
public class ConversationRequest {
|
||||
|
||||
private final String name;
|
||||
private final List<ConversationInput> inputs;
|
||||
private String contextId;
|
||||
private boolean scrubPii;
|
||||
private double temperature;
|
||||
|
||||
/**
|
||||
* Constructs a DaprConversation with a component name and conversation inputs.
|
||||
*
|
||||
* @param name The name of the Dapr conversation component. See a list of all available conversation components
|
||||
* @see <a href="https://docs.dapr.io/reference/components-reference/supported-conversation/"></a>
|
||||
* @param inputs the list of Dapr conversation inputs
|
||||
*/
|
||||
public ConversationRequest(String name, List<ConversationInput> inputs) {
|
||||
this.name = name;
|
||||
this.inputs = inputs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the conversation component name.
|
||||
*
|
||||
* @return the conversation component name
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the list of Dapr conversation input.
|
||||
*
|
||||
* @return the list of conversation input
|
||||
*/
|
||||
public List<ConversationInput> getInputs() {
|
||||
return inputs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the context identifier.
|
||||
*
|
||||
* @return the context identifier
|
||||
*/
|
||||
public String getContextId() {
|
||||
return contextId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the context identifier.
|
||||
*
|
||||
* @param contextId the context identifier to set
|
||||
* @return the current instance of {@link ConversationRequest}
|
||||
*/
|
||||
public ConversationRequest setContextId(String contextId) {
|
||||
this.contextId = contextId;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if PII scrubbing is enabled.
|
||||
*
|
||||
* @return true if PII scrubbing is enabled, false otherwise
|
||||
*/
|
||||
public boolean isScrubPii() {
|
||||
return scrubPii;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable obfuscation of sensitive information returning from the LLM. Optional.
|
||||
*
|
||||
* @param scrubPii whether to enable PII scrubbing
|
||||
* @return the current instance of {@link ConversationRequest}
|
||||
*/
|
||||
public ConversationRequest setScrubPii(boolean scrubPii) {
|
||||
this.scrubPii = scrubPii;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the temperature of the model. Used to optimize for consistency and creativity. Optional
|
||||
*
|
||||
* @return the temperature value
|
||||
*/
|
||||
public double getTemperature() {
|
||||
return temperature;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the temperature of the model. Used to optimize for consistency and creativity. Optional
|
||||
*
|
||||
* @param temperature the temperature value to set
|
||||
* @return the current instance of {@link ConversationRequest}
|
||||
*/
|
||||
public ConversationRequest setTemperature(double temperature) {
|
||||
this.temperature = temperature;
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright 2021 The Dapr Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package io.dapr.client.domain;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Response from the Dapr Conversation API.
|
||||
*/
|
||||
public class ConversationResponse {
|
||||
|
||||
private String contextId;
|
||||
|
||||
private final List<ConversationOutput> outputs;
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*
|
||||
* @param contextId context id supplied to LLM.
|
||||
* @param outputs outputs from the LLM.
|
||||
*/
|
||||
public ConversationResponse(String contextId, List<ConversationOutput> outputs) {
|
||||
this.contextId = contextId;
|
||||
this.outputs = List.copyOf(outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* The ID of an existing chat (like in ChatGPT).
|
||||
*
|
||||
* @return String identifier.
|
||||
*/
|
||||
public String getContextId() {
|
||||
return this.contextId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of conversation outputs.
|
||||
*
|
||||
* @return List{@link ConversationOutput}.
|
||||
*/
|
||||
public List<ConversationOutput> getConversationOutputs() {
|
||||
return this.outputs;
|
||||
}
|
||||
}
|
|
@ -26,6 +26,9 @@ import io.dapr.client.domain.DeleteJobRequest;
|
|||
import io.dapr.client.domain.GetJobRequest;
|
||||
import io.dapr.client.domain.GetJobResponse;
|
||||
import io.dapr.client.domain.JobSchedule;
|
||||
import io.dapr.client.domain.ConversationInput;
|
||||
import io.dapr.client.domain.ConversationRequest;
|
||||
import io.dapr.client.domain.ConversationResponse;
|
||||
import io.dapr.client.domain.QueryStateItem;
|
||||
import io.dapr.client.domain.QueryStateRequest;
|
||||
import io.dapr.client.domain.QueryStateResponse;
|
||||
|
@ -71,6 +74,7 @@ import java.util.concurrent.Semaphore;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static io.dapr.utils.TestUtils.assertThrowsDaprException;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
|
@ -558,6 +562,141 @@ public class DaprPreviewClientGrpcTest {
|
|||
assertEquals(numErrors, errors.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldThrowIllegalArgumentExceptionWhenComponentNameIsNull() throws Exception {
|
||||
List<ConversationInput> inputs = new ArrayList<>();
|
||||
inputs.add(new ConversationInput("Hello there !"));
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
previewClient.converse(new ConversationRequest(null, inputs)).block());
|
||||
assertEquals("LLM name cannot be null or empty.", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldThrowIllegalArgumentExceptionWhenConversationComponentIsEmpty() throws Exception {
|
||||
List<ConversationInput> inputs = new ArrayList<>();
|
||||
inputs.add(new ConversationInput("Hello there !"));
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
previewClient.converse(new ConversationRequest("", inputs)).block());
|
||||
assertEquals("LLM name cannot be null or empty.", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldThrowIllegalArgumentExceptionWhenInputsIsEmpty() throws Exception {
|
||||
List<ConversationInput> inputs = new ArrayList<>();
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
previewClient.converse(new ConversationRequest("openai", inputs)).block());
|
||||
assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldThrowIllegalArgumentExceptionWhenInputsIsNull() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
previewClient.converse(new ConversationRequest("openai", null)).block());
|
||||
assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldThrowIllegalArgumentExceptionWhenInputContentIsNull() throws Exception {
|
||||
List<ConversationInput> inputs = new ArrayList<>();
|
||||
inputs.add(new ConversationInput(null));
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
previewClient.converse(new ConversationRequest("openai", inputs)).block());
|
||||
assertEquals("Conversation input content cannot be null or empty.", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldThrowIllegalArgumentExceptionWhenInputContentIsEmpty() throws Exception {
|
||||
List<ConversationInput> inputs = new ArrayList<>();
|
||||
inputs.add(new ConversationInput(""));
|
||||
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
previewClient.converse(new ConversationRequest("openai", inputs)).block());
|
||||
assertEquals("Conversation input content cannot be null or empty.", exception.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldReturnConversationResponseWhenRequiredInputsAreValid() throws Exception {
|
||||
DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder()
|
||||
.addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build();
|
||||
|
||||
doAnswer(invocation -> {
|
||||
StreamObserver<DaprProtos.ConversationResponse> observer = invocation.getArgument(1);
|
||||
observer.onNext(conversationResponse);
|
||||
observer.onCompleted();
|
||||
return null;
|
||||
}).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any());
|
||||
|
||||
List<ConversationInput> inputs = new ArrayList<>();
|
||||
inputs.add(new ConversationInput("Hello there"));
|
||||
ConversationResponse response =
|
||||
previewClient.converse(new ConversationRequest("openai", inputs)).block();
|
||||
|
||||
ArgumentCaptor<DaprProtos.ConversationRequest> captor =
|
||||
ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class);
|
||||
verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any());
|
||||
|
||||
DaprProtos.ConversationRequest conversationRequest = captor.getValue();
|
||||
|
||||
assertEquals("openai", conversationRequest.getName());
|
||||
assertEquals("Hello there", conversationRequest.getInputs(0).getContent());
|
||||
assertEquals("Hello How are you",
|
||||
response.getConversationOutputs().get(0).getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void converseShouldReturnConversationResponseWhenRequiredAndOptionalInputsAreValid() throws Exception {
|
||||
DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder()
|
||||
.setContextID("contextId")
|
||||
.addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build();
|
||||
|
||||
doAnswer(invocation -> {
|
||||
StreamObserver<DaprProtos.ConversationResponse> observer = invocation.getArgument(1);
|
||||
observer.onNext(conversationResponse);
|
||||
observer.onCompleted();
|
||||
return null;
|
||||
}).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any());
|
||||
|
||||
ConversationInput daprConversationInput = new ConversationInput("Hello there")
|
||||
.setRole("Assistant")
|
||||
.setScrubPii(true);
|
||||
|
||||
List<ConversationInput> inputs = new ArrayList<>();
|
||||
inputs.add(daprConversationInput);
|
||||
|
||||
ConversationResponse response =
|
||||
previewClient.converse(new ConversationRequest("openai", inputs)
|
||||
.setContextId("contextId")
|
||||
.setScrubPii(true)
|
||||
.setTemperature(1.1d)).block();
|
||||
|
||||
ArgumentCaptor<DaprProtos.ConversationRequest> captor =
|
||||
ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class);
|
||||
verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any());
|
||||
|
||||
DaprProtos.ConversationRequest conversationRequest = captor.getValue();
|
||||
|
||||
assertEquals("openai", conversationRequest.getName());
|
||||
assertEquals("contextId", conversationRequest.getContextID());
|
||||
assertTrue(conversationRequest.getScrubPII());
|
||||
assertEquals(1.1d, conversationRequest.getTemperature(), 0d);
|
||||
assertEquals("Hello there", conversationRequest.getInputs(0).getContent());
|
||||
assertTrue(conversationRequest.getInputs(0).getScrubPII());
|
||||
assertEquals("Assistant", conversationRequest.getInputs(0).getRole());
|
||||
assertEquals("contextId", response.getContextId());
|
||||
assertEquals("Hello How are you",
|
||||
response.getConversationOutputs().get(0).getResult());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void scheduleJobShouldSucceedWhenAllFieldsArePresentInRequest() {
|
||||
DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
|
||||
|
|
Loading…
Reference in New Issue