Add Saga pattern support in java-sdk (#956)

* first versoin of saga pattern support based on workflow

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* add unit test for SagaConfiguration to improve code coverage

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* save draft version before refactory to not hide saga.registerCompensatation

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* remove auto register compensation activity on callActivity()

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* rollback COVEREDRATIO to 80%

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* improve code implementation accordings to proposal

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* use ctx.allOf() to do compensation in parallel

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* add code to handle ContinueAsNewInterruption exception for saga compensation

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* add saga context for saga related method

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

* fix for checkstyle

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>

---------

Signed-off-by: Sky Ao <aoxiaojian@gmail.com>
Co-authored-by: Artur Souza <artursouza.ms@outlook.com>
Co-authored-by: Mukundan Sundararajan <65565396+mukundansundar@users.noreply.github.com>
This commit is contained in:
Sky Ao 2024-01-14 04:15:47 +08:00 committed by GitHub
parent fdb4200a2c
commit 4a7b9856ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1634 additions and 17 deletions

View File

@ -18,6 +18,9 @@ import com.microsoft.durabletask.Task;
import com.microsoft.durabletask.TaskCanceledException;
import com.microsoft.durabletask.TaskOptions;
import com.microsoft.durabletask.TaskOrchestrationContext;
import io.dapr.workflows.saga.DaprSagaContextImpl;
import io.dapr.workflows.saga.Saga;
import io.dapr.workflows.saga.SagaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.helpers.NOPLogger;
@ -32,6 +35,7 @@ import java.util.UUID;
public class DaprWorkflowContextImpl implements WorkflowContext {
private final TaskOrchestrationContext innerContext;
private final Logger logger;
private final Saga saga;
/**
* Constructor for DaprWorkflowContextImpl.
@ -51,6 +55,23 @@ public class DaprWorkflowContextImpl implements WorkflowContext {
* @throws IllegalArgumentException if context or logger is null
*/
public DaprWorkflowContextImpl(TaskOrchestrationContext context, Logger logger) throws IllegalArgumentException {
this(context, logger, null);
}
public DaprWorkflowContextImpl(TaskOrchestrationContext context, Saga saga) throws IllegalArgumentException {
this(context, LoggerFactory.getLogger(WorkflowContext.class), saga);
}
/**
* Constructor for DaprWorkflowContextImpl.
*
* @param context TaskOrchestrationContext
* @param logger Logger
* @param saga saga object, if null, saga is disabled
* @throws IllegalArgumentException if context or logger is null
*/
public DaprWorkflowContextImpl(TaskOrchestrationContext context, Logger logger, Saga saga)
throws IllegalArgumentException {
if (context == null) {
throw new IllegalArgumentException("Context cannot be null");
}
@ -60,6 +81,7 @@ public class DaprWorkflowContextImpl implements WorkflowContext {
this.innerContext = context;
this.logger = logger;
this.saga = saga;
}
/**
@ -110,15 +132,20 @@ public class DaprWorkflowContextImpl implements WorkflowContext {
}
/**
* Waits for an event to be raised named {@code name} and returns a {@link Task} that completes when the event is
* Waits for an event to be raised named {@code name} and returns a {@link Task}
* that completes when the event is
* received or is canceled when {@code timeout} expires.
*
* <p>See {@link #waitForExternalEvent(String, Duration, Class)} for a full description.
* <p>See {@link #waitForExternalEvent(String, Duration, Class)} for a full
* description.
*
* @param name the case-insensitive name of the event to wait for
* @param timeout the amount of time to wait before canceling the returned {@code Task}
* @return a new {@link Task} that completes when the external event is received or when {@code timeout} expires
* @throws TaskCanceledException if the specified {@code timeout} value expires before the event is received
* @param timeout the amount of time to wait before canceling the returned
* {@code Task}
* @return a new {@link Task} that completes when the external event is received
* or when {@code timeout} expires
* @throws TaskCanceledException if the specified {@code timeout} value expires
* before the event is received
*/
@Override
public <V> Task<Void> waitForExternalEvent(String name, Duration timeout) throws TaskCanceledException {
@ -126,10 +153,12 @@ public class DaprWorkflowContextImpl implements WorkflowContext {
}
/**
* Waits for an event to be raised named {@code name} and returns a {@link Task} that completes when the event is
* Waits for an event to be raised named {@code name} and returns a {@link Task}
* that completes when the event is
* received.
*
* <p>See {@link #waitForExternalEvent(String, Duration, Class)} for a full description.
* <p>See {@link #waitForExternalEvent(String, Duration, Class)} for a full
* description.
*
* @param name the case-insensitive name of the event to wait for
* @return a new {@link Task} that completes when the external event is received
@ -172,7 +201,6 @@ public class DaprWorkflowContextImpl implements WorkflowContext {
return this.innerContext.createTimer(duration);
}
/**
* {@inheritDoc}
*/
@ -185,7 +213,7 @@ public class DaprWorkflowContextImpl implements WorkflowContext {
*/
@Override
public <V> Task<V> callSubWorkflow(String name, @Nullable Object input, @Nullable String instanceID,
@Nullable TaskOptions options, Class<V> returnType) {
@Nullable TaskOptions options, Class<V> returnType) {
return this.innerContext.callSubOrchestrator(name, input, instanceID, options, returnType);
}
@ -213,4 +241,13 @@ public class DaprWorkflowContextImpl implements WorkflowContext {
public UUID newUuid() {
return this.innerContext.newUUID();
}
@Override
public SagaContext getSagaContext() {
if (this.saga == null) {
throw new UnsupportedOperationException("Saga is not enabled");
}
return new DaprSagaContextImpl(this.saga, this);
}
}

View File

@ -13,11 +13,16 @@ limitations under the License.
package io.dapr.workflows;
import com.microsoft.durabletask.interruption.ContinueAsNewInterruption;
import com.microsoft.durabletask.interruption.OrchestratorBlockedException;
import io.dapr.workflows.saga.SagaCompensationException;
import io.dapr.workflows.saga.SagaOption;
/**
* Common interface for workflow implementations.
*/
public abstract class Workflow {
public Workflow(){
public Workflow() {
}
/**
@ -30,10 +35,50 @@ public abstract class Workflow {
/**
* Executes the workflow logic.
*
* @param ctx provides access to methods for scheduling durable tasks and getting information about the current
* @param ctx provides access to methods for scheduling durable tasks and
* getting information about the current
* workflow instance.
*/
public void run(WorkflowContext ctx) {
this.create().run(ctx);
WorkflowStub stub = this.create();
if (!this.isSagaEnabled()) {
// saga disabled
stub.run(ctx);
} else {
// saga enabled
try {
stub.run(ctx);
} catch (OrchestratorBlockedException | ContinueAsNewInterruption e) {
throw e;
} catch (SagaCompensationException e) {
// Saga compensation is triggered gracefully but failed in exception
// don't need to trigger compensation again
throw e;
} catch (Exception e) {
try {
ctx.getSagaContext().compensate();
} catch (Exception se) {
se.addSuppressed(e);
throw se;
}
throw e;
}
}
}
public boolean isSagaEnabled() {
return this.getSagaOption() != null;
}
/**
* get saga configuration.
*
* @return saga configuration
*/
public SagaOption getSagaOption() {
// by default, saga is disabled
return null;
}
}

View File

@ -18,6 +18,7 @@ import com.microsoft.durabletask.Task;
import com.microsoft.durabletask.TaskCanceledException;
import com.microsoft.durabletask.TaskFailedException;
import com.microsoft.durabletask.TaskOptions;
import io.dapr.workflows.saga.SagaContext;
import org.slf4j.Logger;
import javax.annotation.Nullable;
@ -530,4 +531,12 @@ public interface WorkflowContext {
default UUID newUuid() {
throw new RuntimeException("No implementation found.");
}
/**
* get saga context.
*
* @return saga context
* @throws UnsupportedOperationException if saga is not enabled.
*/
SagaContext getSagaContext();
}

View File

@ -13,8 +13,6 @@ limitations under the License.
package io.dapr.workflows;
import io.dapr.workflows.WorkflowContext;
@FunctionalInterface
public interface WorkflowStub {
void run(WorkflowContext ctx);

View File

@ -17,6 +17,7 @@ import com.microsoft.durabletask.TaskOrchestration;
import com.microsoft.durabletask.TaskOrchestrationFactory;
import io.dapr.workflows.DaprWorkflowContextImpl;
import io.dapr.workflows.Workflow;
import io.dapr.workflows.saga.Saga;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
@ -55,7 +56,13 @@ class OrchestratorWrapper<T extends Workflow> implements TaskOrchestrationFactor
String.format("Unable to instantiate instance of workflow class '%s'", this.name), e
);
}
workflow.run(new DaprWorkflowContextImpl(ctx));
if (workflow.getSagaOption() != null) {
Saga saga = new Saga(workflow.getSagaOption());
workflow.run(new DaprWorkflowContextImpl(ctx, saga));
} else {
workflow.run(new DaprWorkflowContextImpl(ctx));
}
};
}

View File

@ -0,0 +1,68 @@
/*
* Copyright 2023 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.workflows.saga;
import com.microsoft.durabletask.TaskOptions;
/**
* Information for a compensation activity.
*/
class CompensatationInformation {
private final String compensatationActivityClassName;
private final Object compensatationActivityInput;
private final TaskOptions taskOptions;
/**
* Constructor for a compensation information.
*
* @param compensatationActivityClassName Class name of the activity to do
* compensatation.
* @param compensatationActivityInput Input of the activity to do
* compensatation.
* @param taskOptions task options to set retry strategy
*/
public CompensatationInformation(String compensatationActivityClassName,
Object compensatationActivityInput, TaskOptions taskOptions) {
this.compensatationActivityClassName = compensatationActivityClassName;
this.compensatationActivityInput = compensatationActivityInput;
this.taskOptions = taskOptions;
}
/**
* Gets the class name of the activity.
*
* @return the class name of the activity.
*/
public String getCompensatationActivityClassName() {
return compensatationActivityClassName;
}
/**
* Gets the input of the activity.
*
* @return the input of the activity.
*/
public Object getCompensatationActivityInput() {
return compensatationActivityInput;
}
/**
* get task options.
*
* @return task options, null if not set
*/
public TaskOptions getTaskOptions() {
return taskOptions;
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright 2023 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.workflows.saga;
import io.dapr.workflows.WorkflowContext;
/**
* Dapr Saga Context implementation.
*/
public class DaprSagaContextImpl implements SagaContext {
private final Saga saga;
private final WorkflowContext workflowContext;
/**
* Constructor to build up instance.
*
* @param saga Saga instance.
* @param workflowContext Workflow context.
* @throws IllegalArgumentException if saga or workflowContext is null.
*/
public DaprSagaContextImpl(Saga saga, WorkflowContext workflowContext) {
if (saga == null) {
throw new IllegalArgumentException("Saga should not be null");
}
if (workflowContext == null) {
throw new IllegalArgumentException("workflowContext should not be null");
}
this.saga = saga;
this.workflowContext = workflowContext;
}
@Override
public void registerCompensation(String activityClassName, Object activityInput) {
this.saga.registerCompensation(activityClassName, activityInput);
}
@Override
public void compensate() {
this.saga.compensate(workflowContext);
}
}

View File

@ -0,0 +1,130 @@
/*
* Copyright 2023 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.workflows.saga;
import com.microsoft.durabletask.Task;
import com.microsoft.durabletask.TaskOptions;
import com.microsoft.durabletask.interruption.ContinueAsNewInterruption;
import com.microsoft.durabletask.interruption.OrchestratorBlockedException;
import io.dapr.workflows.WorkflowContext;
import java.util.ArrayList;
import java.util.List;
public final class Saga {
private final SagaOption option;
private final List<CompensatationInformation> compensationActivities = new ArrayList<>();
/**
* Build up a Saga with its options.
*
* @param option Saga option.
*/
public Saga(SagaOption option) {
if (option == null) {
throw new IllegalArgumentException("option is required and should not be null.");
}
this.option = option;
}
/**
* Register a compensation activity.
*
* @param activityClassName name of the activity class
* @param activityInput input of the activity to be compensated
*/
public void registerCompensation(String activityClassName, Object activityInput) {
this.registerCompensation(activityClassName, activityInput, null);
}
/**
* Register a compensation activity.
*
* @param activityClassName name of the activity class
* @param activityInput input of the activity to be compensated
* @param taskOptions task options to set retry strategy
*/
public void registerCompensation(String activityClassName, Object activityInput, TaskOptions taskOptions) {
if (activityClassName == null || activityClassName.isEmpty()) {
throw new IllegalArgumentException("activityClassName is required and should not be null or empty.");
}
this.compensationActivities.add(new CompensatationInformation(activityClassName, activityInput, taskOptions));
}
/**
* Compensate all registered activities.
*
* @param ctx Workflow context.
*/
public void compensate(WorkflowContext ctx) {
// Check if parallel compensation is enabled
// Specical case: when parallel compensation is enabled and there is only one
// compensation, we still
// compensate sequentially.
if (option.isParallelCompensation() && compensationActivities.size() > 1) {
compensateInParallel(ctx);
} else {
compensateSequentially(ctx);
}
}
private void compensateInParallel(WorkflowContext ctx) {
List<Task<Void>> tasks = new ArrayList<>(compensationActivities.size());
for (CompensatationInformation compensationActivity : compensationActivities) {
Task<Void> task = executeCompensateActivity(ctx, compensationActivity);
tasks.add(task);
}
try {
ctx.allOf(tasks).await();
} catch (Exception e) {
throw new SagaCompensationException("Failed to compensate in parallel.", e);
}
}
private void compensateSequentially(WorkflowContext ctx) {
SagaCompensationException sagaException = null;
for (int i = compensationActivities.size() - 1; i >= 0; i--) {
String activityClassName = compensationActivities.get(i).getCompensatationActivityClassName();
try {
executeCompensateActivity(ctx, compensationActivities.get(i)).await();
} catch (OrchestratorBlockedException | ContinueAsNewInterruption e) {
throw e;
} catch (Exception e) {
if (sagaException == null) {
sagaException = new SagaCompensationException(
"Exception in saga compensatation: activity=" + activityClassName, e);
;
} else {
sagaException.addSuppressed(e);
}
if (!option.isContinueWithError()) {
throw sagaException;
}
}
}
if (sagaException != null) {
throw sagaException;
}
}
private Task<Void> executeCompensateActivity(WorkflowContext ctx, CompensatationInformation info)
throws SagaCompensationException {
String activityClassName = info.getCompensatationActivityClassName();
return ctx.callActivity(activityClassName, info.getCompensatationActivityInput(),
info.getTaskOptions());
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright 2023 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.workflows.saga;
/**
* saga compensation exception.
*/
public class SagaCompensationException extends RuntimeException {
/**
* build up a SagaCompensationException.
* @param message exception message
* @param cause exception cause
*/
public SagaCompensationException(String message, Exception cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,34 @@
/*
* Copyright 2023 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.workflows.saga;
/**
* Saga context.
*/
public interface SagaContext {
/**
* Register a compensation activity.
*
* @param activityClassName name of the activity class
* @param activityInput input of the activity to be compensated
*/
void registerCompensation(String activityClassName, Object activityInput);
/**
* Compensate all registered activities.
*
*/
void compensate();
}

View File

@ -0,0 +1,102 @@
/*
* Copyright 2023 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.workflows.saga;
/**
* Saga option.
*/
public final class SagaOption {
private final boolean parallelCompensation;
private final int maxParallelThread;
private final boolean continueWithError;
private SagaOption(boolean parallelCompensation, int maxParallelThread, boolean continueWithError) {
this.parallelCompensation = parallelCompensation;
this.maxParallelThread = maxParallelThread;
this.continueWithError = continueWithError;
}
public boolean isParallelCompensation() {
return parallelCompensation;
}
public boolean isContinueWithError() {
return continueWithError;
}
public int getMaxParallelThread() {
return maxParallelThread;
}
public static Builder newBuilder() {
return new Builder();
}
public static final class Builder {
// by default compensation is sequential
private boolean parallelCompensation = false;
// by default max parallel thread is 16, it's enough for most cases
private int maxParallelThread = 16;
// by default set continueWithError to be true
// So if a compensation fails, we should continue with the next compensations
private boolean continueWithError = true;
/**
* Set parallel compensation.
* @param parallelCompensation parallel compensation or not
* @return this builder itself
*/
public Builder setParallelCompensation(boolean parallelCompensation) {
this.parallelCompensation = parallelCompensation;
return this;
}
/**
* set max parallel thread.
*
* <p>Only valid when parallelCompensation is true.
* @param maxParallelThread max parallel thread
* @return this builder itself
*/
public Builder setMaxParallelThread(int maxParallelThread) {
if (maxParallelThread <= 2) {
throw new IllegalArgumentException("maxParallelThread should be greater than 1.");
}
this.maxParallelThread = maxParallelThread;
return this;
}
/**
* Set continue with error.
*
* <p>Only valid when parallelCompensation is false.
* @param continueWithError continue with error or not
* @return this builder itself
*/
public Builder setContinueWithError(boolean continueWithError) {
this.continueWithError = continueWithError;
return this;
}
/**
* Build Saga optiion.
* @return Saga optiion
*/
public SagaOption build() {
return new SagaOption(this.parallelCompensation, this.maxParallelThread, this.continueWithError);
}
}
}

View File

@ -20,6 +20,9 @@ import com.microsoft.durabletask.TaskCanceledException;
import com.microsoft.durabletask.TaskOptions;
import com.microsoft.durabletask.TaskOrchestrationContext;
import io.dapr.workflows.saga.Saga;
import io.dapr.workflows.saga.SagaContext;
import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -31,6 +34,7 @@ import java.time.ZonedDateTime;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
@ -130,6 +134,11 @@ public class DaprWorkflowContextImplTest {
public void continueAsNew(Object input, boolean preserveUnprocessedEvents) {
}
@Override
public SagaContext getSagaContext() {
return null;
}
};
}
@ -181,13 +190,13 @@ public class DaprWorkflowContextImplTest {
@Test
public void DaprWorkflowContextWithEmptyInnerContext() {
assertThrows(IllegalArgumentException.class, () -> {
context = new DaprWorkflowContextImpl(mockInnerContext, null);
context = new DaprWorkflowContextImpl(mockInnerContext, (Logger)null);
}); }
@Test
public void DaprWorkflowContextWithEmptyLogger() {
assertThrows(IllegalArgumentException.class, () -> {
context = new DaprWorkflowContextImpl(null, null);
context = new DaprWorkflowContextImpl(null, (Logger)null);
});
}
@ -309,4 +318,21 @@ public class DaprWorkflowContextImplTest {
String expectedMessage = "No implementation found.";
assertEquals(expectedMessage, runtimeException.getMessage());
}
@Test
public void getSagaContextTest_sagaEnabled() {
Saga saga = mock(Saga.class);
WorkflowContext context = new DaprWorkflowContextImpl(mockInnerContext, saga);
SagaContext sagaContext = context.getSagaContext();
assertNotNull("SagaContext should not be null", sagaContext);
}
@Test
public void getSagaContextTest_sagaDisabled() {
WorkflowContext context = new DaprWorkflowContextImpl(mockInnerContext);
assertThrows(UnsupportedOperationException.class, () -> {
context.getSagaContext();
});
}
}

View File

@ -0,0 +1,197 @@
package io.dapr.workflows;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import org.junit.Test;
import com.microsoft.durabletask.interruption.ContinueAsNewInterruption;
import com.microsoft.durabletask.interruption.OrchestratorBlockedException;
import io.dapr.workflows.saga.SagaCompensationException;
import io.dapr.workflows.saga.SagaContext;
import io.dapr.workflows.saga.SagaOption;
public class WorkflowTest {
@Test
public void testWorkflow_WithoutSaga() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithoutSaga(stub);
assertNull(workflow.getSagaOption());
assertFalse(workflow.isSagaEnabled());
WorkflowContext ctx = mock(WorkflowContext.class);
doNothing().when(stub).run(ctx);
workflow.run(ctx);
verify(stub, times(1)).run(eq(ctx));
}
@Test
public void testWorkflow_WithoutSaga_throwException() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithoutSaga(stub);
WorkflowContext ctx = mock(WorkflowContext.class);
Exception e = new RuntimeException();
doThrow(e).when(stub).run(ctx);
// should throw the exception, not catch
assertThrows(RuntimeException.class, () -> {
workflow.run(ctx);
});
verify(stub, times(1)).run(eq(ctx));
}
@Test
public void testWorkflow_WithSaga() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithSaga(stub);
assertNotNull(workflow.getSagaOption());
assertTrue(workflow.isSagaEnabled());
WorkflowContext ctx = mock(WorkflowContext.class);
doNothing().when(stub).run(ctx);
workflow.run(ctx);
verify(stub, times(1)).run(eq(ctx));
}
@Test
public void testWorkflow_WithSaga_shouldNotCatch_OrchestratorBlockedException() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithSaga(stub);
WorkflowContext ctx = mock(WorkflowContext.class);
Exception e = new OrchestratorBlockedException("test");
doThrow(e).when(stub).run(ctx);
// should not catch OrchestratorBlockedException
assertThrows(OrchestratorBlockedException.class, () -> {
workflow.run(ctx);
});
verify(stub, times(1)).run(eq(ctx));
}
@Test
public void testWorkflow_WithSaga_shouldNotCatch_ContinueAsNewInterruption() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithSaga(stub);
WorkflowContext ctx = mock(WorkflowContext.class);
Exception e = new ContinueAsNewInterruption("test");
doThrow(e).when(stub).run(ctx);
// should not catch ContinueAsNewInterruption
assertThrows(ContinueAsNewInterruption.class, () -> {
workflow.run(ctx);
});
verify(stub, times(1)).run(eq(ctx));
}
@Test
public void testWorkflow_WithSaga_shouldNotCatch_SagaCompensationException() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithSaga(stub);
WorkflowContext ctx = mock(WorkflowContext.class);
Exception e = new SagaCompensationException("test", null);
doThrow(e).when(stub).run(ctx);
// should not catch SagaCompensationException
assertThrows(SagaCompensationException.class, () -> {
workflow.run(ctx);
});
verify(stub, times(1)).run(eq(ctx));
}
@Test
public void testWorkflow_WithSaga_triggerCompensate() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithSaga(stub);
WorkflowContext ctx = mock(WorkflowContext.class);
Exception e = new RuntimeException("test", null);
doThrow(e).when(stub).run(ctx);
SagaContext sagaContext = mock(SagaContext.class);
doReturn(sagaContext).when(ctx).getSagaContext();
doNothing().when(sagaContext).compensate();
assertThrows(RuntimeException.class, () -> {
workflow.run(ctx);
});
verify(stub, times(1)).run(eq(ctx));
verify(sagaContext, times(1)).compensate();
}
@Test
public void testWorkflow_WithSaga_compensateFaile() {
WorkflowStub stub = mock(WorkflowStub.class);
Workflow workflow = new WorkflowWithSaga(stub);
WorkflowContext ctx = mock(WorkflowContext.class);
Exception e = new RuntimeException("workflow fail", null);
doThrow(e).when(stub).run(ctx);
SagaContext sagaContext = mock(SagaContext.class);
doReturn(sagaContext).when(ctx).getSagaContext();
Exception e2 = new RuntimeException("compensate fail", null);
doThrow(e2).when(sagaContext).compensate();
try {
workflow.run(ctx);
fail("sholdd throw exception");
} catch (Exception ex) {
assertEquals(e2.getMessage(), ex.getMessage());
assertEquals(1, ex.getSuppressed().length);
assertEquals(e.getMessage(), ex.getSuppressed()[0].getMessage());
}
verify(stub, times(1)).run(eq(ctx));
verify(sagaContext, times(1)).compensate();
}
public static class WorkflowWithoutSaga extends Workflow {
private final WorkflowStub stub;
public WorkflowWithoutSaga(WorkflowStub stub) {
this.stub = stub;
}
@Override
public WorkflowStub create() {
return stub;
}
}
public static class WorkflowWithSaga extends Workflow {
private final WorkflowStub stub;
public WorkflowWithSaga(WorkflowStub stub) {
this.stub = stub;
}
@Override
public WorkflowStub create() {
return stub;
}
@Override
public SagaOption getSagaOption() {
return SagaOption.newBuilder()
.setParallelCompensation(false)
.build();
}
}
}

View File

@ -0,0 +1,54 @@
package io.dapr.workflows.saga;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import org.junit.Test;
import io.dapr.workflows.WorkflowContext;
public class DaprSagaContextImplTest {
@Test
public void testDaprSagaContextImpl_IllegalArgumentException() {
Saga saga = mock(Saga.class);
WorkflowContext workflowContext = mock(WorkflowContext.class);
assertThrows(IllegalArgumentException.class, () -> {
new DaprSagaContextImpl(saga, null);
});
assertThrows(IllegalArgumentException.class, () -> {
new DaprSagaContextImpl(null, workflowContext);
});
}
@Test
public void test_registerCompensation() {
Saga saga = mock(Saga.class);
WorkflowContext workflowContext = mock(WorkflowContext.class);
DaprSagaContextImpl ctx = new DaprSagaContextImpl(saga, workflowContext);
String activityClassName = "name1";
Object activityInput = new Object();
doNothing().when(saga).registerCompensation(activityClassName, activityInput);
ctx.registerCompensation(activityClassName, activityInput);
verify(saga, times(1)).registerCompensation(activityClassName, activityInput);
}
@Test
public void test_compensate() {
Saga saga = mock(Saga.class);
WorkflowContext workflowContext = mock(WorkflowContext.class);
DaprSagaContextImpl ctx = new DaprSagaContextImpl(saga, workflowContext);
doNothing().when(saga).compensate(workflowContext);
ctx.compensate();
verify(saga, times(1)).compensate(workflowContext);
}
}

View File

@ -0,0 +1,324 @@
package io.dapr.workflows.saga;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.Test;
import com.microsoft.durabletask.TaskActivityContext;
import io.dapr.workflows.runtime.WorkflowActivity;
import io.dapr.workflows.runtime.WorkflowActivityContext;
public class SagaIntegrationTest {
private static int count = 0;
private static Object countLock = new Object();
@Test
public void testSaga_CompensateSequentially() {
int runCount = 10;
int succeedCount = 0;
int compensateCount = 0;
for (int i = 0; i < runCount; i++) {
boolean isSuccueed = doExecuteWorkflowWithSaga(false);
if (isSuccueed) {
succeedCount++;
} else {
compensateCount++;
}
}
System.out.println("Run workflow with saga " + runCount + " times: succeed " + succeedCount
+ " times, failed and compensated " + compensateCount + " times");
}
@Test
public void testSaga_compensateInParallel() {
int runCount = 100;
int succeedCount = 0;
int compensateCount = 0;
for (int i = 0; i < runCount; i++) {
boolean isSuccueed = doExecuteWorkflowWithSaga(true);
if (isSuccueed) {
succeedCount++;
} else {
compensateCount++;
}
}
System.out.println("Run workflow with saga " + runCount + " times: succeed " + succeedCount
+ " times, failed and compensated " + compensateCount + " times");
}
private boolean doExecuteWorkflowWithSaga(boolean parallelCompensation) {
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(parallelCompensation)
.setContinueWithError(true).build();
Saga saga = new Saga(config);
boolean workflowSuccess = false;
// reset count to zero
synchronized (countLock) {
count = 0;
}
Integer addInput = 100;
Integer subtractInput = 20;
Integer multiplyInput = 10;
Integer divideInput = 5;
try {
// step1: add activity
String result = callActivity(AddActivity.class.getName(), addInput, String.class);
saga.registerCompensation(AddCompentationActivity.class.getName(), addInput);
// step2: subtract activity
result = callActivity(SubtractActivity.class.getName(), subtractInput, String.class);
saga.registerCompensation(SubtractCompentationActivity.class.getName(), subtractInput);
if (parallelCompensation) {
// only add/subtract activities support parallel compensation
// so in step3 and step4 we repeat add/subtract activities
// step3: add activity again
result = callActivity(AddActivity.class.getName(), addInput, String.class);
saga.registerCompensation(AddCompentationActivity.class.getName(), addInput);
// step4: substract activity again
result = callActivity(SubtractActivity.class.getName(), subtractInput, String.class);
saga.registerCompensation(SubtractCompentationActivity.class.getName(), subtractInput);
} else {
// step3: multiply activity
result = callActivity(MultiplyActivity.class.getName(), multiplyInput, String.class);
saga.registerCompensation(MultiplyCompentationActivity.class.getName(), multiplyInput);
// step4: divide activity
result = callActivity(DivideActivity.class.getName(), divideInput, String.class);
saga.registerCompensation(DivideCompentationActivity.class.getName(), divideInput);
}
randomFail();
workflowSuccess = true;
} catch (Exception e) {
saga.compensate(SagaTest.createMockContext());
}
if (workflowSuccess) {
int expectResult = 0;
if (parallelCompensation) {
expectResult = 0 + addInput - subtractInput + addInput - subtractInput;
} else {
expectResult = (0 + addInput - subtractInput) * multiplyInput / divideInput;
}
assertEquals(expectResult, count);
} else {
assertEquals(0, count);
}
return workflowSuccess;
}
// mock to call activity in dapr workflow
private <V> V callActivity(String activityClassName, Object input, Class<V> returnType) {
try {
Class<?> activityClass = Class.forName(activityClassName);
WorkflowActivity activity = (WorkflowActivity) activityClass.getDeclaredConstructor().newInstance();
WorkflowActivityContext ctx = new WorkflowActivityContext(new TaskActivityContext() {
@Override
public java.lang.String getName() {
return activityClassName;
}
@Override
public <T> T getInput(Class<T> targetType) {
return (T) input;
}
});
randomFail();
return (V) activity.run(ctx);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static void randomFail() {
int randomInt = (int) (Math.random() * 100);
// if randomInt mod 10 is 0, then throw exception
if (randomInt % 10 == 0) {
throw new RuntimeException("random fail");
}
}
public static class AddActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount + input;
count = updatedCount;
}
String resultString = "current count is updated from " + originalCount + " to " + updatedCount
+ " after adding " + input;
// System.out.println(resultString);
return resultString;
}
}
public static class AddCompentationActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount - input;
count = updatedCount;
}
String resultString = "current count is compensated from " + originalCount + " to "
+ updatedCount + " after compensate adding " + input;
// System.out.println(resultString);
return resultString;
}
}
public static class SubtractActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount - input;
count = updatedCount;
}
String resultString = "current count is updated from " + originalCount + " to " + updatedCount
+ " after substracting " + input;
// System.out.println(resultString);
return resultString;
}
}
public static class SubtractCompentationActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount + input;
count = updatedCount;
}
String resultString = "current count is compensated from " + originalCount + " to " + updatedCount
+ " after compensate substracting " + input;
// System.out.println(resultString);
return resultString;
}
}
public static class MultiplyActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount * input;
count = updatedCount;
}
String resultString = "current count is updated from " + originalCount + " to " + updatedCount
+ " after multiplying " + input;
// System.out.println(resultString);
return resultString;
}
}
public static class MultiplyCompentationActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount / input;
count = updatedCount;
}
String resultString = "current count is compensated from " + originalCount + " to " + updatedCount
+ " after compensate multiplying " + input;
// System.out.println(resultString);
return resultString;
}
}
public static class DivideActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount / input;
count = updatedCount;
}
String resultString = "current count is updated from " + originalCount + " to " + updatedCount
+ " after dividing " + input;
// System.out.println(resultString);
return resultString;
}
}
public static class DivideCompentationActivity implements WorkflowActivity {
@Override
public String run(WorkflowActivityContext ctx) {
Integer input = ctx.getInput(Integer.class);
int originalCount = 0;
int updatedCount = 0;
synchronized (countLock) {
originalCount = count;
updatedCount = originalCount * input;
count = updatedCount;
}
String resultString = "current count is compensated from " + originalCount + " to " + updatedCount
+ " after compensate dividing " + input;
// System.out.println(resultString);
return resultString;
}
}
}

View File

@ -0,0 +1,50 @@
package io.dapr.workflows.saga;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.Test;
public class SagaOptionTest {
@Test
public void testBuild() {
SagaOption.Builder builder = SagaOption.newBuilder();
builder.setParallelCompensation(true);
builder.setMaxParallelThread(32);
builder.setContinueWithError(false);
SagaOption option = builder.build();
assertEquals(true, option.isParallelCompensation());
assertEquals(32, option.getMaxParallelThread());
assertEquals(false, option.isContinueWithError());
}
@Test
public void testBuild_default() {
SagaOption.Builder builder = SagaOption.newBuilder();
SagaOption option = builder.build();
assertEquals(false, option.isParallelCompensation());
assertEquals(16, option.getMaxParallelThread());
assertEquals(true, option.isContinueWithError());
}
@Test
public void testsetMaxParallelThread() {
SagaOption.Builder builder = SagaOption.newBuilder();
assertThrows(IllegalArgumentException.class, () -> {
builder.setMaxParallelThread(0);
});
assertThrows(IllegalArgumentException.class, () -> {
builder.setMaxParallelThread(1);
});
assertThrows(IllegalArgumentException.class, () -> {
builder.setMaxParallelThread(-1);
});
}
}

View File

@ -0,0 +1,454 @@
/*
* Copyright 2023 The Dapr Authors
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
limitations under the License.
*/
package io.dapr.workflows.saga;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import com.microsoft.durabletask.Task;
import com.microsoft.durabletask.TaskOptions;
import io.dapr.workflows.WorkflowContext;
import io.dapr.workflows.runtime.WorkflowActivity;
import io.dapr.workflows.runtime.WorkflowActivityContext;
public class SagaTest {
public static WorkflowContext createMockContext() {
WorkflowContext workflowContext = mock(WorkflowContext.class);
when(workflowContext.callActivity(anyString(), any(), eq((TaskOptions) null))).thenAnswer(new ActivityAnswer());
when(workflowContext.allOf(anyList())).thenAnswer(new AllActivityAnswer());
return workflowContext;
}
@Test
public void testSaga_IllegalArgument() {
assertThrows(IllegalArgumentException.class, () -> {
new Saga(null);
});
}
@Test
public void testregisterCompensation() {
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(false)
.setContinueWithError(true).build();
Saga saga = new Saga(config);
saga.registerCompensation(MockActivity.class.getName(), new MockActivityInput());
}
@Test
public void testregisterCompensation_IllegalArgument() {
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(false)
.setContinueWithError(true).build();
Saga saga = new Saga(config);
assertThrows(IllegalArgumentException.class, () -> {
saga.registerCompensation(null, "input");
});
assertThrows(IllegalArgumentException.class, () -> {
saga.registerCompensation("", "input");
});
}
@Test
public void testCompensateInParallel() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(true).build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
saga.compensate(createMockContext());
assertEquals(3, MockCompentationActivity.compensateOrder.size());
}
@Test
public void testCompensateInParallel_exception_1failed() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(true).build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
input2.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> {
saga.compensate(createMockContext());
});
assertNotNull(exception.getCause());
// 3 compentation activities, 2 succeed, 1 failed
assertEquals(0, exception.getSuppressed().length);
assertEquals(2, MockCompentationActivity.compensateOrder.size());
}
@Test
public void testCompensateInParallel_exception_2failed() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(true).build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
input2.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
input3.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> {
saga.compensate(createMockContext());
});
assertNotNull(exception.getCause());
// 3 compentation activities, 1 succeed, 2 failed
assertEquals(1, MockCompentationActivity.compensateOrder.size());
}
@Test
public void testCompensateInParallel_exception_3failed() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(true).build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
input1.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
input2.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
input3.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> {
saga.compensate(createMockContext());
});
assertNotNull(exception.getCause());
// 3 compentation activities, 0 succeed, 3 failed
assertEquals(0, MockCompentationActivity.compensateOrder.size());
}
@Test
public void testCompensateSequentially() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(false).build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
saga.compensate(createMockContext());
assertEquals(3, MockCompentationActivity.compensateOrder.size());
// the order should be 3 / 2 / 1
assertEquals(Integer.valueOf(3), MockCompentationActivity.compensateOrder.get(0));
assertEquals(Integer.valueOf(2), MockCompentationActivity.compensateOrder.get(1));
assertEquals(Integer.valueOf(1), MockCompentationActivity.compensateOrder.get(2));
}
@Test
public void testCompensateSequentially_continueWithError() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(false)
.setContinueWithError(true)
.build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
input2.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> {
saga.compensate(createMockContext());
});
assertNotNull(exception.getCause());
assertEquals(0, exception.getSuppressed().length);
// 3 compentation activities, 2 succeed, 1 failed
assertEquals(2, MockCompentationActivity.compensateOrder.size());
// the order should be 3 / 1
assertEquals(Integer.valueOf(3), MockCompentationActivity.compensateOrder.get(0));
assertEquals(Integer.valueOf(1), MockCompentationActivity.compensateOrder.get(1));
}
@Test
public void testCompensateSequentially_continueWithError_suppressed() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(false)
.setContinueWithError(true)
.build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
input2.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
input3.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> {
saga.compensate(createMockContext());
});
assertNotNull(exception.getCause());
assertEquals(1, exception.getSuppressed().length);
// 3 compentation activities, 1 succeed, 2 failed
assertEquals(1, MockCompentationActivity.compensateOrder.size());
// the order should be 3 / 1
assertEquals(Integer.valueOf(1), MockCompentationActivity.compensateOrder.get(0));
}
@Test
public void testCompensateSequentially_notContinueWithError() {
MockCompentationActivity.compensateOrder.clear();
SagaOption config = SagaOption.newBuilder()
.setParallelCompensation(false)
.setContinueWithError(false)
.build();
Saga saga = new Saga(config);
MockActivityInput input1 = new MockActivityInput();
input1.setOrder(1);
saga.registerCompensation(MockCompentationActivity.class.getName(), input1);
MockActivityInput input2 = new MockActivityInput();
input2.setOrder(2);
input2.setThrowException(true);
saga.registerCompensation(MockCompentationActivity.class.getName(), input2);
MockActivityInput input3 = new MockActivityInput();
input3.setOrder(3);
saga.registerCompensation(MockCompentationActivity.class.getName(), input3);
SagaCompensationException exception = assertThrows(SagaCompensationException.class, () -> {
saga.compensate(createMockContext());
});
assertNotNull(exception.getCause());
assertEquals(0, exception.getSuppressed().length);
// 3 compentation activities, 1 succeed, 1 failed and not continue
assertEquals(1, MockCompentationActivity.compensateOrder.size());
// the order should be 3 / 1
assertEquals(Integer.valueOf(3), MockCompentationActivity.compensateOrder.get(0));
}
public static class MockActivity implements WorkflowActivity {
@Override
public Object run(WorkflowActivityContext ctx) {
MockActivityOutput output = new MockActivityOutput();
output.setSucceed(true);
return output;
}
}
public static class MockCompentationActivity implements WorkflowActivity {
private static List<Integer> compensateOrder = Collections.synchronizedList(new ArrayList<>());
@Override
public Object run(WorkflowActivityContext ctx) {
MockActivityInput input = ctx.getInput(MockActivityInput.class);
if (input.isThrowException()) {
throw new RuntimeException("compensate failed: order=" + input.getOrder());
}
compensateOrder.add(input.getOrder());
return null;
}
}
public static class MockActivityInput {
private int order = 0;
private boolean throwException;
public int getOrder() {
return order;
}
public void setOrder(int order) {
this.order = order;
}
public boolean isThrowException() {
return throwException;
}
public void setThrowException(boolean throwException) {
this.throwException = throwException;
}
}
public static class MockActivityOutput {
private boolean succeed;
public boolean isSucceed() {
return succeed;
}
public void setSucceed(boolean succeed) {
this.succeed = succeed;
}
}
public static class ActivityAnswer implements Answer<Task<Void>> {
@Override
public Task<Void> answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
String name = (String) args[0];
Object input = args[1];
WorkflowActivity activity;
WorkflowActivityContext activityContext = Mockito.mock(WorkflowActivityContext.class);
try {
activity = (WorkflowActivity) Class.forName(name).getDeclaredConstructor().newInstance();
} catch (Exception e) {
fail(e);
return null;
}
Task<Void> task = mock(Task.class);
when(task.await()).thenAnswer(invocation1 -> {
Mockito.doReturn(input).when(activityContext).getInput(Mockito.any());
activity.run(activityContext);
return null;
});
return task;
}
}
public static class AllActivityAnswer implements Answer<Task<Void>> {
@Override
public Task<Void> answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
List<Task<Void>> tasks = (List<Task<Void>>) args[0];
ExecutorService executor = Executors.newFixedThreadPool(5);
List<Callable<Void>> compensationTasks = new ArrayList<>();
for (Task<Void> task : tasks) {
Callable<Void> compensationTask = new Callable<Void>() {
@Override
public Void call() {
return task.await();
}
};
compensationTasks.add(compensationTask);
}
List<Future<Void>> resultFutures;
try {
resultFutures = executor.invokeAll(compensationTasks, 2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
fail(e);
return null;
}
Task<Void> task = mock(Task.class);
when(task.await()).thenAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
Exception exception = null;
for (Future<Void> resultFuture : resultFutures) {
try {
resultFuture.get();
} catch (Exception e) {
exception = e;
}
}
if (exception != null) {
throw exception;
}
return null;
}
});
return task;
}
}
}