feat: Add async functionality to providers (#413)

Signed-off-by: leohoare <leo@insight.co>
This commit is contained in:
Leo 2025-02-07 04:30:54 +11:00 committed by GitHub
parent 154d8345e7
commit 86e7c07112
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1008 additions and 140 deletions

View File

@ -316,6 +316,25 @@ async def some_endpoint():
return create_response() return create_response()
``` ```
### Asynchronous Feature Retrieval
The OpenFeature API supports asynchronous calls, enabling non-blocking feature evaluations for improved performance, especially useful in concurrent or latency-sensitive scenarios. If a provider *hasn't* implemented asynchronous calls, the client can still be used asynchronously, but calls will be blocking (synchronous).
```python
import asyncio
from openfeature import api
from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider
my_flags = { "v2_enabled": InMemoryFlag("on", {"on": True, "off": False}) }
api.set_provider(InMemoryProvider(my_flags))
client = api.get_client()
flag_value = await client.get_boolean_value_async("v2_enabled", False) # API calls are suffixed by _async
print("Value: " + str(flag_value))
```
See the [develop a provider](#develop-a-provider) for how to support asynchronous functionality in providers.
### Shutdown ### Shutdown
The OpenFeature API provides a shutdown function to perform a cleanup of all registered providers. This should only be called when your application is in the process of shutting down. The OpenFeature API provides a shutdown function to perform a cleanup of all registered providers. This should only be called when your application is in the process of shutting down.
@ -390,6 +409,56 @@ class MyProvider(AbstractProvider):
... ...
``` ```
Providers can also be extended to support async functionality.
To support add asynchronous calls to a provider:
* Implement the `AbstractProvider` as shown above.
* Define asynchronous calls for each data type.
```python
class MyProvider(AbstractProvider):
...
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
...
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
...
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
...
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
...
async def resolve_object_details_async(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
...
```
> Built a new provider? [Let us know](https://github.com/open-feature/openfeature.dev/issues/new?assignees=&labels=provider&projects=&template=document-provider.yaml&title=%5BProvider%5D%3A+) so we can add it to the docs! > Built a new provider? [Let us know](https://github.com/open-feature/openfeature.dev/issues/new?assignees=&labels=provider&projects=&template=document-provider.yaml&title=%5BProvider%5D%3A+) so we can add it to the docs!
### Develop a hook ### Develop a hook

View File

@ -20,7 +20,7 @@ from openfeature.flag_evaluation import (
FlagType, FlagType,
Reason, Reason,
) )
from openfeature.hook import Hook, HookContext from openfeature.hook import Hook, HookContext, HookHints
from openfeature.hook._hook_support import ( from openfeature.hook._hook_support import (
after_all_hooks, after_all_hooks,
after_hooks, after_hooks,
@ -55,6 +55,28 @@ GetDetailCallable = typing.Union[
FlagResolutionDetails[typing.Union[dict, list]], FlagResolutionDetails[typing.Union[dict, list]],
], ],
] ]
GetDetailCallableAsync = typing.Union[
typing.Callable[
[str, bool, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[bool]],
],
typing.Callable[
[str, int, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[int]],
],
typing.Callable[
[str, float, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[float]],
],
typing.Callable[
[str, str, typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[str]],
],
typing.Callable[
[str, typing.Union[dict, list], typing.Optional[EvaluationContext]],
typing.Awaitable[FlagResolutionDetails[typing.Union[dict, list]]],
],
]
TypeMap = typing.Dict[ TypeMap = typing.Dict[
FlagType, FlagType,
typing.Union[ typing.Union[
@ -113,6 +135,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
).value ).value
async def get_boolean_value_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> bool:
details = await self.get_boolean_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_boolean_details( def get_boolean_details(
self, self,
flag_key: str, flag_key: str,
@ -128,6 +165,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
) )
async def get_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[bool]:
return await self.evaluate_flag_details_async(
FlagType.BOOLEAN,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_string_value( def get_string_value(
self, self,
flag_key: str, flag_key: str,
@ -142,6 +194,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
).value ).value
async def get_string_value_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> str:
details = await self.get_string_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_string_details( def get_string_details(
self, self,
flag_key: str, flag_key: str,
@ -157,6 +224,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
) )
async def get_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[str]:
return await self.evaluate_flag_details_async(
FlagType.STRING,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_integer_value( def get_integer_value(
self, self,
flag_key: str, flag_key: str,
@ -171,6 +253,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
).value ).value
async def get_integer_value_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> int:
details = await self.get_integer_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_integer_details( def get_integer_details(
self, self,
flag_key: str, flag_key: str,
@ -186,6 +283,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
) )
async def get_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[int]:
return await self.evaluate_flag_details_async(
FlagType.INTEGER,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_float_value( def get_float_value(
self, self,
flag_key: str, flag_key: str,
@ -200,6 +312,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
).value ).value
async def get_float_value_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> float:
details = await self.get_float_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_float_details( def get_float_details(
self, self,
flag_key: str, flag_key: str,
@ -215,6 +342,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
) )
async def get_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[float]:
return await self.evaluate_flag_details_async(
FlagType.FLOAT,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def get_object_value( def get_object_value(
self, self,
flag_key: str, flag_key: str,
@ -229,6 +371,21 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
).value ).value
async def get_object_value_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> typing.Union[dict, list]:
details = await self.get_object_details_async(
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
return details.value
def get_object_details( def get_object_details(
self, self,
flag_key: str, flag_key: str,
@ -244,26 +401,35 @@ class OpenFeatureClient:
flag_evaluation_options, flag_evaluation_options,
) )
def evaluate_flag_details( # noqa: PLR0915 async def get_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[typing.Union[dict, list]]:
return await self.evaluate_flag_details_async(
FlagType.OBJECT,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
def _establish_hooks_and_provider(
self, self,
flag_type: FlagType, flag_type: FlagType,
flag_key: str, flag_key: str,
default_value: typing.Any, default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None, evaluation_context: typing.Optional[EvaluationContext],
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions],
) -> FlagEvaluationDetails[typing.Any]: ) -> typing.Tuple[
""" FeatureProvider,
Evaluate the flag requested by the user from the clients provider. HookContext,
HookHints,
:param flag_type: the type of the flag being returned typing.List[Hook],
:param flag_key: the string key of the selected flag typing.List[Hook],
:param default_value: backup value returned if no result found by the provider ]:
:param evaluation_context: Information for the purposes of flag evaluation
:param flag_evaluation_options: Additional flag evaluation information
:return: a FlagEvaluationDetails object with the fully evaluated flag from a
provider
"""
if evaluation_context is None: if evaluation_context is None:
evaluation_context = EvaluationContext() evaluation_context = EvaluationContext()
@ -295,39 +461,26 @@ class OpenFeatureClient:
reversed_merged_hooks = merged_hooks[:] reversed_merged_hooks = merged_hooks[:]
reversed_merged_hooks.reverse() reversed_merged_hooks.reverse()
try: return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks
def _assert_provider_status(
self,
) -> None:
status = self.get_provider_status() status = self.get_provider_status()
if status == ProviderStatus.NOT_READY: if status == ProviderStatus.NOT_READY:
error_hooks( raise ProviderNotReadyError()
flag_type,
hook_context,
ProviderNotReadyError(),
reversed_merged_hooks,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_NOT_READY,
)
return flag_evaluation
if status == ProviderStatus.FATAL: if status == ProviderStatus.FATAL:
error_hooks( raise ProviderFatalError()
flag_type, return None
hook_context,
ProviderFatalError(),
reversed_merged_hooks,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_FATAL,
)
return flag_evaluation
def _before_hooks_and_merge_context(
self,
flag_type: FlagType,
hook_context: HookContext,
merged_hooks: typing.List[Hook],
hook_hints: HookHints,
evaluation_context: typing.Optional[EvaluationContext],
) -> EvaluationContext:
# https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md
# Any resulting evaluation context from a before hook will overwrite # Any resulting evaluation context from a before hook will overwrite
# duplicate fields defined globally, on the client, or in the invocation. # duplicate fields defined globally, on the client, or in the invocation.
@ -335,6 +488,7 @@ class OpenFeatureClient:
invocation_context = before_hooks( invocation_context = before_hooks(
flag_type, hook_context, merged_hooks, hook_hints flag_type, hook_context, merged_hooks, hook_hints
) )
if evaluation_context:
invocation_context = invocation_context.merge(ctx2=evaluation_context) invocation_context = invocation_context.merge(ctx2=evaluation_context)
# Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context # Requirement 3.2.2 merge: API.context->transaction.context->client.context->invocation.context
@ -344,6 +498,143 @@ class OpenFeatureClient:
.merge(self.context) .merge(self.context)
.merge(invocation_context) .merge(invocation_context)
) )
return merged_context
async def evaluate_flag_details_async(
self,
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[typing.Any]:
"""
Evaluate the flag requested by the user from the clients provider.
:param flag_type: the type of the flag being returned
:param flag_key: the string key of the selected flag
:param default_value: backup value returned if no result found by the provider
:param evaluation_context: Information for the purposes of flag evaluation
:param flag_evaluation_options: Additional flag evaluation information
:return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a
provider
"""
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
)
try:
self._assert_provider_status()
merged_context = self._before_hooks_and_merge_context(
flag_type,
hook_context,
merged_hooks,
hook_hints,
evaluation_context,
)
flag_evaluation = await self._create_provider_evaluation_async(
provider,
flag_type,
flag_key,
default_value,
merged_context,
)
after_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
hook_hints,
)
return flag_evaluation
except OpenFeatureError as err:
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=err.error_code,
error_message=err.error_message,
)
return flag_evaluation
# Catch any type of exception here since the user can provide any exception
# in the error hooks
except Exception as err: # pragma: no cover
logger.exception(
"Unable to correctly evaluate flag with key: '%s'", flag_key
)
error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints)
error_message = getattr(err, "error_message", str(err))
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.GENERAL,
error_message=error_message,
)
return flag_evaluation
finally:
after_all_hooks(
flag_type,
hook_context,
flag_evaluation,
reversed_merged_hooks,
hook_hints,
)
def evaluate_flag_details(
self,
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None,
) -> FlagEvaluationDetails[typing.Any]:
"""
Evaluate the flag requested by the user from the clients provider.
:param flag_type: the type of the flag being returned
:param flag_key: the string key of the selected flag
:param default_value: backup value returned if no result found by the provider
:param evaluation_context: Information for the purposes of flag evaluation
:param flag_evaluation_options: Additional flag evaluation information
:return: a FlagEvaluationDetails object with the fully evaluated flag from a
provider
"""
provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = (
self._establish_hooks_and_provider(
flag_type,
flag_key,
default_value,
evaluation_context,
flag_evaluation_options,
)
)
try:
self._assert_provider_status()
merged_context = self._before_hooks_and_merge_context(
flag_type,
hook_context,
merged_hooks,
hook_hints,
evaluation_context,
)
flag_evaluation = self._create_provider_evaluation( flag_evaluation = self._create_provider_evaluation(
provider, provider,
@ -402,6 +693,48 @@ class OpenFeatureClient:
hook_hints, hook_hints,
) )
async def _create_provider_evaluation_async(
self,
provider: FeatureProvider,
flag_type: FlagType,
flag_key: str,
default_value: typing.Any,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagEvaluationDetails[typing.Any]:
args = (
flag_key,
default_value,
evaluation_context,
)
get_details_callables_async: typing.Mapping[
FlagType, GetDetailCallableAsync
] = {
FlagType.BOOLEAN: provider.resolve_boolean_details_async,
FlagType.INTEGER: provider.resolve_integer_details_async,
FlagType.FLOAT: provider.resolve_float_details_async,
FlagType.OBJECT: provider.resolve_object_details_async,
FlagType.STRING: provider.resolve_string_details_async,
}
get_details_callable = get_details_callables_async.get(flag_type)
if not get_details_callable:
raise GeneralError(error_message="Unknown flag type")
resolution = await get_details_callable(*args)
resolution.raise_for_error()
# we need to check the get_args to be compatible with union types.
_typecheck_flag_value(resolution.value, flag_type)
return FlagEvaluationDetails(
flag_key=flag_key,
value=resolution.value,
variant=resolution.variant,
flag_metadata=resolution.flag_metadata or {},
reason=resolution.reason,
error_code=resolution.error_code,
error_message=resolution.error_message,
)
def _create_provider_evaluation( def _create_provider_evaluation(
self, self,
provider: FeatureProvider, provider: FeatureProvider,

View File

@ -47,6 +47,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None, evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]: ... ) -> FlagResolutionDetails[bool]: ...
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]: ...
def resolve_string_details( def resolve_string_details(
self, self,
flag_key: str, flag_key: str,
@ -54,6 +61,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None, evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]: ... ) -> FlagResolutionDetails[str]: ...
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]: ...
def resolve_integer_details( def resolve_integer_details(
self, self,
flag_key: str, flag_key: str,
@ -61,6 +75,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None, evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]: ... ) -> FlagResolutionDetails[int]: ...
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]: ...
def resolve_float_details( def resolve_float_details(
self, self,
flag_key: str, flag_key: str,
@ -68,6 +89,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None, evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]: ... ) -> FlagResolutionDetails[float]: ...
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]: ...
def resolve_object_details( def resolve_object_details(
self, self,
flag_key: str, flag_key: str,
@ -75,6 +103,13 @@ class FeatureProvider(typing.Protocol): # pragma: no cover
evaluation_context: typing.Optional[EvaluationContext] = None, evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]: ... ) -> FlagResolutionDetails[typing.Union[dict, list]]: ...
async def resolve_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]: ...
class AbstractProvider(FeatureProvider): class AbstractProvider(FeatureProvider):
def attach( def attach(
@ -111,6 +146,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[bool]: ) -> FlagResolutionDetails[bool]:
pass pass
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return self.resolve_boolean_details(flag_key, default_value, evaluation_context)
@abstractmethod @abstractmethod
def resolve_string_details( def resolve_string_details(
self, self,
@ -120,6 +163,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[str]: ) -> FlagResolutionDetails[str]:
pass pass
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return self.resolve_string_details(flag_key, default_value, evaluation_context)
@abstractmethod @abstractmethod
def resolve_integer_details( def resolve_integer_details(
self, self,
@ -129,6 +180,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[int]: ) -> FlagResolutionDetails[int]:
pass pass
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return self.resolve_integer_details(flag_key, default_value, evaluation_context)
@abstractmethod @abstractmethod
def resolve_float_details( def resolve_float_details(
self, self,
@ -138,6 +197,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[float]: ) -> FlagResolutionDetails[float]:
pass pass
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return self.resolve_float_details(flag_key, default_value, evaluation_context)
@abstractmethod @abstractmethod
def resolve_object_details( def resolve_object_details(
self, self,
@ -147,6 +214,14 @@ class AbstractProvider(FeatureProvider):
) -> FlagResolutionDetails[typing.Union[dict, list]]: ) -> FlagResolutionDetails[typing.Union[dict, list]]:
pass pass
async def resolve_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self.resolve_object_details(flag_key, default_value, evaluation_context)
def emit_provider_ready(self, details: ProviderEventDetails) -> None: def emit_provider_ready(self, details: ProviderEventDetails) -> None:
self.emit(ProviderEvent.PROVIDER_READY, details) self.emit(ProviderEvent.PROVIDER_READY, details)

View File

@ -76,6 +76,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[bool]: ) -> FlagResolutionDetails[bool]:
return self._resolve(flag_key, evaluation_context) return self._resolve(flag_key, evaluation_context)
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_string_details( def resolve_string_details(
self, self,
flag_key: str, flag_key: str,
@ -84,6 +92,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[str]: ) -> FlagResolutionDetails[str]:
return self._resolve(flag_key, evaluation_context) return self._resolve(flag_key, evaluation_context)
async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_integer_details( def resolve_integer_details(
self, self,
flag_key: str, flag_key: str,
@ -92,6 +108,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[int]: ) -> FlagResolutionDetails[int]:
return self._resolve(flag_key, evaluation_context) return self._resolve(flag_key, evaluation_context)
async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_float_details( def resolve_float_details(
self, self,
flag_key: str, flag_key: str,
@ -100,6 +124,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[float]: ) -> FlagResolutionDetails[float]:
return self._resolve(flag_key, evaluation_context) return self._resolve(flag_key, evaluation_context)
async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return await self._resolve_async(flag_key, evaluation_context)
def resolve_object_details( def resolve_object_details(
self, self,
flag_key: str, flag_key: str,
@ -108,6 +140,14 @@ class InMemoryProvider(AbstractProvider):
) -> FlagResolutionDetails[typing.Union[dict, list]]: ) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self._resolve(flag_key, evaluation_context) return self._resolve(flag_key, evaluation_context)
async def resolve_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return await self._resolve_async(flag_key, evaluation_context)
def _resolve( def _resolve(
self, self,
flag_key: str, flag_key: str,
@ -117,3 +157,10 @@ class InMemoryProvider(AbstractProvider):
if flag is None: if flag is None:
raise FlagNotFoundError(f"Flag '{flag_key}' not found") raise FlagNotFoundError(f"Flag '{flag_key}' not found")
return flag.resolve(evaluation_context) return flag.resolve(evaluation_context)
async def _resolve_async(
self,
flag_key: str,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[V]:
return self._resolve(flag_key, evaluation_context)

View File

@ -17,16 +17,20 @@ def test_should_return_in_memory_provider_metadata():
assert metadata.name == "In-Memory Provider" assert metadata.name == "In-Memory Provider"
def test_should_handle_unknown_flags_correctly(): @pytest.mark.asyncio
async def test_should_handle_unknown_flags_correctly():
# Given # Given
provider = InMemoryProvider({}) provider = InMemoryProvider({})
# When # When
with pytest.raises(FlagNotFoundError): with pytest.raises(FlagNotFoundError):
provider.resolve_boolean_details(flag_key="Key", default_value=True) provider.resolve_boolean_details(flag_key="Key", default_value=True)
with pytest.raises(FlagNotFoundError):
await provider.resolve_integer_details_async(flag_key="Key", default_value=1)
# Then # Then
def test_calls_context_evaluator_if_present(): @pytest.mark.asyncio
async def test_calls_context_evaluator_if_present():
# Given # Given
def context_evaluator(flag: InMemoryFlag, evaluation_context: dict): def context_evaluator(flag: InMemoryFlag, evaluation_context: dict):
return FlagResolutionDetails( return FlagResolutionDetails(
@ -44,57 +48,81 @@ def test_calls_context_evaluator_if_present():
} }
) )
# When # When
flag = provider.resolve_boolean_details(flag_key="Key", default_value=False) flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False)
flag_async = await provider.resolve_boolean_details_async(
flag_key="Key", default_value=False
)
# Then # Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None assert flag is not None
assert flag.value is False assert flag.value is False
assert isinstance(flag.value, bool) assert isinstance(flag.value, bool)
assert flag.reason == Reason.TARGETING_MATCH assert flag.reason == Reason.TARGETING_MATCH
def test_should_resolve_boolean_flag_from_in_memory(): @pytest.mark.asyncio
async def test_should_resolve_boolean_flag_from_in_memory():
# Given # Given
provider = InMemoryProvider( provider = InMemoryProvider(
{"Key": InMemoryFlag("true", {"true": True, "false": False})} {"Key": InMemoryFlag("true", {"true": True, "false": False})}
) )
# When # When
flag = provider.resolve_boolean_details(flag_key="Key", default_value=False) flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=False)
flag_async = await provider.resolve_boolean_details_async(
flag_key="Key", default_value=False
)
# Then # Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None assert flag is not None
assert flag.value is True assert flag.value is True
assert isinstance(flag.value, bool) assert isinstance(flag.value, bool)
assert flag.variant == "true" assert flag.variant == "true"
def test_should_resolve_integer_flag_from_in_memory(): @pytest.mark.asyncio
async def test_should_resolve_integer_flag_from_in_memory():
# Given # Given
provider = InMemoryProvider( provider = InMemoryProvider(
{"Key": InMemoryFlag("hundred", {"zero": 0, "hundred": 100})} {"Key": InMemoryFlag("hundred", {"zero": 0, "hundred": 100})}
) )
# When # When
flag = provider.resolve_integer_details(flag_key="Key", default_value=0) flag_sync = provider.resolve_integer_details(flag_key="Key", default_value=0)
flag_async = await provider.resolve_integer_details_async(
flag_key="Key", default_value=0
)
# Then # Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None assert flag is not None
assert flag.value == 100 assert flag.value == 100
assert isinstance(flag.value, Number) assert isinstance(flag.value, Number)
assert flag.variant == "hundred" assert flag.variant == "hundred"
def test_should_resolve_float_flag_from_in_memory(): @pytest.mark.asyncio
async def test_should_resolve_float_flag_from_in_memory():
# Given # Given
provider = InMemoryProvider( provider = InMemoryProvider(
{"Key": InMemoryFlag("ten", {"zero": 0.0, "ten": 10.23})} {"Key": InMemoryFlag("ten", {"zero": 0.0, "ten": 10.23})}
) )
# When # When
flag = provider.resolve_float_details(flag_key="Key", default_value=0.0) flag_sync = provider.resolve_float_details(flag_key="Key", default_value=0.0)
flag_async = await provider.resolve_float_details_async(
flag_key="Key", default_value=0.0
)
# Then # Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None assert flag is not None
assert flag.value == 10.23 assert flag.value == 10.23
assert isinstance(flag.value, Number) assert isinstance(flag.value, Number)
assert flag.variant == "ten" assert flag.variant == "ten"
def test_should_resolve_string_flag_from_in_memory(): @pytest.mark.asyncio
async def test_should_resolve_string_flag_from_in_memory():
# Given # Given
provider = InMemoryProvider( provider = InMemoryProvider(
{ {
@ -105,29 +133,41 @@ def test_should_resolve_string_flag_from_in_memory():
} }
) )
# When # When
flag = provider.resolve_string_details(flag_key="Key", default_value="Default") flag_sync = provider.resolve_string_details(flag_key="Key", default_value="Default")
flag_async = await provider.resolve_string_details_async(
flag_key="Key", default_value="Default"
)
# Then # Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None assert flag is not None
assert flag.value == "String" assert flag.value == "String"
assert isinstance(flag.value, str) assert isinstance(flag.value, str)
assert flag.variant == "stringVariant" assert flag.variant == "stringVariant"
def test_should_resolve_list_flag_from_in_memory(): @pytest.mark.asyncio
async def test_should_resolve_list_flag_from_in_memory():
# Given # Given
provider = InMemoryProvider( provider = InMemoryProvider(
{"Key": InMemoryFlag("twoItems", {"empty": [], "twoItems": ["item1", "item2"]})} {"Key": InMemoryFlag("twoItems", {"empty": [], "twoItems": ["item1", "item2"]})}
) )
# When # When
flag = provider.resolve_object_details(flag_key="Key", default_value=[]) flag_sync = provider.resolve_object_details(flag_key="Key", default_value=[])
flag_async = await provider.resolve_object_details_async(
flag_key="Key", default_value=[]
)
# Then # Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None assert flag is not None
assert flag.value == ["item1", "item2"] assert flag.value == ["item1", "item2"]
assert isinstance(flag.value, list) assert isinstance(flag.value, list)
assert flag.variant == "twoItems" assert flag.variant == "twoItems"
def test_should_resolve_object_flag_from_in_memory(): @pytest.mark.asyncio
async def test_should_resolve_object_flag_from_in_memory():
# Given # Given
return_value = { return_value = {
"String": "string", "String": "string",
@ -138,8 +178,11 @@ def test_should_resolve_object_flag_from_in_memory():
{"Key": InMemoryFlag("obj", {"obj": return_value, "empty": {}})} {"Key": InMemoryFlag("obj", {"obj": return_value, "empty": {}})}
) )
# When # When
flag = provider.resolve_object_details(flag_key="Key", default_value={}) flag_sync = provider.resolve_object_details(flag_key="Key", default_value={})
flag_async = provider.resolve_object_details(flag_key="Key", default_value={})
# Then # Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None assert flag is not None
assert flag.value == return_value assert flag.value == return_value
assert isinstance(flag.value, dict) assert isinstance(flag.value, dict)

View File

@ -0,0 +1,197 @@
from typing import Optional, Union
import pytest
from openfeature.api import get_client, set_provider
from openfeature.evaluation_context import EvaluationContext
from openfeature.flag_evaluation import FlagResolutionDetails
from openfeature.provider import AbstractProvider, Metadata
class SynchronousProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="SynchronousProvider")
def get_provider_hooks(self):
return []
def resolve_boolean_details(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)
def resolve_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return FlagResolutionDetails(value="string")
def resolve_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return FlagResolutionDetails(value=1)
def resolve_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return FlagResolutionDetails(value=10.0)
def resolve_object_details(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
return FlagResolutionDetails(value={"key": "value"})
@pytest.mark.parametrize(
"flag_type, default_value, get_method",
(
(bool, True, "get_boolean_value_async"),
(str, "string", "get_string_value_async"),
(int, 1, "get_integer_value_async"),
(float, 10.0, "get_float_value_async"),
(
dict,
{"key": "value"},
"get_object_value_async",
),
),
)
@pytest.mark.asyncio
async def test_sync_provider_can_be_called_async(flag_type, default_value, get_method):
# Given
set_provider(SynchronousProvider(), "SynchronousProvider")
client = get_client("SynchronousProvider")
# When
async_callable = getattr(client, get_method)
flag = await async_callable(flag_key="Key", default_value=default_value)
# Then
assert flag is not None
assert flag == default_value
assert isinstance(flag, flag_type)
@pytest.mark.asyncio
async def test_sync_provider_can_be_extended_async():
# Given
class ExtendedAsyncProvider(SynchronousProvider):
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=False)
set_provider(ExtendedAsyncProvider(), "ExtendedAsyncProvider")
client = get_client("ExtendedAsyncProvider")
# When
flag = await client.get_boolean_value_async(flag_key="Key", default_value=True)
# Then
assert flag is not None
assert flag is False
# We're not allowing providers to only have async methods
def test_sync_methods_enforced_for_async_providers():
# Given
class AsyncProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="AsyncProvider")
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)
# When
with pytest.raises(TypeError) as exception:
set_provider(AsyncProvider(), "AsyncProvider")
# Then
# assert
exception_message = str(exception.value)
assert exception_message.startswith(
"Can't instantiate abstract class AsyncProvider"
)
assert exception_message.__contains__("resolve_boolean_details")
@pytest.mark.asyncio
async def test_async_provider_not_implemented_exception_workaround():
# Given
class SyncNotImplementedProvider(AbstractProvider):
def get_metadata(self):
return Metadata(name="AsyncProvider")
async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return FlagResolutionDetails(value=True)
def resolve_boolean_details(
self,
flag_key: str,
default_value: bool,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
raise NotImplementedError("Use the async method")
def resolve_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
raise NotImplementedError("Use the async method")
def resolve_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
raise NotImplementedError("Use the async method")
def resolve_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
raise NotImplementedError("Use the async method")
def resolve_object_details(
self,
flag_key: str,
default_value: Union[dict, list],
evaluation_context: Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[Union[dict, list]]:
raise NotImplementedError("Use the async method")
# When
set_provider(SyncNotImplementedProvider(), "SyncNotImplementedProvider")
client = get_client("SyncNotImplementedProvider")
flag = await client.get_boolean_value_async(flag_key="Key", default_value=False)
# Then
assert flag is not None
assert flag is True

View File

@ -1,3 +1,4 @@
import asyncio
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -7,7 +8,7 @@ import pytest
from openfeature import api from openfeature import api
from openfeature.api import add_hooks, clear_hooks, get_client, set_provider from openfeature.api import add_hooks, clear_hooks, get_client, set_provider
from openfeature.client import OpenFeatureClient from openfeature.client import GeneralError, OpenFeatureClient, _typecheck_flag_value
from openfeature.evaluation_context import EvaluationContext from openfeature.evaluation_context import EvaluationContext
from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails from openfeature.event import EventDetails, ProviderEvent, ProviderEventDetails
from openfeature.exception import ErrorCode, OpenFeatureError from openfeature.exception import ErrorCode, OpenFeatureError
@ -23,9 +24,13 @@ from openfeature.transaction_context import ContextVarsTransactionContextPropaga
"flag_type, default_value, get_method", "flag_type, default_value, get_method",
( (
(bool, True, "get_boolean_value"), (bool, True, "get_boolean_value"),
(bool, True, "get_boolean_value_async"),
(str, "String", "get_string_value"), (str, "String", "get_string_value"),
(str, "String", "get_string_value_async"),
(int, 100, "get_integer_value"), (int, 100, "get_integer_value"),
(int, 100, "get_integer_value_async"),
(float, 10.23, "get_float_value"), (float, 10.23, "get_float_value"),
(float, 10.23, "get_float_value_async"),
( (
dict, dict,
{ {
@ -35,21 +40,38 @@ from openfeature.transaction_context import ContextVarsTransactionContextPropaga
}, },
"get_object_value", "get_object_value",
), ),
(
dict,
{
"String": "string",
"Number": 2,
"Boolean": True,
},
"get_object_value_async",
),
( (
list, list,
["string1", "string2"], ["string1", "string2"],
"get_object_value", "get_object_value",
), ),
(
list,
["string1", "string2"],
"get_object_value_async",
),
), ),
) )
def test_should_get_flag_value_based_on_method_type( @pytest.mark.asyncio
async def test_should_get_flag_value_based_on_method_type(
flag_type, default_value, get_method, no_op_provider_client flag_type, default_value, get_method, no_op_provider_client
): ):
# Given # Given
# When # When
flag = getattr(no_op_provider_client, get_method)( method = getattr(no_op_provider_client, get_method)
flag_key="Key", default_value=default_value if asyncio.iscoroutinefunction(method):
) flag = await method(flag_key="Key", default_value=default_value)
else:
flag = method(flag_key="Key", default_value=default_value)
# Then # Then
assert flag is not None assert flag is not None
assert flag == default_value assert flag == default_value
@ -60,9 +82,13 @@ def test_should_get_flag_value_based_on_method_type(
"flag_type, default_value, get_method", "flag_type, default_value, get_method",
( (
(bool, True, "get_boolean_details"), (bool, True, "get_boolean_details"),
(bool, True, "get_boolean_details_async"),
(str, "String", "get_string_details"), (str, "String", "get_string_details"),
(str, "String", "get_string_details_async"),
(int, 100, "get_integer_details"), (int, 100, "get_integer_details"),
(int, 100, "get_integer_details_async"),
(float, 10.23, "get_float_details"), (float, 10.23, "get_float_details"),
(float, 10.23, "get_float_details_async"),
( (
dict, dict,
{ {
@ -72,34 +98,58 @@ def test_should_get_flag_value_based_on_method_type(
}, },
"get_object_details", "get_object_details",
), ),
(
dict,
{
"String": "string",
"Number": 2,
"Boolean": True,
},
"get_object_details_async",
),
( (
list, list,
["string1", "string2"], ["string1", "string2"],
"get_object_details", "get_object_details",
), ),
(
list,
["string1", "string2"],
"get_object_details_async",
),
), ),
) )
def test_should_get_flag_detail_based_on_method_type( @pytest.mark.asyncio
async def test_should_get_flag_detail_based_on_method_type(
flag_type, default_value, get_method, no_op_provider_client flag_type, default_value, get_method, no_op_provider_client
): ):
# Given # Given
# When # When
flag = getattr(no_op_provider_client, get_method)( method = getattr(no_op_provider_client, get_method)
flag_key="Key", default_value=default_value if asyncio.iscoroutinefunction(method):
) flag = await method(flag_key="Key", default_value=default_value)
else:
flag = method(flag_key="Key", default_value=default_value)
# Then # Then
assert flag is not None assert flag is not None
assert flag.value == default_value assert flag.value == default_value
assert isinstance(flag.value, flag_type) assert isinstance(flag.value, flag_type)
def test_should_raise_exception_when_invalid_flag_type_provided(no_op_provider_client): @pytest.mark.asyncio
async def test_should_raise_exception_when_invalid_flag_type_provided(
no_op_provider_client,
):
# Given # Given
# When # When
flag = no_op_provider_client.evaluate_flag_details( flag_sync = no_op_provider_client.evaluate_flag_details(
flag_type=None, flag_key="Key", default_value=True
)
flag_async = await no_op_provider_client.evaluate_flag_details_async(
flag_type=None, flag_key="Key", default_value=True flag_type=None, flag_key="Key", default_value=True
) )
# Then # Then
for flag in [flag_sync, flag_async]:
assert flag.value assert flag.value
assert flag.error_message == "Unknown flag type" assert flag.error_message == "Unknown flag type"
assert flag.error_code == ErrorCode.GENERAL assert flag.error_code == ErrorCode.GENERAL
@ -202,7 +252,8 @@ def test_should_define_a_provider_status_accessor(no_op_provider_client):
# Requirement 1.7.6 # Requirement 1.7.6
def test_should_shortcircuit_if_provider_is_not_ready( @pytest.mark.asyncio
async def test_should_shortcircuit_if_provider_is_not_ready(
no_op_provider_client, monkeypatch no_op_provider_client, monkeypatch
): ):
# Given # Given
@ -212,10 +263,16 @@ def test_should_shortcircuit_if_provider_is_not_ready(
spy_hook = MagicMock(spec=Hook) spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook]) no_op_provider_client.add_hooks([spy_hook])
# When # When
flag_details = no_op_provider_client.get_boolean_details( flag_details_sync = no_op_provider_client.get_boolean_details(
flag_key="Key", default_value=True
)
spy_hook.error.assert_called_once()
spy_hook.reset_mock()
flag_details_async = await no_op_provider_client.get_boolean_details_async(
flag_key="Key", default_value=True flag_key="Key", default_value=True
) )
# Then # Then
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None assert flag_details is not None
assert flag_details.value assert flag_details.value
assert flag_details.reason == Reason.ERROR assert flag_details.reason == Reason.ERROR
@ -225,7 +282,8 @@ def test_should_shortcircuit_if_provider_is_not_ready(
# Requirement 1.7.7 # Requirement 1.7.7
def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( @pytest.mark.asyncio
async def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
no_op_provider_client, monkeypatch no_op_provider_client, monkeypatch
): ):
# Given # Given
@ -235,10 +293,16 @@ def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
spy_hook = MagicMock(spec=Hook) spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook]) no_op_provider_client.add_hooks([spy_hook])
# When # When
flag_details = no_op_provider_client.get_boolean_details( flag_details_sync = no_op_provider_client.get_boolean_details(
flag_key="Key", default_value=True
)
spy_hook.error.assert_called_once()
spy_hook.reset_mock()
flag_details_async = await no_op_provider_client.get_boolean_details_async(
flag_key="Key", default_value=True flag_key="Key", default_value=True
) )
# Then # Then
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None assert flag_details is not None
assert flag_details.value assert flag_details.value
assert flag_details.reason == Reason.ERROR assert flag_details.reason == Reason.ERROR
@ -247,23 +311,32 @@ def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
spy_hook.finally_after.assert_called_once() spy_hook.finally_after.assert_called_once()
def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): @pytest.mark.asyncio
async def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code():
# Given # Given
spy_hook = MagicMock(spec=Hook) spy_hook = MagicMock(spec=Hook)
provider = MagicMock(spec=FeatureProvider) provider = MagicMock(spec=FeatureProvider)
provider.get_provider_hooks.return_value = [] provider.get_provider_hooks.return_value = []
provider.resolve_boolean_details.return_value = FlagResolutionDetails( mock_resolution = FlagResolutionDetails(
value=True, value=True,
reason=Reason.ERROR, reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_FATAL, error_code=ErrorCode.PROVIDER_FATAL,
error_message="This is an error message", error_message="This is an error message",
) )
provider.resolve_boolean_details.return_value = mock_resolution
provider.resolve_boolean_details_async.return_value = mock_resolution
set_provider(provider) set_provider(provider)
client = get_client() client = get_client()
client.add_hooks([spy_hook]) client.add_hooks([spy_hook])
# When # When
flag_details = client.get_boolean_details(flag_key="Key", default_value=True) flag_details_sync = client.get_boolean_details(flag_key="Key", default_value=True)
spy_hook.error.assert_called_once()
spy_hook.reset_mock()
flag_details_async = await client.get_boolean_details_async(
flag_key="Key", default_value=True
)
# Then # Then
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None assert flag_details is not None
assert flag_details.value assert flag_details.value
assert flag_details.reason == Reason.ERROR assert flag_details.reason == Reason.ERROR
@ -271,6 +344,37 @@ def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code()
spy_hook.error.assert_called_once() spy_hook.error.assert_called_once()
@pytest.mark.asyncio
async def test_client_type_mismatch_exceptions():
# Given
client = get_client()
# When
flag_details_sync = client.get_boolean_details(
flag_key="Key", default_value="type mismatch"
)
flag_details_async = await client.get_boolean_details_async(
flag_key="Key", default_value="type mismatch"
)
# Then
for flag_details in [flag_details_sync, flag_details_async]:
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.TYPE_MISMATCH
@pytest.mark.asyncio
async def test_client_general_exception():
# Given
flag_value = "A"
flag_type = None
# When
with pytest.raises(GeneralError) as e:
flag_type = _typecheck_flag_value(flag_value, flag_type)
# Then
assert e.value.error_message == "Unknown flag type"
def test_provider_events(): def test_provider_events():
# Given # Given
provider = NoOpProvider() provider = NoOpProvider()