pipelines/test/sdk-execution-tests/sdk_execution_tests.py

126 lines
4.5 KiB
Python

# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import functools
import os
import sys
from typing import Any, Dict, List, Tuple
from kfp import client
from kfp import dsl # noqa
import kfp_server_api
import pytest
import yaml
KFP_ENDPOINT = os.environ['KFP_ENDPOINT']
TIMEOUT_SECONDS = os.environ['TIMEOUT_SECONDS']
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
PROJECT_ROOT = os.path.abspath(
os.path.join(CURRENT_DIR, *([os.path.pardir] * 2)))
CONFIG_PATH = os.path.join(PROJECT_ROOT, 'sdk', 'python', 'test_data',
'test_data_config.yaml')
kfp_client = client.Client(host=KFP_ENDPOINT)
@dataclasses.dataclass
class TestCase:
name: str
module_path: str
yaml_path: str
function_name: str
arguments: Dict[str, Any]
def create_test_case_parameters() -> List[TestCase]:
parameters: List[TestCase] = []
with open(CONFIG_PATH) as f:
config = yaml.safe_load(f)
for name, test_group in config.items():
test_data_dir = os.path.join(PROJECT_ROOT, test_group['test_data_dir'])
parameters.extend(
TestCase(
name=name + '-' + test_case['module'],
module_path=os.path.join(test_data_dir,
f'{test_case["module"]}.py'),
yaml_path=os.path.join(test_data_dir,
f'{test_case["module"]}.yaml'),
function_name=test_case['name'],
arguments=test_case.get('arguments'),
) for test_case in test_group['test_cases'] if test_case['execute'])
return parameters
def wait(run_result: client.client.RunPipelineResult) -> kfp_server_api.V2beta1Run:
return kfp_client.wait_for_run_completion(
run_id=run_result.run_id, timeout=int(TIMEOUT_SECONDS))
def import_obj_from_file(python_path: str, obj_name: str) -> Any:
sys.path.insert(0, os.path.dirname(python_path))
module_name = os.path.splitext(os.path.split(python_path)[1])[0]
module = __import__(module_name, fromlist=[obj_name])
if not hasattr(module, obj_name):
raise ValueError(
f'Object "{obj_name}" not found in module {python_path}.')
return getattr(module, obj_name)
def run(test_case: TestCase) -> Tuple[str, client.client.RunPipelineResult]:
full_path = os.path.join(PROJECT_ROOT, test_case.module_path)
pipeline_func = import_obj_from_file(full_path, test_case.function_name)
run_result = kfp_client.create_run_from_pipeline_func(
pipeline_func,
enable_caching=True,
arguments=test_case.arguments,
)
run_url = f'{KFP_ENDPOINT}/#/runs/details/{run_result.run_id}'
print(
f'- Created run {test_case.name}\n\tModule: {test_case.module_path}\n\tURL: {run_url}\n'
)
return run_url, run_result
def get_kfp_package_path() -> str:
if os.environ.get('PULL_NUMBER') is not None:
path = f'git+https://github.com/kubeflow/pipelines.git@refs/pull/{os.environ.get("PULL_NUMBER")}/merge#subdirectory=sdk/python'
else:
path = 'git+https://github.com/kubeflow/pipelines.git@master#subdirectory=sdk/python'
print(f'Using the following KFP package path for tests: {path}')
return path
partial_component_decorator = functools.partial(
dsl.component, kfp_package_path=get_kfp_package_path())
@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize('test_case', create_test_case_parameters())
async def test(test_case: TestCase, mocker) -> None:
"""Asynchronously runs all samples and test that they succeed."""
mocker.patch.object(dsl, 'component', partial_component_decorator)
event_loop = asyncio.get_running_loop()
try:
run_url, run_result = run(test_case)
except Exception as e:
raise RuntimeError(
f'Error triggering pipeline {test_case.name}.') from e
api_run = await event_loop.run_in_executor(None, wait, run_result)
assert api_run.state == 'SUCCEEDED', f'Pipeline {test_case.name} ended with incorrect status: {api_run.state}. More info: {run_url}'