SDK - Reduce python component limitations - no import errors for cust… (#3106)

* SDK - Reduce python component limitations - no import errors for custom type annotations

By default, create_component_from_func copies the source code of the function and creates a component using that source code. No global imports are captured. This is problematic for the function definition, since any annotation, that uses a type that needs to be imported, will cause error. There were some special provisions for
NamedTuple,  InputPath and OutputPath, but even they were brittle (for example, "typing.NamedTuple" or "components.InputPath" annotations still caused failures at runtime).

This commit fixes the issue by stripping the type annotations from function declarations.

Fixes cases that were failing before:

```python
import typing
import collections

MyFuncOutputs = typing.NamedTuple('Outputs', [('sum', int), ('product', int)])

@create_component_from_func
def my_func(
    param1: CustomType,  # This caused failure previously
    param2: collections.OrderedDict,  # This caused failure previously
) -> MyFuncOutputs: # This caused failure previously
    pass
```

* Fixed the compiler tests

* Fixed crashes on print function

Code `print(line, end="")` was causing error: "lib2to3.pgen2.parse.ParseError: bad input: type=22, value='=', context=('', (2, 15))"

* Using the strip_hints library to strip the annotations

* Updating test workflow yamls

* Workaround for bug in untokenize

* Switched to the new strip_string_to_string method

* Fixed typo.

Co-Authored-By: Jiaxiao Zheng <jxzheng@google.com>

Co-authored-by: Jiaxiao Zheng <jxzheng@google.com>
This commit is contained in:
Alexey Volkov 2020-02-24 20:50:48 -08:00 committed by GitHub
parent 2e39867be7
commit 578d8de91d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 92 additions and 11 deletions

View File

@ -191,6 +191,66 @@ import pickle
return function_loading_code
def strip_type_hints(source_code: str) -> str:
try:
return _strip_type_hints_using_strip_hints(source_code)
except Exception as ex:
print('Error when stripping type annotations: ' + str(ex))
try:
return _strip_type_hints_using_lib2to3(source_code)
except Exception as ex:
print('Error when stripping type annotations: ' + str(ex))
return source_code
def _strip_type_hints_using_strip_hints(source_code: str) -> str:
from strip_hints import strip_string_to_string
# Workaround for https://github.com/abarker/strip-hints/issues/4 , https://bugs.python.org/issue35107
# I could not repro it though
if source_code[-1] != '\n':
source_code += '\n'
return strip_string_to_string(source_code, to_empty=True)
def _strip_type_hints_using_lib2to3(source_code: str) -> str:
"""Strips type annotations from the function definitions in the provided source code."""
# Using the standard lib2to3 library to strip type annotations.
# Switch to another library like strip-hints if issues are found.
from lib2to3 import fixer_base, refactor, fixer_util
class StripAnnotations(fixer_base.BaseFix):
PATTERN = r'''
typed_func_parameter=tname
|
typed_func_return_value=funcdef< any+ '->' any+ >
'''
def transform(self, node, results):
if 'typed_func_parameter' in results:
# Delete the annotation part of the function parameter declaration
del node.children[1:]
elif 'typed_func_return_value' in results:
# Delete the return annotation part of the function declaration
del node.children[-4:-2]
return node
class Refactor(refactor.RefactoringTool):
def __init__(self, fixers):
self._fixers = [cls(None, None) for cls in fixers]
super().__init__(None, {'print_function': True})
def get_fixers(self):
return self._fixers, []
stripped_code = Refactor([StripAnnotations]).refactor_string(source_code, '')
return stripped_code
def _capture_function_code_using_source_copy(func) -> str:
import textwrap
@ -208,13 +268,13 @@ def _capture_function_code_using_source_copy(func) -> str:
if not func_code_lines:
raise ValueError('Failed to dedent and clean up the source of function "{}". It is probably not properly indented.'.format(func.__name__))
#TODO: Add support for copying the NamedTuple subclass declaration code
#Adding NamedTuple import if needed
if hasattr(inspect.signature(func).return_annotation, '_fields'): #NamedTuple
func_code_lines.insert(0, '')
func_code_lines.insert(0, 'from typing import NamedTuple')
func_code = '\n'.join(func_code_lines)
return '\n'.join(func_code_lines)
# Stripping type annotations to prevent import errors.
# The most common cases are InputPath/OutputPath and typing.NamedTuple annotations
func_code = strip_type_hints(func_code)
return func_code
def _extract_component_interface(func) -> ComponentSpec:
@ -416,7 +476,6 @@ def _func_to_component_spec(func, extra_code='', base_image : str = None, packag
def get_argparse_type_for_input_file(passing_style):
if passing_style is None:
return None
pre_func_definitions.add(inspect.getsource(passing_style))
if passing_style is InputPath:
return 'str'

View File

@ -16,3 +16,4 @@ jsonschema >= 3.0.1
tabulate == 0.8.3
click == 7.0
Deprecated
strip-hints

View File

@ -38,6 +38,7 @@ REQUIRES = [
'tabulate == 0.8.3',
'click == 7.0',
'Deprecated',
'strip-hints',
]
def find_version(*file_path_parts):

View File

@ -485,7 +485,7 @@ spec:
- "-u"
- "-c"
- |
def produce_list_of_dicts() -> list:
def produce_list_of_dicts() :
return ([{"aaa": "aaa1", "bbb": "bbb1"}, {"aaa": "aaa2", "bbb": "bbb2"}],)
def _serialize_json(obj) -> str:
@ -548,7 +548,7 @@ spec:
- "-u"
- "-c"
- |
def produce_list_of_ints() -> list:
def produce_list_of_ints() :
return ([1234567890, 987654321],)
def _serialize_json(obj) -> str:
@ -611,7 +611,7 @@ spec:
- "-u"
- "-c"
- |
def produce_list_of_strings() -> list:
def produce_list_of_strings() :
return (["a", "z"],)
def _serialize_json(obj) -> str:
@ -674,7 +674,7 @@ spec:
- "-u"
- "-c"
- |
def produce_str() -> str:
def produce_str() :
return "Hello"
def _serialize_str(str_value: str) -> str:

View File

@ -711,6 +711,26 @@ class PythonOpTestCase(unittest.TestCase):
},
)
def test_annotations_stripping(self):
import typing
import collections
MyFuncOutputs = typing.NamedTuple('Outputs', [('sum', int), ('product', int)])
class CustomType1:
pass
def my_func(
param1: CustomType1 = None, # This caused failure previously
param2: collections.OrderedDict = None, # This caused failure previously
) -> MyFuncOutputs: # This caused failure previously
assert param1 == None
assert param2 == None
return (8, 15)
task_factory = comp.create_component_from_func(my_func)
self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'sum': '8', 'product': '15'})
def test_file_input_name_conversion(self):
# Checking the input name conversion rules for file inputs: