diff --git a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java index f9fefd372..207f9bb27 100644 --- a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java +++ b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java @@ -13,6 +13,7 @@ limitations under the License. package io.dapr.client; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Strings; import com.google.protobuf.Any; import com.google.protobuf.ByteString; @@ -21,6 +22,7 @@ import com.google.protobuf.Message; import io.dapr.client.domain.ActorMetadata; import io.dapr.client.domain.AppConnectionPropertiesHealthMetadata; import io.dapr.client.domain.AppConnectionPropertiesMetadata; +import io.dapr.client.domain.AssistantMessage; import io.dapr.client.domain.BulkPublishEntry; import io.dapr.client.domain.BulkPublishRequest; import io.dapr.client.domain.BulkPublishResponse; @@ -28,7 +30,7 @@ import io.dapr.client.domain.BulkPublishResponseFailedEntry; import io.dapr.client.domain.CloudEvent; import io.dapr.client.domain.ComponentMetadata; import io.dapr.client.domain.ConfigurationItem; -import io.dapr.client.domain.ConversationFunction; +import io.dapr.client.domain.ConversationToolsFunction; import io.dapr.client.domain.ConversationInput; import io.dapr.client.domain.ConversationInputAlpha2; import io.dapr.client.domain.ConversationMessage; @@ -42,7 +44,7 @@ import io.dapr.client.domain.ConversationResultAlpha2; import io.dapr.client.domain.ConversationResultChoices; import io.dapr.client.domain.ConversationResultMessage; import io.dapr.client.domain.ConversationToolCalls; -import io.dapr.client.domain.ConversationToolCallsFunction; +import io.dapr.client.domain.ConversationToolCallsOfFunction; import io.dapr.client.domain.ConversationTools; import io.dapr.client.domain.DaprMetadata; import io.dapr.client.domain.DeleteJobRequest; @@ -73,6 +75,7 @@ import io.dapr.client.domain.StateOptions; import io.dapr.client.domain.SubscribeConfigurationRequest; import io.dapr.client.domain.SubscribeConfigurationResponse; import io.dapr.client.domain.SubscriptionMetadata; +import io.dapr.client.domain.ToolMessage; import io.dapr.client.domain.TransactionalStateOperation; import io.dapr.client.domain.UnlockRequest; import io.dapr.client.domain.UnlockResponseStatus; @@ -1572,6 +1575,7 @@ public class DaprClientImpl extends AbstractDaprClient { /** * {@inheritDoc} */ + @Deprecated(forRemoval = true) @Override public Mono converse(ConversationRequest conversationRequest) { @@ -1685,7 +1689,7 @@ public class DaprClientImpl extends AbstractDaprClient { } private DaprProtos.ConversationRequestAlpha2 buildConversationRequestProto(ConversationRequestAlpha2 request, - DaprProtos.ConversationRequestAlpha2.Builder builder) { + DaprProtos.ConversationRequestAlpha2.Builder builder) { if (request.getTools() != null) { buildConversationTools(request.getTools(), builder); } @@ -1694,13 +1698,21 @@ public class DaprClientImpl extends AbstractDaprClient { builder.putAllMetadata(request.getMetadata()); } + if (request.getParameters() != null) { Map parameters = request.getParameters() .entrySet().stream() .collect(Collectors.toMap( Map.Entry::getKey, - e -> Any.pack((Message) e.getValue()) - )); + e -> { + try { + return Any.newBuilder().setValue(ByteString.copyFrom(objectSerializer.serialize(e.getValue()))) + .build(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }) + ); builder.putAllParameters(parameters); } @@ -1727,7 +1739,7 @@ public class DaprClientImpl extends AbstractDaprClient { private void buildConversationTools(List tools, DaprProtos.ConversationRequestAlpha2.Builder builder) { for (ConversationTools tool : tools) { - ConversationFunction function = tool.getFunction(); + ConversationToolsFunction function = tool.getFunction(); DaprProtos.ConversationToolsFunction.Builder protoFunction = DaprProtos.ConversationToolsFunction.newBuilder() .setName(function.getName()); @@ -1741,7 +1753,14 @@ public class DaprClientImpl extends AbstractDaprClient { .entrySet().stream() .collect(Collectors.toMap( Map.Entry::getKey, - e -> Any.pack((Message) e.getValue()) + e -> { + try { + return Any.newBuilder().setValue(ByteString.copyFrom(objectSerializer.serialize(e.getValue()))) + .build(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } )); protoFunction.putAllParameters(functionParams); @@ -1766,8 +1785,8 @@ public class DaprClientImpl extends AbstractDaprClient { if (message.getContent() != null) { toolMessage.addAllContent(getConversationMessageContent(message)); } - if (message.getToolId() != null) { - toolMessage.setToolId(message.getToolId()); + if (((ToolMessage)message).getToolId() != null) { + toolMessage.setToolId(((ToolMessage)message).getToolId()); } messageBuilder.setOfTool(toolMessage); break; @@ -1785,14 +1804,15 @@ public class DaprClientImpl extends AbstractDaprClient { case ASSISTANT: DaprProtos.ConversationMessageOfAssistant.Builder assistantMessage = DaprProtos.ConversationMessageOfAssistant.newBuilder(); + if (message.getName() != null) { assistantMessage.setName(message.getName()); } if (message.getContent() != null) { assistantMessage.addAllContent(getConversationMessageContent(message)); } - if (message.getToolCalls() != null) { - assistantMessage.addAllToolCalls(getConversationToolCalls(message)); + if (((AssistantMessage)message).getToolCalls() != null) { + assistantMessage.addAllToolCalls(getConversationToolCalls((AssistantMessage)message)); } messageBuilder.setOfAssistant(assistantMessage); break; @@ -1851,15 +1871,18 @@ public class DaprClientImpl extends AbstractDaprClient { List toolCalls = new ArrayList<>(); for (DaprProtos.ConversationToolCalls protoToolCall : protoChoice.getMessage().getToolCallsList()) { - ConversationToolCallsFunction function = null; + ConversationToolCallsOfFunction function = null; if (protoToolCall.hasFunction()) { - function = new ConversationToolCallsFunction( + function = new ConversationToolCallsOfFunction( protoToolCall.getFunction().getName(), protoToolCall.getFunction().getArguments() ); } - - toolCalls.add(new ConversationToolCalls(protoToolCall.getId(), function)); + + ConversationToolCalls conversationToolCalls = new ConversationToolCalls(function); + conversationToolCalls.setId(protoToolCall.getId()); + + toolCalls.add(conversationToolCalls); } return new ConversationResultMessage( @@ -1882,16 +1905,19 @@ public class DaprClientImpl extends AbstractDaprClient { } private List getConversationToolCalls( - ConversationMessage conversationMessage) { + AssistantMessage assistantMessage) { List conversationToolCalls = new ArrayList<>(); - for (ConversationToolCalls conversationToolCall: conversationMessage.getToolCalls()) { - conversationToolCalls.add(DaprProtos.ConversationToolCalls.newBuilder() - .setId(conversationToolCall.getId()) + for (ConversationToolCalls conversationToolCall: assistantMessage.getToolCalls()) { + DaprProtos.ConversationToolCalls.Builder toolCallsBuilder = DaprProtos.ConversationToolCalls.newBuilder() .setFunction(DaprProtos.ConversationToolCallsOfFunction.newBuilder() - .setName(conversationToolCall.getFunction().getName()) - .setArguments(conversationToolCall.getFunction().getArguments()) - .build()) - .build()); + .setName(conversationToolCall.getFunction().getName()) + .setArguments(conversationToolCall.getFunction().getArguments()) + .build()); + if (conversationToolCall.getId() != null) { + toolCallsBuilder.setId(conversationToolCall.getId()); + } + + conversationToolCalls.add(toolCallsBuilder.build()); } return conversationToolCalls; diff --git a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java index 3d43b9e6d..92c6a61c3 100644 --- a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java +++ b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java @@ -315,6 +315,7 @@ public interface DaprPreviewClient extends AutoCloseable { * @param conversationRequest request to be passed to the LLM. * @return {@link ConversationResponse}. */ + @Deprecated public Mono converse(ConversationRequest conversationRequest); /* diff --git a/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java b/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java new file mode 100644 index 000000000..5e4f72f74 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/AssistantMessage.java @@ -0,0 +1,65 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Assistant message containing responses from the AI model. + * Can include regular content and/or tool calls that the model wants to make. + */ +public class AssistantMessage implements ConversationMessage { + + private String name; + private final List content; + private final List toolCalls; + + /** + * Creates an assistant message with content and optional tool calls. + * @param content the content of the assistant message. + * @param toolCalls the tool calls requested by the assistant. + */ + public AssistantMessage(List content, List toolCalls) { + this.content = content != null ? List.copyOf(content) : null; + this.toolCalls = toolCalls != null ? List.copyOf(toolCalls) : null; + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.ASSISTANT; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the assistant participant. + * + * @param name the name to set + */ + public void setName(String name) { + this.name = name; + } + + @Override + public List getContent() { + return content; + } + + public List getToolCalls() { + return toolCalls; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationMessage.java b/sdk/src/main/java/io/dapr/client/domain/ConversationMessage.java index 039525c7a..c26c0d041 100644 --- a/sdk/src/main/java/io/dapr/client/domain/ConversationMessage.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationMessage.java @@ -16,118 +16,29 @@ package io.dapr.client.domain; import java.util.List; /** - * Represents a conversation message with role-specific content. + * Interface representing a conversation message with role-specific content. * Supports different message types: system, user, assistant, developer, and tool. */ -public class ConversationMessage { - - /** - * Enum representing the different roles a message can have. - */ - public enum Role { - SYSTEM, - USER, - ASSISTANT, - DEVELOPER, - TOOL - } - - private final Role role; - private final String name; - private final List content; - private final List toolCalls; - private final String toolId; - - /** - * Constructor for creating a message with basic content. - * - * @param role the role of the message sender - * @param content the content of the message - */ - public ConversationMessage(Role role, List content) { - this(role, null, content, null, null); - } - - /** - * Constructor for creating a message with name and content. - * - * @param role the role of the message sender - * @param name the name of the participant (optional) - * @param content the content of the message - */ - public ConversationMessage(Role role, String name, List content) { - this(role, name, content, null, null); - } - - /** - * Full constructor for creating a message with all properties. - * - * @param role the role of the message sender - * @param name the name of the participant (optional) - * @param content the content of the message - * @param toolCalls tool calls for assistant messages (optional) - * @param toolId tool ID for tool messages (optional) - */ - public ConversationMessage(Role role, String name, List content, - List toolCalls, String toolId) { - this.role = role; - this.name = name; - this.content = content != null ? List.copyOf(content) : null; - this.toolCalls = toolCalls != null ? List.copyOf(toolCalls) : null; - this.toolId = toolId; - } +public interface ConversationMessage { /** * Gets the role of the message sender. * * @return the message role */ - public Role getRole() { - return role; - } + ConversationMessageRole getRole(); /** * Gets the name of the participant in the message. * * @return the participant name, or null if not specified */ - public String getName() { - return name; - } + String getName(); /** * Gets the content of the message. * * @return the message content */ - public List getContent() { - return content; - } - - /** - * Gets the tool calls generated by the model (for assistant messages). - * - * @return the tool calls, or null if none - */ - public List getToolCalls() { - return toolCalls; - } - - /** - * Gets the tool ID (for tool messages). - * - * @return the tool ID, or null if not a tool message - */ - public String getToolId() { - return toolId; - } - - /** - * Checks if this message has tool calls. - * - * @return true if the message has tool calls, false otherwise - */ - public boolean hasToolCalls() { - return toolCalls != null && !toolCalls.isEmpty(); - } + List getContent(); } diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationMessageRole.java b/sdk/src/main/java/io/dapr/client/domain/ConversationMessageRole.java new file mode 100644 index 000000000..0bfd1b076 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationMessageRole.java @@ -0,0 +1,44 @@ +/* + * Copyright 2022 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Enum representing the different roles a conversation message can have. + */ +public enum ConversationMessageRole { + /** + * System message that sets the behavior or context for the conversation. + */ + SYSTEM, + + /** + * User message containing input from the human user. + */ + USER, + + /** + * Assistant message containing responses from the AI model. + */ + ASSISTANT, + + /** + * Tool message containing results from function/tool calls. + */ + TOOL, + + /** + * Developer message for development and debugging purposes. + */ + DEVELOPER +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationToolCalls.java b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCalls.java index b1f9163ab..a849611a2 100644 --- a/sdk/src/main/java/io/dapr/client/domain/ConversationToolCalls.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCalls.java @@ -18,27 +18,17 @@ package io.dapr.client.domain; */ public class ConversationToolCalls { - private final String id; - private final ConversationToolCallsFunction function; + private String id; + private final ConversationToolCallsOfFunction function; - /** - * Constructor. - * - * @param id the unique identifier for the tool call (optional) - * @param function the function to call - */ - public ConversationToolCalls(String id, ConversationToolCallsFunction function) { - this.id = id; - this.function = function; - } /** * Constructor without ID. * * @param function the function to call */ - public ConversationToolCalls(ConversationToolCallsFunction function) { - this(null, function); + public ConversationToolCalls(ConversationToolCallsOfFunction function) { + this.function = function; } /** @@ -50,12 +40,22 @@ public class ConversationToolCalls { return id; } + /** + * Set with ID. + * + * @param id the unique identifier for the tool call + */ + public ConversationToolCalls setId(String id) { + this.id = id; + return this; + } + /** * Gets the function to call. * * @return the function details */ - public ConversationToolCallsFunction getFunction() { + public ConversationToolCallsOfFunction getFunction() { return function; } } diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsFunction.java b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsOfFunction.java similarity index 92% rename from sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsFunction.java rename to sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsOfFunction.java index e2d33c944..0edacd7d9 100644 --- a/sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsFunction.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationToolCallsOfFunction.java @@ -16,7 +16,7 @@ package io.dapr.client.domain; /** * Represents a function call within a tool call. */ -public class ConversationToolCallsFunction { +public class ConversationToolCallsOfFunction { private final String name; private final String arguments; @@ -27,7 +27,7 @@ public class ConversationToolCallsFunction { * @param name the name of the function to call * @param arguments the arguments to call the function with, as generated by the model in JSON format */ - public ConversationToolCallsFunction(String name, String arguments) { + public ConversationToolCallsOfFunction(String name, String arguments) { this.name = name; this.arguments = arguments; } diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationTools.java b/sdk/src/main/java/io/dapr/client/domain/ConversationTools.java index 8f7f0cc64..f36fa0545 100644 --- a/sdk/src/main/java/io/dapr/client/domain/ConversationTools.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationTools.java @@ -18,35 +18,23 @@ package io.dapr.client.domain; */ public class ConversationTools { - private final String type; - private final ConversationFunction function; + private final ConversationToolsFunction function; /** * Constructor. * - * @param type the type of tool (e.g., "function") * @param function the function definition */ - public ConversationTools(String type, ConversationFunction function) { - this.type = type; + public ConversationTools(ConversationToolsFunction function) { this.function = function; } - /** - * Gets the tool type. - * - * @return the tool type - */ - public String getType() { - return type; - } - /** * Gets the function definition. * * @return the function definition */ - public ConversationFunction getFunction() { + public ConversationToolsFunction getFunction() { return function; } } diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationFunction.java b/sdk/src/main/java/io/dapr/client/domain/ConversationToolsFunction.java similarity index 81% rename from sdk/src/main/java/io/dapr/client/domain/ConversationFunction.java rename to sdk/src/main/java/io/dapr/client/domain/ConversationToolsFunction.java index fb741d7ff..11e3afdf3 100644 --- a/sdk/src/main/java/io/dapr/client/domain/ConversationFunction.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationToolsFunction.java @@ -18,22 +18,20 @@ import java.util.Map; /** * Represents a function definition for conversation tools. */ -public class ConversationFunction { +public class ConversationToolsFunction { + private String description; private final String name; - private final String description; private final Map parameters; /** * Constructor. * * @param name the function name - * @param description the function description * @param parameters the function parameters schema */ - public ConversationFunction(String name, String description, Map parameters) { + public ConversationToolsFunction(String name, Map parameters) { this.name = name; - this.description = description; this.parameters = parameters; } @@ -55,6 +53,17 @@ public class ConversationFunction { return description; } + /** + * Sets the function description. + * + * @param description the function description + * @return this instance for method chaining + */ + public ConversationToolsFunction setDescription(String description) { + this.description = description; + return this; + } + /** * Gets the function parameters schema. * diff --git a/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java b/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java new file mode 100644 index 000000000..009f3dafe --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/DeveloperMessage.java @@ -0,0 +1,59 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Developer message for development and debugging purposes. + * Used for providing additional context or instructions during development. + */ +public class DeveloperMessage implements ConversationMessage { + + private String name; + private final List content; + + /** + * Creates a developer message with content. + * + * @param content the content of the developer message + */ + public DeveloperMessage(List content) { + this.content = content != null ? List.copyOf(content) : null; + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.DEVELOPER; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the developer participant. + * + * @param name the name to set + */ + public void setName(String name) { + this.name = name; + } + + @Override + public List getContent() { + return content; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java b/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java new file mode 100644 index 000000000..20651a519 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/SystemMessage.java @@ -0,0 +1,59 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * System message that sets the behavior or context for the conversation. + * Used to provide instructions or context to the AI model. + */ +public class SystemMessage implements ConversationMessage { + + private String name; + private final List content; + + /** + * Creates a system message with content. + * + * @param content the content of the system message + */ + public SystemMessage(List content) { + this.content = content != null ? List.copyOf(content) : null; + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.SYSTEM; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the system participant. + * + * @param name the name to set + */ + public void setName(String name) { + this.name = name; + } + + @Override + public List getContent() { + return content; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java b/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java new file mode 100644 index 000000000..8d764f5fb --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ToolMessage.java @@ -0,0 +1,73 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Tool message containing results from function/tool calls. + * Used to provide the response from a tool execution back to the AI model. + */ +public class ToolMessage implements ConversationMessage { + + private String toolId; + private String name; + private final List content; + + /** + * Creates a tool message with content. + * + * @param content the content containing the tool execution result + */ + public ToolMessage(List content) { + this.content = content != null ? List.copyOf(content) : null; + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.TOOL; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the tool identifier. + * + * @param toolId the tool identifier to set + */ + public void setToolId(String toolId) { + this.toolId = toolId; + } + + /** + * Sets the name of the tool participant. + * + * @param name the name to set + */ + public void setName(String name) { + this.name = name; + } + + @Override + public List getContent() { + return content; + } + + public String getToolId() { + return toolId; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/UserMessage.java b/sdk/src/main/java/io/dapr/client/domain/UserMessage.java new file mode 100644 index 000000000..0ae225c2e --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/UserMessage.java @@ -0,0 +1,59 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * User message containing input from the human user. + * Represents questions, requests, or other input from the end user. + */ +public class UserMessage implements ConversationMessage { + + private String name; + private final List content; + + /** + * Creates a user message with content. + * + * @param content the content of the user message + */ + public UserMessage(List content) { + this.content = content != null ? List.copyOf(content) : null; + } + + @Override + public ConversationMessageRole getRole() { + return ConversationMessageRole.USER; + } + + @Override + public String getName() { + return name; + } + + /** + * Sets the name of the user participant. + * + * @param name the name to set + */ + public void setName(String name) { + this.name = name; + } + + @Override + public List getContent() { + return content; + } +} diff --git a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java index 5cc49edfd..95bf588bd 100644 --- a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java +++ b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java @@ -18,11 +18,24 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.protobuf.Any; import com.google.protobuf.ByteString; +import io.dapr.client.domain.AssistantMessage; import io.dapr.client.domain.BulkPublishEntry; import io.dapr.client.domain.BulkPublishRequest; import io.dapr.client.domain.BulkPublishResponse; import io.dapr.client.domain.CloudEvent; +import io.dapr.client.domain.ConversationToolCallsOfFunction; +import io.dapr.client.domain.ConversationToolsFunction; +import io.dapr.client.domain.ConversationInputAlpha2; +import io.dapr.client.domain.ConversationMessage; +import io.dapr.client.domain.ConversationMessageContent; +import io.dapr.client.domain.ConversationRequestAlpha2; +import io.dapr.client.domain.ConversationResponseAlpha2; +import io.dapr.client.domain.ConversationResultAlpha2; +import io.dapr.client.domain.ConversationResultChoices; +import io.dapr.client.domain.ConversationToolCalls; +import io.dapr.client.domain.ConversationTools; import io.dapr.client.domain.DeleteJobRequest; +import io.dapr.client.domain.DeveloperMessage; import io.dapr.client.domain.GetJobRequest; import io.dapr.client.domain.GetJobResponse; import io.dapr.client.domain.JobSchedule; @@ -33,7 +46,10 @@ import io.dapr.client.domain.QueryStateItem; import io.dapr.client.domain.QueryStateRequest; import io.dapr.client.domain.QueryStateResponse; import io.dapr.client.domain.ScheduleJobRequest; +import io.dapr.client.domain.SystemMessage; +import io.dapr.client.domain.ToolMessage; import io.dapr.client.domain.UnlockResponseStatus; +import io.dapr.client.domain.UserMessage; import io.dapr.client.domain.query.Query; import io.dapr.serializer.DaprObjectSerializer; import io.dapr.serializer.DefaultObjectSerializer; @@ -984,6 +1000,437 @@ public class DaprPreviewClientGrpcTest { assertEquals("Name in the request cannot be null or empty", exception.getMessage()); } + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsNull() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2(null, null); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsEmpty() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("", null); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ShouldThrowIllegalArgumentExceptionWhenNameIsWhitespace() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2(" ", null); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> + previewClient.converseAlpha2(request).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseAlpha2ExceptionThrownTest() { + doAnswer((Answer) invocation -> { + throw newStatusRuntimeException("INVALID_ARGUMENT", "bad argument"); + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + + assertThrowsDaprException( + StatusRuntimeException.class, + "INVALID_ARGUMENT", + "INVALID_ARGUMENT: bad argument", + () -> previewClient.converseAlpha2(request).block()); + } + + @Test + public void converseAlpha2CallbackExceptionThrownTest() { + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onError(newStatusRuntimeException("INVALID_ARGUMENT", "bad argument")); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + Mono result = previewClient.converseAlpha2(request); + + assertThrowsDaprException( + ExecutionException.class, + "INVALID_ARGUMENT", + "INVALID_ARGUMENT: bad argument", + () -> result.block()); + } + + @Test + public void converseAlpha2MinimalRequestTest() { + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .setContextId("test-context") + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Hello! How can I help you today?") + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals("test-context", response.getContextId()); + assertEquals(1, response.getOutputs().size()); + + ConversationResultAlpha2 result = response.getOutputs().get(0); + assertEquals(1, result.getChoices().size()); + + ConversationResultChoices choice = result.getChoices().get(0); + assertEquals("stop", choice.getFinishReason()); + assertEquals(0, choice.getIndex()); + assertEquals("Hello! How can I help you today?", choice.getMessage().getContent()); + } + + @Test + public void converseAlpha2ComplexRequestTest() { + // Create messages + List messages = new ArrayList<>(); + UserMessage userMessage = new UserMessage(List.of(new ConversationMessageContent("Hello, how are you?"))); + userMessage.setName("John"); + messages.add(userMessage); + + // Create input + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + input.setScrubPii(true); + + // Create tools + Map functionParams = new HashMap<>(); + functionParams.put("location", "Required location parameter"); + List tools = new ArrayList<>(); + ConversationToolsFunction function = new ConversationToolsFunction("get_weather", functionParams); + function.setDescription("Get current weather"); + + ConversationTools tool = new ConversationTools(function); + tools.add(tool); + + Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + + Map parameters = new HashMap<>(); + parameters.put("max_tokens", "1000"); + + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + request.setContextId("test-context"); + request.setTemperature(0.7); + request.setScrubPii(true); + request.setTools(tools); + request.setToolChoice("auto"); + request.setMetadata(metadata); + request.setParameters(parameters); + + // Mock response with tool calls + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .setContextId("test-context") + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("tool_calls") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("I'll help you get the weather information.") + .addToolCalls(DaprProtos.ConversationToolCalls.newBuilder() + .setId("call_123") + .setFunction(DaprProtos.ConversationToolCallsOfFunction.newBuilder() + .setName("get_weather") + .setArguments("{\"location\": \"New York\"}") + .build()) + .build()) + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals("test-context", response.getContextId()); + + ConversationResultChoices choice = response.getOutputs().get(0).getChoices().get(0); + assertEquals("tool_calls", choice.getFinishReason()); + assertEquals("I'll help you get the weather information.", choice.getMessage().getContent()); + assertEquals(1, choice.getMessage().getToolCalls().size()); + + ConversationToolCalls toolCall = choice.getMessage().getToolCalls().get(0); + assertEquals("call_123", toolCall.getId()); + assertEquals("get_weather", toolCall.getFunction().getName()); + assertEquals("{\"location\": \"New York\"}", toolCall.getFunction().getArguments()); + + // Verify the request was built correctly + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequestAlpha2.class); + verify(daprStub).converseAlpha2(captor.capture(), any()); + + DaprProtos.ConversationRequestAlpha2 capturedRequest = captor.getValue(); + assertEquals("openai", capturedRequest.getName()); + assertEquals("test-context", capturedRequest.getContextId()); + assertEquals(0.7, capturedRequest.getTemperature(), 0.001); + assertTrue(capturedRequest.getScrubPii()); + assertEquals("auto", capturedRequest.getToolChoice()); + assertEquals("value1", capturedRequest.getMetadataMap().get("key1")); + assertEquals(1, capturedRequest.getToolsCount()); + assertEquals("get_weather", capturedRequest.getTools(0).getFunction().getName()); + } + + @Test + public void converseAlpha2AllMessageTypesTest() { + List messages = new ArrayList<>(); + + // System message + SystemMessage systemMsg = new SystemMessage(List.of(new ConversationMessageContent("You are a helpful assistant."))); + systemMsg.setName("system"); + messages.add(systemMsg); + + // User message + UserMessage userMsg = new UserMessage(List.of(new ConversationMessageContent("Hello!"))); + userMsg.setName("user"); + messages.add(userMsg); + + // Assistant message + AssistantMessage assistantMsg = new AssistantMessage(List.of(new ConversationMessageContent("Hi there!")), + List.of(new ConversationToolCalls(new ConversationToolCallsOfFunction("abc", "parameters")))); + assistantMsg.setName("assistant"); + messages.add(assistantMsg); + + // Tool message + ToolMessage toolMsg = new ToolMessage(List.of(new ConversationMessageContent("Weather data: 72F"))); + toolMsg.setName("tool"); + messages.add(toolMsg); + + // Developer message + DeveloperMessage devMsg = new DeveloperMessage(List.of(new ConversationMessageContent("Debug info"))); + devMsg.setName("developer"); + messages.add(devMsg); + + ConversationInputAlpha2 input = new ConversationInputAlpha2(messages); + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", List.of(input)); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Processed all message types") + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals("Processed all message types", response.getOutputs().get(0).getChoices().get(0).getMessage().getContent()); + + // Verify all message types were processed + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequestAlpha2.class); + verify(daprStub).converseAlpha2(captor.capture(), any()); + + DaprProtos.ConversationRequestAlpha2 capturedRequest = captor.getValue(); + assertEquals(1, capturedRequest.getInputsCount()); + assertEquals(5, capturedRequest.getInputs(0).getMessagesCount()); + + // Verify each message type was converted correctly + List capturedMessages = capturedRequest.getInputs(0).getMessagesList(); + assertTrue(capturedMessages.get(0).hasOfSystem()); + assertTrue(capturedMessages.get(1).hasOfUser()); + assertTrue(capturedMessages.get(2).hasOfAssistant()); + assertTrue(capturedMessages.get(3).hasOfTool()); + assertTrue(capturedMessages.get(4).hasOfDeveloper()); + } + + @Test + public void converseAlpha2EmptyInputsTest() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", new ArrayList<>()); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequestAlpha2.class); + verify(daprStub).converseAlpha2(captor.capture(), any()); + + DaprProtos.ConversationRequestAlpha2 capturedRequest = captor.getValue(); + assertEquals(0, capturedRequest.getInputsCount()); + } + + @Test + public void converseAlpha2ResponseWithoutMessageTest() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + // No message set + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + ConversationResultChoices choice = response.getOutputs().get(0).getChoices().get(0); + assertEquals("stop", choice.getFinishReason()); + assertEquals(0, choice.getIndex()); + assertNull(choice.getMessage()); + } + + @Test + public void converseAlpha2MultipleResultsTest() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("First choice") + .build()) + .build()) + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("stop") + .setIndex(1) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Second choice") + .build()) + .build()) + .build()) + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("length") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Third result") + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + assertEquals(2, response.getOutputs().size()); + + // First result with 2 choices + ConversationResultAlpha2 firstResult = response.getOutputs().get(0); + assertEquals(2, firstResult.getChoices().size()); + assertEquals("First choice", firstResult.getChoices().get(0).getMessage().getContent()); + assertEquals("Second choice", firstResult.getChoices().get(1).getMessage().getContent()); + + // Second result with 1 choice + ConversationResultAlpha2 secondResult = response.getOutputs().get(1); + assertEquals(1, secondResult.getChoices().size()); + assertEquals("Third result", secondResult.getChoices().get(0).getMessage().getContent()); + } + + @Test + public void converseAlpha2ToolCallWithoutFunctionTest() { + ConversationRequestAlpha2 request = new ConversationRequestAlpha2("openai", null); + + DaprProtos.ConversationResponseAlpha2 grpcResponse = DaprProtos.ConversationResponseAlpha2.newBuilder() + .addOutputs(DaprProtos.ConversationResultAlpha2.newBuilder() + .addChoices(DaprProtos.ConversationResultChoices.newBuilder() + .setFinishReason("tool_calls") + .setIndex(0) + .setMessage(DaprProtos.ConversationResultMessage.newBuilder() + .setContent("Test content") + .addToolCalls(DaprProtos.ConversationToolCalls.newBuilder() + .setId("call_123") + // No function set + .build()) + .build()) + .build()) + .build()) + .build(); + + doAnswer((Answer) invocation -> { + StreamObserver observer = + (StreamObserver) invocation.getArguments()[1]; + observer.onNext(grpcResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha2(any(DaprProtos.ConversationRequestAlpha2.class), any()); + + ConversationResponseAlpha2 response = previewClient.converseAlpha2(request).block(); + + assertNotNull(response); + ConversationToolCalls toolCall = response.getOutputs().get(0).getChoices().get(0) + .getMessage().getToolCalls().get(0); + assertEquals("call_123", toolCall.getId()); + assertNull(toolCall.getFunction()); + } + private DaprProtos.QueryStateResponse buildQueryStateResponse(List> resp,String token) throws JsonProcessingException { List items = new ArrayList<>();