Fix unittests for TensorFlow nightly

PiperOrigin-RevId: 602740133
This commit is contained in:
Richard Stotz 2024-01-30 08:50:45 -08:00 committed by Copybara-Service
parent 14b58329b8
commit cf97bc70fd
4 changed files with 75 additions and 2 deletions

51
third_party/tensorflow/tf-216.patch vendored Normal file
View File

@ -0,0 +1,51 @@
diff --git a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
index 9261a652f9c..0557e345ab1 100644
--- a/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
+++ b/tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl
@@ -2,7 +2,7 @@
load("//tensorflow/tools/toolchains:cpus/aarch64/aarch64.bzl", "remote_aarch64_configure")
load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure")
-load("//third_party/py:python_configure.bzl", "remote_python_configure")
+load("//third_party/py/non_hermetic:python_configure.bzl", "remote_python_configure")
def ml2014_tf_aarch64_configs(name_container_map, env):
for name, container in name_container_map.items():
diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
index 9f71a414bf7..57f70752323 100644
--- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
+++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl
@@ -1,6 +1,6 @@
"""Macro that creates external repositories for remote config."""
-load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure")
+load("//third_party/py/non_hermetic:python_configure.bzl", "local_python_configure", "remote_python_configure")
load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure")
load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure")
load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure")
diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl
index 056df85ffdb..7422baf8c59 100644
--- a/tensorflow/workspace2.bzl
+++ b/tensorflow/workspace2.bzl
@@ -37,7 +37,7 @@ load("//third_party/nasm:workspace.bzl", nasm = "repo")
load("//third_party/nccl:nccl_configure.bzl", "nccl_configure")
load("//third_party/opencl_headers:workspace.bzl", opencl_headers = "repo")
load("//third_party/pasta:workspace.bzl", pasta = "repo")
-load("//third_party/py:python_configure.bzl", "python_configure")
+load("//third_party/py/non_hermetic:python_configure.bzl", "python_configure")
load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo")
load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
diff --git a/third_party/py/non_hermetic/python_configure.bzl b/third_party/py/non_hermetic/python_configure.bzl
index 89732c3e33d..4ac1c8f5c04 100644
--- a/third_party/py/non_hermetic/python_configure.bzl
+++ b/third_party/py/non_hermetic/python_configure.bzl
@@ -203,7 +203,7 @@ def _create_local_python_repository(repository_ctx):
# Resolve all labels before doing any real work. Resolving causes the
# function to be restarted with all previous state being lost. This
# can easily lead to a O(n^2) runtime in the number of labels.
- build_tpl = repository_ctx.path(Label("//third_party/py:BUILD.tpl"))
+ build_tpl = repository_ctx.path(Label("//third_party/py/non_hermetic:BUILD.tpl"))
python_bin = get_python_bin(repository_ctx)
_check_python_bin(repository_ctx, python_bin)

View File

@ -8,6 +8,8 @@ def deps(from_git_repo = True):
name = "ydf",
urls = ["https://github.com/google/yggdrasil-decision-forests/archive/refs/heads/main.zip"],
strip_prefix = "yggdrasil-decision-forests-main",
# patch_args = ["-p1"],
# patches = ["@ydf//yggdrasil_decision_forests:ydf.patch"],
)
else:
# You can also clone the YDF repository manually.

View File

@ -0,0 +1,13 @@
diff --git a/yggdrasil_decision_forests/learner/decision_tree/BUILD b/yggdrasil_decision_forests/learner/decision_tree/BUILD
index 201418c..23de5cc 100644
--- a/yggdrasil_decision_forests/learner/decision_tree/BUILD
+++ b/yggdrasil_decision_forests/learner/decision_tree/BUILD
@@ -49,7 +49,7 @@ cc_library_ydf(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
- "@eigen_archive//:eigen3_internal",
+ "@eigen_archive//:eigen3",
] + select({
"//conditions:default": [
],

View File

@ -58,12 +58,11 @@ if [ ${TF_VERSION} == "nightly" ]; then
${PYTHON} -m pip install tf-nightly --force-reinstall
else
${PYTHON} -m pip install tensorflow==${TF_VERSION} --force-reinstall
TF_MINOR=$(echo $TF_VERSION | grep -oE '[0-9]+\.[0-9]*')
fi
ext=""
pip list
ext=""
if is_macos; then
ext='""'
# Tensorflow requires the use of GNU realpath instead of MacOS realpath.
@ -72,6 +71,14 @@ if is_macos; then
export PATH="/opt/homebrew/opt/coreutils/libexec/gnubin:$PATH"
fi
# For Tensorflow versions > 2.15, apply compatibility patches.
TF_MINOR=$(echo $TF_VERSION | grep -oP '[0-9]+\.[0-9]+')
if [[ ${TF_MINOR} != "2.15" ]]; then
sed -i $ext "s/tensorflow:tf.patch/tensorflow:tf-216.patch/" WORKSPACE
sed -i $ext "s/# patch_args = \[\"-p1\"\],/patch_args = \[\"-p1\"\],/" third_party/yggdrasil_decision_forests/workspace.bzl
sed -i $ext "s/# patches = \[\"\/\/third_party\/yggdrasil_decision_forests:ydf.patch\"\],/patches = \[\"\/\/third_party\/yggdrasil_decision_forests:ydf.patch\"\],/" third_party/yggdrasil_decision_forests/workspace.bzl
fi
# Get the commit SHA
short_commit_sha=$(${PYTHON} -c 'import tensorflow as tf; print(tf.__git_version__)' | tail -1)
if is_macos; then