python-sdk/tests/test_transaction_context.py

176 lines
5.0 KiB
Python

import asyncio
import threading
from unittest.mock import MagicMock
import pytest
from openfeature.api import (
get_transaction_context,
set_transaction_context,
set_transaction_context_propagator,
)
from openfeature.evaluation_context import EvaluationContext
from openfeature.transaction_context import (
ContextVarsTransactionContextPropagator,
TransactionContextPropagator,
)
from openfeature.transaction_context.no_op_transaction_context_propagator import (
NoOpTransactionContextPropagator,
)
# Test cases
def test_should_return_default_evaluation_context_with_noop_propagator():
# Given
set_transaction_context_propagator(NoOpTransactionContextPropagator())
# When
context = get_transaction_context()
# Then
assert isinstance(context, EvaluationContext)
assert context.attributes == {}
def test_should_set_and_get_custom_transaction_context():
# Given
set_transaction_context_propagator(ContextVarsTransactionContextPropagator())
evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"})
# When
set_transaction_context(evaluation_context)
# Then
context = get_transaction_context()
assert context.targeting_key == "custom_key"
assert context.attributes == {"attr1": "val1"}
def test_should_override_propagator_and_reset_context():
# Given
custom_propagator = MagicMock(spec=TransactionContextPropagator)
default_context = EvaluationContext()
set_transaction_context_propagator(custom_propagator)
# When
set_transaction_context_propagator(NoOpTransactionContextPropagator())
# Then
assert get_transaction_context() == default_context
def test_should_call_set_transaction_context_on_propagator():
# Given
custom_propagator = MagicMock(spec=TransactionContextPropagator)
evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"})
set_transaction_context_propagator(custom_propagator)
# When
set_transaction_context(evaluation_context)
# Then
custom_propagator.set_transaction_context.assert_called_with(evaluation_context)
def test_should_return_default_context_with_noop_propagator_set():
# Given
noop_propagator = NoOpTransactionContextPropagator()
set_transaction_context_propagator(noop_propagator)
# When
context = get_transaction_context()
# Then
assert context == EvaluationContext()
def test_should_propagate_event_when_context_set():
# Given
custom_propagator = ContextVarsTransactionContextPropagator()
set_transaction_context_propagator(custom_propagator)
evaluation_context = EvaluationContext("custom_key", {"attr1": "val1"})
# When
set_transaction_context(evaluation_context)
# Then
assert (
custom_propagator._transaction_context_var.get().targeting_key == "custom_key"
)
assert custom_propagator._transaction_context_var.get().attributes == {
"attr1": "val1"
}
def test_context_vars_transaction_context_propagator_multiple_threads():
# Given
context_var_propagator = ContextVarsTransactionContextPropagator()
set_transaction_context_propagator(context_var_propagator)
number_of_threads = 3
barrier = threading.Barrier(number_of_threads)
def thread_func(context_value, result_list, index):
context = EvaluationContext(
f"context_{context_value}", {"thread": context_value}
)
set_transaction_context(context)
barrier.wait()
result_list[index] = get_transaction_context()
results = [None] * number_of_threads
threads = []
# When
for i in range(3):
thread = threading.Thread(target=thread_func, args=(i, results, i))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Then
for i in range(3):
assert results[i].targeting_key == f"context_{i}"
assert results[i].attributes == {"thread": i}
@pytest.mark.asyncio
async def test_context_vars_transaction_context_propagator_asyncio():
# Given
context_var_propagator = ContextVarsTransactionContextPropagator()
set_transaction_context_propagator(context_var_propagator)
number_of_tasks = 3
event = asyncio.Event()
ready_count = 0
async def async_func(context_value, results, index):
nonlocal ready_count
context = EvaluationContext(
f"context_{context_value}", {"async": context_value}
)
set_transaction_context(context)
ready_count += 1 # Increment the ready count
if ready_count == number_of_tasks:
event.set() # Set the event when all tasks are ready
await event.wait() # Wait for the event to be set
results[index] = get_transaction_context()
# Placeholder for results
results = [None] * number_of_tasks
# When
tasks = [async_func(i, results, i) for i in range(number_of_tasks)]
await asyncio.gather(*tasks)
# Then
for i in range(number_of_tasks):
assert results[i].targeting_key == f"context_{i}"
assert results[i].attributes == {"async": i}