Add Conversation AI to Java SDK (#1235)

* Conversation first commit

Signed-off-by: Siri Varma Vegiraju <svegiraju@microsoft.com>
Signed-off-by: sirivarma <siri.varma@outlook.com>
Signed-off-by: siri-varma <siri.varma@outlook.com>

* Add unit tests

Signed-off-by: sirivarma <siri.varma@outlook.com>
Signed-off-by: siri-varma <siri.varma@outlook.com>

* change ai to conv

Signed-off-by: sirivarma <siri.varma@outlook.com>
Signed-off-by: siri-varma <siri.varma@outlook.com>

* Move to single module

Signed-off-by: sirivarma <siri.varma@outlook.com>
Signed-off-by: siri-varma <siri.varma@outlook.com>

* Remove module

Signed-off-by: sirivarma <siri.varma@outlook.com>
Signed-off-by: siri-varma <siri.varma@outlook.com>

* Add Integration tests

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Update sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java

Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com>
Signed-off-by: siri-varma <siri.varma@outlook.com>

* Fix things

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Address comments

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Import tag

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Address comments

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Make common config

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Address comments

Signed-off-by: siri-varma <siri.varma@outlook.com>

* fix constant

Signed-off-by: siri-varma <siri.varma@outlook.com>

* fix constant

Signed-off-by: siri-varma <siri.varma@outlook.com>

* fix constant

Signed-off-by: siri-varma <siri.varma@outlook.com>

* fix s

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Fix things

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Fix things

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Fix things

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Make common config

Signed-off-by: siri-varma <siri.varma@outlook.com>

* Update README.md

Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com>

* Update README.md

Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com>

---------

Signed-off-by: Siri Varma Vegiraju <svegiraju@microsoft.com>
Signed-off-by: sirivarma <siri.varma@outlook.com>
Signed-off-by: siri-varma <siri.varma@outlook.com>
Signed-off-by: Siri Varma Vegiraju <siri.varma@outlook.com>
Co-authored-by: Cassie Coyle <cassie.i.coyle@gmail.com>
Co-authored-by: Cassie Coyle <cassie@diagrid.io>
This commit is contained in:
Siri Varma Vegiraju 2025-05-01 12:11:03 -07:00 committed by GitHub
parent c53f000ef2
commit 99e21db465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1005 additions and 133 deletions

View File

@ -121,6 +121,12 @@ jobs:
mm.py ./src/main/java/io/dapr/examples/jobs/README.md
env:
DOCKER_HOST: ${{steps.setup_docker.outputs.sock}}
- name: Validate conversation ai example
working-directory: ./examples
run: |
mm.py ./src/main/java/io/dapr/examples/conversation/README.md
env:
DOCKER_HOST: ${{steps.setup_docker.outputs.sock}}
- name: Validate invoke http example
working-directory: ./examples
run: |

View File

@ -0,0 +1,7 @@
apiVersion: dapr.io/v1alpha1
kind: Component
metadata:
name: echo
spec:
type: conversation.echo
version: v1

View File

@ -0,0 +1,49 @@
/*
* Copyright 2021 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.examples.conversation;
import io.dapr.client.DaprClientBuilder;
import io.dapr.client.DaprPreviewClient;
import io.dapr.client.domain.ConversationInput;
import io.dapr.client.domain.ConversationRequest;
import io.dapr.client.domain.ConversationResponse;
import reactor.core.publisher.Mono;
import java.util.List;
public class DemoConversationAI {
/**
* The main method to start the client.
*
* @param args Input arguments (unused).
*/
public static void main(String[] args) {
try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) {
System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567");
ConversationInput daprConversationInput = new ConversationInput("Hello How are you? "
+ "This is the my number 672-123-4567");
// Component name is the name provided in the metadata block of the conversation.yaml file.
Mono<ConversationResponse> responseMono = client.converse(new ConversationRequest("echo",
List.of(daprConversationInput))
.setContextId("contextId")
.setScrubPii(true).setTemperature(1.1d));
ConversationResponse response = responseMono.block();
System.out.printf("Conversation output: %s", response.getConversationOutputs().get(0).getResult());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,114 @@
## Manage Dapr via the Conversation API
This example provides the different capabilities provided by Dapr Java SDK for Conversation. For further information about Conversation APIs please refer to [this link](https://docs.dapr.io/developing-applications/building-blocks/conversation/conversation-overview/)
### Using the Conversation API
The Java SDK exposes several methods for this -
* `client.converse(...)` for conversing with an LLM through Dapr.
## Pre-requisites
* [Dapr CLI](https://docs.dapr.io/getting-started/install-dapr-cli/).
* Java JDK 11 (or greater):
* [Microsoft JDK 11](https://docs.microsoft.com/en-us/java/openjdk/download#openjdk-11)
* [Oracle JDK 11](https://www.oracle.com/technetwork/java/javase/downloads/index.html#JDK11)
* [OpenJDK 11](https://jdk.java.net/11/)
* [Apache Maven](https://maven.apache.org/install.html) version 3.x.
### Checking out the code
Clone this repository:
```sh
git clone https://github.com/dapr/java-sdk.git
cd java-sdk
```
Then build the Maven project:
```sh
# make sure you are in the `java-sdk` directory
mvn install
```
Then get into the examples directory:
```sh
cd examples
```
### Initialize Dapr
Run `dapr init` to initialize Dapr in Self-Hosted Mode if it's not already initialized.
### Running the example
This example uses the Java SDK Dapr client in order to **Converse** with an LLM.
`DemoConversationAI.java` is the example class demonstrating these features.
Kindly check [DaprPreviewClient.java](https://github.com/dapr/java-sdk/blob/master/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java) for a detailed description of the supported APIs.
```java
public class DemoConversationAI {
/**
* The main method to start the client.
*
* @param args Input arguments (unused).
*/
public static void main(String[] args) {
try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) {
System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567");
ConversationInput daprConversationInput = new ConversationInput("Hello How are you? "
+ "This is the my number 672-123-4567");
// Component name is the name provided in the metadata block of the conversation.yaml file.
Mono<ConversationResponse> responseMono = client.converse(new ConversationRequest("echo",
List.of(daprConversationInput))
.setContextId("contextId")
.setScrubPii(true).setTemperature(1.1d));
ConversationResponse response = responseMono.block();
System.out.printf("Conversation output: %s", response.getConversationOutpus().get(0).getResult());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
```
Use the following command to run this example-
<!-- STEP
name: Run Demo Conversation Client example
expected_stdout_lines:
- "== APP == Conversation output: Hello How are you? This is the my number <ISBN>"
background: true
output_match_mode: substring
sleep: 10
-->
```bash
dapr run --resources-path ./components/conversation --app-id myapp --app-port 8080 --dapr-http-port 3500 --dapr-grpc-port 51439 --log-level debug -- java -jar target/dapr-java-sdk-examples-exec.jar io.dapr.examples.conversation.DemoConversationAI
```
<!-- END_STEP -->
### Sample output
```
== APP == Conversation output: Hello How are you? This is the my number <ISBN>
```
### Cleanup
To stop the app, run (or press CTRL+C):
<!-- STEP
name: Cleanup
-->
```bash
dapr stop --app-id myapp
```
<!-- END_STEP -->

View File

@ -0,0 +1,133 @@
/*
* Copyright 2021 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.it.testcontainers;
import io.dapr.client.DaprPreviewClient;
import io.dapr.client.domain.ConversationInput;
import io.dapr.client.domain.ConversationRequest;
import io.dapr.client.domain.ConversationResponse;
import io.dapr.testcontainers.Component;
import io.dapr.testcontainers.DaprContainer;
import io.dapr.testcontainers.DaprLogLevel;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;
import org.testcontainers.containers.Network;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import static io.dapr.it.testcontainers.ContainerConstants.DAPR_RUNTIME_IMAGE_TAG;
@SpringBootTest(
webEnvironment = WebEnvironment.RANDOM_PORT,
classes = {
DaprPreviewClientConfiguration.class,
TestConversationApplication.class
}
)
@Testcontainers
@Tag("testcontainers")
public class DaprConversationIT {
private static final Network DAPR_NETWORK = Network.newNetwork();
private static final Random RANDOM = new Random();
private static final int PORT = RANDOM.nextInt(1000) + 8000;
@Container
private static final DaprContainer DAPR_CONTAINER = new DaprContainer(DAPR_RUNTIME_IMAGE_TAG)
.withAppName("conversation-dapr-app")
.withComponent(new Component("echo", "conversation.echo", "v1", new HashMap<>()))
.withNetwork(DAPR_NETWORK)
.withDaprLogLevel(DaprLogLevel.DEBUG)
.withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String()))
.withAppChannelAddress("host.testcontainers.internal")
.withAppPort(PORT);
/**
* Expose the Dapr port to the host.
*
* @param registry the dynamic property registry
*/
@DynamicPropertySource
static void daprProperties(DynamicPropertyRegistry registry) {
registry.add("dapr.http.endpoint", DAPR_CONTAINER::getHttpEndpoint);
registry.add("dapr.grpc.endpoint", DAPR_CONTAINER::getGrpcEndpoint);
registry.add("server.port", () -> PORT);
}
@Autowired
private DaprPreviewClient daprPreviewClient;
@BeforeEach
public void setUp(){
org.testcontainers.Testcontainers.exposeHostPorts(PORT);
}
@Test
public void testConversationSDKShouldHaveSameOutputAndInput() {
ConversationInput conversationInput = new ConversationInput("input this");
List<ConversationInput> conversationInputList = new ArrayList<>();
conversationInputList.add(conversationInput);
ConversationResponse response =
this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block();
Assertions.assertEquals("", response.getContextId());
Assertions.assertEquals("input this", response.getConversationOutputs().get(0).getResult());
}
@Test
public void testConversationSDKShouldScrubPIIWhenScrubPIIIsSetInRequestBody() {
List<ConversationInput> conversationInputList = new ArrayList<>();
conversationInputList.add(new ConversationInput("input this abcd@gmail.com"));
conversationInputList.add(new ConversationInput("input this +12341567890"));
ConversationResponse response =
this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)
.setScrubPii(true)).block();
Assertions.assertEquals("", response.getContextId());
Assertions.assertEquals("input this <EMAIL_ADDRESS>",
response.getConversationOutputs().get(0).getResult());
Assertions.assertEquals("input this <PHONE_NUMBER>",
response.getConversationOutputs().get(1).getResult());
}
@Test
public void testConversationSDKShouldScrubPIIOnlyForTheInputWhereScrubPIIIsSet() {
List<ConversationInput> conversationInputList = new ArrayList<>();
conversationInputList.add(new ConversationInput("input this abcd@gmail.com"));
conversationInputList.add(new ConversationInput("input this +12341567890").setScrubPii(true));
ConversationResponse response =
this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block();
Assertions.assertEquals("", response.getContextId());
Assertions.assertEquals("input this abcd@gmail.com",
response.getConversationOutputs().get(0).getResult());
Assertions.assertEquals("input this <PHONE_NUMBER>",
response.getConversationOutputs().get(1).getResult());
}
}

View File

@ -45,7 +45,7 @@ import static org.junit.Assert.assertEquals;
@SpringBootTest(
webEnvironment = WebEnvironment.RANDOM_PORT,
classes = {
TestDaprJobsConfiguration.class,
DaprPreviewClientConfiguration.class,
TestJobsApplication.class
}
)

View File

@ -14,11 +14,9 @@ limitations under the License.
package io.dapr.it.testcontainers;
import io.dapr.client.DaprClientBuilder;
import io.dapr.client.DaprClientImpl;
import io.dapr.client.DaprPreviewClient;
import io.dapr.config.Properties;
import io.dapr.config.Property;
import io.dapr.serializer.DefaultObjectSerializer;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@ -26,7 +24,7 @@ import org.springframework.context.annotation.Configuration;
import java.util.Map;
@Configuration
public class TestDaprJobsConfiguration {
public class DaprPreviewClientConfiguration {
@Bean
public DaprPreviewClient daprPreviewClient(
@Value("${dapr.http.endpoint}") String daprHttpEndpoint,

View File

@ -0,0 +1,25 @@
/*
* Copyright 2024 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.it.testcontainers;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class TestConversationApplication {
public static void main(String[] args) {
SpringApplication.run(TestConversationApplication.class, args);
}
}

View File

@ -27,6 +27,10 @@ import io.dapr.client.domain.BulkPublishResponseFailedEntry;
import io.dapr.client.domain.CloudEvent;
import io.dapr.client.domain.ComponentMetadata;
import io.dapr.client.domain.ConfigurationItem;
import io.dapr.client.domain.ConversationInput;
import io.dapr.client.domain.ConversationOutput;
import io.dapr.client.domain.ConversationRequest;
import io.dapr.client.domain.ConversationResponse;
import io.dapr.client.domain.DaprMetadata;
import io.dapr.client.domain.DeleteJobRequest;
import io.dapr.client.domain.DeleteStateRequest;
@ -99,7 +103,6 @@ import javax.annotation.Nonnull;
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
@ -1552,6 +1555,79 @@ public class DaprClientImpl extends AbstractDaprClient {
});
}
/**
* {@inheritDoc}
*/
@Override
public Mono<ConversationResponse> converse(ConversationRequest conversationRequest) {
try {
validateConversationRequest(conversationRequest);
DaprProtos.ConversationRequest.Builder protosConversationRequestBuilder = DaprProtos.ConversationRequest
.newBuilder().setTemperature(conversationRequest.getTemperature())
.setScrubPII(conversationRequest.isScrubPii())
.setName(conversationRequest.getName());
if (conversationRequest.getContextId() != null) {
protosConversationRequestBuilder.setContextID(conversationRequest.getContextId());
}
for (ConversationInput input : conversationRequest.getInputs()) {
if (input.getContent() == null || input.getContent().isEmpty()) {
throw new IllegalArgumentException("Conversation input content cannot be null or empty.");
}
DaprProtos.ConversationInput.Builder conversationInputOrBuilder = DaprProtos.ConversationInput.newBuilder()
.setContent(input.getContent())
.setScrubPII(input.isScrubPii());
if (input.getRole() != null) {
conversationInputOrBuilder.setRole(input.getRole().toString());
}
protosConversationRequestBuilder.addInputs(conversationInputOrBuilder.build());
}
Mono<DaprProtos.ConversationResponse> conversationResponseMono = Mono.deferContextual(
context -> this.createMono(
it -> intercept(context, asyncStub)
.converseAlpha1(protosConversationRequestBuilder.build(), it)
)
);
return conversationResponseMono.map(conversationResponse -> {
List<ConversationOutput> conversationOutputs = new ArrayList<>();
for (DaprProtos.ConversationResult conversationResult : conversationResponse.getOutputsList()) {
Map<String, byte[]> parameters = new HashMap<>();
for (Map.Entry<String, Any> entrySet : conversationResult.getParametersMap().entrySet()) {
parameters.put(entrySet.getKey(), entrySet.getValue().toByteArray());
}
ConversationOutput conversationOutput =
new ConversationOutput(conversationResult.getResult(), parameters);
conversationOutputs.add(conversationOutput);
}
return new ConversationResponse(conversationResponse.getContextID(), conversationOutputs);
});
} catch (Exception ex) {
return DaprException.wrapMono(ex);
}
}
private void validateConversationRequest(ConversationRequest conversationRequest) {
if ((conversationRequest.getName() == null) || (conversationRequest.getName().trim().isEmpty())) {
throw new IllegalArgumentException("LLM name cannot be null or empty.");
}
if ((conversationRequest.getInputs() == null) || (conversationRequest
.getInputs().isEmpty())) {
throw new IllegalArgumentException("Conversation inputs cannot be null or empty.");
}
}
private DaprMetadata buildDaprMetadata(DaprProtos.GetMetadataResponse response) throws IOException {
String id = response.getId();
String runtimeVersion = response.getRuntimeVersion();

View File

@ -17,6 +17,8 @@ import io.dapr.client.domain.BulkPublishEntry;
import io.dapr.client.domain.BulkPublishRequest;
import io.dapr.client.domain.BulkPublishResponse;
import io.dapr.client.domain.BulkPublishResponseFailedEntry;
import io.dapr.client.domain.ConversationRequest;
import io.dapr.client.domain.ConversationResponse;
import io.dapr.client.domain.DeleteJobRequest;
import io.dapr.client.domain.GetJobRequest;
import io.dapr.client.domain.GetJobResponse;
@ -304,4 +306,12 @@ public interface DaprPreviewClient extends AutoCloseable {
* @throws IllegalArgumentException If the request or its required fields like name are null or empty.
*/
public Mono<Void> deleteJob(DeleteJobRequest deleteJobRequest);
/*
* Converse with an LLM.
*
* @param conversationRequest request to be passed to the LLM.
* @return {@link ConversationResponse}.
*/
public Mono<ConversationResponse> converse(ConversationRequest conversationRequest);
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2021 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.client.domain;
/**
* Represents an input message for a conversation with an LLM.
*/
public class ConversationInput {
private final String content;
private String role;
private boolean scrubPii;
/**
* Constructor.
*
* @param content for the llm.
*/
public ConversationInput(String content) {
this.content = content;
}
/**
* The message content to send to the LLM. Required
*
* @return The content to be sent to the LLM.
*/
public String getContent() {
return content;
}
/**
* The role for the LLM to assume.
*
* @return this.
*/
public String getRole() {
return role;
}
/**
* Set the role for LLM to assume.
*
* @param role The role to assign to the message.
* @return this.
*/
public ConversationInput setRole(String role) {
this.role = role;
return this;
}
/**
* Checks if Personally Identifiable Information (PII) should be scrubbed before sending to the LLM.
*
* @return {@code true} if PII should be scrubbed, {@code false} otherwise.
*/
public boolean isScrubPii() {
return scrubPii;
}
/**
* Enable obfuscation of sensitive information present in the content field. Optional
*
* @param scrubPii A boolean indicating whether to remove PII.
* @return this.
*/
public ConversationInput setScrubPii(boolean scrubPii) {
this.scrubPii = scrubPii;
return this;
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright 2021 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.client.domain;
import java.util.Collections;
import java.util.Map;
/**
* Returns the conversation output.
*/
public class ConversationOutput {
private final String result;
private final Map<String, byte[]> parameters;
/**
* Constructor.
*
* @param result result for one of the conversation input.
* @param parameters all custom fields.
*/
public ConversationOutput(String result, Map<String, byte[]> parameters) {
this.result = result;
this.parameters = Map.copyOf(parameters);
}
/**
* Result for the one conversation input.
*
* @return result output from the LLM.
*/
public String getResult() {
return this.result;
}
/**
* Parameters for all custom fields.
*
* @return parameters.
*/
public Map<String, byte[]> getParameters() {
return this.parameters;
}
}

View File

@ -0,0 +1,119 @@
/*
* Copyright 2021 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.client.domain;
import java.util.List;
/**
* Represents a conversation configuration with details about component name,
* conversation inputs, context identifier, PII scrubbing, and temperature control.
*/
public class ConversationRequest {
private final String name;
private final List<ConversationInput> inputs;
private String contextId;
private boolean scrubPii;
private double temperature;
/**
* Constructs a DaprConversation with a component name and conversation inputs.
*
* @param name The name of the Dapr conversation component. See a list of all available conversation components
* @see <a href="https://docs.dapr.io/reference/components-reference/supported-conversation/"></a>
* @param inputs the list of Dapr conversation inputs
*/
public ConversationRequest(String name, List<ConversationInput> inputs) {
this.name = name;
this.inputs = inputs;
}
/**
* Gets the conversation component name.
*
* @return the conversation component name
*/
public String getName() {
return name;
}
/**
* Gets the list of Dapr conversation input.
*
* @return the list of conversation input
*/
public List<ConversationInput> getInputs() {
return inputs;
}
/**
* Gets the context identifier.
*
* @return the context identifier
*/
public String getContextId() {
return contextId;
}
/**
* Sets the context identifier.
*
* @param contextId the context identifier to set
* @return the current instance of {@link ConversationRequest}
*/
public ConversationRequest setContextId(String contextId) {
this.contextId = contextId;
return this;
}
/**
* Checks if PII scrubbing is enabled.
*
* @return true if PII scrubbing is enabled, false otherwise
*/
public boolean isScrubPii() {
return scrubPii;
}
/**
* Enable obfuscation of sensitive information returning from the LLM. Optional.
*
* @param scrubPii whether to enable PII scrubbing
* @return the current instance of {@link ConversationRequest}
*/
public ConversationRequest setScrubPii(boolean scrubPii) {
this.scrubPii = scrubPii;
return this;
}
/**
* Gets the temperature of the model. Used to optimize for consistency and creativity. Optional
*
* @return the temperature value
*/
public double getTemperature() {
return temperature;
}
/**
* Sets the temperature of the model. Used to optimize for consistency and creativity. Optional
*
* @param temperature the temperature value to set
* @return the current instance of {@link ConversationRequest}
*/
public ConversationRequest setTemperature(double temperature) {
this.temperature = temperature;
return this;
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright 2021 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.client.domain;
import java.util.Collections;
import java.util.List;
/**
* Response from the Dapr Conversation API.
*/
public class ConversationResponse {
private String contextId;
private final List<ConversationOutput> outputs;
/**
* Constructor.
*
* @param contextId context id supplied to LLM.
* @param outputs outputs from the LLM.
*/
public ConversationResponse(String contextId, List<ConversationOutput> outputs) {
this.contextId = contextId;
this.outputs = List.copyOf(outputs);
}
/**
* The ID of an existing chat (like in ChatGPT).
*
* @return String identifier.
*/
public String getContextId() {
return this.contextId;
}
/**
* Get list of conversation outputs.
*
* @return List{@link ConversationOutput}.
*/
public List<ConversationOutput> getConversationOutputs() {
return this.outputs;
}
}

View File

@ -26,6 +26,9 @@ import io.dapr.client.domain.DeleteJobRequest;
import io.dapr.client.domain.GetJobRequest;
import io.dapr.client.domain.GetJobResponse;
import io.dapr.client.domain.JobSchedule;
import io.dapr.client.domain.ConversationInput;
import io.dapr.client.domain.ConversationRequest;
import io.dapr.client.domain.ConversationResponse;
import io.dapr.client.domain.QueryStateItem;
import io.dapr.client.domain.QueryStateRequest;
import io.dapr.client.domain.QueryStateResponse;
@ -71,6 +74,7 @@ import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import static io.dapr.utils.TestUtils.assertThrowsDaprException;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@ -110,7 +114,7 @@ public class DaprPreviewClientGrpcTest {
daprHttp = mock(DaprHttp.class);
when(daprStub.withInterceptors(any())).thenReturn(daprStub);
previewClient = new DaprClientImpl(
channel, daprStub, daprHttp, new DefaultObjectSerializer(), new DefaultObjectSerializer());
channel, daprStub, daprHttp, new DefaultObjectSerializer(), new DefaultObjectSerializer());
doNothing().when(channel).close();
}
@ -128,28 +132,28 @@ public class DaprPreviewClientGrpcTest {
}).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any());
assertThrowsDaprException(
StatusRuntimeException.class,
"INVALID_ARGUMENT",
"INVALID_ARGUMENT: bad bad argument",
() -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.EMPTY_LIST)).block());
StatusRuntimeException.class,
"INVALID_ARGUMENT",
"INVALID_ARGUMENT: bad bad argument",
() -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.EMPTY_LIST)).block());
}
@Test
public void publishEventsCallbackExceptionThrownTest() {
doAnswer((Answer<Void>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
observer.onError(newStatusRuntimeException("INVALID_ARGUMENT", "bad bad argument"));
return null;
}).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any());
assertThrowsDaprException(
ExecutionException.class,
"INVALID_ARGUMENT",
"INVALID_ARGUMENT: bad bad argument",
() -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.EMPTY_LIST)).block());
ExecutionException.class,
"INVALID_ARGUMENT",
"INVALID_ARGUMENT: bad bad argument",
() -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.EMPTY_LIST)).block());
}
@Test
@ -157,7 +161,7 @@ public class DaprPreviewClientGrpcTest {
DaprObjectSerializer mockSerializer = mock(DaprObjectSerializer.class);
doAnswer((Answer<Void>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance());
observer.onCompleted();
return null;
@ -165,9 +169,9 @@ public class DaprPreviewClientGrpcTest {
BulkPublishEntry<String> entry = new BulkPublishEntry<>("1", "testEntry"
, "application/octet-stream", null);
, "application/octet-stream", null);
BulkPublishRequest<String> wrongReq = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.singletonList(entry));
Collections.singletonList(entry));
assertThrows(IllegalArgumentException.class, () -> previewClient.publishEvents(wrongReq).block());
}
@ -178,30 +182,30 @@ public class DaprPreviewClientGrpcTest {
previewClient = new DaprClientImpl(channel, daprStub, daprHttp, mockSerializer, new DefaultObjectSerializer());
doAnswer((Answer<Void>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance());
observer.onCompleted();
return null;
}).when(daprStub).publishEvent(any(DaprProtos.PublishEventRequest.class), any());
BulkPublishEntry<Map<String, String>> entry = new BulkPublishEntry<>("1", new HashMap<>(),
"application/json", null);
"application/json", null);
BulkPublishRequest<Map<String, String>> req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.singletonList(entry));
Collections.singletonList(entry));
when(mockSerializer.serialize(any())).thenThrow(IOException.class);
Mono<BulkPublishResponse<Map<String, String>>> result = previewClient.publishEvents(req);
assertThrowsDaprException(
IOException.class,
"UNKNOWN",
"UNKNOWN: ",
() -> result.block());
IOException.class,
"UNKNOWN",
"UNKNOWN: ",
() -> result.block());
}
@Test
public void publishEventsTest() {
doAnswer((Answer<BulkPublishResponse>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
DaprProtos.BulkPublishResponse.Builder builder = DaprProtos.BulkPublishResponse.newBuilder();
observer.onNext(builder.build());
observer.onCompleted();
@ -209,9 +213,9 @@ public class DaprPreviewClientGrpcTest {
}).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any());
BulkPublishEntry<String> entry = new BulkPublishEntry<>("1", "test",
"text/plain", null);
"text/plain", null);
BulkPublishRequest<String> req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.singletonList(entry));
Collections.singletonList(entry));
Mono<BulkPublishResponse<String>> result = previewClient.publishEvents(req);
BulkPublishResponse res = result.block();
Assertions.assertNotNull(res);
@ -222,7 +226,7 @@ public class DaprPreviewClientGrpcTest {
public void publishEventsWithoutMetaTest() {
doAnswer((Answer<BulkPublishResponse>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
DaprProtos.BulkPublishResponse.Builder builder = DaprProtos.BulkPublishResponse.newBuilder();
observer.onNext(builder.build());
observer.onCompleted();
@ -230,7 +234,7 @@ public class DaprPreviewClientGrpcTest {
}).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any());
Mono<BulkPublishResponse<String>> result = previewClient.publishEvents(PUBSUB_NAME, TOPIC_NAME,
"text/plain", Collections.singletonList("test"));
"text/plain", Collections.singletonList("test"));
BulkPublishResponse<String> res = result.block();
Assertions.assertNotNull(res);
assertEquals( 0, res.getFailedEntries().size(), "expected no entries in failed entries list");
@ -240,7 +244,7 @@ public class DaprPreviewClientGrpcTest {
public void publishEventsWithRequestMetaTest() {
doAnswer((Answer<BulkPublishResponse>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
DaprProtos.BulkPublishResponse.Builder builder = DaprProtos.BulkPublishResponse.newBuilder();
observer.onNext(builder.build());
observer.onCompleted();
@ -248,9 +252,9 @@ public class DaprPreviewClientGrpcTest {
}).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any());
Mono<BulkPublishResponse<String>> result = previewClient.publishEvents(PUBSUB_NAME, TOPIC_NAME,
"text/plain", new HashMap<String, String>(){{
put("ttlInSeconds", "123");
}}, Collections.singletonList("test"));
"text/plain", new HashMap<String, String>(){{
put("ttlInSeconds", "123");
}}, Collections.singletonList("test"));
BulkPublishResponse<String> res = result.block();
Assertions.assertNotNull(res);
assertEquals( 0, res.getFailedEntries().size(), "expected no entry in failed entries list");
@ -260,7 +264,7 @@ public class DaprPreviewClientGrpcTest {
public void publishEventsObjectTest() {
doAnswer((Answer<Void>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance());
observer.onCompleted();
return null;
@ -271,7 +275,7 @@ public class DaprPreviewClientGrpcTest {
}
if (!"{\"id\":1,\"value\":\"Event\"}".equals(new String(entry.getEvent().toByteArray())) &&
!"{\"value\":\"Event\",\"id\":1}".equals(new String(entry.getEvent().toByteArray()))) {
!"{\"value\":\"Event\",\"id\":1}".equals(new String(entry.getEvent().toByteArray()))) {
return false;
}
return true;
@ -280,9 +284,9 @@ public class DaprPreviewClientGrpcTest {
DaprClientGrpcTest.MyObject event = new DaprClientGrpcTest.MyObject(1, "Event");
BulkPublishEntry<DaprClientGrpcTest.MyObject> entry = new BulkPublishEntry<>("1", event,
"application/json", null);
"application/json", null);
BulkPublishRequest<DaprClientGrpcTest.MyObject> req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.singletonList(entry));
Collections.singletonList(entry));
BulkPublishResponse<DaprClientGrpcTest.MyObject> result = previewClient.publishEvents(req).block();
Assertions.assertNotNull(result);
Assertions.assertEquals(0, result.getFailedEntries().size(), "expected no entries to be failed");
@ -292,7 +296,7 @@ public class DaprPreviewClientGrpcTest {
public void publishEventsContentTypeOverrideTest() {
doAnswer((Answer<Void>) invocation -> {
StreamObserver<DaprProtos.BulkPublishResponse> observer =
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.BulkPublishResponse>) invocation.getArguments()[1];
observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance());
observer.onCompleted();
return null;
@ -309,9 +313,9 @@ public class DaprPreviewClientGrpcTest {
}), any());
BulkPublishEntry<String> entry = new BulkPublishEntry<>("1", "hello",
"", null);
"", null);
BulkPublishRequest<String> req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME,
Collections.singletonList(entry));
Collections.singletonList(entry));
BulkPublishResponse<String> result = previewClient.publishEvents(req).block();
Assertions.assertNotNull(result);
Assertions.assertEquals( 0, result.getFailedEntries().size(), "expected no entries to be failed");
@ -351,7 +355,7 @@ public class DaprPreviewClientGrpcTest {
assertEquals(0, req.getMetadataCount());
StreamObserver<DaprProtos.QueryStateResponse> observer = (StreamObserver<DaprProtos.QueryStateResponse>)
invocation.getArguments()[1];
invocation.getArguments()[1];
observer.onNext(responseEnvelope);
observer.onCompleted();
return null;
@ -378,14 +382,14 @@ public class DaprPreviewClientGrpcTest {
assertEquals(1, req.getMetadataCount());
StreamObserver<DaprProtos.QueryStateResponse> observer = (StreamObserver<DaprProtos.QueryStateResponse>)
invocation.getArguments()[1];
invocation.getArguments()[1];
observer.onNext(responseEnvelope);
observer.onCompleted();
return null;
}).when(daprStub).queryStateAlpha1(any(DaprProtos.QueryStateRequest.class), any());
QueryStateResponse<String> response = previewClient.queryState(QUERY_STORE_NAME, "query",
new HashMap<String, String>(){{ put("key", "error"); }}, String.class).block();
new HashMap<String, String>(){{ put("key", "error"); }}, String.class).block();
assertNotNull(response);
assertEquals(1, response.getResults().size(), "result size must be 1");
assertEquals( "1", response.getResults().get(0).getKey(), "result must be same");
@ -396,7 +400,7 @@ public class DaprPreviewClientGrpcTest {
public void tryLock() {
DaprProtos.TryLockResponse.Builder builder = DaprProtos.TryLockResponse.newBuilder()
.setSuccess(true);
.setSuccess(true);
DaprProtos.TryLockResponse response = builder.build();
@ -408,7 +412,7 @@ public class DaprPreviewClientGrpcTest {
assertEquals(10, req.getExpiryInSeconds());
StreamObserver<DaprProtos.TryLockResponse> observer =
(StreamObserver<DaprProtos.TryLockResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.TryLockResponse>) invocation.getArguments()[1];
observer.onNext(response);
observer.onCompleted();
return null;
@ -422,7 +426,7 @@ public class DaprPreviewClientGrpcTest {
public void unLock() {
DaprProtos.UnlockResponse.Builder builder = DaprProtos.UnlockResponse.newBuilder()
.setStatus(DaprProtos.UnlockResponse.Status.SUCCESS);
.setStatus(DaprProtos.UnlockResponse.Status.SUCCESS);
DaprProtos.UnlockResponse response = builder.build();
@ -433,7 +437,7 @@ public class DaprPreviewClientGrpcTest {
assertEquals("owner", req.getLockOwner());
StreamObserver<DaprProtos.UnlockResponse> observer =
(StreamObserver<DaprProtos.UnlockResponse>) invocation.getArguments()[1];
(StreamObserver<DaprProtos.UnlockResponse>) invocation.getArguments()[1];
observer.onNext(response);
observer.onCompleted();
return null;
@ -457,7 +461,7 @@ public class DaprPreviewClientGrpcTest {
doAnswer((Answer<StreamObserver<DaprProtos.SubscribeTopicEventsRequestAlpha1>>) invocation -> {
StreamObserver<DaprProtos.SubscribeTopicEventsResponseAlpha1> observer =
(StreamObserver<DaprProtos.SubscribeTopicEventsResponseAlpha1>) invocation.getArguments()[0];
(StreamObserver<DaprProtos.SubscribeTopicEventsResponseAlpha1>) invocation.getArguments()[0];
var emitterThread = new Thread(() -> {
try {
started.acquire();
@ -467,27 +471,27 @@ public class DaprPreviewClientGrpcTest {
observer.onNext(DaprProtos.SubscribeTopicEventsResponseAlpha1.getDefaultInstance());
for (int i = 0; i < numEvents; i++) {
observer.onNext(DaprProtos.SubscribeTopicEventsResponseAlpha1.newBuilder()
.setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder()
.setId(Integer.toString(i))
.setPubsubName(pubsubName)
.setTopic(topicName)
.setData(ByteString.copyFromUtf8("\"" + data + "\""))
.setDataContentType("application/json")
.build())
.build());
.setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder()
.setId(Integer.toString(i))
.setPubsubName(pubsubName)
.setTopic(topicName)
.setData(ByteString.copyFromUtf8("\"" + data + "\""))
.setDataContentType("application/json")
.build())
.build());
}
for (int i = 0; i < numDrops; i++) {
// Bad messages
observer.onNext(DaprProtos.SubscribeTopicEventsResponseAlpha1.newBuilder()
.setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder()
.setId(UUID.randomUUID().toString())
.setPubsubName("bad pubsub")
.setTopic("bad topic")
.setData(ByteString.copyFromUtf8("\"\""))
.setDataContentType("application/json")
.build())
.build());
.setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder()
.setId(UUID.randomUUID().toString())
.setPubsubName("bad pubsub")
.setTopic("bad topic")
.setData(ByteString.copyFromUtf8("\"\""))
.setDataContentType("application/json")
.build())
.build());
}
observer.onCompleted();
});
@ -517,38 +521,38 @@ public class DaprPreviewClientGrpcTest {
final AtomicInteger errorsToBeEmitted = new AtomicInteger(numErrors);
var subscription = previewClient.subscribeToEvents(
"pubsubname",
"topic",
new SubscriptionListener<>() {
@Override
public Mono<Status> onEvent(CloudEvent<String> event) {
if (event.getPubsubName().equals(pubsubName) &&
event.getTopic().equals(topicName) &&
event.getData().equals(data)) {
"pubsubname",
"topic",
new SubscriptionListener<>() {
@Override
public Mono<Status> onEvent(CloudEvent<String> event) {
if (event.getPubsubName().equals(pubsubName) &&
event.getTopic().equals(topicName) &&
event.getData().equals(data)) {
// Simulate an error
if ((success.size() == 4 /* some random entry */) && errorsToBeEmitted.decrementAndGet() >= 0) {
throw new RuntimeException("simulated exception on event " + event.getId());
// Simulate an error
if ((success.size() == 4 /* some random entry */) && errorsToBeEmitted.decrementAndGet() >= 0) {
throw new RuntimeException("simulated exception on event " + event.getId());
}
success.add(event.getId());
if (success.size() >= numEvents) {
gotAll.release();
}
return Mono.just(Status.SUCCESS);
}
dropCounter.incrementAndGet();
return Mono.just(Status.DROP);
}
success.add(event.getId());
if (success.size() >= numEvents) {
gotAll.release();
@Override
public void onError(RuntimeException exception) {
errors.add(exception.getMessage());
}
return Mono.just(Status.SUCCESS);
}
dropCounter.incrementAndGet();
return Mono.just(Status.DROP);
}
@Override
public void onError(RuntimeException exception) {
errors.add(exception.getMessage());
}
},
TypeRef.STRING);
},
TypeRef.STRING);
gotAll.acquire();
subscription.close();
@ -558,17 +562,152 @@ public class DaprPreviewClientGrpcTest {
assertEquals(numErrors, errors.size());
}
@Test
public void converseShouldThrowIllegalArgumentExceptionWhenComponentNameIsNull() throws Exception {
List<ConversationInput> inputs = new ArrayList<>();
inputs.add(new ConversationInput("Hello there !"));
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () ->
previewClient.converse(new ConversationRequest(null, inputs)).block());
assertEquals("LLM name cannot be null or empty.", exception.getMessage());
}
@Test
public void converseShouldThrowIllegalArgumentExceptionWhenConversationComponentIsEmpty() throws Exception {
List<ConversationInput> inputs = new ArrayList<>();
inputs.add(new ConversationInput("Hello there !"));
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () ->
previewClient.converse(new ConversationRequest("", inputs)).block());
assertEquals("LLM name cannot be null or empty.", exception.getMessage());
}
@Test
public void converseShouldThrowIllegalArgumentExceptionWhenInputsIsEmpty() throws Exception {
List<ConversationInput> inputs = new ArrayList<>();
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () ->
previewClient.converse(new ConversationRequest("openai", inputs)).block());
assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage());
}
@Test
public void converseShouldThrowIllegalArgumentExceptionWhenInputsIsNull() throws Exception {
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () ->
previewClient.converse(new ConversationRequest("openai", null)).block());
assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage());
}
@Test
public void converseShouldThrowIllegalArgumentExceptionWhenInputContentIsNull() throws Exception {
List<ConversationInput> inputs = new ArrayList<>();
inputs.add(new ConversationInput(null));
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () ->
previewClient.converse(new ConversationRequest("openai", inputs)).block());
assertEquals("Conversation input content cannot be null or empty.", exception.getMessage());
}
@Test
public void converseShouldThrowIllegalArgumentExceptionWhenInputContentIsEmpty() throws Exception {
List<ConversationInput> inputs = new ArrayList<>();
inputs.add(new ConversationInput(""));
IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () ->
previewClient.converse(new ConversationRequest("openai", inputs)).block());
assertEquals("Conversation input content cannot be null or empty.", exception.getMessage());
}
@Test
public void converseShouldReturnConversationResponseWhenRequiredInputsAreValid() throws Exception {
DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder()
.addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build();
doAnswer(invocation -> {
StreamObserver<DaprProtos.ConversationResponse> observer = invocation.getArgument(1);
observer.onNext(conversationResponse);
observer.onCompleted();
return null;
}).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any());
List<ConversationInput> inputs = new ArrayList<>();
inputs.add(new ConversationInput("Hello there"));
ConversationResponse response =
previewClient.converse(new ConversationRequest("openai", inputs)).block();
ArgumentCaptor<DaprProtos.ConversationRequest> captor =
ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class);
verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any());
DaprProtos.ConversationRequest conversationRequest = captor.getValue();
assertEquals("openai", conversationRequest.getName());
assertEquals("Hello there", conversationRequest.getInputs(0).getContent());
assertEquals("Hello How are you",
response.getConversationOutputs().get(0).getResult());
}
@Test
public void converseShouldReturnConversationResponseWhenRequiredAndOptionalInputsAreValid() throws Exception {
DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder()
.setContextID("contextId")
.addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build();
doAnswer(invocation -> {
StreamObserver<DaprProtos.ConversationResponse> observer = invocation.getArgument(1);
observer.onNext(conversationResponse);
observer.onCompleted();
return null;
}).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any());
ConversationInput daprConversationInput = new ConversationInput("Hello there")
.setRole("Assistant")
.setScrubPii(true);
List<ConversationInput> inputs = new ArrayList<>();
inputs.add(daprConversationInput);
ConversationResponse response =
previewClient.converse(new ConversationRequest("openai", inputs)
.setContextId("contextId")
.setScrubPii(true)
.setTemperature(1.1d)).block();
ArgumentCaptor<DaprProtos.ConversationRequest> captor =
ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class);
verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any());
DaprProtos.ConversationRequest conversationRequest = captor.getValue();
assertEquals("openai", conversationRequest.getName());
assertEquals("contextId", conversationRequest.getContextID());
assertTrue(conversationRequest.getScrubPII());
assertEquals(1.1d, conversationRequest.getTemperature(), 0d);
assertEquals("Hello there", conversationRequest.getInputs(0).getContent());
assertTrue(conversationRequest.getInputs(0).getScrubPII());
assertEquals("Assistant", conversationRequest.getInputs(0).getRole());
assertEquals("contextId", response.getContextId());
assertEquals("Hello How are you",
response.getConversationOutputs().get(0).getResult());
}
@Test
public void scheduleJobShouldSucceedWhenAllFieldsArePresentInRequest() {
DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
.withZone(ZoneOffset.UTC);
.withZone(ZoneOffset.UTC);
ScheduleJobRequest expectedScheduleJobRequest = new ScheduleJobRequest("testJob",
JobSchedule.fromString("*/5 * * * *"))
.setData("testData".getBytes())
.setTtl(Instant.now().plus(1, ChronoUnit.DAYS))
.setRepeat(5)
.setDueTime(Instant.now().plus(10, ChronoUnit.MINUTES));
JobSchedule.fromString("*/5 * * * *"))
.setData("testData".getBytes())
.setTtl(Instant.now().plus(1, ChronoUnit.DAYS))
.setRepeat(5)
.setDueTime(Instant.now().plus(10, ChronoUnit.MINUTES));
doAnswer(invocation -> {
StreamObserver<DaprProtos.ScheduleJobResponse> observer = invocation.getArgument(1);
@ -579,14 +718,14 @@ public class DaprPreviewClientGrpcTest {
assertDoesNotThrow(() -> previewClient.scheduleJob(expectedScheduleJobRequest).block());
ArgumentCaptor<DaprProtos.ScheduleJobRequest> captor =
ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class);
ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class);
verify(daprStub, times(1)).scheduleJobAlpha1(captor.capture(), Mockito.any());
DaprProtos.ScheduleJobRequest actualScheduleJobReq = captor.getValue();
assertEquals("testJob", actualScheduleJobReq.getJob().getName());
assertEquals("testData",
new String(actualScheduleJobReq.getJob().getData().getValue().toByteArray(), StandardCharsets.UTF_8));
new String(actualScheduleJobReq.getJob().getData().getValue().toByteArray(), StandardCharsets.UTF_8));
assertEquals("*/5 * * * *", actualScheduleJobReq.getJob().getSchedule());
assertEquals(iso8601Formatter.format(expectedScheduleJobRequest.getTtl()), actualScheduleJobReq.getJob().getTtl());
assertEquals(expectedScheduleJobRequest.getRepeats(), actualScheduleJobReq.getJob().getRepeats());
@ -596,7 +735,7 @@ public class DaprPreviewClientGrpcTest {
@Test
public void scheduleJobShouldSucceedWhenRequiredFieldsNameAndDueTimeArePresentInRequest() {
DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
.withZone(ZoneOffset.UTC);
.withZone(ZoneOffset.UTC);
doAnswer(invocation -> {
StreamObserver<DaprProtos.ScheduleJobResponse> observer = invocation.getArgument(1);
@ -605,11 +744,11 @@ public class DaprPreviewClientGrpcTest {
}).when(daprStub).scheduleJobAlpha1(any(DaprProtos.ScheduleJobRequest.class), any());
ScheduleJobRequest expectedScheduleJobRequest =
new ScheduleJobRequest("testJob", Instant.now().plus(10, ChronoUnit.MINUTES));
new ScheduleJobRequest("testJob", Instant.now().plus(10, ChronoUnit.MINUTES));
assertDoesNotThrow(() -> previewClient.scheduleJob(expectedScheduleJobRequest).block());
ArgumentCaptor<DaprProtos.ScheduleJobRequest> captor =
ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class);
ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class);
verify(daprStub, times(1)).scheduleJobAlpha1(captor.capture(), Mockito.any());
DaprProtos.ScheduleJobRequest actualScheduleJobRequest = captor.getValue();
@ -620,13 +759,13 @@ public class DaprPreviewClientGrpcTest {
assertEquals(0, job.getRepeats());
assertFalse(job.hasTtl());
assertEquals(iso8601Formatter.format(expectedScheduleJobRequest.getDueTime()),
actualScheduleJobRequest.getJob().getDueTime());
actualScheduleJobRequest.getJob().getDueTime());
}
@Test
public void scheduleJobShouldSucceedWhenRequiredFieldsNameAndScheduleArePresentInRequest() {
DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
.withZone(ZoneOffset.UTC);
.withZone(ZoneOffset.UTC);
doAnswer(invocation -> {
StreamObserver<DaprProtos.ScheduleJobResponse> observer = invocation.getArgument(1);
@ -635,11 +774,11 @@ public class DaprPreviewClientGrpcTest {
}).when(daprStub).scheduleJobAlpha1(any(DaprProtos.ScheduleJobRequest.class), any());
ScheduleJobRequest expectedScheduleJobRequest = new ScheduleJobRequest("testJob",
JobSchedule.fromString("* * * * * *"));
JobSchedule.fromString("* * * * * *"));
assertDoesNotThrow(() -> previewClient.scheduleJob(expectedScheduleJobRequest).block());
ArgumentCaptor<DaprProtos.ScheduleJobRequest> captor =
ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class);
ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class);
verify(daprStub, times(1)).scheduleJobAlpha1(captor.capture(), Mockito.any());
DaprProtos.ScheduleJobRequest actualScheduleJobRequest = captor.getValue();
@ -681,24 +820,24 @@ public class DaprPreviewClientGrpcTest {
@Test
public void getJobShouldReturnResponseWhenAllFieldsArePresentInRequest() {
DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
.withZone(ZoneOffset.UTC);
.withZone(ZoneOffset.UTC);
GetJobRequest getJobRequest = new GetJobRequest("testJob");
DaprProtos.Job job = DaprProtos.Job.newBuilder()
.setName("testJob")
.setTtl(OffsetDateTime.now().format(iso8601Formatter))
.setData(Any.newBuilder().setValue(ByteString.copyFrom("testData".getBytes())).build())
.setSchedule("*/5 * * * *")
.setRepeats(5)
.setDueTime(iso8601Formatter.format(Instant.now().plus(10, ChronoUnit.MINUTES)))
.build();
.setName("testJob")
.setTtl(OffsetDateTime.now().format(iso8601Formatter))
.setData(Any.newBuilder().setValue(ByteString.copyFrom("testData".getBytes())).build())
.setSchedule("*/5 * * * *")
.setRepeats(5)
.setDueTime(iso8601Formatter.format(Instant.now().plus(10, ChronoUnit.MINUTES)))
.build();
doAnswer(invocation -> {
StreamObserver<DaprProtos.GetJobResponse> observer = invocation.getArgument(1);
observer.onNext(DaprProtos.GetJobResponse.newBuilder()
.setJob(job)
.build());
.setJob(job)
.build());
observer.onCompleted();
return null;
}).when(daprStub).getJobAlpha1(any(DaprProtos.GetJobRequest.class), any());
@ -720,15 +859,15 @@ public class DaprPreviewClientGrpcTest {
GetJobRequest getJobRequest = new GetJobRequest("testJob");
DaprProtos.Job job = DaprProtos.Job.newBuilder()
.setName("testJob")
.setSchedule("0 0 0 1 1 *")
.build();
.setName("testJob")
.setSchedule("0 0 0 1 1 *")
.build();
doAnswer(invocation -> {
StreamObserver<DaprProtos.GetJobResponse> observer = invocation.getArgument(1);
observer.onNext(DaprProtos.GetJobResponse.newBuilder()
.setJob(job)
.build());
.setJob(job)
.build());
observer.onCompleted();
return null;
}).when(daprStub).getJobAlpha1(any(DaprProtos.GetJobRequest.class), any());
@ -751,15 +890,15 @@ public class DaprPreviewClientGrpcTest {
String datetime = OffsetDateTime.now().toString();
DaprProtos.Job job = DaprProtos.Job.newBuilder()
.setName("testJob")
.setDueTime(datetime)
.build();
.setName("testJob")
.setDueTime(datetime)
.build();
doAnswer(invocation -> {
StreamObserver<DaprProtos.GetJobResponse> observer = invocation.getArgument(1);
observer.onNext(DaprProtos.GetJobResponse.newBuilder()
.setJob(job)
.build());
.setJob(job)
.build());
observer.onCompleted();
return null;
}).when(daprStub).getJobAlpha1(any(DaprProtos.GetJobRequest.class), any());
@ -846,15 +985,15 @@ public class DaprPreviewClientGrpcTest {
}
private DaprProtos.QueryStateResponse buildQueryStateResponse(List<QueryStateItem<?>> resp,String token)
throws JsonProcessingException {
throws JsonProcessingException {
List<DaprProtos.QueryStateItem> items = new ArrayList<>();
for (QueryStateItem<?> item: resp) {
items.add(buildQueryStateItem(item));
}
return DaprProtos.QueryStateResponse.newBuilder()
.addAllResults(items)
.setToken(token)
.build();
.addAllResults(items)
.setToken(token)
.build();
}
private DaprProtos.QueryStateItem buildQueryStateItem(QueryStateItem<?> item) throws JsonProcessingException {