560 lines
18 KiB
Python
560 lines
18 KiB
Python
import inspect
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from openfeature import api
|
|
from openfeature.api import add_hooks, clear_hooks, get_client, set_provider
|
|
from openfeature.client import OpenFeatureClient, _typecheck_flag_value
|
|
from openfeature.evaluation_context import EvaluationContext
|
|
from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails
|
|
from openfeature.exception import ErrorCode, OpenFeatureError
|
|
from openfeature.flag_evaluation import FlagResolutionDetails, FlagType, Reason
|
|
from openfeature.hook import Hook
|
|
from openfeature.provider import FeatureProvider, ProviderStatus
|
|
from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider
|
|
from openfeature.provider.no_op_provider import NoOpProvider
|
|
from openfeature.transaction_context import ContextVarsTransactionContextPropagator
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"flag_type, default_value, get_method",
|
|
(
|
|
(bool, True, "get_boolean_value"),
|
|
(bool, True, "get_boolean_value_async"),
|
|
(str, "String", "get_string_value"),
|
|
(str, "String", "get_string_value_async"),
|
|
(int, 100, "get_integer_value"),
|
|
(int, 100, "get_integer_value_async"),
|
|
(float, 10.23, "get_float_value"),
|
|
(float, 10.23, "get_float_value_async"),
|
|
(
|
|
dict,
|
|
{
|
|
"String": "string",
|
|
"Number": 2,
|
|
"Boolean": True,
|
|
},
|
|
"get_object_value",
|
|
),
|
|
(
|
|
dict,
|
|
{
|
|
"String": "string",
|
|
"Number": 2,
|
|
"Boolean": True,
|
|
},
|
|
"get_object_value_async",
|
|
),
|
|
(
|
|
list,
|
|
["string1", "string2"],
|
|
"get_object_value",
|
|
),
|
|
(
|
|
list,
|
|
["string1", "string2"],
|
|
"get_object_value_async",
|
|
),
|
|
),
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_should_get_flag_value_based_on_method_type(
|
|
flag_type, default_value, get_method, no_op_provider_client
|
|
):
|
|
# Given
|
|
# When
|
|
method = getattr(no_op_provider_client, get_method)
|
|
if inspect.iscoroutinefunction(method):
|
|
flag = await method(flag_key="Key", default_value=default_value)
|
|
else:
|
|
flag = method(flag_key="Key", default_value=default_value)
|
|
# Then
|
|
assert flag is not None
|
|
assert flag == default_value
|
|
assert isinstance(flag, flag_type)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"flag_type, default_value, get_method",
|
|
(
|
|
(bool, True, "get_boolean_details"),
|
|
(bool, True, "get_boolean_details_async"),
|
|
(str, "String", "get_string_details"),
|
|
(str, "String", "get_string_details_async"),
|
|
(int, 100, "get_integer_details"),
|
|
(int, 100, "get_integer_details_async"),
|
|
(float, 10.23, "get_float_details"),
|
|
(float, 10.23, "get_float_details_async"),
|
|
(
|
|
dict,
|
|
{
|
|
"String": "string",
|
|
"Number": 2,
|
|
"Boolean": True,
|
|
},
|
|
"get_object_details",
|
|
),
|
|
(
|
|
dict,
|
|
{
|
|
"String": "string",
|
|
"Number": 2,
|
|
"Boolean": True,
|
|
},
|
|
"get_object_details_async",
|
|
),
|
|
(
|
|
list,
|
|
["string1", "string2"],
|
|
"get_object_details",
|
|
),
|
|
(
|
|
list,
|
|
["string1", "string2"],
|
|
"get_object_details_async",
|
|
),
|
|
),
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_should_get_flag_detail_based_on_method_type(
|
|
flag_type, default_value, get_method, no_op_provider_client
|
|
):
|
|
# Given
|
|
# When
|
|
method = getattr(no_op_provider_client, get_method)
|
|
if inspect.iscoroutinefunction(method):
|
|
flag = await method(flag_key="Key", default_value=default_value)
|
|
else:
|
|
flag = method(flag_key="Key", default_value=default_value)
|
|
# Then
|
|
assert flag is not None
|
|
assert flag.value == default_value
|
|
assert isinstance(flag.value, flag_type)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_should_raise_exception_when_invalid_flag_type_provided(
|
|
no_op_provider_client,
|
|
):
|
|
# Given
|
|
# When
|
|
flag_sync = no_op_provider_client.evaluate_flag_details(
|
|
flag_type=None, flag_key="Key", default_value=True
|
|
)
|
|
flag_async = await no_op_provider_client.evaluate_flag_details_async(
|
|
flag_type=None, flag_key="Key", default_value=True
|
|
)
|
|
# Then
|
|
for flag in [flag_sync, flag_async]:
|
|
assert flag.value
|
|
assert flag.error_message == "Unknown flag type"
|
|
assert flag.error_code == ErrorCode.GENERAL
|
|
assert flag.reason == Reason.ERROR
|
|
|
|
|
|
def test_should_pass_flag_metadata_from_resolution_to_evaluation_details():
|
|
# Given
|
|
provider = InMemoryProvider(
|
|
{
|
|
"Key": InMemoryFlag(
|
|
"true",
|
|
{"true": True, "false": False},
|
|
flag_metadata={"foo": "bar"},
|
|
)
|
|
}
|
|
)
|
|
set_provider(provider, "my-client")
|
|
|
|
client = OpenFeatureClient("my-client", None)
|
|
|
|
# When
|
|
details = client.get_boolean_details(flag_key="Key", default_value=False)
|
|
|
|
# Then
|
|
assert details is not None
|
|
assert details.flag_metadata == {"foo": "bar"}
|
|
|
|
|
|
def test_should_handle_a_generic_exception_thrown_by_a_provider(no_op_provider_client):
|
|
# Given
|
|
exception_hook = MagicMock(spec=Hook)
|
|
exception_hook.after.side_effect = Exception("Generic exception raised")
|
|
no_op_provider_client.add_hooks([exception_hook])
|
|
# When
|
|
flag_details = no_op_provider_client.get_boolean_details(
|
|
flag_key="Key", default_value=True
|
|
)
|
|
# Then
|
|
assert flag_details is not None
|
|
assert flag_details.value
|
|
assert isinstance(flag_details.value, bool)
|
|
assert flag_details.reason == Reason.ERROR
|
|
assert flag_details.error_message == "Generic exception raised"
|
|
|
|
|
|
def test_should_handle_an_open_feature_exception_thrown_by_a_provider(
|
|
no_op_provider_client,
|
|
):
|
|
# Given
|
|
exception_hook = MagicMock(spec=Hook)
|
|
exception_hook.after.side_effect = OpenFeatureError(
|
|
ErrorCode.GENERAL, "error_message"
|
|
)
|
|
no_op_provider_client.add_hooks([exception_hook])
|
|
|
|
# When
|
|
flag_details = no_op_provider_client.get_boolean_details(
|
|
flag_key="Key", default_value=True
|
|
)
|
|
# Then
|
|
assert flag_details is not None
|
|
assert flag_details.value
|
|
assert isinstance(flag_details.value, bool)
|
|
assert flag_details.reason == Reason.ERROR
|
|
assert flag_details.error_message == "error_message"
|
|
|
|
|
|
def test_should_return_client_metadata_with_domain():
|
|
# Given
|
|
client = OpenFeatureClient("my-client", None, NoOpProvider())
|
|
# When
|
|
metadata = client.get_metadata()
|
|
# Then
|
|
assert metadata is not None
|
|
assert metadata.domain == "my-client"
|
|
|
|
|
|
def test_should_call_api_level_hooks(no_op_provider_client):
|
|
# Given
|
|
clear_hooks()
|
|
api_hook = MagicMock(spec=Hook)
|
|
add_hooks([api_hook])
|
|
|
|
# When
|
|
no_op_provider_client.get_boolean_details(flag_key="Key", default_value=True)
|
|
|
|
# Then
|
|
api_hook.before.assert_called_once()
|
|
api_hook.after.assert_called_once()
|
|
|
|
|
|
# Requirement 1.7.5
|
|
def test_should_define_a_provider_status_accessor(no_op_provider_client):
|
|
# When
|
|
status = no_op_provider_client.get_provider_status()
|
|
# Then
|
|
assert status is not None
|
|
assert status == ProviderStatus.READY
|
|
|
|
|
|
# Requirement 1.7.6
|
|
@pytest.mark.asyncio
|
|
async def test_should_shortcircuit_if_provider_is_not_ready(
|
|
no_op_provider_client, monkeypatch
|
|
):
|
|
# Given
|
|
monkeypatch.setattr(
|
|
no_op_provider_client, "get_provider_status", lambda: ProviderStatus.NOT_READY
|
|
)
|
|
spy_hook = MagicMock(spec=Hook)
|
|
no_op_provider_client.add_hooks([spy_hook])
|
|
# When
|
|
flag_details_sync = no_op_provider_client.get_boolean_details(
|
|
flag_key="Key", default_value=True
|
|
)
|
|
spy_hook.error.assert_called_once()
|
|
spy_hook.reset_mock()
|
|
flag_details_async = await no_op_provider_client.get_boolean_details_async(
|
|
flag_key="Key", default_value=True
|
|
)
|
|
# Then
|
|
for flag_details in [flag_details_sync, flag_details_async]:
|
|
assert flag_details is not None
|
|
assert flag_details.value
|
|
assert flag_details.reason == Reason.ERROR
|
|
assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY
|
|
spy_hook.error.assert_called_once()
|
|
spy_hook.finally_after.assert_called_once()
|
|
|
|
|
|
# Requirement 1.7.7
|
|
@pytest.mark.asyncio
|
|
async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
|
|
no_op_provider_client, monkeypatch
|
|
):
|
|
# Given
|
|
monkeypatch.setattr(
|
|
no_op_provider_client, "get_provider_status", lambda: ProviderStatus.FATAL
|
|
)
|
|
spy_hook = MagicMock(spec=Hook)
|
|
no_op_provider_client.add_hooks([spy_hook])
|
|
# When
|
|
flag_details_sync = no_op_provider_client.get_boolean_details(
|
|
flag_key="Key", default_value=True
|
|
)
|
|
spy_hook.error.assert_called_once()
|
|
spy_hook.reset_mock()
|
|
flag_details_async = await no_op_provider_client.get_boolean_details_async(
|
|
flag_key="Key", default_value=True
|
|
)
|
|
# Then
|
|
for flag_details in [flag_details_sync, flag_details_async]:
|
|
assert flag_details is not None
|
|
assert flag_details.value
|
|
assert flag_details.reason == Reason.ERROR
|
|
assert flag_details.error_code == ErrorCode.PROVIDER_FATAL
|
|
spy_hook.error.assert_called_once()
|
|
spy_hook.finally_after.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code():
|
|
# Given
|
|
spy_hook = MagicMock(spec=Hook)
|
|
provider = MagicMock(spec=FeatureProvider)
|
|
provider.get_provider_hooks.return_value = []
|
|
mock_resolution = FlagResolutionDetails(
|
|
value=True,
|
|
reason=Reason.ERROR,
|
|
error_code=ErrorCode.PROVIDER_FATAL,
|
|
error_message="This is an error message",
|
|
)
|
|
provider.resolve_boolean_details.return_value = mock_resolution
|
|
provider.resolve_boolean_details_async.return_value = mock_resolution
|
|
set_provider(provider)
|
|
client = get_client()
|
|
client.add_hooks([spy_hook])
|
|
# When
|
|
flag_details_sync = client.get_boolean_details(flag_key="Key", default_value=True)
|
|
spy_hook.error.assert_called_once()
|
|
spy_hook.reset_mock()
|
|
flag_details_async = await client.get_boolean_details_async(
|
|
flag_key="Key", default_value=True
|
|
)
|
|
# Then
|
|
for flag_details in [flag_details_sync, flag_details_async]:
|
|
assert flag_details is not None
|
|
assert flag_details.value
|
|
assert flag_details.reason == Reason.ERROR
|
|
assert flag_details.error_code == ErrorCode.PROVIDER_FATAL
|
|
spy_hook.error.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_type_mismatch_exceptions():
|
|
# Given
|
|
client = get_client()
|
|
# When
|
|
flag_details_sync = client.get_boolean_details(
|
|
flag_key="Key", default_value="type mismatch"
|
|
)
|
|
flag_details_async = await client.get_boolean_details_async(
|
|
flag_key="Key", default_value="type mismatch"
|
|
)
|
|
# Then
|
|
for flag_details in [flag_details_sync, flag_details_async]:
|
|
assert flag_details is not None
|
|
assert flag_details.value
|
|
assert flag_details.reason == Reason.ERROR
|
|
assert flag_details.error_code == ErrorCode.TYPE_MISMATCH
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_typecheck_flag_value_general_error():
|
|
# Given
|
|
flag_value = "A"
|
|
flag_type = None
|
|
# When
|
|
err = _typecheck_flag_value(value=flag_value, flag_type=flag_type)
|
|
# Then
|
|
assert err.error_code == ErrorCode.GENERAL
|
|
assert err.error_message == "Unknown flag type"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_typecheck_flag_value_type_mismatch_error():
|
|
# Given
|
|
flag_value = "A"
|
|
flag_type = FlagType.BOOLEAN
|
|
# When
|
|
err = _typecheck_flag_value(value=flag_value, flag_type=flag_type)
|
|
# Then
|
|
assert err.error_code == ErrorCode.TYPE_MISMATCH
|
|
assert err.error_message == "Expected type <class 'bool'> but got <class 'str'>"
|
|
|
|
|
|
def test_provider_events():
|
|
# Given
|
|
provider = NoOpProvider()
|
|
set_provider(provider)
|
|
|
|
other_provider = NoOpProvider()
|
|
set_provider(other_provider, "my-domain")
|
|
|
|
provider_details = ProviderEventDetails(message="message")
|
|
details = EventDetails.from_provider_event_details(
|
|
provider.get_metadata().name, provider_details
|
|
)
|
|
|
|
def emit_all_events(provider):
|
|
provider.emit_provider_configuration_changed(provider_details)
|
|
provider.emit_provider_error(provider_details)
|
|
provider.emit_provider_stale(provider_details)
|
|
|
|
spy = MagicMock()
|
|
|
|
client = get_client()
|
|
client.add_handler(ProviderEvent.PROVIDER_READY, spy.provider_ready)
|
|
client.add_handler(
|
|
ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed
|
|
)
|
|
client.add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error)
|
|
client.add_handler(ProviderEvent.PROVIDER_STALE, spy.provider_stale)
|
|
|
|
# When
|
|
emit_all_events(provider)
|
|
emit_all_events(other_provider)
|
|
|
|
# Then
|
|
# NOTE: provider_ready is called immediately after adding the handler
|
|
spy.provider_ready.assert_called_once()
|
|
spy.provider_configuration_changed.assert_called_once_with(details)
|
|
spy.provider_error.assert_called_once_with(details)
|
|
spy.provider_stale.assert_called_once_with(details)
|
|
|
|
|
|
def test_add_remove_event_handler():
|
|
# Given
|
|
provider = NoOpProvider()
|
|
set_provider(provider)
|
|
|
|
spy = MagicMock()
|
|
|
|
client = get_client()
|
|
client.add_handler(
|
|
ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed
|
|
)
|
|
client.remove_handler(
|
|
ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed
|
|
)
|
|
|
|
provider_details = ProviderEventDetails(message="message")
|
|
|
|
# When
|
|
provider.emit_provider_configuration_changed(provider_details)
|
|
|
|
# Then
|
|
spy.provider_configuration_changed.assert_not_called()
|
|
|
|
|
|
# Requirement 5.1.2, Requirement 5.1.3
|
|
def test_provider_event_late_binding():
|
|
# Given
|
|
provider = NoOpProvider()
|
|
set_provider(provider, "my-domain")
|
|
other_provider = NoOpProvider()
|
|
|
|
spy = MagicMock()
|
|
|
|
client = get_client("my-domain")
|
|
client.add_handler(
|
|
ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed
|
|
)
|
|
|
|
set_provider(other_provider, "my-domain")
|
|
|
|
provider_details = ProviderEventDetails(message="message from provider")
|
|
other_provider_details = ProviderEventDetails(message="message from other provider")
|
|
|
|
details = EventDetails.from_provider_event_details(
|
|
other_provider.get_metadata().name, other_provider_details
|
|
)
|
|
|
|
# When
|
|
provider.emit_provider_configuration_changed(provider_details)
|
|
other_provider.emit_provider_configuration_changed(other_provider_details)
|
|
|
|
# Then
|
|
spy.provider_configuration_changed.assert_called_once_with(details)
|
|
|
|
|
|
def test_client_handlers_thread_safety():
|
|
provider = NoOpProvider()
|
|
set_provider(provider)
|
|
|
|
def add_handlers_task():
|
|
def handler(*args, **kwargs):
|
|
time.sleep(0.005)
|
|
|
|
for _ in range(10):
|
|
time.sleep(0.01)
|
|
client = get_client(str(uuid.uuid4()))
|
|
client.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler)
|
|
|
|
def emit_events_task():
|
|
for _ in range(10):
|
|
time.sleep(0.01)
|
|
provider.emit_provider_configuration_changed(ProviderEventDetails())
|
|
|
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
f1 = executor.submit(add_handlers_task)
|
|
f2 = executor.submit(emit_events_task)
|
|
f1.result()
|
|
f2.result()
|
|
|
|
|
|
def test_client_should_merge_contexts():
|
|
api.clear_hooks()
|
|
api.set_transaction_context_propagator(ContextVarsTransactionContextPropagator())
|
|
|
|
provider = NoOpProvider()
|
|
provider.resolve_boolean_details = MagicMock(wraps=provider.resolve_boolean_details)
|
|
api.set_provider(provider)
|
|
|
|
# Global evaluation context
|
|
global_context = EvaluationContext(
|
|
targeting_key="global", attributes={"global_attr": "global_value"}
|
|
)
|
|
api.set_evaluation_context(global_context)
|
|
|
|
# Transaction context
|
|
transaction_context = EvaluationContext(
|
|
targeting_key="transaction",
|
|
attributes={"transaction_attr": "transaction_value"},
|
|
)
|
|
api.set_transaction_context(transaction_context)
|
|
|
|
# Client-specific context
|
|
client_context = EvaluationContext(
|
|
targeting_key="client", attributes={"client_attr": "client_value"}
|
|
)
|
|
client = OpenFeatureClient(domain=None, version=None, context=client_context)
|
|
|
|
# Invocation-specific context
|
|
invocation_context = EvaluationContext(
|
|
targeting_key="invocation", attributes={"invocation_attr": "invocation_value"}
|
|
)
|
|
flag_input = "flag"
|
|
flag_default = False
|
|
client.get_boolean_details(flag_input, flag_default, invocation_context)
|
|
|
|
# Retrieve the call arguments
|
|
args, kwargs = provider.resolve_boolean_details.call_args
|
|
flag_key, default_value, context = (
|
|
kwargs["flag_key"],
|
|
kwargs["default_value"],
|
|
kwargs["evaluation_context"],
|
|
)
|
|
|
|
assert flag_key == flag_input
|
|
assert default_value is flag_default
|
|
assert context.targeting_key == "invocation" # Last one in the merge chain
|
|
assert context.attributes["global_attr"] == "global_value"
|
|
assert context.attributes["transaction_attr"] == "transaction_value"
|
|
assert context.attributes["client_attr"] == "client_value"
|
|
assert context.attributes["invocation_attr"] == "invocation_value"
|