feat: Add async functionality to providers (#413)

Signed-off-by: leohoare <leo@insight.co>
This commit is contained in:
Leo 2025-02-07 04:30:54 +11:00 committed by GitHub
parent 154d8345e7
commit 86e7c07112
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1008 additions and 140 deletions

View File

@ -316,6 +316,25 @@ async def some_endpoint():
return create_response()
```
### Asynchronous Feature Retrieval
The OpenFeature API supports asynchronous calls, enabling non-blocking feature evaluations for improved performance, especially useful in concurrent or latency-sensitive scenarios. If a provider *hasn't* implemented asynchronous calls, the client can still be used asynchronously, but calls will be blocking (synchronous).
```python
import asyncio
from openfeature import api
from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider
my_flags = { "v2_enabled": InMemoryFlag("on", {"on": True, "off": False}) }
api.set_provider(InMemoryProvider(my_flags))
client = api.get_client()
flag_value = await client.get_boolean_value_async("v2_enabled", False) # API calls are suffixed by _async
print("Value: " + str(flag_value))
```
See the [develop a provider](#develop-a-provider) for how to support asynchronous functionality in providers.
### Shutdown
The OpenFeature API provides a shutdown function to perform a cleanup of all registered providers. This should only be called when your application is in the process of shutting down.
@ -390,6 +409,56 @@ class MyProvider(AbstractProvider):
...
```
Providers can also be extended to support async functionality.
To support add asynchronous calls to a provider:
* Implement the `AbstractProvider` as shown above.
* Define asynchronous calls for each data type.
```python
class MyProvider(AbstractProvider):
...
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
...
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
...
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
...
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
...
async def resolve_object_details_async(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
...
```
> Built a new provider? [Let us know](https://github.com/open-feature/openfeature.dev/issues/new?assignees=&labels=provider&projects=&template=document-provider.yaml&title=%5BProvider%5D%3A+) so we can add it to the docs!
### Develop a hook

View File

@ -20,7 +20,7 @@ from openfeature.flag_evaluation import (
FlagType,
Reason,
)
from openfeature.hook import Hook, HookContext
from openfeature.hook import Hook, HookContext, HookHints
from openfeature.hook._hook_support import (
after_all_hooks,
after_hooks,
@ -55,6 +55,28 @@ GetDetailCallable = typing.Union[
FlagResolutionDetails[typing.Union[dict, list]],
],
]
GetDetailCallableAsync = typing.Union[
typing.Callable[
[str, bool, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[bool]],
],
typing.Callable[
[str, int, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[int]],
],
typing.Callable[
[str, float, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[float]],
],
typing.Callable[
[str, str, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[str]],
],
typing.Callable[
[str, typing.Union[dict, list], typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[typing.Union[dict, list]]],
],
]
TypeMap = typing.Dict[
FlagType,
typing.Union[
@ -113,6 +135,21 @@ class OpenFeatureClient:
flag_evaluation_options,
).value
async def get_boolean_value_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> bool:
details = await self.get_boolean_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_boolean_details(
self,
flag_key: str,
@ -128,6 +165,21 @@ class OpenFeatureClient:
flag_evaluation_options,
)
async def get_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[bool]:
return await self.evaluate_flag_details_async(
FlagType.BOOLEAN,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_string_value(
self,
flag_key: str,
@ -142,6 +194,21 @@ class OpenFeatureClient:
flag_evaluation_options,
).value
async def get_string_value_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> str:
details = await self.get_string_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_string_details(
self,
flag_key: str,
@ -157,6 +224,21 @@ class OpenFeatureClient:
flag_evaluation_options,
)
async def get_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[str]:
return await self.evaluate_flag_details_async(
FlagType.STRING,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_integer_value(
self,
flag_key: str,
@ -171,6 +253,21 @@ class OpenFeatureClient:
flag_evaluation_options,
).value
async def get_integer_value_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> int:
details = await self.get_integer_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_integer_details(
self,
flag_key: str,
@ -186,6 +283,21 @@ class OpenFeatureClient:
flag_evaluation_options,
)
async def get_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[int]:
return await self.evaluate_flag_details_async(
FlagType.INTEGER,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_float_value(
self,
flag_key: str,
@ -200,6 +312,21 @@ class OpenFeatureClient:
flag_evaluation_options,
).value
async def get_float_value_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> float:
details = await self.get_float_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_float_details(
self,
flag_key: str,
@ -215,6 +342,21 @@ class OpenFeatureClient:
flag_evaluation_options,
)
async def get_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[float]:
return await self.evaluate_flag_details_async(
FlagType.FLOAT,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_object_value(
self,
flag_key: str,
@ -229,6 +371,21 @@ class OpenFeatureClient:
flag_evaluation_options,
).value
async def get_object_value_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> typing.Union[dict, list]:
details = await self.get_object_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_object_details(
self,
flag_key: str,
@ -244,26 +401,35 @@ class OpenFeatureClient:
flag_evaluation_options,
)
def evaluate_flag_details( # noqa: PLR0915
async def get_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[typing.Union[dict, list]]:
return await self.evaluate_flag_details_async(
FlagType.OBJECT,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def _establish_hooks_and_provider(
self,
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[typing.Any]:
"""
Evaluate the flag requested by the user from the clients provider.
:param flag_type: the type of the flag being returned
:param flag_key: the string key of the selected flag
:param default_value: backup value returned if no result found by the provider
:param evaluation_context: Information for the purposes of flag evaluation
:param flag_evaluation_options: Additional flag evaluation information
:return: a FlagEvaluationDetails object with the fully evaluated flag from a
provider
"""
evaluation_context: typing.Optional[EvaluationContext],
flag_evaluation_options: typing.Optional[FlagEvaluationOptions],
) -> typing.Tuple[
FeatureProvider,
HookContext,
HookHints,
typing.List[Hook],
typing.List[Hook],
]:
if evaluation_context is None:
evaluation_context = EvaluationContext()
@ -295,54 +461,179 @@ class OpenFeatureClient:
reversed_merged_hooks = merged_hooks[:]
reversed_merged_hooks.reverse()
try:
status = self.get_provider_status()
if status == ProviderStatus.NOT_READY:
error_hooks(
flag_type,
hook_context,
ProviderNotReadyError(),
reversed_merged_hooks,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_NOT_READY,
)
return flag_evaluation
if status == ProviderStatus.FATAL:
error_hooks(
flag_type,
hook_context,
ProviderFatalError(),
reversed_merged_hooks,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_FATAL,
)
return flag_evaluation
return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks
# https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md
# Any resulting evaluation context from a before hook will overwrite
# duplicate fields defined globally, on the client, or in the invocation.
# Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context
invocation_context = before_hooks(
flag_type, hook_context, merged_hooks, hook_hints
)
def _assert_provider_status(
self,
) -> None:
status = self.get_provider_status()
if status == ProviderStatus.NOT_READY:
raise ProviderNotReadyError()
if status == ProviderStatus.FATAL:
raise ProviderFatalError()
return None
def _before_hooks_and_merge_context(
self,
flag_type: FlagType,
hook_context: HookContext,
merged_hooks: typing.List[Hook],
hook_hints: HookHints,
evaluation_context: typing.Optional[EvaluationContext],
) -> EvaluationContext:
# https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md
# Any resulting evaluation context from a before hook will overwrite
# duplicate fields defined globally, on the client, or in the invocation.
# Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context
invocation_context = before_hooks(
flag_type, hook_context, merged_hooks, hook_hints
)
if evaluation_context:
invocation_context = invocation_context.merge(ctx2=evaluation_context)
# Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context
merged_context = (
api.get_evaluation_context()
.merge(api.get_transaction_context())
.merge(self.context)
.merge(invocation_context)
# Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context
merged_context = (
api.get_evaluation_context()
.merge(api.get_transaction_context())
.merge(self.context)
.merge(invocation_context)
)
return merged_context
async def evaluate_flag_details_async(
self,
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[typing.Any]:
"""
Evaluate the flag requested by the user from the clients provider.
:param flag_type: the type of the flag being returned
:param flag_key: the string key of the selected flag
:param default_value: backup value returned if no result found by the provider
:param evaluation_context: Information for the purposes of flag evaluation
:param flag_evaluation_options: Additional flag evaluation information
:return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a
provider
"""
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
)
try:
self._assert_provider_status()
merged_context = self._before_hooks_and_merge_context(
flag_type,
hook_context,
merged_hooks,
hook_hints,
evaluation_context,
)
flag_evaluation = await self._create_provider_evaluation_async(
provider,
flag_type,
flag_key,
default_value,
merged_context,
)
after_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
hook_hints,
)
return flag_evaluation
except OpenFeatureError as err:
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=err.error_code,
error_message=err.error_message,
)
return flag_evaluation
# Catch any type of exception here since the user can provide any exception
# in the error hooks
except Exception as err: # pragma: no cover
logger.exception(
"Unable to correctly evaluate flag with key: '%s'", flag_key
)
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
error_message = getattr(err, "error_message", str(err))
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.GENERAL,
error_message=error_message,
)
return flag_evaluation
finally:
after_all_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
hook_hints,
)
def evaluate_flag_details(
self,
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[typing.Any]:
"""
Evaluate the flag requested by the user from the clients provider.
:param flag_type: the type of the flag being returned
:param flag_key: the string key of the selected flag
:param default_value: backup value returned if no result found by the provider
:param evaluation_context: Information for the purposes of flag evaluation
:param flag_evaluation_options: Additional flag evaluation information
:return: a FlagEvaluationDetails object with the fully evaluated flag from a
provider
"""
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
)
try:
self._assert_provider_status()
merged_context = self._before_hooks_and_merge_context(
flag_type,
hook_context,
merged_hooks,
hook_hints,
evaluation_context,
)
flag_evaluation = self._create_provider_evaluation(
@ -402,6 +693,48 @@ class OpenFeatureClient:
hook_hints,
)
async def _create_provider_evaluation_async(
self,
provider: FeatureProvider,
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagEvaluationDetails[typing.Any]:
args = (
flag_key,
default_value,
evaluation_context,
)
get_details_callables_async: typing.Mapping[
FlagType, GetDetailCallableAsync
] = {
FlagType.BOOLEAN: provider.resolve_boolean_details_async,
FlagType.INTEGER: provider.resolve_integer_details_async,
FlagType.FLOAT: provider.resolve_float_details_async,
FlagType.OBJECT: provider.resolve_object_details_async,
FlagType.STRING: provider.resolve_string_details_async,
}
get_details_callable = get_details_callables_async.get(flag_type)
if not get_details_callable:
raise GeneralError(error_message="Unknown flag type")
resolution = await get_details_callable(*args)
resolution.raise_for_error()
# we need to check the get_args to be compatible with union types.
_typecheck_flag_value(resolution.value, flag_type)
return FlagEvaluationDetails(
flag_key=flag_key,
value=resolution.value,
variant=resolution.variant,
flag_metadata=resolution.flag_metadata or {},
reason=resolution.reason,
error_code=resolution.error_code,
error_message=resolution.error_message,
)
def _create_provider_evaluation(
self,
provider: FeatureProvider,

View File

@ -47,6 +47,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]: ...
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]: ...
def resolve_string_details(
self,
flag_key: str,
@ -54,6 +61,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]: ...
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]: ...
def resolve_integer_details(
self,
flag_key: str,
@ -61,6 +75,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]: ...
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]: ...
def resolve_float_details(
self,
flag_key: str,
@ -68,6 +89,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]: ...
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]: ...
def resolve_object_details(
self,
flag_key: str,
@ -75,6 +103,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]: ...
async def resolve_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]: ...
class AbstractProvider(FeatureProvider):
def attach(
@ -111,6 +146,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[bool]:
pass
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return self.resolve_boolean_details(flag_key, default_value, evaluation_context)
@abstractmethod
def resolve_string_details(
self,
@ -120,6 +163,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[str]:
pass
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return self.resolve_string_details(flag_key, default_value, evaluation_context)
@abstractmethod
def resolve_integer_details(
self,
@ -129,6 +180,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[int]:
pass
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return self.resolve_integer_details(flag_key, default_value, evaluation_context)
@abstractmethod
def resolve_float_details(
self,
@ -138,6 +197,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[float]:
pass
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return self.resolve_float_details(flag_key, default_value, evaluation_context)
@abstractmethod
def resolve_object_details(
self,
@ -147,6 +214,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[typing.Union[dict, list]]:
pass
async def resolve_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self.resolve_object_details(flag_key, default_value, evaluation_context)
def emit_provider_ready(self, details: ProviderEventDetails) -> None:
self.emit(ProviderEvent.PROVIDER_READY, details)

View File

@ -76,6 +76,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[bool]:
return self._resolve(flag_key, evaluation_context)
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_string_details(
self,
flag_key: str,
@ -84,6 +92,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[str]:
return self._resolve(flag_key, evaluation_context)
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_integer_details(
self,
flag_key: str,
@ -92,6 +108,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[int]:
return self._resolve(flag_key, evaluation_context)
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_float_details(
self,
flag_key: str,
@ -100,6 +124,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[float]:
return self._resolve(flag_key, evaluation_context)
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_object_details(
self,
flag_key: str,
@ -108,6 +140,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self._resolve(flag_key, evaluation_context)
async def resolve_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return await self._resolve_async(flag_key, evaluation_context)
def _resolve(
self,
flag_key: str,
@ -117,3 +157,10 @@ class InMemoryProvider(AbstractProvider):
if flag is None:
raise FlagNotFoundError(f"Flag '{flag_key}' not found")
return flag.resolve(evaluation_context)
async def _resolve_async(
self,
flag_key: str,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[V]:
return self._resolve(flag_key, evaluation_context)

View File

@ -17,16 +17,20 @@ def test_should_return_in_memory_provider_metadata():
assert metadata.name == "In-Memory Provider"
def test_should_handle_unknown_flags_correctly():
@pytest.mark.asyncio
async def test_should_handle_unknown_flags_correctly():
# Given
provider = InMemoryProvider({})
# When
with pytest.raises(FlagNotFoundError):
provider.resolve_boolean_details(flag_key="Key", default_value=True)
with pytest.raises(FlagNotFoundError):
await provider.resolve_integer_details_async(flag_key="Key", default_value=1)
# Then
def test_calls_context_evaluator_if_present():
@pytest.mark.asyncio
async def test_calls_context_evaluator_if_present():
# Given
def context_evaluator(flag: InMemoryFlag, evaluation_context: dict):
return FlagResolutionDetails(
@ -44,57 +48,81 @@ def test_calls_context_evaluator_if_present():
}
)
# When
flag = provider.resolve_boolean_details(flag_key="Key", default_value=False)
flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False)
flag_async = await provider.resolve_boolean_details_async(
flag_key="Key", default_value=False
)
# Then
assert flag is not None
assert flag.value is False
assert isinstance(flag.value, bool)
assert flag.reason == Reason.TARGETING_MATCH
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value is False
assert isinstance(flag.value, bool)
assert flag.reason == Reason.TARGETING_MATCH
def test_should_resolve_boolean_flag_from_in_memory():
@pytest.mark.asyncio
async def test_should_resolve_boolean_flag_from_in_memory():
# Given
provider = InMemoryProvider(
{"Key": InMemoryFlag("true", {"true": True, "false": False})}
)
# When
flag = provider.resolve_boolean_details(flag_key="Key", default_value=False)
flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False)
flag_async = await provider.resolve_boolean_details_async(
flag_key="Key", default_value=False
)
# Then
assert flag is not None
assert flag.value is True
assert isinstance(flag.value, bool)
assert flag.variant == "true"
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value is True
assert isinstance(flag.value, bool)
assert flag.variant == "true"
def test_should_resolve_integer_flag_from_in_memory():
@pytest.mark.asyncio
async def test_should_resolve_integer_flag_from_in_memory():
# Given
provider = InMemoryProvider(
{"Key": InMemoryFlag("hundred", {"zero": 0, "hundred": 100})}
)
# When
flag = provider.resolve_integer_details(flag_key="Key", default_value=0)
flag_sync = provider.resolve_integer_details(flag_key="Key", default_value=0)
flag_async = await provider.resolve_integer_details_async(
flag_key="Key", default_value=0
)
# Then
assert flag is not None
assert flag.value == 100
assert isinstance(flag.value, Number)
assert flag.variant == "hundred"
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value == 100
assert isinstance(flag.value, Number)
assert flag.variant == "hundred"
def test_should_resolve_float_flag_from_in_memory():
@pytest.mark.asyncio
async def test_should_resolve_float_flag_from_in_memory():
# Given
provider = InMemoryProvider(
{"Key": InMemoryFlag("ten", {"zero": 0.0, "ten": 10.23})}
)
# When
flag = provider.resolve_float_details(flag_key="Key", default_value=0.0)
flag_sync = provider.resolve_float_details(flag_key="Key", default_value=0.0)
flag_async = await provider.resolve_float_details_async(
flag_key="Key", default_value=0.0
)
# Then
assert flag is not None
assert flag.value == 10.23
assert isinstance(flag.value, Number)
assert flag.variant == "ten"
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value == 10.23
assert isinstance(flag.value, Number)
assert flag.variant == "ten"
def test_should_resolve_string_flag_from_in_memory():
@pytest.mark.asyncio
async def test_should_resolve_string_flag_from_in_memory():
# Given
provider = InMemoryProvider(
{
@ -105,29 +133,41 @@ def test_should_resolve_string_flag_from_in_memory():
}
)
# When
flag = provider.resolve_string_details(flag_key="Key", default_value="Default")
flag_sync = provider.resolve_string_details(flag_key="Key", default_value="Default")
flag_async = await provider.resolve_string_details_async(
flag_key="Key", default_value="Default"
)
# Then
assert flag is not None
assert flag.value == "String"
assert isinstance(flag.value, str)
assert flag.variant == "stringVariant"
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value == "String"
assert isinstance(flag.value, str)
assert flag.variant == "stringVariant"
def test_should_resolve_list_flag_from_in_memory():
@pytest.mark.asyncio
async def test_should_resolve_list_flag_from_in_memory():
# Given
provider = InMemoryProvider(
{"Key": InMemoryFlag("twoItems", {"empty": [], "twoItems": ["item1", "item2"]})}
)
# When
flag = provider.resolve_object_details(flag_key="Key", default_value=[])
flag_sync = provider.resolve_object_details(flag_key="Key", default_value=[])
flag_async = await provider.resolve_object_details_async(
flag_key="Key", default_value=[]
)
# Then
assert flag is not None
assert flag.value == ["item1", "item2"]
assert isinstance(flag.value, list)
assert flag.variant == "twoItems"
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value == ["item1", "item2"]
assert isinstance(flag.value, list)
assert flag.variant == "twoItems"
def test_should_resolve_object_flag_from_in_memory():
@pytest.mark.asyncio
async def test_should_resolve_object_flag_from_in_memory():
# Given
return_value = {
"String": "string",
@ -138,9 +178,12 @@ def test_should_resolve_object_flag_from_in_memory():
{"Key": InMemoryFlag("obj", {"obj": return_value, "empty": {}})}
)
# When
flag = provider.resolve_object_details(flag_key="Key", default_value={})
flag_sync = provider.resolve_object_details(flag_key="Key", default_value={})
flag_async = provider.resolve_object_details(flag_key="Key", default_value={})
# Then
assert flag is not None
assert flag.value == return_value
assert isinstance(flag.value, dict)
assert flag.variant == "obj"
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value == return_value
assert isinstance(flag.value, dict)
assert flag.variant == "obj"

View File

@ -0,0 +1,197 @@
from typing import Optional, Union
import pytest
from openfeature.api import get_client, set_provider
from openfeature.evaluation_context import EvaluationContext
from openfeature.flag_evaluation import FlagResolutionDetails
from openfeature.provider import AbstractProvider, Metadata
class SynchronousProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="SynchronousProvider")
def get_provider_hooks(self):
return []
def resolve_boolean_details(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)
def resolve_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return FlagResolutionDetails(value="string")
def resolve_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return FlagResolutionDetails(value=1)
def resolve_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return FlagResolutionDetails(value=10.0)
def resolve_object_details(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
return FlagResolutionDetails(value={"key": "value"})
@pytest.mark.parametrize(
"flag_type, default_value, get_method",
(
(bool, True, "get_boolean_value_async"),
(str, "string", "get_string_value_async"),
(int, 1, "get_integer_value_async"),
(float, 10.0, "get_float_value_async"),
(
dict,
{"key": "value"},
"get_object_value_async",
),
),
)
@pytest.mark.asyncio
async def test_sync_provider_can_be_called_async(flag_type, default_value, get_method):
# Given
set_provider(SynchronousProvider(), "SynchronousProvider")
client = get_client("SynchronousProvider")
# When
async_callable = getattr(client, get_method)
flag = await async_callable(flag_key="Key", default_value=default_value)
# Then
assert flag is not None
assert flag == default_value
assert isinstance(flag, flag_type)
@pytest.mark.asyncio
async def test_sync_provider_can_be_extended_async():
# Given
class ExtendedAsyncProvider(SynchronousProvider):
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=False)
set_provider(ExtendedAsyncProvider(), "ExtendedAsyncProvider")
client = get_client("ExtendedAsyncProvider")
# When
flag = await client.get_boolean_value_async(flag_key="Key", default_value=True)
# Then
assert flag is not None
assert flag is False
# We're not allowing providers to only have async methods
def test_sync_methods_enforced_for_async_providers():
# Given
class AsyncProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="AsyncProvider")
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)
# When
with pytest.raises(TypeError) as exception:
set_provider(AsyncProvider(), "AsyncProvider")
# Then
# assert
exception_message = str(exception.value)
assert exception_message.startswith(
"Can't instantiate abstract class AsyncProvider"
)
assert exception_message.__contains__("resolve_boolean_details")
@pytest.mark.asyncio
async def test_async_provider_not_implemented_exception_workaround():
# Given
class SyncNotImplementedProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="AsyncProvider")
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)
def resolve_boolean_details(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
raise NotImplementedError("Use the async method")
def resolve_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
raise NotImplementedError("Use the async method")
def resolve_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
raise NotImplementedError("Use the async method")
def resolve_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
raise NotImplementedError("Use the async method")
def resolve_object_details(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
raise NotImplementedError("Use the async method")
# When
set_provider(SyncNotImplementedProvider(), "SyncNotImplementedProvider")
client = get_client("SyncNotImplementedProvider")
flag = await client.get_boolean_value_async(flag_key="Key", default_value=False)
# Then
assert flag is not None
assert flag is True

View File

@ -1,3 +1,4 @@
import asyncio
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
@ -7,7 +8,7 @@ import pytest
from openfeature import api
from openfeature.api import add_hooks, clear_hooks, get_client, set_provider
from openfeature.client import OpenFeatureClient
from openfeature.client import GeneralError, OpenFeatureClient, _typecheck_flag_value
from openfeature.evaluation_context import EvaluationContext
from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails
from openfeature.exception import ErrorCode, OpenFeatureError
@ -23,9 +24,13 @@ from openfeature.transaction_context import ContextVarsTransactionContextPropaga
"flag_type, default_value, get_method",
(
(bool, True, "get_boolean_value"),
(bool, True, "get_boolean_value_async"),
(str, "String", "get_string_value"),
(str, "String", "get_string_value_async"),
(int, 100, "get_integer_value"),
(int, 100, "get_integer_value_async"),
(float, 10.23, "get_float_value"),
(float, 10.23, "get_float_value_async"),
(
dict,
{
@ -35,21 +40,38 @@ from openfeature.transaction_context import ContextVarsTransactionContextPropaga
},
"get_object_value",
),
(
dict,
{
"String": "string",
"Number": 2,
"Boolean": True,
},
"get_object_value_async",
),
(
list,
["string1", "string2"],
"get_object_value",
),
(
list,
["string1", "string2"],
"get_object_value_async",
),
),
)
def test_should_get_flag_value_based_on_method_type(
@pytest.mark.asyncio
async def test_should_get_flag_value_based_on_method_type(
flag_type, default_value, get_method, no_op_provider_client
):
# Given
# When
flag = getattr(no_op_provider_client, get_method)(
flag_key="Key", default_value=default_value
)
method = getattr(no_op_provider_client, get_method)
if asyncio.iscoroutinefunction(method):
flag = await method(flag_key="Key", default_value=default_value)
else:
flag = method(flag_key="Key", default_value=default_value)
# Then
assert flag is not None
assert flag == default_value
@ -60,9 +82,13 @@ def test_should_get_flag_value_based_on_method_type(
"flag_type, default_value, get_method",
(
(bool, True, "get_boolean_details"),
(bool, True, "get_boolean_details_async"),
(str, "String", "get_string_details"),
(str, "String", "get_string_details_async"),
(int, 100, "get_integer_details"),
(int, 100, "get_integer_details_async"),
(float, 10.23, "get_float_details"),
(float, 10.23, "get_float_details_async"),
(
dict,
{
@ -72,38 +98,62 @@ def test_should_get_flag_value_based_on_method_type(
},
"get_object_details",
),
(
dict,
{
"String": "string",
"Number": 2,
"Boolean": True,
},
"get_object_details_async",
),
(
list,
["string1", "string2"],
"get_object_details",
),
(
list,
["string1", "string2"],
"get_object_details_async",
),
),
)
def test_should_get_flag_detail_based_on_method_type(
@pytest.mark.asyncio
async def test_should_get_flag_detail_based_on_method_type(
flag_type, default_value, get_method, no_op_provider_client
):
# Given
# When
flag = getattr(no_op_provider_client, get_method)(
flag_key="Key", default_value=default_value
)
method = getattr(no_op_provider_client, get_method)
if asyncio.iscoroutinefunction(method):
flag = await method(flag_key="Key", default_value=default_value)
else:
flag = method(flag_key="Key", default_value=default_value)
# Then
assert flag is not None
assert flag.value == default_value
assert isinstance(flag.value, flag_type)
def test_should_raise_exception_when_invalid_flag_type_provided(no_op_provider_client):
@pytest.mark.asyncio
async def test_should_raise_exception_when_invalid_flag_type_provided(
no_op_provider_client,
):
# Given
# When
flag = no_op_provider_client.evaluate_flag_details(
flag_sync = no_op_provider_client.evaluate_flag_details(
flag_type=None, flag_key="Key", default_value=True
)
flag_async = await no_op_provider_client.evaluate_flag_details_async(
flag_type=None, flag_key="Key", default_value=True
)
# Then
assert flag.value
assert flag.error_message == "Unknown flag type"
assert flag.error_code == ErrorCode.GENERAL
assert flag.reason == Reason.ERROR
for flag in [flag_sync, flag_async]:
assert flag.value
assert flag.error_message == "Unknown flag type"
assert flag.error_code == ErrorCode.GENERAL
assert flag.reason == Reason.ERROR
def test_should_pass_flag_metadata_from_resolution_to_evaluation_details():
@ -202,7 +252,8 @@ def test_should_define_a_provider_status_accessor(no_op_provider_client):
# Requirement 1.7.6
def test_should_shortcircuit_if_provider_is_not_ready(
@pytest.mark.asyncio
async def test_should_shortcircuit_if_provider_is_not_ready(
no_op_provider_client, monkeypatch
):
# Given
@ -212,20 +263,27 @@ def test_should_shortcircuit_if_provider_is_not_ready(
spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook])
# When
flag_details = no_op_provider_client.get_boolean_details(
flag_details_sync = no_op_provider_client.get_boolean_details(
flag_key="Key", default_value=True
)
spy_hook.error.assert_called_once()
spy_hook.reset_mock()
flag_details_async = await no_op_provider_client.get_boolean_details_async(
flag_key="Key", default_value=True
)
# Then
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY
spy_hook.error.assert_called_once()
spy_hook.finally_after.assert_called_once()
# Requirement 1.7.7
def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
@pytest.mark.asyncio
async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
no_op_provider_client, monkeypatch
):
# Given
@ -235,40 +293,86 @@ def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook])
# When
flag_details = no_op_provider_client.get_boolean_details(
flag_details_sync = no_op_provider_client.get_boolean_details(
flag_key="Key", default_value=True
)
spy_hook.error.assert_called_once()
spy_hook.reset_mock()
flag_details_async = await no_op_provider_client.get_boolean_details_async(
flag_key="Key", default_value=True
)
# Then
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_FATAL
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_FATAL
spy_hook.error.assert_called_once()
spy_hook.finally_after.assert_called_once()
def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code():
@pytest.mark.asyncio
async def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code():
# Given
spy_hook = MagicMock(spec=Hook)
provider = MagicMock(spec=FeatureProvider)
provider.get_provider_hooks.return_value = []
provider.resolve_boolean_details.return_value = FlagResolutionDetails(
mock_resolution = FlagResolutionDetails(
value=True,
reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_FATAL,
error_message="This is an error message",
)
provider.resolve_boolean_details.return_value = mock_resolution
provider.resolve_boolean_details_async.return_value = mock_resolution
set_provider(provider)
client = get_client()
client.add_hooks([spy_hook])
# When
flag_details = client.get_boolean_details(flag_key="Key", default_value=True)
# Then
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_FATAL
flag_details_sync = client.get_boolean_details(flag_key="Key", default_value=True)
spy_hook.error.assert_called_once()
spy_hook.reset_mock()
flag_details_async = await client.get_boolean_details_async(
flag_key="Key", default_value=True
)
# Then
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_FATAL
spy_hook.error.assert_called_once()
@pytest.mark.asyncio
async def test_client_type_mismatch_exceptions():
# Given
client = get_client()
# When
flag_details_sync = client.get_boolean_details(
flag_key="Key", default_value="type mismatch"
)
flag_details_async = await client.get_boolean_details_async(
flag_key="Key", default_value="type mismatch"
)
# Then
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.TYPE_MISMATCH
@pytest.mark.asyncio
async def test_client_general_exception():
# Given
flag_value = "A"
flag_type = None
# When
with pytest.raises(GeneralError) as e:
flag_type = _typecheck_flag_value(flag_value, flag_type)
# Then
assert e.value.error_message == "Unknown flag type"
def test_provider_events():