14 KiB
Substrates (JAX, NumPy)
Current as of 9/15/2020
TFP supports alternative numerical backends to TensorFlow, including both JAX and NumPy. The intent of this document is to explain some of the details of supporting disparate backends, and how we have handled them, with the objective of helping contributors read / write code, and understand how these alternative-substrate packages are assembled.
Alternative backends
In
tensorflow_probability/python/internal/backend
we find the implementations of both the NumPy and JAX backends. These imitate
the portion of the TensorFlow API used by TFP, but are implemented in terms of
the corresponding substrate's primitives.
Since JAX provides jax.numpy, we are in most cases able to write a single
NumPy implementation under
tensorflow_probability/python/internal/backend,
then use a rewrite script (found at
tensorflow_probability/python/internal/backend/jax/rewrite.py)
to generate a JAX variant at bazel build time. See the genrules in
tensorflow_probability/python/internal/backend/jax/BUILD
for more details of how this occurs. In cases where JAX provides a different API
(e.g. random) or a more performant API (e.g. batched matrix decompositions,
vmap, etc.), we will special-case using an if JAX_MODE: block.
Wrinkles
-
Shapes
In TensorFlow, the
shapeattribute of aTensoris atf.TensorShapeobject, whereas in JAX and NumPy, it is a simpletupleofints. To handle cases for both systems nicely, we usefrom tensorflow_probability.python.internal import tensorshape_utiland move special-casing into this library. -
DTypes
TensorFlow, and to some extent JAX, use the dtype of inputs to infer the dtype of outputs, and generally try to preserve the
Tensordtype in a binary op betweenTensorand non-Tensor. NumPy, on the other hand, aggressively pushes dtypes toward 64-bit when left unspecified. In some cases debugging the JAX substrate, we have seen issues with dtypes changing from float32 to float64 or vice versa across the iterations of a while loop. Finding where this change happens can be tricky. Where possible, we aim to fix these in theinternal/backendpackage, as opposed to implementation files. -
Shapes, again (JAX "omnistaging" /
prefer_static)Every JAX primitive observed within a JIT or JAX control flow context becomes an abstract Tracer. This is similar to
@tf.function. The main challenge this introduces for TFP is that TF allows dynamic shapes asTensors whereas JAX (being built atop XLA) needs shapes to be statically available (i.e. atupleornumpy.ndarray, not a JAXndarray).If you observe issues with shapes derived from
Tracers in JAX, often a simple fix isfrom tensorflow_probability.python.internal import prefer_static as psfollowed by replacingtf.shapewithps.shape, and similar for other ops such astf.rank,tf.size,tf.concat(when dealing with shapes), (the args to)tf.range, etc. It's also useful to be aware ofps.convert_to_shape_tensor, which behaves liketf.convert_to_tensorfor TF, but leaves things asnp.ndarrayfor JAX. Similarly, in constructors, use theas_shape_tensor=Truearg totensor_util.convert_nonref_to_tensorfor shape-related values. -
tf.GradientTapeTF uses a tape to record ops for later gradient evaluation, whereas JAX rewrites a function while tracing its execution. Since the function transform is more general, we aim to replace usage of
GradientTape(in tests, TFP impl, etc), withtfp.math.value_and_gradientor similar. Then, we can special-caseJAX_MODEinside the body ofvalue_and_gradient. -
tf.Variable,tf_keras.optimizers.OptimizerTF provides a
Variableabstraction so that graph functions may modify state, including using the KerasOptimizersubclasses likeAdam. JAX, in contrast, operates only on pure functions. In general, TFP is fairly functional (e.g.tfp.optimizer.lbfgs_minimize), but in some cases (e.g.tfp.vi.fit_surrogate_posterior,tfp.optimizer.StochasticGradientLangevinDynamics) we have felt the mismatch too strong to try to port code to JAX. Some approaches to hoisting state out of a stateful function can be seen in the TFP spinoff projectoryx. -
Custom derivatives
JAX supports both forward and reverse mode autodifferentiation, and where possible TFP aims to support both in JAX. To do so, in places where we define a custom derivative, we use an internal wrapper which provides a function decorator that supports both TF and JAX's interfaces for custom derivatives, namely:
from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient def _f(..): pass def _f_fwd(..): return _f(..), bwd_auxiliary_data def _f_bwd(bwd_auxiliary_data, dy): pass def _f_jvp(primals, tangents): return _f(*primals), df(primals, tangents) @tfp_custom_gradient.custom_gradient(vjp_fwd=_f_fwd, vjp_bwd=_f_bwd, jvp_fn=_f_jvp) def f(..): return _f(..)For more information, the JAX custom derivatives doc can be useful.
-
Randomness
In TF, we support both "stateful" (i.e. some latent memory tracks the state of the sampler) and "stateless" sampling. JAX natively supports only stateless, for functional purity reasons. For internal use, we have
from tensorflow_probability.python.internal import samplers, a library that provides methods to:- convert stateful seeds to stateless, add salts (
sanitize_seed) - split stateless seeds to multiple descendant seeds (
split_seed) - proxy through to a number of stateless samplers (
normal,uniform, ...)
When the rewrite script is dealing with a
..._test.pyfile, it will rewrite calls totf.random.{uniform,normal,...}totf.random.stateless_{uniform,normal,...}to ensure compatibility with the JAX backend, which only implements the stateless samplers. - convert stateful seeds to stateless, add salts (
Rewriting TF code
In a couple cases, we commit into the repository script-munged source from
TensorFlow. These files can be found under
tensorflow_probability/python/internal/backend/numpy/gen.
They currently include:
- an implementation of
tf.TensorShape - several parts of
tf.linalg, especially thetf.linalg.LinearOperatorclasses
The actual rewriting is accomplished by scripts found under
tensorflow_probability/python/internal/backend/meta,
namely gen_linear_operators.py and gen_tensor_shape.py.
The test
tensorflow_probability/python/internal/backend/numpy/rewrite_equivalence_test.py
verifies that the files in TensorFlow, when rewritten, match the files in the
gen/ directory. The test uses BUILD dependencies on genrules that apply
the rewrite scripts, and compares those genrule inputs to the source of the
files under the gen/ directory.
Similar to the sources in internal/backend/numpy, the sources in
internal/backend/numpy/gen are rewritten by jax/rewrite.py. Note that the
files under gen/ do not have the numpy import rewritten. This is because we
only want to rewrite TensorFlow usage of TensorFlow-ported code; typically when
TF code is using numpy, it is munging shapes, and JAX does not like shapes to
be munged using jax.numpy (must use plain numpy).
Rewriting TFP code
With internal/backend/{numpy,jax} now ready to provide a tf2jax or
tf2numpy backend, we can proceed to the core packages of TFP.
The script
tensorflow_probability/substrates/meta/rewrite.py
runs on TFP sources to auto-generate JAX and NumPy python source corresponding
to the given TF source.
The most important job of the rewrite script is to rewrite import tensorflow.compat.v2 as tf to from tensorflow_probability.python.internal.backend.jax import v2 as tf. Second to
that, the script will rewrite dependencies on TFP subpackages to dependencies on
the corresponding substrate-specific TFP subpackages. For example, the line
from tensorflow_probability.python import math as tfp_math becomes from tensorflow_probability.substrates.jax import math as tfp_math. Beyond that,
there are a number of peripheral replacements to work around other wrinkles
we've accumulated over time.
In rare cases we will put an explicit if JAX_MODE: or if NUMPY_MODE: block
into the implementation code of a TFP submodule. This should be very uncommon.
Whenever possible, the intent is for such special-casing to live under
python/internal. For example, today we see in bijectors/softplus.py:
# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
_stable_grad_softplus = tf.nn.softplus
else:
@tf.custom_gradient
def _stable_grad_softplus(x): # ...
Note that this rewrite currently adds exactly a 10-line header, so line numbers from stack traces will be +10 from the raw code.
BUILD rules
tensorflow_probability/python/build_defs.bzl
defines a pair of bazel build rules: multi_substrate_py_library and
multi_substrate_py_test.
These rules automatically invoke
tensorflow_probability/substrates/meta/rewrite.py
to emit JAX/NumPy source variants. The file bijectors/softplus.py gets
rewritten into bijectors/_generated_jax_softplus.py (you can view the output
under the corresponding bazel-genfiles directory).
These build rules are also responsible for rewriting TFP-internal deps to the
some_dep.jax or some_dep.numpy substrate-specific replacement.
The multi_substrate_py_library will emit three targets: a TF py_library with
the name given by the name argument, a JAX py_library with name name + '.jax', and a NumPy py_library with name name + '.numpy'.
The multi_substrate_py_test will emit three targets, each of name + '.tf',
name + '.jax', and name + '.numpy'. Rules specified by the
disabled_substrates arg will not have BUILD rules emitted at all; jax_tags
and numpy_tags can be used to specify specific tags that drop CI coverage
while keeping the target buildable and testable. The distinction is useful so
that we can track cases where we think a test should be fixable, but we haven't
yet, as opposed to cases like HMC where we know the test will never pass for
NumPy so we prefer to not even have the test target. All emitted test targets
are aggregated into a test_suite with name corresponding to the original
name arg.
In cases where we know we will never be able to support a given feature, the
substrates_omit_deps, jax_omit_deps, and numpy_omit_deps args to
multi_substrate_py_library can be used to exclude things. Examples include
non-pure code or code using tf.Variable (JAX wants pure functions), or HMC (no
gradients in NumPy!). When rewriting an __init__.py file, the rewrite script
is set up to comment out imports and __all__ lines corresponding to the
omitted deps.
In order to test against the same directory hierarchy as we use for wheel
packaging, the multi_substrate_py_library does some internal gymnastics with a
custom bazel rule which is able to add symlinks into
tensorflow_probability/substrates
pointing to that point to implementation files generated under
bazel-genfiles/tensorflow_probability/python (details in
_substrate_runfiles_symlinks_impl of build_defs.bzl).
Wheel packaging
When it comes to building the wheel, we must first use cp -L to resolve the
symlinks added as part of the bazel build. Otherwise the wheel does not follow
them and fails to include tfp.substrates. This cp -L command sits in
pip_pkg.sh (currently adjacent to this doc).
Integration testing
A couple of integration tests sit in
tensorflow_probability/substrates/meta/jax_integration_test.py
and
tensorflow_probability/substrates/meta/numpy_integration_test.py.
We run these under CI after building and installing a wheel to verify that the
tfp.substrates packages load correctly and do not require a tensorflow
install.