pipelines/test/kfp-kubernetes-execution-tests/sdk_execution_tests.py

142 lines
4.5 KiB
Python

# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import functools
import os
import sys
from typing import Any, Dict, List, Tuple
from kfp import client
from kfp import dsl
import kfp_server_api
import pytest
import yaml
KFP_ENDPOINT = os.environ['KFP_ENDPOINT']
TIMEOUT_SECONDS = os.environ['TIMEOUT_SECONDS']
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
PROJECT_ROOT = os.path.abspath(
os.path.join(CURRENT_DIR, *([os.path.pardir] * 2)))
CONFIG_PATH = os.path.join(
PROJECT_ROOT,
'kubernetes_platform',
'python',
'test',
'snapshot',
'test_data_config.yaml',
)
kfp_client = client.Client(host=KFP_ENDPOINT)
@dataclasses.dataclass
class TestCase:
name: str
module_path: str
yaml_path: str
function_name: str
arguments: Dict[str, Any]
def create_test_case_parameters() -> List[TestCase]:
parameters: List[TestCase] = []
with open(CONFIG_PATH) as f:
config = yaml.safe_load(f)
test_data_dir = os.path.join(
PROJECT_ROOT,
'kubernetes_platform',
'python',
'test',
'snapshot',
'data',
)
parameters.extend(
TestCase(
name=test_case['name'] + '-' + test_case['module'],
module_path=os.path.join(test_data_dir,
f'{test_case["module"]}.py'),
yaml_path=os.path.join(test_data_dir,
f'{test_case["module"]}.yaml'),
function_name=test_case['name'],
arguments=test_case.get('arguments'),
) for test_case in config['test_cases'])
return parameters
def wait(
run_result: client.client.RunPipelineResult
) -> kfp_server_api.V2beta1Run:
return kfp_client.wait_for_run_completion(
run_id=run_result.run_id, timeout=int(TIMEOUT_SECONDS))
def import_obj_from_file(python_path: str, obj_name: str) -> Any:
sys.path.insert(0, os.path.dirname(python_path))
module_name = os.path.splitext(os.path.split(python_path)[1])[0]
module = __import__(module_name, fromlist=[obj_name])
if not hasattr(module, obj_name):
raise ValueError(
f'Object "{obj_name}" not found in module {python_path}.')
return getattr(module, obj_name)
def run(test_case: TestCase) -> Tuple[str, client.client.RunPipelineResult]:
full_path = os.path.join(PROJECT_ROOT, test_case.module_path)
pipeline_func = import_obj_from_file(full_path, test_case.function_name)
run_result = kfp_client.create_run_from_pipeline_func(
pipeline_func,
enable_caching=True,
arguments=test_case.arguments,
)
run_url = f'{KFP_ENDPOINT}/#/runs/details/{run_result.run_id}'
print(
f'- Created run {test_case.name}\n\tModule: {test_case.module_path}\n\tURL: {run_url}\n'
)
return run_url, run_result
def get_kfp_package_path() -> str:
if os.environ.get('PULL_NUMBER') is not None:
path = f'git+https://github.com/kubeflow/pipelines.git@refs/pull/{os.environ.get("PULL_NUMBER")}/merge#subdirectory=sdk/python'
else:
path = 'git+https://github.com/kubeflow/pipelines.git@master#subdirectory=sdk/python'
print(f'Using the following KFP package path for tests: {path}')
return path
dsl.component = functools.partial(
dsl.component, kfp_package_path=get_kfp_package_path())
@pytest.mark.asyncio_cooperative
@pytest.mark.parametrize('test_case', create_test_case_parameters())
async def test(test_case: TestCase) -> None:
"""Asynchronously runs all samples and test that they succeed."""
event_loop = asyncio.get_running_loop()
try:
run_url, run_result = run(test_case)
except Exception as e:
raise RuntimeError(
f'Error triggering pipeline {test_case.name}.') from e
api_run = await event_loop.run_in_executor(None, wait, run_result)
assert api_run.state == 'SUCCEEDED', f'Pipeline {test_case.name} ended with incorrect status: {api_run.state}. More info: {run_url}'
if __name__ == '__main__':
pytest.main()