8.1 KiB
Copyright 2018 The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Standards for tfp.distributions
This document represents the TFP standards for all members of the
tfp.distributions module. These standards only apply to the
tfp.distributions module and is intentionally a "high bar" meant to protect
users of the module. The "TensorFlow Distributions" whitepaper is an exposition on the standards
described here (see Section 3). These standards have been in effect since
16-aug-2016.
You are encouraged to subclass tfp.distributions.Distribution and invited to
disregard any/all of these standards. We especially recommend you ignore point
#4.
All TFP code (including tfp.distributions) follows the TFP Style Guide.
Requirements (Comprehensive)
-
A
Distributionsubclass must implement asamplefunction. Thesamplefunction shall generate random outcomes which in expectation correspond to other properties of theDistribution. Other properties must assume the semantics implied by the output ofsample. -
A
Distributionmust implement alog_probfunction. This function is the log probability mass function or probability density function, depending on the underlying measure. -
All member functions must be efficiently computable. To us, efficiently computable means: computable in (at most) expected polynomial time.
-
All member functions (except sample) must be deterministic, reproducible, and platform invariant. The means no stochastic approximations (even seeded). For example, it is acceptable to implement an efficiently computable analytical upper bound on entropy but it is not acceptable to implement a Monte Carlo estimate of entropy, even if that Monte Carlo estimate uses a seeded RNG for reproducibility. It also acceptable to implement an approximation of statistic (e.g., approximation of LogitNormal mean), as long as this can be computed non-stochastically.
-
Implementing other distribution properties is highly encouraged iff they are mathematically well-defined and efficiently computable. For example
mean_log_detis a meaningful property for a Wishart but not for a (scalar) Normal. -
A
Distribution's inputs/outputs areTensors with "Distributionshape semantics." For example, functions likeprobandcdfaccept aTensorwith "Distributionshape semantics" whereassamplereturns aTensorwith "Distribution shape semantics."- Exception:
tfd.JointDistributiontakes/returns alist-like ofTensors.
- Exception:
-
ADistribution'sevent_ndimsmust be known statically. For example,Wisharthasevent_ndims=2,MultivariateNormalDiaghasevent_ndims=1,Normalhasevent_ndims=0,Categoricalhasevent_ndims=0andOneHotCategoricalhasevent_ndims=1. Theevent_shapeneed not be known statically, i.e., this might only be known at runtime (in Eager mode) or at graph execution time (in graph mode). Often aDistribution'sevent_ndimswill be self-evident from the class name itself.- Redacted 07-nov-2019.
-
All
Distributions are implicitly or explicitly conditioned on global or local (per-sample) parameters. ADistributioninfers dtype and batch/event shape from its global parameters. For example, a scalarDistribution'sevent_shapeis implicitly inferrable (event_shape=[]) thus always known statically; the same is not necessarily true of aMultivariateNormalDiag. -
When possible, a
Distributionmust support Numpy-like broadcasting for all arguments. When broadcasting is not possible, arguments must be validated. -
All possible effort is made to validate arguments prior to graph execution. Any validation requiring graph execution must be gated by a Boolean, global parameter.
-
Distributionparameters are descriptive English (i.e., not Greek letters) and draw upon a relatively small, shared lexicon. Examples include:loc,scale,concentration,rate,probs,logits,df. When forced to choose between mathematical purity and conveying intuitive meaning, prefer the latter but provide extensive documentation. -
TFP
Distributions guarantee that input arguments are not manipulated by__init__except to convert non-TF-derived inputs totf.Tensors.Among other things, this contract implies:
-
tf.Variable-derived__init__arguments are not read ("concretized") until some computation is requested, e.g. a member function is called. We call this "maximally deferred read" idea: "tf.Variablesafety." -
No additional computation result is stored in lieu of
__init__arguments except to convert them totf.Tensors (if they aren't already). For reasons made clear below, we call this "non manipulation" idea: "tf.GradientTapesafety".
The above contract ensures several desirable features of
Distributions:-
Evaluations of mutable arguments (including assertions) are re-run any time the underlying values could possibly change. Example:
loc = tf.constant(0.) scale = tf.Variable(1.) d = tfp.distributions.Normal(loc, scale, validate_args=True) d.log_prob(0.) # ==> -0.918938 d.scale.assign(-1.) d.log_prob(0.) # ==> InvalidArgumentError: Argument `scale` must be positive. -
Gradients of public
Distributionmethods with respect to__init__arguments are valid regardless of theDistributionbeing created inside or outside thetf.GradientTape. Example:loc = tf.constant(0.) scale = tf.Variable(1.) d = tfp.distributions.Normal(loc, scale, validate_args=True) with tf.GradientTape() as tape: tape.watch(loc) # `tape.watch(scale)` is not required since `tf.GradientTape` # automatically watches `tf.Variable` dependencies (by default). x = -d.log_prob(1.) grad = tape.gradient(x, [loc, scale, d.loc, d.scale]) assert all([g is not None for g in grad])
Note that both of these properties would be lost if
__init__memoized any derived computation in lieu of the originalTensor-convertible arguments. -
-
Distributionand subclasses'@propertymethods shall never execute TF ops. For example, in graph execution regime this implies calling@propertywill never mutate the graph.
Non-Requirements (Noncomprehensive)
In this section we list items which have historically been presumed true of
tfp.distributions but are not official requirements.
-
Mutable state is not explicitly disallowed. However, it is highly discouraged as it makes reasoning about the object more challenging (for both API owner and user). As of 11-nov-2019, no
tfp.distributionsmember has its own mutable state although all distributions do mutate the global random number generate state on access tosample. -
Subclasses are free to override public base class members. I.e., you don't have to follow the "public calls private" pattern. (However, as of 11-nov-2019, there has not yet been a reason to deviate from this pattern.)