# Copyright 2022 The Kubeflow Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import dataclasses import functools import os import sys from typing import Any, Dict, List, Tuple from kfp import client from kfp import dsl # noqa import kfp_server_api import pytest import yaml KFP_ENDPOINT = os.environ['KFP_ENDPOINT'] TIMEOUT_SECONDS = os.environ['TIMEOUT_SECONDS'] CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) PROJECT_ROOT = os.path.abspath( os.path.join(CURRENT_DIR, *([os.path.pardir] * 2))) CONFIG_PATH = os.path.join(PROJECT_ROOT, 'sdk', 'python', 'test_data', 'test_data_config.yaml') kfp_client = client.Client(host=KFP_ENDPOINT) @dataclasses.dataclass class TestCase: name: str module_path: str yaml_path: str function_name: str arguments: Dict[str, Any] def create_test_case_parameters() -> List[TestCase]: parameters: List[TestCase] = [] with open(CONFIG_PATH) as f: config = yaml.safe_load(f) for name, test_group in config.items(): test_data_dir = os.path.join(PROJECT_ROOT, test_group['test_data_dir']) parameters.extend( TestCase( name=name + '-' + test_case['module'], module_path=os.path.join(test_data_dir, f'{test_case["module"]}.py'), yaml_path=os.path.join(test_data_dir, f'{test_case["module"]}.yaml'), function_name=test_case['name'], arguments=test_case.get('arguments'), ) for test_case in test_group['test_cases'] if test_case['execute']) return parameters def wait(run_result: client.client.RunPipelineResult) -> kfp_server_api.V2beta1Run: return kfp_client.wait_for_run_completion( run_id=run_result.run_id, timeout=int(TIMEOUT_SECONDS)) def import_obj_from_file(python_path: str, obj_name: str) -> Any: sys.path.insert(0, os.path.dirname(python_path)) module_name = os.path.splitext(os.path.split(python_path)[1])[0] module = __import__(module_name, fromlist=[obj_name]) if not hasattr(module, obj_name): raise ValueError( f'Object "{obj_name}" not found in module {python_path}.') return getattr(module, obj_name) def run(test_case: TestCase) -> Tuple[str, client.client.RunPipelineResult]: full_path = os.path.join(PROJECT_ROOT, test_case.module_path) pipeline_func = import_obj_from_file(full_path, test_case.function_name) run_result = kfp_client.create_run_from_pipeline_func( pipeline_func, enable_caching=True, arguments=test_case.arguments, ) run_url = f'{KFP_ENDPOINT}/#/runs/details/{run_result.run_id}' print( f'- Created run {test_case.name}\n\tModule: {test_case.module_path}\n\tURL: {run_url}\n' ) return run_url, run_result def get_kfp_package_path() -> str: if os.environ.get('PULL_NUMBER') is not None: path = f'git+https://github.com/kubeflow/pipelines.git@refs/pull/{os.environ.get("PULL_NUMBER")}/merge#subdirectory=sdk/python' else: path = 'git+https://github.com/kubeflow/pipelines.git@master#subdirectory=sdk/python' print(f'Using the following KFP package path for tests: {path}') return path partial_component_decorator = functools.partial( dsl.component, kfp_package_path=get_kfp_package_path()) @pytest.mark.asyncio_cooperative @pytest.mark.parametrize('test_case', create_test_case_parameters()) async def test(test_case: TestCase, mocker) -> None: """Asynchronously runs all samples and test that they succeed.""" mocker.patch.object(dsl, 'component', partial_component_decorator) event_loop = asyncio.get_running_loop() try: run_url, run_result = run(test_case) except Exception as e: raise RuntimeError( f'Error triggering pipeline {test_case.name}.') from e api_run = await event_loop.run_in_executor(None, wait, run_result) assert api_run.state == 'SUCCEEDED', f'Pipeline {test_case.name} ended with incorrect status: {api_run.state}. More info: {run_url}'