[tfjs-node] replace deprecated utils (#8425)

Co-authored-by: Matthew Soulanille <msoulanille@google.com>
This commit is contained in:
Valérian Rousset 2025-04-23 20:30:04 +02:00 committed by GitHub
parent 407c6e56b9
commit f2e55729ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 8 deletions

View File

@ -16,7 +16,6 @@
*/
import {KernelConfig, scalar, TopK, TopKAttrs, TopKInputs} from '@tensorflow/tfjs';
import {isNullOrUndefined} from 'util';
import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend';
@ -28,8 +27,8 @@ export const topKConfig: KernelConfig = {
const backend = args.backend as NodeJSKernelBackend;
const {k, sorted} = args.attrs as unknown as TopKAttrs;
const kCount = isNullOrUndefined(k) ? 1 : k;
const isSorted = isNullOrUndefined(sorted) ? true : sorted;
const kCount = k ?? 1;
const isSorted = sorted ?? true;
const opAttrs = [
{name: 'sorted', type: backend.binding.TF_ATTR_BOOL, value: isSorted},
createTensorsTypeOpAttr('T', x.dtype),

View File

@ -17,7 +17,6 @@
import * as tf from '@tensorflow/tfjs';
import {backend_util, BackendTimingInfo, DataId, DataType, KernelBackend, ModelTensorInfo, Rank, Scalar, scalar, ScalarLike, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorInfo, tidy, util} from '@tensorflow/tfjs';
import {isArray, isNullOrUndefined} from 'util';
import {encodeInt32ArrayAsInt64, Int64Scalar} from './int64_tensors';
import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding';
@ -740,7 +739,7 @@ export function getTFDType(dataType: tf.DataType): number {
export function createTensorsTypeOpAttr(
attrName: string,
tensorsOrDtype: tf.Tensor|tf.Tensor[]|tf.DataType): TFEOpAttr {
if (isNullOrUndefined(tensorsOrDtype)) {
if (tensorsOrDtype === null || tensorsOrDtype === undefined) {
throw new Error('Invalid input tensors value.');
}
return {
@ -757,7 +756,7 @@ export function createTensorsTypeOpAttr(
export function createOpAttr(
attrName: string, tensorsOrDtype: tf.Tensor|tf.Tensor[]|tf.DataType,
value: ScalarLike): TFEOpAttr {
if (isNullOrUndefined(tensorsOrDtype)) {
if (tensorsOrDtype === null || tensorsOrDtype === undefined) {
throw new Error('Invalid input tensors value.');
}
return {name: attrName, type: nodeBackend().binding.TF_BOOL, value};
@ -765,10 +764,10 @@ export function createOpAttr(
/** Returns the dtype number for a single or list of input Tensors. */
function getTFDTypeForInputs(tensors: tf.Tensor|tf.Tensor[]): number {
if (isNullOrUndefined(tensors)) {
if (tensors === null || tensors === undefined) {
throw new Error('Invalid input tensors value.');
}
if (isArray(tensors)) {
if (Array.isArray(tensors)) {
for (let i = 0; i < tensors.length; i++) {
return getTFDType(tensors[i].dtype);
}