diff --git a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java index e6c1f7d9c..1b9aa7f6e 100644 --- a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java +++ b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java @@ -1715,6 +1715,10 @@ public class DaprClientImpl extends AbstractDaprClient { throw new IllegalArgumentException("LLM name cannot be null or empty."); } + if (conversationRequestAlpha2.getInputs() == null || conversationRequestAlpha2.getInputs().isEmpty()) { + throw new IllegalArgumentException("Conversation Inputs cannot be null or empty."); + } + DaprProtos.ConversationRequestAlpha2.Builder builder = DaprProtos.ConversationRequestAlpha2 .newBuilder() .setTemperature(conversationRequestAlpha2.getTemperature()) @@ -1776,21 +1780,19 @@ public class DaprClientImpl extends AbstractDaprClient { builder.putAllParameters(parameters); } - if (request.getInputs() != null) { - for (ConversationInputAlpha2 input : request.getInputs()) { - DaprProtos.ConversationInputAlpha2.Builder inputBuilder = DaprProtos.ConversationInputAlpha2 - .newBuilder() - .setScrubPii(input.isScrubPii()); + for (ConversationInputAlpha2 input : request.getInputs()) { + DaprProtos.ConversationInputAlpha2.Builder inputBuilder = DaprProtos.ConversationInputAlpha2 + .newBuilder() + .setScrubPii(input.isScrubPii()); - if (input.getMessages() != null) { - for (ConversationMessage message : input.getMessages()) { - DaprProtos.ConversationMessage protoMessage = buildConversationMessage(message); - inputBuilder.addMessages(protoMessage); - } + if (input.getMessages() != null) { + for (ConversationMessage message : input.getMessages()) { + DaprProtos.ConversationMessage protoMessage = buildConversationMessage(message); + inputBuilder.addMessages(protoMessage); } - - builder.addInputs(inputBuilder.build()); } + + builder.addInputs(inputBuilder.build()); } return builder.build(); @@ -1945,9 +1947,7 @@ public class DaprClientImpl extends AbstractDaprClient { toolCalls.add(conversationToolCalls); } - return new ConversationResultMessage( - protoChoice.getMessage().getContent(), - toolCalls + return new ConversationResultMessage(protoChoice.getMessage().getContent(), toolCalls ); } diff --git a/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java b/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java index cfe12ab9e..1007066aa 100644 --- a/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java +++ b/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java @@ -31,8 +31,8 @@ public class AssistantMessage implements ConversationMessage { * @param toolCalls the tool calls requested by the assistant. */ public AssistantMessage(List content, List toolCalls) { - this.content = content != null ? List.copyOf(content) : null; - this.toolCalls = toolCalls != null ? List.copyOf(toolCalls) : null; + this.content = List.copyOf(content); + this.toolCalls = List.copyOf(toolCalls); } @Override diff --git a/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java b/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java index 9224c818c..5dbaa58a5 100644 --- a/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java +++ b/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java @@ -30,7 +30,7 @@ public class DeveloperMessage implements ConversationMessage { * @param content the content of the developer message */ public DeveloperMessage(List content) { - this.content = content != null ? List.copyOf(content) : null; + this.content = List.copyOf(content); } @Override diff --git a/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java b/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java index 20651a519..aacdb80d5 100644 --- a/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java +++ b/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java @@ -30,7 +30,7 @@ public class SystemMessage implements ConversationMessage { * @param content the content of the system message */ public SystemMessage(List content) { - this.content = content != null ? List.copyOf(content) : null; + this.content = List.copyOf(content); } @Override diff --git a/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java b/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java index efc68c464..e88e37af4 100644 --- a/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java +++ b/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java @@ -31,7 +31,7 @@ public class ToolMessage implements ConversationMessage { * @param content the content containing the tool execution result */ public ToolMessage(List content) { - this.content = content != null ? List.copyOf(content) : null; + this.content = List.copyOf(content); } @Override diff --git a/sdk/src/main/java/io/dapr/client/domain/UserMessage.java b/sdk/src/main/java/io/dapr/client/domain/UserMessage.java index 85b0ce44c..30ae023a8 100644 --- a/sdk/src/main/java/io/dapr/client/domain/UserMessage.java +++ b/sdk/src/main/java/io/dapr/client/domain/UserMessage.java @@ -30,7 +30,7 @@ public class UserMessage implements ConversationMessage { * @param content the content of the user message */ public UserMessage(List content) { - this.content = content != null ? List.copyOf(content) : null; + this.content = List.copyOf(content); } @Override diff --git a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java index 0b0bed5ae..7e4fd5e3d 100644 --- a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java +++ b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java @@ -16,6 +16,7 @@ package io.dapr.client; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import io.dapr.client.domain.AssistantMessage; @@ -1347,7 +1348,14 @@ public class DaprPreviewClientGrpcTest { @Test public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsNull() { - ConversationRequestAlpha2 request = new ConversationRequestAlpha2(null, null); + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2(null, List.of(input)); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> previewClient.converseAlpha2(request).block()); @@ -1356,7 +1364,14 @@ public class DaprPreviewClientGrpcTest { @Test public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsEmpty() { - ConversationRequestAlpha2 request = new ConversationRequestAlpha2("", null); + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("", List.of(input)); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> previewClient.converseAlpha2(request).block()); @@ -1365,6 +1380,13 @@ public class DaprPreviewClientGrpcTest { @Test public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsWhitespace() { + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + ConversationRequestAlpha2 request = new ConversationRequestAlpha2(" ", null); IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> @@ -1372,6 +1394,24 @@ public class DaprPreviewClientGrpcTest { assertEquals("LLM name cannot be null or empty.", exception.getMessage()); } + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenInputIsNull() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("abc", null); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenInputIsEmpty() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("abc", new ArrayList<>()); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + @Test public void converseAlpha2ExceptionThrownTest() { doAnswer((Answer) invocation -> { @@ -1380,11 +1420,7 @@ public class DaprPreviewClientGrpcTest { ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); - assertThrowsDaprException( - StatusRuntimeException.class, - "INVALID_ARGUMENT", - "INVALID_ARGUMENT: bad argument", - () -> previewClient.converseAlpha2(request).block()); + assertThrows(IllegalArgumentException.class, () -> previewClient.converseAlpha2(request).block()); } @Test @@ -1396,7 +1432,14 @@ public class DaprPreviewClientGrpcTest { return null; }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); - ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + List messages = new ArrayList<>(); + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("System info"))); + systemMsg.setName("system"); + messages.add(systemMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); Mono result = previewClient.converseAlpha2(request); assertThrowsDaprException( @@ -1429,7 +1472,14 @@ public class DaprPreviewClientGrpcTest { return null; }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); - ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + List messages = new ArrayList<>(); + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); assertNotNull(response); @@ -1618,42 +1668,16 @@ public class DaprPreviewClientGrpcTest { assertTrue(capturedMessages.get(4).hasOfDeveloper()); } - @Test - public void converseAlpha2EmptyInputsTest() { - ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", new ArrayList<>()); - - DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() - .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() - .addChoices(DaprProtos.ConversationResultChoices.newBuilder() - .setFinishReason("stop") - .setIndex(0) - .build()) - .build()) - .build(); - - doAnswer((Answer) invocation -> { - StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; - observer.onNext(grpcResponse); - observer.onCompleted(); - return null; - }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); - - ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); - - assertNotNull(response); - - ArgumentCaptor captor = - ArgumentCaptor.forClass(DaprProtos.ConversationRequestAlpha2.class); - verify(daprStub).converseAlpha2(captor.capture(), any()); - - DaprProtos.ConversationRequestAlpha2 capturedRequest = captor.getValue(); - assertEquals(0, capturedRequest.getInputsCount()); - } - @Test public void converseAlpha2ResponseWithoutMessageTest() { - ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + List messages = new ArrayList<>(); + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() @@ -1684,7 +1708,14 @@ public class DaprPreviewClientGrpcTest { @Test public void converseAlpha2MultipleResultsTest() { - ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + List messages = new ArrayList<>(); + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() @@ -1741,8 +1772,14 @@ public class DaprPreviewClientGrpcTest { @Test public void converseAlpha2ToolCallWithoutFunctionTest() { - ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + List messages = new ArrayList<>(); + UserMessage userMsg = new UserMessage(List.of(new ConversationMessageContent("Debug info"))); + userMsg.setName("developer"); + messages.add(userMsg); + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() .addChoices(DaprProtos.ConversationResultChoices.newBuilder()