mirror of https://github.com/tensorflow/tfjs.git
Body: FEATURE Co-authored-by: Matthew Soulanille <msoulanille@google.com>
This commit is contained in:
parent
0677375de6
commit
936b448c20
|
|
@ -314,3 +314,22 @@ export function MSE(yTrue: Tensor, yPred: Tensor): Tensor {
|
|||
export function mse(yTrue: Tensor, yPred: Tensor): Tensor {
|
||||
return losses.meanSquaredError(yTrue, yPred);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes R2 score.
|
||||
*
|
||||
* ```js
|
||||
* const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
|
||||
* const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
|
||||
* const r2Score = tf.metrics.r2Score(yTrue, yPred);
|
||||
* r2Score.print();
|
||||
* ```
|
||||
* @param yTrue Truth Tensor.
|
||||
* @param yPred Prediction Tensor.
|
||||
* @return R2 score Tensor.
|
||||
*
|
||||
* @doc {heading: 'Metrics', namespace: 'metrics'}
|
||||
*/
|
||||
export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor {
|
||||
return metrics.r2Score(yTrue, yPred);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ import {Tensor, tidy} from '@tensorflow/tfjs-core';
|
|||
|
||||
import * as K from './backend/tfjs_backend';
|
||||
import {NotImplementedError, ValueError} from './errors';
|
||||
import {categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses';
|
||||
import {binaryCrossentropy as lossBinaryCrossentropy} from './losses';
|
||||
import {lossesMap} from './losses';
|
||||
import {binaryCrossentropy as lossBinaryCrossentropy, categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, lossesMap, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses';
|
||||
import {LossOrMetricFn} from './types';
|
||||
import * as util from './utils/generic_utils';
|
||||
|
||||
|
|
@ -112,6 +110,14 @@ export function sparseTopKCategoricalAccuracy(
|
|||
throw new NotImplementedError();
|
||||
}
|
||||
|
||||
export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor {
|
||||
return tidy(() => {
|
||||
const sumSquaresResiduals = yTrue.sub(yPred).square().sum();
|
||||
const sumSquares = yTrue.sub(yTrue.mean()).square().sum();
|
||||
return tfc.scalar(1).sub(sumSquaresResiduals.div(sumSquares));
|
||||
});
|
||||
}
|
||||
|
||||
// Aliases.
|
||||
export const mse = meanSquaredError;
|
||||
export const MSE = meanSquaredError;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import {scalar, Tensor, tensor, tensor1d, tensor2d} from '@tensorflow/tfjs-core'
|
|||
|
||||
import {setEpsilon} from './backend/common';
|
||||
import * as tfl from './index';
|
||||
import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName} from './metrics';
|
||||
import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName, r2Score} from './metrics';
|
||||
import {LossOrMetricFn} from './types';
|
||||
import {describeMathCPUAndGPU, describeMathCPUAndWebGL2, expectTensorsClose} from './utils/test_utils';
|
||||
|
||||
|
|
@ -283,6 +283,27 @@ describeMathCPUAndGPU('recall metric', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describeMathCPUAndGPU('r2Score', () => {
|
||||
it('1D', () => {
|
||||
const yTrue = tensor1d([3, -0.5, 2, 7, 4.2, 8.5, 1.3, 2.8, 6.7, 9.0]);
|
||||
const yPred = tensor1d([2.5, 0.0, 2.1, 7.8, 4.0, 8.2, 1.4, 2.9, 6.5, 9.1]);
|
||||
const score = r2Score(yTrue, yPred);
|
||||
expectTensorsClose(score, scalar(0.985));
|
||||
});
|
||||
it('2D', () => {
|
||||
const yTrue = tensor2d([
|
||||
[3, 2.5], [-0.5, 3.2], [2, 1.9], [7, 5.1], [4.2, 3.8], [8.5, 7.4],
|
||||
[1.3, 0.6], [2.8, 2.1], [6.7, 5.3], [9.0, 8.7]
|
||||
]);
|
||||
const yPred = tensor2d([
|
||||
[2.7, 2.3], [0.0, 3.1], [2.1, 1.8], [6.8, 5.0], [4.1, 3.7], [8.4, 7.2],
|
||||
[1.4, 0.7], [2.9, 2.2], [6.6, 5.2], [9.2, 8.9]
|
||||
]);
|
||||
const score = r2Score(yTrue, yPred);
|
||||
expectTensorsClose(score, scalar(0.995));
|
||||
});
|
||||
});
|
||||
|
||||
describe('metrics.get', () => {
|
||||
it('valid name, not alias', () => {
|
||||
expect(get('binaryAccuracy') === get('categoricalAccuracy')).toEqual(false);
|
||||
|
|
|
|||
Loading…
Reference in New Issue