pipelines/components/PyTorch/pytorch-kfp-components/tests/test_visualization.py

326 lines
12 KiB
Python

#!/usr/bin/env/python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for visualization component."""
import os
import json
import tempfile
from unittest.mock import patch
import mock
from pytorch_kfp_components.components.visualization.component import Visualization
from pytorch_kfp_components.components.visualization.executor import Executor
import pytest
metdata_dir = tempfile.mkdtemp()
@pytest.fixture(scope="class")
def viz_params():
"""Setting visualization parameters.
Returns:
viz_param : dict of visualization parameters.
"""
markdown_params = {
"storage": "dummy-storage",
"source": {
"dummy_key": "dummy_value"
},
}
viz_param = {
"mlpipeline_ui_metadata":
os.path.join(metdata_dir, "mlpipeline_ui_metadata.json"),
"mlpipeline_metrics":
os.path.join(metdata_dir, "mlpipeline_metrics"),
"confusion_matrix_dict": {},
"test_accuracy":
99.05,
"markdown":
markdown_params,
}
return viz_param
@pytest.fixture(scope="class")
def confusion_matrix_params():
"""Setting the confusion matrix parameters.
Returns:
confusion_matrix_params : Dict of confusion matrix parmas
"""
confusion_matrix_param = {
"actuals": ["1", "2", "3", "4"],
"preds": ["2", "3", "4", "0"],
"classes": ["dummy", "dummy"],
"url": "minio://dummy_bucket/folder_name",
}
return confusion_matrix_param
def generate_visualization(viz_params: dict): #pylint: disable=redefined-outer-name
"""Generates the visualization object.
Returns:
output_dict : output dict of vizualization obj.
"""
viz_obj = Visualization(
mlpipeline_ui_metadata=viz_params["mlpipeline_ui_metadata"],
mlpipeline_metrics=viz_params["mlpipeline_metrics"],
confusion_matrix_dict=viz_params["confusion_matrix_dict"],
test_accuracy=viz_params["test_accuracy"],
markdown=viz_params["markdown"],
)
return viz_obj.output_dict
@pytest.mark.parametrize(
"viz_key",
[
"confusion_matrix_dict",
"test_accuracy",
"markdown",
],
)
def test_invalid_type_viz_params(viz_params, viz_key): #pylint: disable=redefined-outer-name
"""Test visualization for invalid parameter type."""
viz_params[viz_key] = "dummy"
if viz_key == "test_accuracy":
expected_type = "<class 'float'>"
else:
expected_type = "<class 'dict'>"
expected_exception_msg = f"{viz_key} must be of type {expected_type} but" \
f" received as {type(viz_params[viz_key])}"
with pytest.raises(TypeError, match=expected_exception_msg):
generate_visualization(viz_params)
@pytest.mark.parametrize(
"viz_key",
[
"mlpipeline_ui_metadata",
"mlpipeline_metrics",
],
)
def test_invalid_type_metadata_path(viz_params, viz_key): #pylint: disable=redefined-outer-name
"""Test visualization with invalid metadata path."""
viz_params[viz_key] = ["dummy"]
expected_exception_msg = f"{viz_key} must be of type <class 'str'> " \
f"but received as {type(viz_params[viz_key])}"
with pytest.raises(TypeError, match=expected_exception_msg):
generate_visualization(viz_params)
@pytest.mark.parametrize(
"viz_key",
[
"mlpipeline_ui_metadata",
"mlpipeline_metrics",
],
)
def test_default_metadata_path(viz_params, viz_key): #pylint: disable=redefined-outer-name
"""Test visualization with default metadata path."""
viz_params[viz_key] = None
expected_output = {
"mlpipeline_ui_metadata": "/mlpipeline-ui-metadata.json",
"mlpipeline_metrics": "/mlpipeline-metrics.json",
}
with patch(
"test_visualization.generate_visualization",
return_value=expected_output,
):
output_dict = generate_visualization(viz_params)
assert output_dict == expected_output
def test_custom_metadata_path(viz_params, tmpdir): #pylint: disable=redefined-outer-name
"""Test visualization with custom metadata path."""
metadata_ui_path = os.path.join(str(tmpdir), "mlpipeline_ui_metadata.json")
metadata_metrics_path = os.path.join(str(tmpdir),
"mlpipeline_metrics.json")
viz_params["mlpipeline_ui_metadata"] = metadata_ui_path
viz_params["mlpipeline_metrics"] = metadata_metrics_path
output_dict = generate_visualization(viz_params)
assert output_dict is not None
assert output_dict["mlpipeline_ui_metadata"] == metadata_ui_path
assert output_dict["mlpipeline_metrics"] == metadata_metrics_path
assert os.path.exists(metadata_ui_path)
assert os.path.exists(metadata_metrics_path)
def test_setting_all_keys_to_none(viz_params): #pylint: disable=redefined-outer-name
"""Test visialization with all parameters set to None tyoe."""
for key in viz_params.keys():
viz_params[key] = None
expected_exception_msg = r"Any one of these keys should be set -" \
r" confusion_matrix_dict, test_accuracy, markdown"
with pytest.raises(ValueError, match=expected_exception_msg):
generate_visualization(viz_params)
def test_accuracy_metric(viz_params): #pylint: disable=redefined-outer-name
"""Test for getting proper accuracy metric."""
output_dict = generate_visualization(viz_params)
assert output_dict is not None
metadata_metric_file = viz_params["mlpipeline_metrics"]
assert os.path.exists(metadata_metric_file)
with open(metadata_metric_file) as file:
data = json.load(file)
assert data["metrics"][0]["numberValue"] == viz_params["test_accuracy"]
def test_markdown_storage_invalid_datatype(viz_params): #pylint: disable=redefined-outer-name
"""Test for passing invalid markdown storage datatype."""
viz_params["markdown"]["storage"] = ["test"]
expected_exception_msg = (
r"storage must be of type <class 'str'> but received as {}".format(
type(viz_params["markdown"]["storage"])))
with pytest.raises(TypeError, match=expected_exception_msg):
generate_visualization(viz_params)
def test_markdown_source_invalid_datatype(viz_params): #pylint: disable=redefined-outer-name
"""Test for passing invalid markdown source datatype."""
viz_params["markdown"]["source"] = "test"
expected_exception_msg = (
r"source must be of type <class 'dict'> but received as {}".format(
type(viz_params["markdown"]["source"])))
with pytest.raises(TypeError, match=expected_exception_msg):
generate_visualization(viz_params)
@pytest.mark.parametrize(
"markdown_key",
[
"source",
"storage",
],
)
def test_markdown_source_missing_key(viz_params, markdown_key): #pylint: disable=redefined-outer-name
"""Test with markdown source missing keys."""
del viz_params["markdown"][markdown_key]
expected_exception_msg = r"Missing mandatory key - {}".format(markdown_key)
with pytest.raises(ValueError, match=expected_exception_msg):
generate_visualization(viz_params)
def test_markdown_success(viz_params): #pylint: disable=redefined-outer-name
"""Test for successful markdown generation."""
output_dict = generate_visualization(viz_params)
assert output_dict is not None
assert "mlpipeline_ui_metadata" in output_dict
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
with open(output_dict["mlpipeline_ui_metadata"]) as file:
data = file.read()
assert "dummy_key" in data
assert "dummy_value" in data
def test_different_storage_value(viz_params): #pylint: disable=redefined-outer-name
"""Test for different storgae values for markdown."""
viz_params["markdown"]["storage"] = "inline"
output_dict = generate_visualization(viz_params)
assert output_dict is not None
assert "mlpipeline_ui_metadata" in output_dict
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
with open(output_dict["mlpipeline_ui_metadata"]) as file:
data = file.read()
assert "inline" in data
def test_multiple_metadata_appends(viz_params): #pylint: disable=redefined-outer-name
"""Test for multiple metadata append."""
if os.path.exists(viz_params["mlpipeline_ui_metadata"]):
os.remove(viz_params["mlpipeline_ui_metadata"])
if os.path.exists(viz_params["mlpipeline_metrics"]):
os.remove(viz_params["mlpipeline_metrics"])
generate_visualization(viz_params)
generate_visualization(viz_params)
output_dict = generate_visualization(viz_params)
assert output_dict is not None
assert "mlpipeline_ui_metadata" in output_dict
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
with open(output_dict["mlpipeline_ui_metadata"]) as file:
data = json.load(file)
assert len(data["outputs"]) == 3
@pytest.mark.parametrize(
"cm_key",
["actuals", "preds", "classes", "url"],
)
def test_confusion_matrix_invalid_types(
viz_params, #pylint: disable=redefined-outer-name
confusion_matrix_params, #pylint: disable=redefined-outer-name
cm_key):
"""Test for invalid type keys for confusion matrix."""
confusion_matrix_params[cm_key] = {"test": "dummy"}
viz_params["confusion_matrix_dict"] = confusion_matrix_params
with pytest.raises(TypeError):
generate_visualization(viz_params)
@pytest.mark.parametrize(
"cm_key",
["actuals", "preds", "classes", "url"],
)
def test_confusion_matrix_optional_check(
viz_params, #pylint: disable=redefined-outer-name
confusion_matrix_params, #pylint: disable=redefined-outer-name
cm_key):
"""Tests for passing confusion matrix keys as optional."""
confusion_matrix_params[cm_key] = {}
viz_params["confusion_matrix_dict"] = confusion_matrix_params
expected_error_msg = f"{cm_key} is not optional. " \
f"Received value: {confusion_matrix_params[cm_key]}"
with pytest.raises(ValueError, match=expected_error_msg):
generate_visualization(viz_params)
@pytest.mark.parametrize(
"cm_key",
["actuals", "preds", "classes", "url"],
)
def test_confusion_matrix_missing_check(
viz_params, #pylint: disable=redefined-outer-name
confusion_matrix_params, #pylint: disable=redefined-outer-name
cm_key):
"""Tests for missing confusion matrix keys."""
del confusion_matrix_params[cm_key]
viz_params["confusion_matrix_dict"] = confusion_matrix_params
expected_error_msg = f"Missing mandatory key - {cm_key}"
with pytest.raises(ValueError, match=expected_error_msg):
generate_visualization(viz_params)
def test_confusion_matrix_success(viz_params, confusion_matrix_params): #pylint: disable=redefined-outer-name
"""Test for successful confusion matrix generation."""
if os.path.exists(viz_params["mlpipeline_ui_metadata"]):
os.remove(viz_params["mlpipeline_ui_metadata"])
viz_params["confusion_matrix_dict"] = confusion_matrix_params
with mock.patch.object(Executor, "_upload_confusion_matrix_to_minio"):
output_dict = generate_visualization(viz_params)
assert output_dict is not None
assert "mlpipeline_ui_metadata" in output_dict
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
with open(output_dict["mlpipeline_ui_metadata"]) as file:
data = file.read()
assert "confusion_matrix" in data