80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import sys
|
|
import types
|
|
from typing import Any, Callable, Mapping, Sequence
|
|
from inspect import Parameter, Signature
|
|
|
|
|
|
class KwParameter(Parameter):
|
|
def __init__(self, name, default=Parameter.empty, annotation=Parameter.empty):
|
|
super().__init__(name, Parameter.POSITIONAL_OR_KEYWORD, default=default, annotation=annotation)
|
|
|
|
|
|
def create_function_from_parameter_names(func: Callable[[Mapping[str, Any]], Any], parameter_names: Sequence[str], documentation=None, func_name=None, func_filename=None):
|
|
return create_function_from_parameters(func, [KwParameter(name) for name in parameter_names], documentation, func_name, func_filename)
|
|
|
|
|
|
def create_function_from_parameters(func: Callable[[Mapping[str, Any]], Any], parameters: Sequence[Parameter], documentation=None, func_name=None, func_filename=None):
|
|
new_signature = Signature(parameters) # Checks the parameter consistency
|
|
|
|
def pass_locals():
|
|
return dict_func(locals()) # noqa: F821 TODO
|
|
|
|
code = pass_locals.__code__
|
|
mod_co_argcount = len(parameters)
|
|
mod_co_nlocals = len(parameters)
|
|
mod_co_varnames = tuple(param.name for param in parameters)
|
|
mod_co_name = func_name or code.co_name
|
|
if func_filename:
|
|
mod_co_filename = func_filename
|
|
mod_co_firstlineno = 1
|
|
else:
|
|
mod_co_filename = code.co_filename
|
|
mod_co_firstlineno = code.co_firstlineno
|
|
|
|
if sys.version_info >= (3, 8):
|
|
modified_code = code.replace(
|
|
co_argcount=mod_co_argcount,
|
|
co_nlocals=mod_co_nlocals,
|
|
co_varnames=mod_co_varnames,
|
|
co_filename=mod_co_filename,
|
|
co_name=mod_co_name,
|
|
co_firstlineno=mod_co_firstlineno,
|
|
)
|
|
else:
|
|
modified_code = types.CodeType(
|
|
mod_co_argcount,
|
|
code.co_kwonlyargcount,
|
|
mod_co_nlocals,
|
|
code.co_stacksize,
|
|
code.co_flags,
|
|
code.co_code,
|
|
code.co_consts,
|
|
code.co_names,
|
|
mod_co_varnames,
|
|
mod_co_filename,
|
|
mod_co_name,
|
|
mod_co_firstlineno,
|
|
code.co_lnotab
|
|
)
|
|
|
|
default_arg_values = tuple( p.default for p in parameters if p.default != Parameter.empty ) #!argdefs "starts from the right"/"is right-aligned"
|
|
modified_func = types.FunctionType(modified_code, {'dict_func': func, 'locals': locals}, name=func_name, argdefs=default_arg_values)
|
|
modified_func.__doc__ = documentation
|
|
modified_func.__signature__ = new_signature
|
|
|
|
return modified_func
|