import inspect import time import uuid from concurrent.futures import ThreadPoolExecutor from unittest.mock import MagicMock import pytest from openfeature import api from openfeature.api import add_hooks, clear_hooks, get_client, set_provider from openfeature.client import OpenFeatureClient, _typecheck_flag_value from openfeature.evaluation_context import EvaluationContext from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails from openfeature.exception import ErrorCode, OpenFeatureError from openfeature.flag_evaluation import FlagResolutionDetails, FlagType, Reason from openfeature.hook import Hook from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider from openfeature.provider.no_op_provider import NoOpProvider from openfeature.transaction_context import ContextVarsTransactionContextPropagator @pytest.mark.parametrize( "flag_type, default_value, get_method", ( (bool, True, "get_boolean_value"), (bool, True, "get_boolean_value_async"), (str, "String", "get_string_value"), (str, "String", "get_string_value_async"), (int, 100, "get_integer_value"), (int, 100, "get_integer_value_async"), (float, 10.23, "get_float_value"), (float, 10.23, "get_float_value_async"), ( dict, { "String": "string", "Number": 2, "Boolean": True, }, "get_object_value", ), ( dict, { "String": "string", "Number": 2, "Boolean": True, }, "get_object_value_async", ), ( list, ["string1", "string2"], "get_object_value", ), ( list, ["string1", "string2"], "get_object_value_async", ), ), ) @pytest.mark.asyncio async def test_should_get_flag_value_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When method = getattr(no_op_provider_client, get_method) if inspect.iscoroutinefunction(method): flag = await method(flag_key="Key", default_value=default_value) else: flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag == default_value assert isinstance(flag, flag_type) @pytest.mark.parametrize( "flag_type, default_value, get_method", ( (bool, True, "get_boolean_details"), (bool, True, "get_boolean_details_async"), (str, "String", "get_string_details"), (str, "String", "get_string_details_async"), (int, 100, "get_integer_details"), (int, 100, "get_integer_details_async"), (float, 10.23, "get_float_details"), (float, 10.23, "get_float_details_async"), ( dict, { "String": "string", "Number": 2, "Boolean": True, }, "get_object_details", ), ( dict, { "String": "string", "Number": 2, "Boolean": True, }, "get_object_details_async", ), ( list, ["string1", "string2"], "get_object_details", ), ( list, ["string1", "string2"], "get_object_details_async", ), ), ) @pytest.mark.asyncio async def test_should_get_flag_detail_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When method = getattr(no_op_provider_client, get_method) if inspect.iscoroutinefunction(method): flag = await method(flag_key="Key", default_value=default_value) else: flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag.value == default_value assert isinstance(flag.value, flag_type) @pytest.mark.asyncio async def test_should_raise_exception_when_invalid_flag_type_provided( no_op_provider_client, ): # Given # When flag_sync = no_op_provider_client.evaluate_flag_details( flag_type=None, flag_key="Key", default_value=True ) flag_async = await no_op_provider_client.evaluate_flag_details_async( flag_type=None, flag_key="Key", default_value=True ) # Then for flag in [flag_sync, flag_async]: assert flag.value assert flag.error_message == "Unknown flag type" assert flag.error_code == ErrorCode.GENERAL assert flag.reason == Reason.ERROR def test_should_pass_flag_metadata_from_resolution_to_evaluation_details(): # Given provider = InMemoryProvider( { "Key": InMemoryFlag( "true", {"true": True, "false": False}, flag_metadata={"foo": "bar"}, ) } ) set_provider(provider, "my-client") client = OpenFeatureClient("my-client", None) # When details = client.get_boolean_details(flag_key="Key", default_value=False) # Then assert details is not None assert details.flag_metadata == {"foo": "bar"} def test_should_handle_a_generic_exception_thrown_by_a_provider(no_op_provider_client): # Given exception_hook = MagicMock(spec=Hook) exception_hook.after.side_effect = Exception("Generic exception raised") no_op_provider_client.add_hooks([exception_hook]) # When flag_details = no_op_provider_client.get_boolean_details( flag_key="Key", default_value=True ) # Then assert flag_details is not None assert flag_details.value assert isinstance(flag_details.value, bool) assert flag_details.reason == Reason.ERROR assert flag_details.error_message == "Generic exception raised" def test_should_handle_an_open_feature_exception_thrown_by_a_provider( no_op_provider_client, ): # Given exception_hook = MagicMock(spec=Hook) exception_hook.after.side_effect = OpenFeatureError( ErrorCode.GENERAL, "error_message" ) no_op_provider_client.add_hooks([exception_hook]) # When flag_details = no_op_provider_client.get_boolean_details( flag_key="Key", default_value=True ) # Then assert flag_details is not None assert flag_details.value assert isinstance(flag_details.value, bool) assert flag_details.reason == Reason.ERROR assert flag_details.error_message == "error_message" def test_should_return_client_metadata_with_domain(): # Given client = OpenFeatureClient("my-client", None, NoOpProvider()) # When metadata = client.get_metadata() # Then assert metadata is not None assert metadata.domain == "my-client" def test_should_call_api_level_hooks(no_op_provider_client): # Given clear_hooks() api_hook = MagicMock(spec=Hook) add_hooks([api_hook]) # When no_op_provider_client.get_boolean_details(flag_key="Key", default_value=True) # Then api_hook.before.assert_called_once() api_hook.after.assert_called_once() # Requirement 1.7.5 def test_should_define_a_provider_status_accessor(no_op_provider_client): # When status = no_op_provider_client.get_provider_status() # Then assert status is not None assert status == ProviderStatus.READY # Requirement 1.7.6 @pytest.mark.asyncio async def test_should_shortcircuit_if_provider_is_not_ready( no_op_provider_client, monkeypatch ): # Given monkeypatch.setattr( no_op_provider_client, "get_provider_status", lambda: ProviderStatus.NOT_READY ) spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When flag_details_sync = no_op_provider_client.get_boolean_details( flag_key="Key", default_value=True ) spy_hook.error.assert_called_once() spy_hook.reset_mock() flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then for flag_details in [flag_details_sync, flag_details_async]: assert flag_details is not None assert flag_details.value assert flag_details.reason == Reason.ERROR assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY spy_hook.error.assert_called_once() spy_hook.finally_after.assert_called_once() # Requirement 1.7.7 @pytest.mark.asyncio async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( no_op_provider_client, monkeypatch ): # Given monkeypatch.setattr( no_op_provider_client, "get_provider_status", lambda: ProviderStatus.FATAL ) spy_hook = MagicMock(spec=Hook) no_op_provider_client.add_hooks([spy_hook]) # When flag_details_sync = no_op_provider_client.get_boolean_details( flag_key="Key", default_value=True ) spy_hook.error.assert_called_once() spy_hook.reset_mock() flag_details_async = await no_op_provider_client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then for flag_details in [flag_details_sync, flag_details_async]: assert flag_details is not None assert flag_details.value assert flag_details.reason == Reason.ERROR assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() spy_hook.finally_after.assert_called_once() @pytest.mark.asyncio async def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): # Given spy_hook = MagicMock(spec=Hook) provider = MagicMock(spec=FeatureProvider) provider.get_provider_hooks.return_value = [] mock_resolution = FlagResolutionDetails( value=True, reason=Reason.ERROR, error_code=ErrorCode.PROVIDER_FATAL, error_message="This is an error message", ) provider.resolve_boolean_details.return_value = mock_resolution provider.resolve_boolean_details_async.return_value = mock_resolution set_provider(provider) client = get_client() client.add_hooks([spy_hook]) # When flag_details_sync = client.get_boolean_details(flag_key="Key", default_value=True) spy_hook.error.assert_called_once() spy_hook.reset_mock() flag_details_async = await client.get_boolean_details_async( flag_key="Key", default_value=True ) # Then for flag_details in [flag_details_sync, flag_details_async]: assert flag_details is not None assert flag_details.value assert flag_details.reason == Reason.ERROR assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() @pytest.mark.asyncio async def test_client_type_mismatch_exceptions(): # Given client = get_client() # When flag_details_sync = client.get_boolean_details( flag_key="Key", default_value="type mismatch" ) flag_details_async = await client.get_boolean_details_async( flag_key="Key", default_value="type mismatch" ) # Then for flag_details in [flag_details_sync, flag_details_async]: assert flag_details is not None assert flag_details.value assert flag_details.reason == Reason.ERROR assert flag_details.error_code == ErrorCode.TYPE_MISMATCH @pytest.mark.asyncio async def test_typecheck_flag_value_general_error(): # Given flag_value = "A" flag_type = None # When err = _typecheck_flag_value(value=flag_value, flag_type=flag_type) # Then assert err.error_code == ErrorCode.GENERAL assert err.error_message == "Unknown flag type" @pytest.mark.asyncio async def test_typecheck_flag_value_type_mismatch_error(): # Given flag_value = "A" flag_type = FlagType.BOOLEAN # When err = _typecheck_flag_value(value=flag_value, flag_type=flag_type) # Then assert err.error_code == ErrorCode.TYPE_MISMATCH assert err.error_message == "Expected type but got " def test_provider_events(): # Given provider = NoOpProvider() set_provider(provider) other_provider = NoOpProvider() set_provider(other_provider, "my-domain") provider_details = ProviderEventDetails(message="message") details = EventDetails.from_provider_event_details( provider.get_metadata().name, provider_details ) def emit_all_events(provider): provider.emit_provider_configuration_changed(provider_details) provider.emit_provider_error(provider_details) provider.emit_provider_stale(provider_details) spy = MagicMock() client = get_client() client.add_handler(ProviderEvent.PROVIDER_READY, spy.provider_ready) client.add_handler( ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed ) client.add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) client.add_handler(ProviderEvent.PROVIDER_STALE, spy.provider_stale) # When emit_all_events(provider) emit_all_events(other_provider) # Then # NOTE: provider_ready is called immediately after adding the handler spy.provider_ready.assert_called_once() spy.provider_configuration_changed.assert_called_once_with(details) spy.provider_error.assert_called_once_with(details) spy.provider_stale.assert_called_once_with(details) def test_add_remove_event_handler(): # Given provider = NoOpProvider() set_provider(provider) spy = MagicMock() client = get_client() client.add_handler( ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed ) client.remove_handler( ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed ) provider_details = ProviderEventDetails(message="message") # When provider.emit_provider_configuration_changed(provider_details) # Then spy.provider_configuration_changed.assert_not_called() # Requirement 5.1.2, Requirement 5.1.3 def test_provider_event_late_binding(): # Given provider = NoOpProvider() set_provider(provider, "my-domain") other_provider = NoOpProvider() spy = MagicMock() client = get_client("my-domain") client.add_handler( ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, spy.provider_configuration_changed ) set_provider(other_provider, "my-domain") provider_details = ProviderEventDetails(message="message from provider") other_provider_details = ProviderEventDetails(message="message from other provider") details = EventDetails.from_provider_event_details( other_provider.get_metadata().name, other_provider_details ) # When provider.emit_provider_configuration_changed(provider_details) other_provider.emit_provider_configuration_changed(other_provider_details) # Then spy.provider_configuration_changed.assert_called_once_with(details) def test_client_handlers_thread_safety(): provider = NoOpProvider() set_provider(provider) def add_handlers_task(): def handler(*args, **kwargs): time.sleep(0.005) for _ in range(10): time.sleep(0.01) client = get_client(str(uuid.uuid4())) client.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler) def emit_events_task(): for _ in range(10): time.sleep(0.01) provider.emit_provider_configuration_changed(ProviderEventDetails()) with ThreadPoolExecutor(max_workers=2) as executor: f1 = executor.submit(add_handlers_task) f2 = executor.submit(emit_events_task) f1.result() f2.result() def test_client_should_merge_contexts(): api.clear_hooks() api.set_transaction_context_propagator(ContextVarsTransactionContextPropagator()) provider = NoOpProvider() provider.resolve_boolean_details = MagicMock(wraps=provider.resolve_boolean_details) api.set_provider(provider) # Global evaluation context global_context = EvaluationContext( targeting_key="global", attributes={"global_attr": "global_value"} ) api.set_evaluation_context(global_context) # Transaction context transaction_context = EvaluationContext( targeting_key="transaction", attributes={"transaction_attr": "transaction_value"}, ) api.set_transaction_context(transaction_context) # Client-specific context client_context = EvaluationContext( targeting_key="client", attributes={"client_attr": "client_value"} ) client = OpenFeatureClient(domain=None, version=None, context=client_context) # Invocation-specific context invocation_context = EvaluationContext( targeting_key="invocation", attributes={"invocation_attr": "invocation_value"} ) flag_input = "flag" flag_default = False client.get_boolean_details(flag_input, flag_default, invocation_context) # Retrieve the call arguments args, kwargs = provider.resolve_boolean_details.call_args flag_key, default_value, context = ( kwargs["flag_key"], kwargs["default_value"], kwargs["evaluation_context"], ) assert flag_key == flag_input assert default_value is flag_default assert context.targeting_key == "invocation" # Last one in the merge chain assert context.attributes["global_attr"] == "global_value" assert context.attributes["transaction_attr"] == "transaction_value" assert context.attributes["client_attr"] == "client_value" assert context.attributes["invocation_attr"] == "invocation_value"