feat(frontend): Display scalar metrics table (#8026)

* Update the compare table to include parent column labels.

* Show scalar metrics tablethrough rough calculations.

* Remove unnecessary comment.

* Remove unnecesssary print statements and simplify code.

* Break up the loading of scalar metrics table data.

* Update display and error when scalar metrics are not available.

* Update compare table comment and error formatting.

* Add tests and update artifact index to global variable.

* Update formatting.

* Separate scalar metrics table calculation into a separate utils file.

* Add compare utils test and reorganize calculation of compare table props.

* Ensure that the table organization works without names.

* Update compare utils and table tests.

* Update formatting and add compare table test and snapshot.

* Rename file from tsx to ts.

* Reset this branch to only hold the CompareTable changes.

* Update formatting.

* Update naming of artifactCount to newArtifactIndex.

* Update  new tests to use React Testing Library.

* Minor fixes - newline and remove unused variable.

* Update formatting.

* Add back the changes to incorporate the CompareTable into the CompareV2 page.

* Remove metrics tab text from the MetricsDropdown.

* Fix the value type.

* Change getExecutionName to getExecutionDisplayName.

* Move imports to CompareUtils and top-of-file.

* Add return type to getCompareTableProps.

* Update return type to defined interface.

* UExtract the data map key to a shared function.

* Update the dataMap to more quickly fill the scalar metric table rows.

* Update the sorting mechanism and remove the use of lodash.

* Fix test.
This commit is contained in:
Chris Elliott 2022-07-22 12:11:48 -07:00 committed by GitHub
parent 061905b6df
commit d3fe514db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 446 additions and 42 deletions

View File

@ -21,10 +21,11 @@ import TestUtils, { testBestPractices } from 'src/TestUtils';
import { Artifact, Event, Execution, Value } from 'src/third_party/mlmd';
import * as metricsVisualizations from 'src/components/viewers/MetricsVisualizations';
import * as Utils from 'src/lib/Utils';
import { MetricsType, RunArtifact, SelectedArtifact } from 'src/pages/CompareV2';
import { SelectedArtifact } from 'src/pages/CompareV2';
import { LinkedArtifact } from 'src/mlmd/MlmdUtils';
import * as jspb from 'google-protobuf';
import MetricsDropdown from './MetricsDropdown';
import { MetricsType, RunArtifact } from 'src/lib/v2/CompareUtils';
function newMockExecution(id: number, displayName?: string): Execution {
const execution = new Execution();
@ -343,7 +344,7 @@ describe('MetricsDropdown', () => {
});
});
it('HTML file loading and error display', async () => {
it('HTML file loading and error display with namespace input', async () => {
const getHtmlViewerConfigSpy = jest.spyOn(metricsVisualizations, 'getHtmlViewerConfig');
getHtmlViewerConfigSpy.mockRejectedValue(new Error('HTML file not found.'));
@ -354,6 +355,7 @@ describe('MetricsDropdown', () => {
metricsTab={MetricsType.HTML}
selectedArtifacts={emptySelectedArtifacts}
updateSelectedArtifacts={updateSelectedArtifactsSpy}
namespace='namespaceInput'
/>
</CommonTestWrapper>,
);
@ -365,6 +367,10 @@ describe('MetricsDropdown', () => {
screen.getByRole('circularprogress');
await waitFor(() => {
expect(getHtmlViewerConfigSpy).toHaveBeenLastCalledWith(
[firstLinkedArtifact],
'namespaceInput',
);
screen.getByText('Error: failed loading HTML file. Click Details for more information.');
});
});
@ -402,4 +408,6 @@ describe('MetricsDropdown', () => {
screen.getByText('Choose a first Confusion Matrix artifact');
screen.getByTitle('run1 > execution1 > artifact1');
});
// TODO: Namespace...
});

View File

@ -32,11 +32,16 @@ import PlotCard from 'src/components/PlotCard';
import { ViewerConfig } from 'src/components/viewers/Viewer';
import CircularProgress from '@material-ui/core/CircularProgress';
import Banner from 'src/components/Banner';
import { ExecutionArtifact, MetricsType, RunArtifact, SelectedArtifact } from 'src/pages/CompareV2';
import { SelectedArtifact } from 'src/pages/CompareV2';
import { useQuery } from 'react-query';
import { errorToMessage, logger } from 'src/lib/Utils';
import { Execution } from 'src/third_party/mlmd';
import { metricsTypeToString } from 'src/lib/v2/CompareUtils';
import {
metricsTypeToString,
ExecutionArtifact,
MetricsType,
RunArtifact,
} from 'src/lib/v2/CompareUtils';
const css = stylesheet({
leftCell: {
@ -79,7 +84,7 @@ interface MetricsDropdownProps {
metricsTab: MetricsType;
selectedArtifacts: SelectedArtifact[];
updateSelectedArtifacts: (selectedArtifacts: SelectedArtifact[]) => void;
namespace: string | undefined;
namespace?: string;
}
export default function MetricsDropdown(props: MetricsDropdownProps) {

View File

@ -0,0 +1,201 @@
/*
* Copyright 2022 The Kubeflow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { testBestPractices } from 'src/TestUtils';
import { getCompareTableProps, RunArtifact } from './CompareUtils';
import { Artifact, Event, Execution, Value } from 'src/third_party/mlmd';
import { LinkedArtifact } from 'src/mlmd/MlmdUtils';
import * as jspb from 'google-protobuf';
function newMockExecution(id: number, displayName?: string): Execution {
const execution = new Execution();
execution.setId(id);
if (displayName) {
const customPropertiesMap: Map<string, Value> = new Map();
const displayNameValue = new Value();
displayNameValue.setStringValue(displayName);
customPropertiesMap.set('display_name', displayNameValue);
jest.spyOn(execution, 'getCustomPropertiesMap').mockReturnValue(customPropertiesMap);
}
return execution;
}
function newMockEvent(id: number, displayName?: string): Event {
const event = new Event();
event.setArtifactId(id);
event.setExecutionId(id);
event.setType(Event.Type.OUTPUT);
if (displayName) {
const path = new Event.Path();
const step = new Event.Path.Step();
step.setKey(displayName);
path.addSteps(step);
event.setPath(path);
}
return event;
}
function newMockArtifact(id: number, scalarMetricValues: number[], displayName?: string): Artifact {
const artifact = new Artifact();
artifact.setId(id);
const customPropertiesMap: jspb.Map<string, Value> = jspb.Map.fromObject([], null, null);
if (displayName) {
const displayNameValue = new Value();
displayNameValue.setStringValue(displayName);
customPropertiesMap.set('display_name', displayNameValue);
}
scalarMetricValues.forEach((scalarMetricValue, index) => {
const value = new Value();
value.setDoubleValue(scalarMetricValue);
customPropertiesMap.set(`scalarMetric${index}`, value);
});
jest.spyOn(artifact, 'getCustomPropertiesMap').mockReturnValue(customPropertiesMap);
return artifact;
}
function newMockLinkedArtifact(
id: number,
scalarMetricValues: number[],
displayName?: string,
): LinkedArtifact {
return {
artifact: newMockArtifact(id, scalarMetricValues, displayName),
event: newMockEvent(id, displayName),
} as LinkedArtifact;
}
testBestPractices();
describe('CompareUtils', () => {
it('Empty scalar metrics artifacts results in empty table data', () => {
expect(getCompareTableProps([], 0)).toMatchObject({
xLabels: [],
yLabels: [],
xParentLabels: [],
rows: [],
});
});
it('Scalar metrics artifacts with all data and names populated', () => {
const scalarMetricsArtifacts: RunArtifact[] = [
{
run: {
run: {
id: '1',
name: 'run1',
},
},
executionArtifacts: [
{
execution: newMockExecution(1, 'execution1'),
linkedArtifacts: [
newMockLinkedArtifact(1, [1, 2], 'artifact1'),
newMockLinkedArtifact(2, [1], 'artifact2'),
],
},
{
execution: newMockExecution(2, 'execution2'),
linkedArtifacts: [newMockLinkedArtifact(3, [3], 'artifact3')],
},
],
},
{
run: {
run: {
id: '2',
name: 'run2',
},
},
executionArtifacts: [
{
execution: newMockExecution(3, 'execution1'),
linkedArtifacts: [newMockLinkedArtifact(4, [4], 'artifact1')],
},
],
},
];
const artifactCount: number = 4;
expect(getCompareTableProps(scalarMetricsArtifacts, artifactCount)).toMatchObject({
xLabels: [
'execution1 > artifact1',
'execution1 > artifact2',
'execution2 > artifact3',
'execution1 > artifact1',
],
yLabels: ['scalarMetric0', 'scalarMetric1'],
xParentLabels: [
{ colSpan: 3, label: 'run1' },
{ colSpan: 1, label: 'run2' },
],
rows: [
['1', '1', '3', '4'],
['2', '', '', ''],
],
});
});
it('Scalar metrics artifacts with data populated and no names', () => {
const scalarMetricsArtifacts: RunArtifact[] = [
{
run: {
run: {
id: '1',
},
},
executionArtifacts: [
{
execution: newMockExecution(1),
linkedArtifacts: [newMockLinkedArtifact(1, [1, 2]), newMockLinkedArtifact(2, [1])],
},
{
execution: newMockExecution(2),
linkedArtifacts: [newMockLinkedArtifact(3, [3])],
},
],
},
{
run: {
run: {
id: '2',
},
},
executionArtifacts: [
{
execution: newMockExecution(3),
linkedArtifacts: [newMockLinkedArtifact(4, [4])],
},
],
},
];
const artifactCount: number = 4;
expect(getCompareTableProps(scalarMetricsArtifacts, artifactCount)).toMatchObject({
xLabels: ['- > -', '- > -', '- > -', '- > -'],
yLabels: ['scalarMetric0', 'scalarMetric1'],
xParentLabels: [
{ colSpan: 3, label: '-' },
{ colSpan: 1, label: '-' },
],
rows: [
['1', '1', '3', '4'],
['2', '', '', ''],
],
});
});
});

View File

@ -14,7 +14,156 @@
* limitations under the License.
*/
import { MetricsType } from 'src/pages/CompareV2';
import { CompareTableProps, xParentLabel } from 'src/components/CompareTable';
import { getArtifactName, getExecutionDisplayName, LinkedArtifact } from 'src/mlmd/MlmdUtils';
import { getMetadataValue } from 'src/mlmd/Utils';
import { Execution, Value } from 'src/third_party/mlmd';
import * as jspb from 'google-protobuf';
import { ApiRunDetail } from 'src/apis/run';
export interface ExecutionArtifact {
execution: Execution;
linkedArtifacts: LinkedArtifact[];
}
export interface RunArtifact {
run: ApiRunDetail;
executionArtifacts: ExecutionArtifact[];
}
interface ScalarRowData {
row: string[];
dataCount: number;
}
interface ScalarTableData {
xLabels: string[];
xParentLabels: xParentLabel[];
dataMap: { [key: string]: ScalarRowData };
}
export interface RunArtifactData {
runArtifacts: RunArtifact[];
artifactCount: number;
}
export const getCompareTableProps = (
scalarMetricsArtifacts: RunArtifact[],
artifactCount: number,
): CompareTableProps => {
const scalarTableData = getScalarTableData(scalarMetricsArtifacts, artifactCount);
// Sort by decreasing data item count.
const sortedDataList = Object.entries(scalarTableData.dataMap).sort(
(a, b) => b[1].dataCount - a[1].dataCount,
);
const yLabels: string[] = [];
const rows: string[][] = [];
for (const sortedDataItem of sortedDataList) {
yLabels.push(sortedDataItem[0]);
rows.push(sortedDataItem[1].row);
}
return {
xLabels: scalarTableData.xLabels,
yLabels,
xParentLabels: scalarTableData.xParentLabels,
rows,
} as CompareTableProps;
};
// Get different components needed to construct the scalar metrics table.
const getScalarTableData = (
scalarMetricsArtifacts: RunArtifact[],
artifactCount: number,
): ScalarTableData => {
const xLabels: string[] = [];
const xParentLabels: xParentLabel[] = [];
const dataMap: { [key: string]: ScalarRowData } = {};
let artifactIndex = 0;
for (const runArtifact of scalarMetricsArtifacts) {
const runName = runArtifact.run.run?.name || '-';
const newArtifactIndex = loadScalarExecutionArtifacts(
runArtifact.executionArtifacts,
xLabels,
dataMap,
artifactIndex,
artifactCount,
);
const xParentLabel: xParentLabel = {
label: runName,
colSpan: newArtifactIndex - artifactIndex,
};
xParentLabels.push(xParentLabel);
artifactIndex = newArtifactIndex;
}
return {
xLabels,
xParentLabels,
dataMap,
} as ScalarTableData;
};
// Load the data as well as row and column labels from execution artifacts.
const loadScalarExecutionArtifacts = (
executionArtifacts: ExecutionArtifact[],
xLabels: string[],
dataMap: { [key: string]: ScalarRowData },
artifactIndex: number,
artifactCount: number,
): number => {
for (const executionArtifact of executionArtifacts) {
const executionText: string = getExecutionDisplayName(executionArtifact.execution) || '-';
for (const linkedArtifact of executionArtifact.linkedArtifacts) {
const linkedArtifactText: string = getArtifactName(linkedArtifact) || '-';
const xLabel = `${executionText} > ${linkedArtifactText}`;
xLabels.push(xLabel);
const customProperties = linkedArtifact.artifact.getCustomPropertiesMap();
addScalarDataItems(customProperties, dataMap, artifactIndex, artifactCount);
artifactIndex++;
}
}
return artifactIndex;
};
// Add the scalar metric names and data items.
const addScalarDataItems = (
customProperties: jspb.Map<string, Value>,
dataMap: { [key: string]: ScalarRowData },
artifactIndex: number,
artifactCount: number,
) => {
for (const entry of customProperties.getEntryList()) {
const scalarMetricName: string = entry[0];
if (scalarMetricName === 'display_name') {
continue;
}
if (!dataMap[scalarMetricName]) {
dataMap[scalarMetricName] = {
row: Array(artifactCount).fill(''),
dataCount: 0,
};
}
dataMap[scalarMetricName].row[artifactIndex] = JSON.stringify(
getMetadataValue(customProperties.get(scalarMetricName)),
);
dataMap[scalarMetricName].dataCount++;
}
};
export enum MetricsType {
SCALAR_METRICS,
CONFUSION_MATRIX,
ROC_CURVE,
HTML,
MARKDOWN,
}
export const metricsTypeToString = (metricsType: MetricsType): string => {
switch (metricsType) {

View File

@ -377,6 +377,13 @@ export function filterLinkedArtifactsByType(
return artifacts.filter(x => artifactTypeIds.includes(x.artifact.getTypeId()));
}
export function getExecutionDisplayName(execution: Execution): string | undefined {
return execution
.getCustomPropertiesMap()
.get('display_name')
?.getStringValue();
}
export function getArtifactName(linkedArtifact: LinkedArtifact): string | undefined {
return getArtifactNameFromEvent(linkedArtifact.event);
}

View File

@ -423,20 +423,25 @@ describe('CompareV2', () => {
);
await TestUtils.flushPromises();
// TODO(zpChris): This test will be improved after default error states are provided in #8029.
screen.getByText('This is the Scalar Metrics tab.');
screen.getByText('There are no Scalar Metrics artifacts available on the selected runs.');
fireEvent.click(screen.getByText('ROC Curve'));
expect(screen.queryByText('This is the Scalar Metrics Tab')).toBeNull();
fireEvent.click(screen.getByText('Confusion Matrix'));
screen.getByText('There are no Confusion Matrix artifacts available on the selected runs.');
expect(
screen.queryByText('There are no Scalar Metrics artifacts available on the selected runs.'),
).toBeNull();
fireEvent.click(screen.getByText('ROC Curve'));
expect(screen.queryByText('This is the Scalar Metrics Tab')).toBeNull();
fireEvent.click(screen.getByText('Confusion Matrix'));
screen.getByText('There are no Confusion Matrix artifacts available on the selected runs.');
fireEvent.click(screen.getByText('Scalar Metrics'));
screen.getByText('This is the Scalar Metrics tab.');
screen.getByText('There are no Scalar Metrics artifacts available on the selected runs.');
expect(
screen.queryByText('There are no Confusion Matrix artifacts available on the selected runs.'),
).toBeNull();
});
it('Two-panel tabs have no dropdown loaded as content is not present', async () => {
it('Metrics tabs have no content loaded as artifacts are not present', async () => {
const getRunSpy = jest.spyOn(Apis.runServiceApi, 'getRun');
runs = [newMockRun(MOCK_RUN_1_ID), newMockRun(MOCK_RUN_2_ID), newMockRun(MOCK_RUN_3_ID)];
getRunSpy.mockImplementation((id: string) => runs.find(r => r.run!.id === id));
@ -454,6 +459,8 @@ describe('CompareV2', () => {
);
await TestUtils.flushPromises();
screen.getByText('There are no Scalar Metrics artifacts available on the selected runs.');
fireEvent.click(screen.getByText('Confusion Matrix'));
screen.getByText('There are no Confusion Matrix artifacts available on the selected runs.');
@ -518,7 +525,7 @@ describe('CompareV2', () => {
);
await TestUtils.flushPromises();
await waitFor(() => expect(filterLinkedArtifactsByTypeSpy).toHaveBeenCalledTimes(12));
await waitFor(() => expect(filterLinkedArtifactsByTypeSpy).toHaveBeenCalledTimes(15));
expect(screen.queryByText(/Confusion matrix: artifactName/)).toBeNull();

View File

@ -42,6 +42,14 @@ import RunList from './RunList';
import { METRICS_SECTION_NAME, OVERVIEW_SECTION_NAME, PARAMS_SECTION_NAME } from './Compare';
import { SelectedItem } from 'src/components/TwoLevelDropdown';
import MD2Tabs from 'src/atoms/MD2Tabs';
import CompareTable, { CompareTableProps } from 'src/components/CompareTable';
import {
ExecutionArtifact,
getCompareTableProps,
MetricsType,
RunArtifact,
RunArtifactData,
} from 'src/lib/v2/CompareUtils';
import { ConfidenceMetricsSection } from 'src/components/viewers/MetricsVisualizations';
import { flatMapDeep } from 'lodash';
import { NamespaceContext, useNamespaceChangeEvent } from 'src/lib/KubeflowClient';
@ -63,24 +71,6 @@ interface MlmdPackage {
events: Event[];
}
export interface ExecutionArtifact {
execution: Execution;
linkedArtifacts: LinkedArtifact[];
}
export interface RunArtifact {
run: ApiRunDetail;
executionArtifacts: ExecutionArtifact[];
}
export enum MetricsType {
SCALAR_METRICS,
CONFUSION_MATRIX,
ROC_CURVE,
HTML,
MARKDOWN,
}
const metricsTypeToFilter = (metricsType: MetricsType): string => {
switch (metricsType) {
case MetricsType.SCALAR_METRICS:
@ -103,9 +93,10 @@ function filterRunArtifactsByType(
runArtifacts: RunArtifact[],
artifactTypes: ArtifactType[],
metricsType: MetricsType,
): RunArtifact[] {
): RunArtifactData {
const metricsFilter = metricsTypeToFilter(metricsType);
const typeRuns: RunArtifact[] = [];
let artifactCount: number = 0;
for (const runArtifact of runArtifacts) {
const typeExecutions: ExecutionArtifact[] = [];
for (const e of runArtifact.executionArtifacts) {
@ -124,6 +115,7 @@ function filterRunArtifactsByType(
);
}
if (typeArtifacts.length > 0) {
artifactCount += typeArtifacts.length;
typeExecutions.push({
execution: e.execution,
linkedArtifacts: typeArtifacts,
@ -137,7 +129,10 @@ function filterRunArtifactsByType(
} as RunArtifact);
}
}
return typeRuns;
return {
runArtifacts: typeRuns,
artifactCount,
};
}
function getRunArtifacts(runs: ApiRunDetail[], mlmdPackages: MlmdPackage[]): RunArtifact[] {
@ -202,6 +197,12 @@ function CompareV2(props: CompareV2Props) {
const [rocCurveArtifacts, setRocCurveArtifacts] = useState<Artifact[]>([]);
const [selectedRocCurveArtifacts, setSelectedRocCurveArtifacts] = useState<Artifact[]>([]);
const [scalarMetricsArtifacts, setScalarMetricsArtifacts] = useState<RunArtifact[]>([]);
const [scalarMetricsArtifactCount, setScalarMetricsArtifactCount] = useState<number>(0);
const [scalarMetricsTableData, setScalarMetricsTableData] = useState<
CompareTableProps | undefined
>(undefined);
// Selected artifacts for two-panel layout.
const createSelectedArtifactArray = (count: number): SelectedArtifact[] => {
const array: SelectedArtifact[] = [];
@ -279,19 +280,29 @@ function CompareV2(props: CompareV2Props) {
useEffect(() => {
if (runs && mlmdPackages && artifactTypes) {
const runArtifacts: RunArtifact[] = getRunArtifacts(runs, mlmdPackages);
setConfusionMatrixRunArtifacts(
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.CONFUSION_MATRIX),
const scalarMetricsArtifactData = filterRunArtifactsByType(
runArtifacts,
artifactTypes,
MetricsType.SCALAR_METRICS,
);
setScalarMetricsArtifacts(scalarMetricsArtifactData.runArtifacts);
setScalarMetricsArtifactCount(scalarMetricsArtifactData.artifactCount);
setConfusionMatrixRunArtifacts(
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.CONFUSION_MATRIX)
.runArtifacts,
);
setHtmlRunArtifacts(
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.HTML).runArtifacts,
);
setHtmlRunArtifacts(filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.HTML));
setMarkdownRunArtifacts(
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.MARKDOWN),
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.MARKDOWN).runArtifacts,
);
const rocCurveRunArtifacts: RunArtifact[] = filterRunArtifactsByType(
runArtifacts,
artifactTypes,
MetricsType.ROC_CURVE,
);
).runArtifacts;
const rocCurveArtifacts: Artifact[] = flatMapDeep(
rocCurveRunArtifacts.map(rocCurveArtifact =>
rocCurveArtifact.executionArtifacts.map(executionArtifact =>
@ -394,6 +405,18 @@ function CompareV2(props: CompareV2Props) {
setSelectedIds(selectedIds);
};
useEffect(() => {
const compareTableProps: CompareTableProps = getCompareTableProps(
scalarMetricsArtifacts,
scalarMetricsArtifactCount,
);
if (compareTableProps.yLabels.length === 0) {
setScalarMetricsTableData(undefined);
} else {
setScalarMetricsTableData(compareTableProps);
}
}, [scalarMetricsArtifacts, scalarMetricsArtifactCount]);
const updateSelectedArtifacts = (newArtifacts: SelectedArtifact[]) => {
selectedArtifactsMap[metricsTab] = newArtifacts;
setSelectedArtifactsMap(selectedArtifactsMap);
@ -452,8 +475,12 @@ function CompareV2(props: CompareV2Props) {
onSwitch={setMetricsTab}
/>
<div className={classes(padding(20, 'lrt'), css.outputsOverflow)}>
{/* TODO(zpChris): Add the scalar metrics table. */}
{metricsTab === MetricsType.SCALAR_METRICS && <p>This is the Scalar Metrics tab.</p>}
{metricsTab === MetricsType.SCALAR_METRICS &&
(scalarMetricsTableData ? (
<CompareTable {...scalarMetricsTableData} />
) : (
<p>There are no Scalar Metrics artifacts available on the selected runs.</p>
))}
{metricsTab === MetricsType.CONFUSION_MATRIX && (
<MetricsDropdown
filteredRunArtifacts={confusionMatrixRunArtifacts}