diff --git a/frontend/src/lib/OutputArtifactLoader.ts b/frontend/src/lib/OutputArtifactLoader.ts index 52d8494833..35b6f16eef 100644 --- a/frontend/src/lib/OutputArtifactLoader.ts +++ b/frontend/src/lib/OutputArtifactLoader.ts @@ -294,21 +294,36 @@ export class OutputArtifactLoader { return buildArtifactViewer({ script, namespace }); }), ); - const anomaliesArtifactUris = filterArtifactUrisByType( - 'ExampleAnomalies', - artifactTypes, - artifacts, - ); + const anomaliesArtifacts = filterArtifactsByType('ExampleAnomalies', artifactTypes, artifacts); viewers = viewers.concat( - anomaliesArtifactUris.map(uri => { - uri = uri + '/anomalies.pbtxt'; - const script = [ - 'import tensorflow_data_validation as tfdv', - `anomalies = tfdv.load_anomalies_text('${uri}')`, - 'tfdv.display_anomalies(anomalies)', - ]; - return buildArtifactViewer({ script, namespace }); - }), + anomaliesArtifacts + .map(artifact => { + const splitNamesJSON = artifact + .getPropertiesMap() + .get('split_names') + ?.getStringValue(); + if (!splitNamesJSON) { + return []; + } + let splitNames; + try { + splitNames = JSON.parse(splitNamesJSON); + } catch (e) { + logger.warn('Failed to parse split names as a JSON array:', e); + } + if (!Array.isArray(splitNames)) { + return []; + } + return splitNames.map(name => { + const script = [ + 'import tensorflow_data_validation as tfdv', + `anomalies = tfdv.load_anomalies_text('${artifact.getUri()}/${name}')`, + 'tfdv.display_anomalies(anomalies)', + ]; + return buildArtifactViewer({ script, namespace }); + }); + }) + .flat(), ); const EvaluatorArtifactUris = filterArtifactUrisByType( 'ModelEvaluation', @@ -452,22 +467,25 @@ async function getArtifactTypes(): Promise { return res.getArtifactTypesList(); } +function filterArtifactsByType( + artifactTypeName: string, + artifactTypes: ArtifactType[], + artifacts: Artifact[], +): Artifact[] { + const artifactTypeIds = artifactTypes + .filter(artifactType => artifactType.getName() === artifactTypeName) + .map(artifactType => artifactType.getId()); + return artifacts.filter(artifact => artifactTypeIds.includes(artifact.getTypeId())); +} + function filterArtifactUrisByType( artifactTypeName: string, artifactTypes: ArtifactType[], artifacts: Artifact[], ): string[] { - const artifactTypeIds = artifactTypes - .filter(artifactType => artifactType.getName() === artifactTypeName) - .map(artifactType => artifactType.getId()); - const matchingArtifacts = artifacts.filter(artifact => - artifactTypeIds.includes(artifact.getTypeId()), - ); - - const tfdvArtifactsPaths = matchingArtifacts + return filterArtifactsByType(artifactTypeName, artifactTypes, artifacts) .map(artifact => artifact.getUri()) .filter(uri => uri); // uri not empty - return tfdvArtifactsPaths; } async function buildArtifactViewer({