feat(frontend): Display scalar metrics table (#8026)
* Update the compare table to include parent column labels. * Show scalar metrics tablethrough rough calculations. * Remove unnecessary comment. * Remove unnecesssary print statements and simplify code. * Break up the loading of scalar metrics table data. * Update display and error when scalar metrics are not available. * Update compare table comment and error formatting. * Add tests and update artifact index to global variable. * Update formatting. * Separate scalar metrics table calculation into a separate utils file. * Add compare utils test and reorganize calculation of compare table props. * Ensure that the table organization works without names. * Update compare utils and table tests. * Update formatting and add compare table test and snapshot. * Rename file from tsx to ts. * Reset this branch to only hold the CompareTable changes. * Update formatting. * Update naming of artifactCount to newArtifactIndex. * Update new tests to use React Testing Library. * Minor fixes - newline and remove unused variable. * Update formatting. * Add back the changes to incorporate the CompareTable into the CompareV2 page. * Remove metrics tab text from the MetricsDropdown. * Fix the value type. * Change getExecutionName to getExecutionDisplayName. * Move imports to CompareUtils and top-of-file. * Add return type to getCompareTableProps. * Update return type to defined interface. * UExtract the data map key to a shared function. * Update the dataMap to more quickly fill the scalar metric table rows. * Update the sorting mechanism and remove the use of lodash. * Fix test.
This commit is contained in:
parent
061905b6df
commit
d3fe514db6
|
|
@ -21,10 +21,11 @@ import TestUtils, { testBestPractices } from 'src/TestUtils';
|
|||
import { Artifact, Event, Execution, Value } from 'src/third_party/mlmd';
|
||||
import * as metricsVisualizations from 'src/components/viewers/MetricsVisualizations';
|
||||
import * as Utils from 'src/lib/Utils';
|
||||
import { MetricsType, RunArtifact, SelectedArtifact } from 'src/pages/CompareV2';
|
||||
import { SelectedArtifact } from 'src/pages/CompareV2';
|
||||
import { LinkedArtifact } from 'src/mlmd/MlmdUtils';
|
||||
import * as jspb from 'google-protobuf';
|
||||
import MetricsDropdown from './MetricsDropdown';
|
||||
import { MetricsType, RunArtifact } from 'src/lib/v2/CompareUtils';
|
||||
|
||||
function newMockExecution(id: number, displayName?: string): Execution {
|
||||
const execution = new Execution();
|
||||
|
|
@ -343,7 +344,7 @@ describe('MetricsDropdown', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('HTML file loading and error display', async () => {
|
||||
it('HTML file loading and error display with namespace input', async () => {
|
||||
const getHtmlViewerConfigSpy = jest.spyOn(metricsVisualizations, 'getHtmlViewerConfig');
|
||||
getHtmlViewerConfigSpy.mockRejectedValue(new Error('HTML file not found.'));
|
||||
|
||||
|
|
@ -354,6 +355,7 @@ describe('MetricsDropdown', () => {
|
|||
metricsTab={MetricsType.HTML}
|
||||
selectedArtifacts={emptySelectedArtifacts}
|
||||
updateSelectedArtifacts={updateSelectedArtifactsSpy}
|
||||
namespace='namespaceInput'
|
||||
/>
|
||||
</CommonTestWrapper>,
|
||||
);
|
||||
|
|
@ -365,6 +367,10 @@ describe('MetricsDropdown', () => {
|
|||
|
||||
screen.getByRole('circularprogress');
|
||||
await waitFor(() => {
|
||||
expect(getHtmlViewerConfigSpy).toHaveBeenLastCalledWith(
|
||||
[firstLinkedArtifact],
|
||||
'namespaceInput',
|
||||
);
|
||||
screen.getByText('Error: failed loading HTML file. Click Details for more information.');
|
||||
});
|
||||
});
|
||||
|
|
@ -402,4 +408,6 @@ describe('MetricsDropdown', () => {
|
|||
screen.getByText('Choose a first Confusion Matrix artifact');
|
||||
screen.getByTitle('run1 > execution1 > artifact1');
|
||||
});
|
||||
|
||||
// TODO: Namespace...
|
||||
});
|
||||
|
|
|
|||
|
|
@ -32,11 +32,16 @@ import PlotCard from 'src/components/PlotCard';
|
|||
import { ViewerConfig } from 'src/components/viewers/Viewer';
|
||||
import CircularProgress from '@material-ui/core/CircularProgress';
|
||||
import Banner from 'src/components/Banner';
|
||||
import { ExecutionArtifact, MetricsType, RunArtifact, SelectedArtifact } from 'src/pages/CompareV2';
|
||||
import { SelectedArtifact } from 'src/pages/CompareV2';
|
||||
import { useQuery } from 'react-query';
|
||||
import { errorToMessage, logger } from 'src/lib/Utils';
|
||||
import { Execution } from 'src/third_party/mlmd';
|
||||
import { metricsTypeToString } from 'src/lib/v2/CompareUtils';
|
||||
import {
|
||||
metricsTypeToString,
|
||||
ExecutionArtifact,
|
||||
MetricsType,
|
||||
RunArtifact,
|
||||
} from 'src/lib/v2/CompareUtils';
|
||||
|
||||
const css = stylesheet({
|
||||
leftCell: {
|
||||
|
|
@ -79,7 +84,7 @@ interface MetricsDropdownProps {
|
|||
metricsTab: MetricsType;
|
||||
selectedArtifacts: SelectedArtifact[];
|
||||
updateSelectedArtifacts: (selectedArtifacts: SelectedArtifact[]) => void;
|
||||
namespace: string | undefined;
|
||||
namespace?: string;
|
||||
}
|
||||
|
||||
export default function MetricsDropdown(props: MetricsDropdownProps) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,201 @@
|
|||
/*
|
||||
* Copyright 2022 The Kubeflow Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import { testBestPractices } from 'src/TestUtils';
|
||||
import { getCompareTableProps, RunArtifact } from './CompareUtils';
|
||||
import { Artifact, Event, Execution, Value } from 'src/third_party/mlmd';
|
||||
import { LinkedArtifact } from 'src/mlmd/MlmdUtils';
|
||||
import * as jspb from 'google-protobuf';
|
||||
|
||||
function newMockExecution(id: number, displayName?: string): Execution {
|
||||
const execution = new Execution();
|
||||
execution.setId(id);
|
||||
if (displayName) {
|
||||
const customPropertiesMap: Map<string, Value> = new Map();
|
||||
const displayNameValue = new Value();
|
||||
displayNameValue.setStringValue(displayName);
|
||||
customPropertiesMap.set('display_name', displayNameValue);
|
||||
jest.spyOn(execution, 'getCustomPropertiesMap').mockReturnValue(customPropertiesMap);
|
||||
}
|
||||
return execution;
|
||||
}
|
||||
|
||||
function newMockEvent(id: number, displayName?: string): Event {
|
||||
const event = new Event();
|
||||
event.setArtifactId(id);
|
||||
event.setExecutionId(id);
|
||||
event.setType(Event.Type.OUTPUT);
|
||||
if (displayName) {
|
||||
const path = new Event.Path();
|
||||
const step = new Event.Path.Step();
|
||||
step.setKey(displayName);
|
||||
path.addSteps(step);
|
||||
event.setPath(path);
|
||||
}
|
||||
return event;
|
||||
}
|
||||
|
||||
function newMockArtifact(id: number, scalarMetricValues: number[], displayName?: string): Artifact {
|
||||
const artifact = new Artifact();
|
||||
artifact.setId(id);
|
||||
|
||||
const customPropertiesMap: jspb.Map<string, Value> = jspb.Map.fromObject([], null, null);
|
||||
if (displayName) {
|
||||
const displayNameValue = new Value();
|
||||
displayNameValue.setStringValue(displayName);
|
||||
customPropertiesMap.set('display_name', displayNameValue);
|
||||
}
|
||||
|
||||
scalarMetricValues.forEach((scalarMetricValue, index) => {
|
||||
const value = new Value();
|
||||
value.setDoubleValue(scalarMetricValue);
|
||||
customPropertiesMap.set(`scalarMetric${index}`, value);
|
||||
});
|
||||
|
||||
jest.spyOn(artifact, 'getCustomPropertiesMap').mockReturnValue(customPropertiesMap);
|
||||
return artifact;
|
||||
}
|
||||
|
||||
function newMockLinkedArtifact(
|
||||
id: number,
|
||||
scalarMetricValues: number[],
|
||||
displayName?: string,
|
||||
): LinkedArtifact {
|
||||
return {
|
||||
artifact: newMockArtifact(id, scalarMetricValues, displayName),
|
||||
event: newMockEvent(id, displayName),
|
||||
} as LinkedArtifact;
|
||||
}
|
||||
|
||||
testBestPractices();
|
||||
describe('CompareUtils', () => {
|
||||
it('Empty scalar metrics artifacts results in empty table data', () => {
|
||||
expect(getCompareTableProps([], 0)).toMatchObject({
|
||||
xLabels: [],
|
||||
yLabels: [],
|
||||
xParentLabels: [],
|
||||
rows: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('Scalar metrics artifacts with all data and names populated', () => {
|
||||
const scalarMetricsArtifacts: RunArtifact[] = [
|
||||
{
|
||||
run: {
|
||||
run: {
|
||||
id: '1',
|
||||
name: 'run1',
|
||||
},
|
||||
},
|
||||
executionArtifacts: [
|
||||
{
|
||||
execution: newMockExecution(1, 'execution1'),
|
||||
linkedArtifacts: [
|
||||
newMockLinkedArtifact(1, [1, 2], 'artifact1'),
|
||||
newMockLinkedArtifact(2, [1], 'artifact2'),
|
||||
],
|
||||
},
|
||||
{
|
||||
execution: newMockExecution(2, 'execution2'),
|
||||
linkedArtifacts: [newMockLinkedArtifact(3, [3], 'artifact3')],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
run: {
|
||||
run: {
|
||||
id: '2',
|
||||
name: 'run2',
|
||||
},
|
||||
},
|
||||
executionArtifacts: [
|
||||
{
|
||||
execution: newMockExecution(3, 'execution1'),
|
||||
linkedArtifacts: [newMockLinkedArtifact(4, [4], 'artifact1')],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
const artifactCount: number = 4;
|
||||
|
||||
expect(getCompareTableProps(scalarMetricsArtifacts, artifactCount)).toMatchObject({
|
||||
xLabels: [
|
||||
'execution1 > artifact1',
|
||||
'execution1 > artifact2',
|
||||
'execution2 > artifact3',
|
||||
'execution1 > artifact1',
|
||||
],
|
||||
yLabels: ['scalarMetric0', 'scalarMetric1'],
|
||||
xParentLabels: [
|
||||
{ colSpan: 3, label: 'run1' },
|
||||
{ colSpan: 1, label: 'run2' },
|
||||
],
|
||||
rows: [
|
||||
['1', '1', '3', '4'],
|
||||
['2', '', '', ''],
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('Scalar metrics artifacts with data populated and no names', () => {
|
||||
const scalarMetricsArtifacts: RunArtifact[] = [
|
||||
{
|
||||
run: {
|
||||
run: {
|
||||
id: '1',
|
||||
},
|
||||
},
|
||||
executionArtifacts: [
|
||||
{
|
||||
execution: newMockExecution(1),
|
||||
linkedArtifacts: [newMockLinkedArtifact(1, [1, 2]), newMockLinkedArtifact(2, [1])],
|
||||
},
|
||||
{
|
||||
execution: newMockExecution(2),
|
||||
linkedArtifacts: [newMockLinkedArtifact(3, [3])],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
run: {
|
||||
run: {
|
||||
id: '2',
|
||||
},
|
||||
},
|
||||
executionArtifacts: [
|
||||
{
|
||||
execution: newMockExecution(3),
|
||||
linkedArtifacts: [newMockLinkedArtifact(4, [4])],
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
const artifactCount: number = 4;
|
||||
|
||||
expect(getCompareTableProps(scalarMetricsArtifacts, artifactCount)).toMatchObject({
|
||||
xLabels: ['- > -', '- > -', '- > -', '- > -'],
|
||||
yLabels: ['scalarMetric0', 'scalarMetric1'],
|
||||
xParentLabels: [
|
||||
{ colSpan: 3, label: '-' },
|
||||
{ colSpan: 1, label: '-' },
|
||||
],
|
||||
rows: [
|
||||
['1', '1', '3', '4'],
|
||||
['2', '', '', ''],
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -14,7 +14,156 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import { MetricsType } from 'src/pages/CompareV2';
|
||||
import { CompareTableProps, xParentLabel } from 'src/components/CompareTable';
|
||||
import { getArtifactName, getExecutionDisplayName, LinkedArtifact } from 'src/mlmd/MlmdUtils';
|
||||
import { getMetadataValue } from 'src/mlmd/Utils';
|
||||
import { Execution, Value } from 'src/third_party/mlmd';
|
||||
import * as jspb from 'google-protobuf';
|
||||
import { ApiRunDetail } from 'src/apis/run';
|
||||
|
||||
export interface ExecutionArtifact {
|
||||
execution: Execution;
|
||||
linkedArtifacts: LinkedArtifact[];
|
||||
}
|
||||
|
||||
export interface RunArtifact {
|
||||
run: ApiRunDetail;
|
||||
executionArtifacts: ExecutionArtifact[];
|
||||
}
|
||||
|
||||
interface ScalarRowData {
|
||||
row: string[];
|
||||
dataCount: number;
|
||||
}
|
||||
|
||||
interface ScalarTableData {
|
||||
xLabels: string[];
|
||||
xParentLabels: xParentLabel[];
|
||||
dataMap: { [key: string]: ScalarRowData };
|
||||
}
|
||||
|
||||
export interface RunArtifactData {
|
||||
runArtifacts: RunArtifact[];
|
||||
artifactCount: number;
|
||||
}
|
||||
|
||||
export const getCompareTableProps = (
|
||||
scalarMetricsArtifacts: RunArtifact[],
|
||||
artifactCount: number,
|
||||
): CompareTableProps => {
|
||||
const scalarTableData = getScalarTableData(scalarMetricsArtifacts, artifactCount);
|
||||
|
||||
// Sort by decreasing data item count.
|
||||
const sortedDataList = Object.entries(scalarTableData.dataMap).sort(
|
||||
(a, b) => b[1].dataCount - a[1].dataCount,
|
||||
);
|
||||
const yLabels: string[] = [];
|
||||
const rows: string[][] = [];
|
||||
for (const sortedDataItem of sortedDataList) {
|
||||
yLabels.push(sortedDataItem[0]);
|
||||
rows.push(sortedDataItem[1].row);
|
||||
}
|
||||
return {
|
||||
xLabels: scalarTableData.xLabels,
|
||||
yLabels,
|
||||
xParentLabels: scalarTableData.xParentLabels,
|
||||
rows,
|
||||
} as CompareTableProps;
|
||||
};
|
||||
|
||||
// Get different components needed to construct the scalar metrics table.
|
||||
const getScalarTableData = (
|
||||
scalarMetricsArtifacts: RunArtifact[],
|
||||
artifactCount: number,
|
||||
): ScalarTableData => {
|
||||
const xLabels: string[] = [];
|
||||
const xParentLabels: xParentLabel[] = [];
|
||||
const dataMap: { [key: string]: ScalarRowData } = {};
|
||||
|
||||
let artifactIndex = 0;
|
||||
for (const runArtifact of scalarMetricsArtifacts) {
|
||||
const runName = runArtifact.run.run?.name || '-';
|
||||
|
||||
const newArtifactIndex = loadScalarExecutionArtifacts(
|
||||
runArtifact.executionArtifacts,
|
||||
xLabels,
|
||||
dataMap,
|
||||
artifactIndex,
|
||||
artifactCount,
|
||||
);
|
||||
|
||||
const xParentLabel: xParentLabel = {
|
||||
label: runName,
|
||||
colSpan: newArtifactIndex - artifactIndex,
|
||||
};
|
||||
xParentLabels.push(xParentLabel);
|
||||
artifactIndex = newArtifactIndex;
|
||||
}
|
||||
|
||||
return {
|
||||
xLabels,
|
||||
xParentLabels,
|
||||
dataMap,
|
||||
} as ScalarTableData;
|
||||
};
|
||||
|
||||
// Load the data as well as row and column labels from execution artifacts.
|
||||
const loadScalarExecutionArtifacts = (
|
||||
executionArtifacts: ExecutionArtifact[],
|
||||
xLabels: string[],
|
||||
dataMap: { [key: string]: ScalarRowData },
|
||||
artifactIndex: number,
|
||||
artifactCount: number,
|
||||
): number => {
|
||||
for (const executionArtifact of executionArtifacts) {
|
||||
const executionText: string = getExecutionDisplayName(executionArtifact.execution) || '-';
|
||||
for (const linkedArtifact of executionArtifact.linkedArtifacts) {
|
||||
const linkedArtifactText: string = getArtifactName(linkedArtifact) || '-';
|
||||
const xLabel = `${executionText} > ${linkedArtifactText}`;
|
||||
xLabels.push(xLabel);
|
||||
|
||||
const customProperties = linkedArtifact.artifact.getCustomPropertiesMap();
|
||||
addScalarDataItems(customProperties, dataMap, artifactIndex, artifactCount);
|
||||
artifactIndex++;
|
||||
}
|
||||
}
|
||||
return artifactIndex;
|
||||
};
|
||||
|
||||
// Add the scalar metric names and data items.
|
||||
const addScalarDataItems = (
|
||||
customProperties: jspb.Map<string, Value>,
|
||||
dataMap: { [key: string]: ScalarRowData },
|
||||
artifactIndex: number,
|
||||
artifactCount: number,
|
||||
) => {
|
||||
for (const entry of customProperties.getEntryList()) {
|
||||
const scalarMetricName: string = entry[0];
|
||||
if (scalarMetricName === 'display_name') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!dataMap[scalarMetricName]) {
|
||||
dataMap[scalarMetricName] = {
|
||||
row: Array(artifactCount).fill(''),
|
||||
dataCount: 0,
|
||||
};
|
||||
}
|
||||
|
||||
dataMap[scalarMetricName].row[artifactIndex] = JSON.stringify(
|
||||
getMetadataValue(customProperties.get(scalarMetricName)),
|
||||
);
|
||||
dataMap[scalarMetricName].dataCount++;
|
||||
}
|
||||
};
|
||||
|
||||
export enum MetricsType {
|
||||
SCALAR_METRICS,
|
||||
CONFUSION_MATRIX,
|
||||
ROC_CURVE,
|
||||
HTML,
|
||||
MARKDOWN,
|
||||
}
|
||||
|
||||
export const metricsTypeToString = (metricsType: MetricsType): string => {
|
||||
switch (metricsType) {
|
||||
|
|
|
|||
|
|
@ -377,6 +377,13 @@ export function filterLinkedArtifactsByType(
|
|||
return artifacts.filter(x => artifactTypeIds.includes(x.artifact.getTypeId()));
|
||||
}
|
||||
|
||||
export function getExecutionDisplayName(execution: Execution): string | undefined {
|
||||
return execution
|
||||
.getCustomPropertiesMap()
|
||||
.get('display_name')
|
||||
?.getStringValue();
|
||||
}
|
||||
|
||||
export function getArtifactName(linkedArtifact: LinkedArtifact): string | undefined {
|
||||
return getArtifactNameFromEvent(linkedArtifact.event);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -423,20 +423,25 @@ describe('CompareV2', () => {
|
|||
);
|
||||
await TestUtils.flushPromises();
|
||||
|
||||
// TODO(zpChris): This test will be improved after default error states are provided in #8029.
|
||||
screen.getByText('This is the Scalar Metrics tab.');
|
||||
screen.getByText('There are no Scalar Metrics artifacts available on the selected runs.');
|
||||
|
||||
fireEvent.click(screen.getByText('ROC Curve'));
|
||||
expect(screen.queryByText('This is the Scalar Metrics Tab')).toBeNull();
|
||||
fireEvent.click(screen.getByText('Confusion Matrix'));
|
||||
screen.getByText('There are no Confusion Matrix artifacts available on the selected runs.');
|
||||
expect(
|
||||
screen.queryByText('There are no Scalar Metrics artifacts available on the selected runs.'),
|
||||
).toBeNull();
|
||||
|
||||
fireEvent.click(screen.getByText('ROC Curve'));
|
||||
expect(screen.queryByText('This is the Scalar Metrics Tab')).toBeNull();
|
||||
fireEvent.click(screen.getByText('Confusion Matrix'));
|
||||
screen.getByText('There are no Confusion Matrix artifacts available on the selected runs.');
|
||||
|
||||
fireEvent.click(screen.getByText('Scalar Metrics'));
|
||||
screen.getByText('This is the Scalar Metrics tab.');
|
||||
screen.getByText('There are no Scalar Metrics artifacts available on the selected runs.');
|
||||
expect(
|
||||
screen.queryByText('There are no Confusion Matrix artifacts available on the selected runs.'),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it('Two-panel tabs have no dropdown loaded as content is not present', async () => {
|
||||
it('Metrics tabs have no content loaded as artifacts are not present', async () => {
|
||||
const getRunSpy = jest.spyOn(Apis.runServiceApi, 'getRun');
|
||||
runs = [newMockRun(MOCK_RUN_1_ID), newMockRun(MOCK_RUN_2_ID), newMockRun(MOCK_RUN_3_ID)];
|
||||
getRunSpy.mockImplementation((id: string) => runs.find(r => r.run!.id === id));
|
||||
|
|
@ -454,6 +459,8 @@ describe('CompareV2', () => {
|
|||
);
|
||||
await TestUtils.flushPromises();
|
||||
|
||||
screen.getByText('There are no Scalar Metrics artifacts available on the selected runs.');
|
||||
|
||||
fireEvent.click(screen.getByText('Confusion Matrix'));
|
||||
screen.getByText('There are no Confusion Matrix artifacts available on the selected runs.');
|
||||
|
||||
|
|
@ -518,7 +525,7 @@ describe('CompareV2', () => {
|
|||
);
|
||||
await TestUtils.flushPromises();
|
||||
|
||||
await waitFor(() => expect(filterLinkedArtifactsByTypeSpy).toHaveBeenCalledTimes(12));
|
||||
await waitFor(() => expect(filterLinkedArtifactsByTypeSpy).toHaveBeenCalledTimes(15));
|
||||
|
||||
expect(screen.queryByText(/Confusion matrix: artifactName/)).toBeNull();
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,14 @@ import RunList from './RunList';
|
|||
import { METRICS_SECTION_NAME, OVERVIEW_SECTION_NAME, PARAMS_SECTION_NAME } from './Compare';
|
||||
import { SelectedItem } from 'src/components/TwoLevelDropdown';
|
||||
import MD2Tabs from 'src/atoms/MD2Tabs';
|
||||
import CompareTable, { CompareTableProps } from 'src/components/CompareTable';
|
||||
import {
|
||||
ExecutionArtifact,
|
||||
getCompareTableProps,
|
||||
MetricsType,
|
||||
RunArtifact,
|
||||
RunArtifactData,
|
||||
} from 'src/lib/v2/CompareUtils';
|
||||
import { ConfidenceMetricsSection } from 'src/components/viewers/MetricsVisualizations';
|
||||
import { flatMapDeep } from 'lodash';
|
||||
import { NamespaceContext, useNamespaceChangeEvent } from 'src/lib/KubeflowClient';
|
||||
|
|
@ -63,24 +71,6 @@ interface MlmdPackage {
|
|||
events: Event[];
|
||||
}
|
||||
|
||||
export interface ExecutionArtifact {
|
||||
execution: Execution;
|
||||
linkedArtifacts: LinkedArtifact[];
|
||||
}
|
||||
|
||||
export interface RunArtifact {
|
||||
run: ApiRunDetail;
|
||||
executionArtifacts: ExecutionArtifact[];
|
||||
}
|
||||
|
||||
export enum MetricsType {
|
||||
SCALAR_METRICS,
|
||||
CONFUSION_MATRIX,
|
||||
ROC_CURVE,
|
||||
HTML,
|
||||
MARKDOWN,
|
||||
}
|
||||
|
||||
const metricsTypeToFilter = (metricsType: MetricsType): string => {
|
||||
switch (metricsType) {
|
||||
case MetricsType.SCALAR_METRICS:
|
||||
|
|
@ -103,9 +93,10 @@ function filterRunArtifactsByType(
|
|||
runArtifacts: RunArtifact[],
|
||||
artifactTypes: ArtifactType[],
|
||||
metricsType: MetricsType,
|
||||
): RunArtifact[] {
|
||||
): RunArtifactData {
|
||||
const metricsFilter = metricsTypeToFilter(metricsType);
|
||||
const typeRuns: RunArtifact[] = [];
|
||||
let artifactCount: number = 0;
|
||||
for (const runArtifact of runArtifacts) {
|
||||
const typeExecutions: ExecutionArtifact[] = [];
|
||||
for (const e of runArtifact.executionArtifacts) {
|
||||
|
|
@ -124,6 +115,7 @@ function filterRunArtifactsByType(
|
|||
);
|
||||
}
|
||||
if (typeArtifacts.length > 0) {
|
||||
artifactCount += typeArtifacts.length;
|
||||
typeExecutions.push({
|
||||
execution: e.execution,
|
||||
linkedArtifacts: typeArtifacts,
|
||||
|
|
@ -137,7 +129,10 @@ function filterRunArtifactsByType(
|
|||
} as RunArtifact);
|
||||
}
|
||||
}
|
||||
return typeRuns;
|
||||
return {
|
||||
runArtifacts: typeRuns,
|
||||
artifactCount,
|
||||
};
|
||||
}
|
||||
|
||||
function getRunArtifacts(runs: ApiRunDetail[], mlmdPackages: MlmdPackage[]): RunArtifact[] {
|
||||
|
|
@ -202,6 +197,12 @@ function CompareV2(props: CompareV2Props) {
|
|||
const [rocCurveArtifacts, setRocCurveArtifacts] = useState<Artifact[]>([]);
|
||||
const [selectedRocCurveArtifacts, setSelectedRocCurveArtifacts] = useState<Artifact[]>([]);
|
||||
|
||||
const [scalarMetricsArtifacts, setScalarMetricsArtifacts] = useState<RunArtifact[]>([]);
|
||||
const [scalarMetricsArtifactCount, setScalarMetricsArtifactCount] = useState<number>(0);
|
||||
const [scalarMetricsTableData, setScalarMetricsTableData] = useState<
|
||||
CompareTableProps | undefined
|
||||
>(undefined);
|
||||
|
||||
// Selected artifacts for two-panel layout.
|
||||
const createSelectedArtifactArray = (count: number): SelectedArtifact[] => {
|
||||
const array: SelectedArtifact[] = [];
|
||||
|
|
@ -279,19 +280,29 @@ function CompareV2(props: CompareV2Props) {
|
|||
useEffect(() => {
|
||||
if (runs && mlmdPackages && artifactTypes) {
|
||||
const runArtifacts: RunArtifact[] = getRunArtifacts(runs, mlmdPackages);
|
||||
setConfusionMatrixRunArtifacts(
|
||||
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.CONFUSION_MATRIX),
|
||||
const scalarMetricsArtifactData = filterRunArtifactsByType(
|
||||
runArtifacts,
|
||||
artifactTypes,
|
||||
MetricsType.SCALAR_METRICS,
|
||||
);
|
||||
setScalarMetricsArtifacts(scalarMetricsArtifactData.runArtifacts);
|
||||
setScalarMetricsArtifactCount(scalarMetricsArtifactData.artifactCount);
|
||||
setConfusionMatrixRunArtifacts(
|
||||
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.CONFUSION_MATRIX)
|
||||
.runArtifacts,
|
||||
);
|
||||
setHtmlRunArtifacts(
|
||||
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.HTML).runArtifacts,
|
||||
);
|
||||
setHtmlRunArtifacts(filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.HTML));
|
||||
setMarkdownRunArtifacts(
|
||||
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.MARKDOWN),
|
||||
filterRunArtifactsByType(runArtifacts, artifactTypes, MetricsType.MARKDOWN).runArtifacts,
|
||||
);
|
||||
|
||||
const rocCurveRunArtifacts: RunArtifact[] = filterRunArtifactsByType(
|
||||
runArtifacts,
|
||||
artifactTypes,
|
||||
MetricsType.ROC_CURVE,
|
||||
);
|
||||
).runArtifacts;
|
||||
const rocCurveArtifacts: Artifact[] = flatMapDeep(
|
||||
rocCurveRunArtifacts.map(rocCurveArtifact =>
|
||||
rocCurveArtifact.executionArtifacts.map(executionArtifact =>
|
||||
|
|
@ -394,6 +405,18 @@ function CompareV2(props: CompareV2Props) {
|
|||
setSelectedIds(selectedIds);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const compareTableProps: CompareTableProps = getCompareTableProps(
|
||||
scalarMetricsArtifacts,
|
||||
scalarMetricsArtifactCount,
|
||||
);
|
||||
if (compareTableProps.yLabels.length === 0) {
|
||||
setScalarMetricsTableData(undefined);
|
||||
} else {
|
||||
setScalarMetricsTableData(compareTableProps);
|
||||
}
|
||||
}, [scalarMetricsArtifacts, scalarMetricsArtifactCount]);
|
||||
|
||||
const updateSelectedArtifacts = (newArtifacts: SelectedArtifact[]) => {
|
||||
selectedArtifactsMap[metricsTab] = newArtifacts;
|
||||
setSelectedArtifactsMap(selectedArtifactsMap);
|
||||
|
|
@ -452,8 +475,12 @@ function CompareV2(props: CompareV2Props) {
|
|||
onSwitch={setMetricsTab}
|
||||
/>
|
||||
<div className={classes(padding(20, 'lrt'), css.outputsOverflow)}>
|
||||
{/* TODO(zpChris): Add the scalar metrics table. */}
|
||||
{metricsTab === MetricsType.SCALAR_METRICS && <p>This is the Scalar Metrics tab.</p>}
|
||||
{metricsTab === MetricsType.SCALAR_METRICS &&
|
||||
(scalarMetricsTableData ? (
|
||||
<CompareTable {...scalarMetricsTableData} />
|
||||
) : (
|
||||
<p>There are no Scalar Metrics artifacts available on the selected runs.</p>
|
||||
))}
|
||||
{metricsTab === MetricsType.CONFUSION_MATRIX && (
|
||||
<MetricsDropdown
|
||||
filteredRunArtifacts={confusionMatrixRunArtifacts}
|
||||
|
|
|
|||
Loading…
Reference in New Issue