mirror of https://github.com/tensorflow/tfjs.git
[tfjs-node] replace deprecated utils (#8425)
Co-authored-by: Matthew Soulanille <msoulanille@google.com>
This commit is contained in:
parent
407c6e56b9
commit
f2e55729ba
|
|
@ -16,7 +16,6 @@
|
|||
*/
|
||||
|
||||
import {KernelConfig, scalar, TopK, TopKAttrs, TopKInputs} from '@tensorflow/tfjs';
|
||||
import {isNullOrUndefined} from 'util';
|
||||
|
||||
import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend';
|
||||
|
||||
|
|
@ -28,8 +27,8 @@ export const topKConfig: KernelConfig = {
|
|||
const backend = args.backend as NodeJSKernelBackend;
|
||||
const {k, sorted} = args.attrs as unknown as TopKAttrs;
|
||||
|
||||
const kCount = isNullOrUndefined(k) ? 1 : k;
|
||||
const isSorted = isNullOrUndefined(sorted) ? true : sorted;
|
||||
const kCount = k ?? 1;
|
||||
const isSorted = sorted ?? true;
|
||||
const opAttrs = [
|
||||
{name: 'sorted', type: backend.binding.TF_ATTR_BOOL, value: isSorted},
|
||||
createTensorsTypeOpAttr('T', x.dtype),
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
import * as tf from '@tensorflow/tfjs';
|
||||
import {backend_util, BackendTimingInfo, DataId, DataType, KernelBackend, ModelTensorInfo, Rank, Scalar, scalar, ScalarLike, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorInfo, tidy, util} from '@tensorflow/tfjs';
|
||||
import {isArray, isNullOrUndefined} from 'util';
|
||||
|
||||
import {encodeInt32ArrayAsInt64, Int64Scalar} from './int64_tensors';
|
||||
import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding';
|
||||
|
|
@ -740,7 +739,7 @@ export function getTFDType(dataType: tf.DataType): number {
|
|||
export function createTensorsTypeOpAttr(
|
||||
attrName: string,
|
||||
tensorsOrDtype: tf.Tensor|tf.Tensor[]|tf.DataType): TFEOpAttr {
|
||||
if (isNullOrUndefined(tensorsOrDtype)) {
|
||||
if (tensorsOrDtype === null || tensorsOrDtype === undefined) {
|
||||
throw new Error('Invalid input tensors value.');
|
||||
}
|
||||
return {
|
||||
|
|
@ -757,7 +756,7 @@ export function createTensorsTypeOpAttr(
|
|||
export function createOpAttr(
|
||||
attrName: string, tensorsOrDtype: tf.Tensor|tf.Tensor[]|tf.DataType,
|
||||
value: ScalarLike): TFEOpAttr {
|
||||
if (isNullOrUndefined(tensorsOrDtype)) {
|
||||
if (tensorsOrDtype === null || tensorsOrDtype === undefined) {
|
||||
throw new Error('Invalid input tensors value.');
|
||||
}
|
||||
return {name: attrName, type: nodeBackend().binding.TF_BOOL, value};
|
||||
|
|
@ -765,10 +764,10 @@ export function createOpAttr(
|
|||
|
||||
/** Returns the dtype number for a single or list of input Tensors. */
|
||||
function getTFDTypeForInputs(tensors: tf.Tensor|tf.Tensor[]): number {
|
||||
if (isNullOrUndefined(tensors)) {
|
||||
if (tensors === null || tensors === undefined) {
|
||||
throw new Error('Invalid input tensors value.');
|
||||
}
|
||||
if (isArray(tensors)) {
|
||||
if (Array.isArray(tensors)) {
|
||||
for (let i = 0; i < tensors.length; i++) {
|
||||
return getTFDType(tensors[i].dtype);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue