176 lines
5.0 KiB
Python
176 lines
5.0 KiB
Python
import asyncio
|
|
import threading
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from openfeature.api import (
|
|
get_transaction_context,
|
|
set_transaction_context,
|
|
set_transaction_context_propagator,
|
|
)
|
|
from openfeature.evaluation_context import EvaluationContext
|
|
from openfeature.transaction_context import (
|
|
ContextVarsTransactionContextPropagator,
|
|
TransactionContextPropagator,
|
|
)
|
|
from openfeature.transaction_context.no_op_transaction_context_propagator import (
|
|
NoOpTransactionContextPropagator,
|
|
)
|
|
|
|
|
|
# Test cases
|
|
def test_should_return_default_evaluation_context_with_noop_propagator():
|
|
# Given
|
|
set_transaction_context_propagator(NoOpTransactionContextPropagator())
|
|
|
|
# When
|
|
context = get_transaction_context()
|
|
|
|
# Then
|
|
assert isinstance(context, EvaluationContext)
|
|
assert context.attributes == {}
|
|
|
|
|
|
def test_should_set_and_get_custom_transaction_context():
|
|
# Given
|
|
set_transaction_context_propagator(ContextVarsTransactionContextPropagator())
|
|
evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"})
|
|
|
|
# When
|
|
set_transaction_context(evaluation_context)
|
|
|
|
# Then
|
|
context = get_transaction_context()
|
|
assert context.targeting_key == "custom_key"
|
|
assert context.attributes == {"attr1": "val1"}
|
|
|
|
|
|
def test_should_override_propagator_and_reset_context():
|
|
# Given
|
|
custom_propagator = MagicMock(spec=TransactionContextPropagator)
|
|
default_context = EvaluationContext()
|
|
|
|
set_transaction_context_propagator(custom_propagator)
|
|
|
|
# When
|
|
set_transaction_context_propagator(NoOpTransactionContextPropagator())
|
|
|
|
# Then
|
|
assert get_transaction_context() == default_context
|
|
|
|
|
|
def test_should_call_set_transaction_context_on_propagator():
|
|
# Given
|
|
custom_propagator = MagicMock(spec=TransactionContextPropagator)
|
|
evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"})
|
|
set_transaction_context_propagator(custom_propagator)
|
|
|
|
# When
|
|
set_transaction_context(evaluation_context)
|
|
|
|
# Then
|
|
custom_propagator.set_transaction_context.assert_called_with(evaluation_context)
|
|
|
|
|
|
def test_should_return_default_context_with_noop_propagator_set():
|
|
# Given
|
|
noop_propagator = NoOpTransactionContextPropagator()
|
|
|
|
set_transaction_context_propagator(noop_propagator)
|
|
|
|
# When
|
|
context = get_transaction_context()
|
|
|
|
# Then
|
|
assert context == EvaluationContext()
|
|
|
|
|
|
def test_should_propagate_event_when_context_set():
|
|
# Given
|
|
custom_propagator = ContextVarsTransactionContextPropagator()
|
|
set_transaction_context_propagator(custom_propagator)
|
|
evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"})
|
|
|
|
# When
|
|
set_transaction_context(evaluation_context)
|
|
|
|
# Then
|
|
assert (
|
|
custom_propagator._transaction_context_var.get().targeting_key == "custom_key"
|
|
)
|
|
assert custom_propagator._transaction_context_var.get().attributes == {
|
|
"attr1": "val1"
|
|
}
|
|
|
|
|
|
def test_context_vars_transaction_context_propagator_multiple_threads():
|
|
# Given
|
|
context_var_propagator = ContextVarsTransactionContextPropagator()
|
|
set_transaction_context_propagator(context_var_propagator)
|
|
|
|
number_of_threads = 3
|
|
barrier = threading.Barrier(number_of_threads)
|
|
|
|
def thread_func(context_value, result_list, index):
|
|
context = EvaluationContext(
|
|
f"context_{context_value}", {"thread": context_value}
|
|
)
|
|
set_transaction_context(context)
|
|
barrier.wait()
|
|
result_list[index] = get_transaction_context()
|
|
|
|
results = [None] * number_of_threads
|
|
threads = []
|
|
|
|
# When
|
|
for i in range(3):
|
|
thread = threading.Thread(target=thread_func, args=(i, results, i))
|
|
threads.append(thread)
|
|
thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
# Then
|
|
for i in range(3):
|
|
assert results[i].targeting_key == f"context_{i}"
|
|
assert results[i].attributes == {"thread": i}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_vars_transaction_context_propagator_asyncio():
|
|
# Given
|
|
context_var_propagator = ContextVarsTransactionContextPropagator()
|
|
set_transaction_context_propagator(context_var_propagator)
|
|
|
|
number_of_tasks = 3
|
|
event = asyncio.Event()
|
|
ready_count = 0
|
|
|
|
async def async_func(context_value, results, index):
|
|
nonlocal ready_count
|
|
context = EvaluationContext(
|
|
f"context_{context_value}", {"async": context_value}
|
|
)
|
|
set_transaction_context(context)
|
|
|
|
ready_count += 1 # Increment the ready count
|
|
if ready_count == number_of_tasks:
|
|
event.set() # Set the event when all tasks are ready
|
|
|
|
await event.wait() # Wait for the event to be set
|
|
results[index] = get_transaction_context()
|
|
|
|
# Placeholder for results
|
|
results = [None] * number_of_tasks
|
|
|
|
# When
|
|
tasks = [async_func(i, results, i) for i in range(number_of_tasks)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
# Then
|
|
for i in range(number_of_tasks):
|
|
assert results[i].targeting_key == f"context_{i}"
|
|
assert results[i].attributes == {"async": i}
|