feat(components): Adding Visualization component for PyTorch - KFP (#5810)
* Adding visualization component Signed-off-by: ankan94 <ankan@ideas2it.com> * Switch to Apache2 License Signed-off-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com> Co-authored-by: Arvind-Ideas2IT <arvindkumarsingh.gautam@ideas2it.com> Co-authored-by: Shrinath Suresh <shrinath@ideas2it.com>
This commit is contained in:
parent
f843121e91
commit
941879dd5c
|
|
@ -0,0 +1,109 @@
|
|||
#!/usr/bin/env/python3
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Visualization Component Class."""
|
||||
from pytorch_kfp_components.types import standard_component_specs
|
||||
from pytorch_kfp_components.components.base.base_component import BaseComponent
|
||||
from pytorch_kfp_components.components.visualization.executor import Executor
|
||||
|
||||
|
||||
class Visualization(BaseComponent): # pylint: disable=R0903
|
||||
"""Visualization Component Class."""
|
||||
|
||||
def __init__( # pylint: disable=R0913
|
||||
self,
|
||||
mlpipeline_ui_metadata=None,
|
||||
mlpipeline_metrics=None,
|
||||
confusion_matrix_dict=None,
|
||||
test_accuracy=None,
|
||||
markdown=None,
|
||||
):
|
||||
"""Initializes the Visualization component.
|
||||
|
||||
Args:
|
||||
mlpipeline_ui_metadata : path to save ui metadata
|
||||
mlpipeline_metrics : metrics to be uploaded
|
||||
confusion_metrics_dict : dict for the confusion metrics
|
||||
test_accuracy : test accuracy of the model
|
||||
markdown : markdown dictionary
|
||||
"""
|
||||
super(BaseComponent, self).__init__() # pylint: disable=E1003
|
||||
|
||||
input_dict = {
|
||||
standard_component_specs.VIZ_CONFUSION_MATRIX_DICT:
|
||||
confusion_matrix_dict,
|
||||
standard_component_specs.VIZ_TEST_ACCURACY: test_accuracy,
|
||||
standard_component_specs.VIZ_MARKDOWN: markdown,
|
||||
}
|
||||
|
||||
output_dict = {}
|
||||
|
||||
exec_properties = {
|
||||
standard_component_specs.VIZ_MLPIPELINE_UI_METADATA:
|
||||
mlpipeline_ui_metadata,
|
||||
standard_component_specs.VIZ_MLPIPELINE_METRICS:
|
||||
mlpipeline_metrics,
|
||||
}
|
||||
|
||||
spec = standard_component_specs.VisualizationSpec()
|
||||
self._validate_spec(
|
||||
spec=spec,
|
||||
input_dict=input_dict,
|
||||
output_dict=output_dict,
|
||||
exec_properties=exec_properties,
|
||||
)
|
||||
if markdown:
|
||||
self._validate_markdown_spec(spec=spec, markdown_dict=markdown)
|
||||
|
||||
if confusion_matrix_dict:
|
||||
self._validate_confusion_matrix_spec(
|
||||
spec=spec, confusion_matrix_dict=confusion_matrix_dict
|
||||
)
|
||||
|
||||
Executor().Do(
|
||||
input_dict=input_dict,
|
||||
output_dict=output_dict,
|
||||
exec_properties=exec_properties,
|
||||
)
|
||||
|
||||
self.output_dict = output_dict
|
||||
|
||||
def _validate_markdown_spec(
|
||||
self, spec: standard_component_specs, markdown_dict: dict
|
||||
):
|
||||
"""Vaildates markdown specs type"""
|
||||
for key in spec.MARKDOWN_DICT:
|
||||
if key not in markdown_dict:
|
||||
raise ValueError(f"Missing mandatory key - {key}")
|
||||
if key in markdown_dict:
|
||||
self._type_check(
|
||||
actual_value=markdown_dict[key],
|
||||
key=key,
|
||||
spec_dict=spec.MARKDOWN_DICT,
|
||||
)
|
||||
|
||||
def _validate_confusion_matrix_spec(
|
||||
self, spec: standard_component_specs, confusion_matrix_dict: dict
|
||||
):
|
||||
"""Validates confusion matrix specs type"""
|
||||
for key in spec.CONFUSION_MATRIX_DICT:
|
||||
if key not in confusion_matrix_dict:
|
||||
raise ValueError(f"Missing mandatory key - {key}")
|
||||
if key in confusion_matrix_dict:
|
||||
self._type_check(
|
||||
actual_value=confusion_matrix_dict[key],
|
||||
key=key,
|
||||
spec_dict=spec.CONFUSION_MATRIX_DICT,
|
||||
)
|
||||
|
|
@ -0,0 +1,286 @@
|
|||
#!/usr/bin/env/python3
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Visualization Executor Class."""
|
||||
# pylint: disable=C0103
|
||||
# pylint: disable=R0201
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pandas as pd
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
|
||||
from pytorch_kfp_components.components.minio.component import MinIO
|
||||
from pytorch_kfp_components.types import standard_component_specs
|
||||
|
||||
|
||||
class Executor(BaseExecutor): # pylint: disable=R0903
|
||||
"""Visualization Executor Class."""
|
||||
|
||||
def __init__(self):
|
||||
super(Executor, self).__init__() # pylint: disable=R1725
|
||||
self.mlpipeline_ui_metadata = None
|
||||
self.mlpipeline_metrics = None
|
||||
|
||||
def _write_ui_metadata(
|
||||
self, metadata_filepath, metadata_dict, key="outputs"
|
||||
):
|
||||
"""Function to write the metadata to UI."""
|
||||
if not os.path.exists(metadata_filepath):
|
||||
metadata = {key: [metadata_dict]}
|
||||
else:
|
||||
with open(metadata_filepath) as fp:
|
||||
metadata = json.load(fp)
|
||||
metadata_outputs = metadata[key]
|
||||
metadata_outputs.append(metadata_dict)
|
||||
|
||||
print("Writing to file: {}".format(metadata_filepath))
|
||||
with open(metadata_filepath, "w") as fp:
|
||||
json.dump(metadata, fp)
|
||||
|
||||
def _generate_markdown(self, markdown_dict):
|
||||
"""Generates a markdown.
|
||||
|
||||
Args:
|
||||
markdown_dict : dict of markdown specifications
|
||||
"""
|
||||
source_str = json.dumps(
|
||||
markdown_dict["source"], sort_keys=True, indent=4
|
||||
)
|
||||
source = f"```json \n {source_str} ```"
|
||||
markdown_metadata = {
|
||||
"storage": markdown_dict["storage"],
|
||||
"source": source,
|
||||
"type": "markdown",
|
||||
}
|
||||
|
||||
self._write_ui_metadata(
|
||||
metadata_filepath=self.mlpipeline_ui_metadata,
|
||||
metadata_dict=markdown_metadata,
|
||||
)
|
||||
|
||||
def _generate_confusion_matrix_metadata(
|
||||
self, confusion_matrix_path, classes
|
||||
):
|
||||
"""Generates the confusion matrix metadata and writes in ui."""
|
||||
print("Generating Confusion matrix Metadata")
|
||||
metadata = {
|
||||
"type": "confusion_matrix",
|
||||
"format": "csv",
|
||||
"schema": [
|
||||
{"name": "target", "type": "CATEGORY"},
|
||||
{"name": "predicted", "type": "CATEGORY"},
|
||||
{"name": "count", "type": "NUMBER"},
|
||||
],
|
||||
"source": confusion_matrix_path,
|
||||
"labels": list(map(str, classes)),
|
||||
}
|
||||
|
||||
self._write_ui_metadata(
|
||||
metadata_filepath=self.mlpipeline_ui_metadata,
|
||||
metadata_dict=metadata,
|
||||
)
|
||||
|
||||
def _upload_confusion_matrix_to_minio(
|
||||
self, confusion_matrix_url, confusion_matrix_output_path
|
||||
):
|
||||
parse_obj = urlparse(confusion_matrix_url, allow_fragments=False)
|
||||
bucket_name = parse_obj.netloc
|
||||
folder_name = str(parse_obj.path).lstrip("/")
|
||||
|
||||
# TODO: # pylint: disable=W0511
|
||||
endpoint = "minio-service.kubeflow:9000"
|
||||
MinIO(
|
||||
source=confusion_matrix_output_path,
|
||||
bucket_name=bucket_name,
|
||||
destination=folder_name,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
def _generate_confusion_matrix(
|
||||
self, confusion_matrix_dict
|
||||
): # pylint: disable=R0914
|
||||
"""Generates confusion matrix in minio."""
|
||||
actuals = confusion_matrix_dict["actuals"]
|
||||
preds = confusion_matrix_dict["preds"]
|
||||
confusion_matrix_url = confusion_matrix_dict["url"]
|
||||
|
||||
# Generating confusion matrix
|
||||
df = pd.DataFrame(
|
||||
list(zip(actuals, preds)), columns=["target", "predicted"]
|
||||
)
|
||||
vocab = list(df["target"].unique())
|
||||
cm = confusion_matrix(df["target"], df["predicted"], labels=vocab)
|
||||
data = []
|
||||
for target_index, target_row in enumerate(cm):
|
||||
for predicted_index, count in enumerate(target_row):
|
||||
data.append(
|
||||
(vocab[target_index], vocab[predicted_index], count)
|
||||
)
|
||||
|
||||
confusion_matrix_df = pd.DataFrame(
|
||||
data, columns=["target", "predicted", "count"]
|
||||
)
|
||||
|
||||
confusion_matrix_output_dir = str(tempfile.mkdtemp())
|
||||
confusion_matrix_output_path = os.path.join(
|
||||
confusion_matrix_output_dir, "confusion_matrix.csv"
|
||||
)
|
||||
# saving confusion matrix
|
||||
confusion_matrix_df.to_csv(
|
||||
confusion_matrix_output_path, index=False, header=False
|
||||
)
|
||||
|
||||
self._upload_confusion_matrix_to_minio(
|
||||
confusion_matrix_url=confusion_matrix_url,
|
||||
confusion_matrix_output_path=confusion_matrix_output_path,
|
||||
)
|
||||
|
||||
# Generating metadata
|
||||
self._generate_confusion_matrix_metadata(
|
||||
confusion_matrix_path=os.path.join(
|
||||
confusion_matrix_url, "confusion_matrix.csv"
|
||||
),
|
||||
classes=vocab,
|
||||
)
|
||||
|
||||
def _visualize_accuracy_metric(self, accuracy):
|
||||
"""Generates the visualization for accuracy."""
|
||||
metadata = {
|
||||
"name": "accuracy-score",
|
||||
"numberValue": accuracy,
|
||||
"format": "PERCENTAGE",
|
||||
}
|
||||
self._write_ui_metadata(
|
||||
metadata_filepath=self.mlpipeline_metrics,
|
||||
metadata_dict=metadata,
|
||||
key="metrics",
|
||||
)
|
||||
|
||||
def _get_fn_args(self, input_dict: dict, exec_properties: dict):
|
||||
"""Extracts the confusion matrix dict, test accuracy, markdown from the
|
||||
input dict and mlpipeline ui metadata & metrics from exec_properties.
|
||||
|
||||
Args:
|
||||
input_dict : a dictionary of inputs,
|
||||
example: confusion matrix dict, markdown
|
||||
exe_properties : a dict of execution properties
|
||||
example : mlpipeline_ui_metadata
|
||||
Returns:
|
||||
confusion_matrix_dict : dict of confusion metrics
|
||||
test_accuracy : model test accuracy metrics
|
||||
markdown : markdown dict
|
||||
mlpipeline_ui_metadata : path of ui metadata
|
||||
mlpipeline_metrics : metrics to be uploaded
|
||||
"""
|
||||
confusion_matrix_dict = input_dict.get(
|
||||
standard_component_specs.VIZ_CONFUSION_MATRIX_DICT
|
||||
)
|
||||
test_accuracy = input_dict.get(
|
||||
standard_component_specs.VIZ_TEST_ACCURACY
|
||||
)
|
||||
markdown = input_dict.get(standard_component_specs.VIZ_MARKDOWN)
|
||||
|
||||
mlpipeline_ui_metadata = exec_properties.get(
|
||||
standard_component_specs.VIZ_MLPIPELINE_UI_METADATA
|
||||
)
|
||||
mlpipeline_metrics = exec_properties.get(
|
||||
standard_component_specs.VIZ_MLPIPELINE_METRICS
|
||||
)
|
||||
|
||||
return (
|
||||
confusion_matrix_dict,
|
||||
test_accuracy,
|
||||
markdown,
|
||||
mlpipeline_ui_metadata,
|
||||
mlpipeline_metrics,
|
||||
)
|
||||
|
||||
def _set_defalt_mlpipeline_path(
|
||||
self, mlpipeline_ui_metadata: str, mlpipeline_metrics: str
|
||||
):
|
||||
"""Sets the default mlpipeline path."""
|
||||
|
||||
if mlpipeline_ui_metadata:
|
||||
Path(os.path.dirname(mlpipeline_ui_metadata)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
else:
|
||||
mlpipeline_ui_metadata = "/mlpipeline-ui-metadata.json"
|
||||
|
||||
if mlpipeline_metrics:
|
||||
Path(os.path.dirname(mlpipeline_metrics)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
else:
|
||||
mlpipeline_metrics = "/mlpipeline-metrics.json"
|
||||
|
||||
return mlpipeline_ui_metadata, mlpipeline_metrics
|
||||
|
||||
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
|
||||
"""Executes the visualization process and uploads to minio
|
||||
Args:
|
||||
input_dict : a dictionary of inputs,
|
||||
example: confusion matrix dict, markdown
|
||||
output_dict :
|
||||
exec_properties : a dict of execution properties
|
||||
example : mlpipeline_ui_metadata
|
||||
"""
|
||||
|
||||
(
|
||||
confusion_matrix_dict,
|
||||
test_accuracy,
|
||||
markdown,
|
||||
mlpipeline_ui_metadata,
|
||||
mlpipeline_metrics,
|
||||
) = self._get_fn_args(
|
||||
input_dict=input_dict, exec_properties=exec_properties
|
||||
)
|
||||
|
||||
(
|
||||
self.mlpipeline_ui_metadata,
|
||||
self.mlpipeline_metrics,
|
||||
) = self._set_defalt_mlpipeline_path(
|
||||
mlpipeline_ui_metadata=mlpipeline_ui_metadata,
|
||||
mlpipeline_metrics=mlpipeline_metrics,
|
||||
)
|
||||
|
||||
if not (confusion_matrix_dict or test_accuracy or markdown):
|
||||
raise ValueError(
|
||||
"Any one of these keys should be set - "
|
||||
"confusion_matrix_dict, test_accuracy, markdown"
|
||||
)
|
||||
|
||||
if confusion_matrix_dict:
|
||||
self._generate_confusion_matrix(
|
||||
confusion_matrix_dict=confusion_matrix_dict,
|
||||
)
|
||||
|
||||
if test_accuracy:
|
||||
self._visualize_accuracy_metric(accuracy=test_accuracy)
|
||||
|
||||
if markdown:
|
||||
self._generate_markdown(markdown_dict=markdown)
|
||||
|
||||
output_dict[
|
||||
standard_component_specs.VIZ_MLPIPELINE_UI_METADATA
|
||||
] = self.mlpipeline_ui_metadata
|
||||
output_dict[
|
||||
standard_component_specs.VIZ_MLPIPELINE_METRICS
|
||||
] = self.mlpipeline_metrics
|
||||
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for defining standard specifications and validation of parameter
|
||||
type."""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,325 @@
|
|||
#!/usr/bin/env/python3
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for visualization component."""
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
import mock
|
||||
from pytorch_kfp_components.components.visualization.component import Visualization
|
||||
from pytorch_kfp_components.components.visualization.executor import Executor
|
||||
import pytest
|
||||
|
||||
metdata_dir = tempfile.mkdtemp()
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def viz_params():
|
||||
"""Setting visualization parameters.
|
||||
|
||||
Returns:
|
||||
viz_param : dict of visualization parameters.
|
||||
"""
|
||||
markdown_params = {
|
||||
"storage": "dummy-storage",
|
||||
"source": {
|
||||
"dummy_key": "dummy_value"
|
||||
},
|
||||
}
|
||||
|
||||
viz_param = {
|
||||
"mlpipeline_ui_metadata":
|
||||
os.path.join(metdata_dir, "mlpipeline_ui_metadata.json"),
|
||||
"mlpipeline_metrics":
|
||||
os.path.join(metdata_dir, "mlpipeline_metrics"),
|
||||
"confusion_matrix_dict": {},
|
||||
"test_accuracy":
|
||||
99.05,
|
||||
"markdown":
|
||||
markdown_params,
|
||||
}
|
||||
return viz_param
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def confusion_matrix_params():
|
||||
"""Setting the confusion matrix parameters.
|
||||
|
||||
Returns:
|
||||
confusion_matrix_params : Dict of confusion matrix parmas
|
||||
"""
|
||||
confusion_matrix_param = {
|
||||
"actuals": ["1", "2", "3", "4"],
|
||||
"preds": ["2", "3", "4", "0"],
|
||||
"classes": ["dummy", "dummy"],
|
||||
"url": "minio://dummy_bucket/folder_name",
|
||||
}
|
||||
return confusion_matrix_param
|
||||
|
||||
|
||||
def generate_visualization(viz_params: dict): #pylint: disable=redefined-outer-name
|
||||
"""Generates the visualization object.
|
||||
|
||||
Returns:
|
||||
output_dict : output dict of vizualization obj.
|
||||
"""
|
||||
viz_obj = Visualization(
|
||||
mlpipeline_ui_metadata=viz_params["mlpipeline_ui_metadata"],
|
||||
mlpipeline_metrics=viz_params["mlpipeline_metrics"],
|
||||
confusion_matrix_dict=viz_params["confusion_matrix_dict"],
|
||||
test_accuracy=viz_params["test_accuracy"],
|
||||
markdown=viz_params["markdown"],
|
||||
)
|
||||
|
||||
return viz_obj.output_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"viz_key",
|
||||
[
|
||||
"confusion_matrix_dict",
|
||||
"test_accuracy",
|
||||
"markdown",
|
||||
],
|
||||
)
|
||||
def test_invalid_type_viz_params(viz_params, viz_key): #pylint: disable=redefined-outer-name
|
||||
"""Test visualization for invalid parameter type."""
|
||||
viz_params[viz_key] = "dummy"
|
||||
if viz_key == "test_accuracy":
|
||||
expected_type = "<class 'float'>"
|
||||
else:
|
||||
expected_type = "<class 'dict'>"
|
||||
expected_exception_msg = f"{viz_key} must be of type {expected_type} but" \
|
||||
f" received as {type(viz_params[viz_key])}"
|
||||
with pytest.raises(TypeError, match=expected_exception_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"viz_key",
|
||||
[
|
||||
"mlpipeline_ui_metadata",
|
||||
"mlpipeline_metrics",
|
||||
],
|
||||
)
|
||||
def test_invalid_type_metadata_path(viz_params, viz_key): #pylint: disable=redefined-outer-name
|
||||
"""Test visualization with invalid metadata path."""
|
||||
|
||||
viz_params[viz_key] = ["dummy"]
|
||||
expected_exception_msg = f"{viz_key} must be of type <class 'str'> " \
|
||||
f"but received as {type(viz_params[viz_key])}"
|
||||
with pytest.raises(TypeError, match=expected_exception_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"viz_key",
|
||||
[
|
||||
"mlpipeline_ui_metadata",
|
||||
"mlpipeline_metrics",
|
||||
],
|
||||
)
|
||||
def test_default_metadata_path(viz_params, viz_key): #pylint: disable=redefined-outer-name
|
||||
"""Test visualization with default metadata path."""
|
||||
viz_params[viz_key] = None
|
||||
expected_output = {
|
||||
"mlpipeline_ui_metadata": "/mlpipeline-ui-metadata.json",
|
||||
"mlpipeline_metrics": "/mlpipeline-metrics.json",
|
||||
}
|
||||
with patch(
|
||||
"test_visualization.generate_visualization",
|
||||
return_value=expected_output,
|
||||
):
|
||||
output_dict = generate_visualization(viz_params)
|
||||
assert output_dict == expected_output
|
||||
|
||||
|
||||
def test_custom_metadata_path(viz_params, tmpdir): #pylint: disable=redefined-outer-name
|
||||
"""Test visualization with custom metadata path."""
|
||||
metadata_ui_path = os.path.join(str(tmpdir), "mlpipeline_ui_metadata.json")
|
||||
metadata_metrics_path = os.path.join(str(tmpdir),
|
||||
"mlpipeline_metrics.json")
|
||||
viz_params["mlpipeline_ui_metadata"] = metadata_ui_path
|
||||
viz_params["mlpipeline_metrics"] = metadata_metrics_path
|
||||
output_dict = generate_visualization(viz_params)
|
||||
assert output_dict is not None
|
||||
assert output_dict["mlpipeline_ui_metadata"] == metadata_ui_path
|
||||
assert output_dict["mlpipeline_metrics"] == metadata_metrics_path
|
||||
assert os.path.exists(metadata_ui_path)
|
||||
assert os.path.exists(metadata_metrics_path)
|
||||
|
||||
|
||||
def test_setting_all_keys_to_none(viz_params): #pylint: disable=redefined-outer-name
|
||||
"""Test visialization with all parameters set to None tyoe."""
|
||||
for key in viz_params.keys():
|
||||
viz_params[key] = None
|
||||
|
||||
expected_exception_msg = r"Any one of these keys should be set -" \
|
||||
r" confusion_matrix_dict, test_accuracy, markdown"
|
||||
with pytest.raises(ValueError, match=expected_exception_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
def test_accuracy_metric(viz_params): #pylint: disable=redefined-outer-name
|
||||
"""Test for getting proper accuracy metric."""
|
||||
output_dict = generate_visualization(viz_params)
|
||||
assert output_dict is not None
|
||||
metadata_metric_file = viz_params["mlpipeline_metrics"]
|
||||
assert os.path.exists(metadata_metric_file)
|
||||
with open(metadata_metric_file) as file:
|
||||
data = json.load(file)
|
||||
assert data["metrics"][0]["numberValue"] == viz_params["test_accuracy"]
|
||||
|
||||
|
||||
def test_markdown_storage_invalid_datatype(viz_params): #pylint: disable=redefined-outer-name
|
||||
"""Test for passing invalid markdown storage datatype."""
|
||||
viz_params["markdown"]["storage"] = ["test"]
|
||||
expected_exception_msg = (
|
||||
r"storage must be of type <class 'str'> but received as {}".format(
|
||||
type(viz_params["markdown"]["storage"])))
|
||||
with pytest.raises(TypeError, match=expected_exception_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
def test_markdown_source_invalid_datatype(viz_params): #pylint: disable=redefined-outer-name
|
||||
"""Test for passing invalid markdown source datatype."""
|
||||
viz_params["markdown"]["source"] = "test"
|
||||
expected_exception_msg = (
|
||||
r"source must be of type <class 'dict'> but received as {}".format(
|
||||
type(viz_params["markdown"]["source"])))
|
||||
with pytest.raises(TypeError, match=expected_exception_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"markdown_key",
|
||||
[
|
||||
"source",
|
||||
"storage",
|
||||
],
|
||||
)
|
||||
def test_markdown_source_missing_key(viz_params, markdown_key): #pylint: disable=redefined-outer-name
|
||||
"""Test with markdown source missing keys."""
|
||||
del viz_params["markdown"][markdown_key]
|
||||
expected_exception_msg = r"Missing mandatory key - {}".format(markdown_key)
|
||||
with pytest.raises(ValueError, match=expected_exception_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
def test_markdown_success(viz_params): #pylint: disable=redefined-outer-name
|
||||
"""Test for successful markdown generation."""
|
||||
output_dict = generate_visualization(viz_params)
|
||||
assert output_dict is not None
|
||||
assert "mlpipeline_ui_metadata" in output_dict
|
||||
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
|
||||
with open(output_dict["mlpipeline_ui_metadata"]) as file:
|
||||
data = file.read()
|
||||
assert "dummy_key" in data
|
||||
assert "dummy_value" in data
|
||||
|
||||
|
||||
def test_different_storage_value(viz_params): #pylint: disable=redefined-outer-name
|
||||
"""Test for different storgae values for markdown."""
|
||||
viz_params["markdown"]["storage"] = "inline"
|
||||
output_dict = generate_visualization(viz_params)
|
||||
assert output_dict is not None
|
||||
assert "mlpipeline_ui_metadata" in output_dict
|
||||
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
|
||||
with open(output_dict["mlpipeline_ui_metadata"]) as file:
|
||||
data = file.read()
|
||||
assert "inline" in data
|
||||
|
||||
|
||||
def test_multiple_metadata_appends(viz_params): #pylint: disable=redefined-outer-name
|
||||
"""Test for multiple metadata append."""
|
||||
if os.path.exists(viz_params["mlpipeline_ui_metadata"]):
|
||||
os.remove(viz_params["mlpipeline_ui_metadata"])
|
||||
|
||||
if os.path.exists(viz_params["mlpipeline_metrics"]):
|
||||
os.remove(viz_params["mlpipeline_metrics"])
|
||||
generate_visualization(viz_params)
|
||||
generate_visualization(viz_params)
|
||||
output_dict = generate_visualization(viz_params)
|
||||
assert output_dict is not None
|
||||
assert "mlpipeline_ui_metadata" in output_dict
|
||||
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
|
||||
with open(output_dict["mlpipeline_ui_metadata"]) as file:
|
||||
data = json.load(file)
|
||||
assert len(data["outputs"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cm_key",
|
||||
["actuals", "preds", "classes", "url"],
|
||||
)
|
||||
def test_confusion_matrix_invalid_types(
|
||||
viz_params, #pylint: disable=redefined-outer-name
|
||||
confusion_matrix_params, #pylint: disable=redefined-outer-name
|
||||
cm_key):
|
||||
"""Test for invalid type keys for confusion matrix."""
|
||||
confusion_matrix_params[cm_key] = {"test": "dummy"}
|
||||
viz_params["confusion_matrix_dict"] = confusion_matrix_params
|
||||
with pytest.raises(TypeError):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cm_key",
|
||||
["actuals", "preds", "classes", "url"],
|
||||
)
|
||||
def test_confusion_matrix_optional_check(
|
||||
viz_params, #pylint: disable=redefined-outer-name
|
||||
confusion_matrix_params, #pylint: disable=redefined-outer-name
|
||||
cm_key):
|
||||
"""Tests for passing confusion matrix keys as optional."""
|
||||
confusion_matrix_params[cm_key] = {}
|
||||
viz_params["confusion_matrix_dict"] = confusion_matrix_params
|
||||
expected_error_msg = f"{cm_key} is not optional. " \
|
||||
f"Received value: {confusion_matrix_params[cm_key]}"
|
||||
with pytest.raises(ValueError, match=expected_error_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cm_key",
|
||||
["actuals", "preds", "classes", "url"],
|
||||
)
|
||||
def test_confusion_matrix_missing_check(
|
||||
viz_params, #pylint: disable=redefined-outer-name
|
||||
confusion_matrix_params, #pylint: disable=redefined-outer-name
|
||||
cm_key):
|
||||
"""Tests for missing confusion matrix keys."""
|
||||
del confusion_matrix_params[cm_key]
|
||||
viz_params["confusion_matrix_dict"] = confusion_matrix_params
|
||||
expected_error_msg = f"Missing mandatory key - {cm_key}"
|
||||
with pytest.raises(ValueError, match=expected_error_msg):
|
||||
generate_visualization(viz_params)
|
||||
|
||||
|
||||
def test_confusion_matrix_success(viz_params, confusion_matrix_params): #pylint: disable=redefined-outer-name
|
||||
"""Test for successful confusion matrix generation."""
|
||||
if os.path.exists(viz_params["mlpipeline_ui_metadata"]):
|
||||
os.remove(viz_params["mlpipeline_ui_metadata"])
|
||||
viz_params["confusion_matrix_dict"] = confusion_matrix_params
|
||||
with mock.patch.object(Executor, "_upload_confusion_matrix_to_minio"):
|
||||
output_dict = generate_visualization(viz_params)
|
||||
|
||||
assert output_dict is not None
|
||||
assert "mlpipeline_ui_metadata" in output_dict
|
||||
assert os.path.exists(output_dict["mlpipeline_ui_metadata"])
|
||||
with open(output_dict["mlpipeline_ui_metadata"]) as file:
|
||||
data = file.read()
|
||||
assert "confusion_matrix" in data
|
||||
Loading…
Reference in New Issue