With motivations from deep learning, neuroscience and mathematics, neural ODE is an attempt to replace layers of neural networks with a continuous-depth model enabled by ODE solvers. The solver parameters are updated using a second ODE solver along with the adjoint method, making it more efficient for both space and time. The model also has a controllable time cost during test and has potential applications in continuous time-series models.
Motivations
Deep Learning Motivation: ResNet
Given a hidden layer that maps x to y, a ResNet block enforces y=F(x)+x. The motivation is that if you can not do better with more layers, just make F(x)=0.
ODEs can be helpful for describing a continuous time system. Consider z:t↦z(t),R→Rn
G(t,z(t),z′(t))=0⟺z′(t)=f(z(t),t)
In ODE, we are usually interested in solving an initial value problem:
Given
{z(0)z′(t)=f(z(t),t)
we want to know z(t)=z(0)+∫0tz′(t).
When it is difficult or impossible to solve z(t), we can numerically approximate it at points of interest. This is what people working in numerical differential equations/computational mathematics are doing.
The most basic method is Euler's method: z(t0+h)≈z(t0)+hz′(t0)
Comparing with ResNet: y=x+F(x), we can view each Residual layer as performing z(x)↦z(x+1), the final output is just z(T).
But it seems that h can take any real value, then can we have infinite number of layers?
Yes!
Neuroscience Motivation
Real brains are continuous time systems, hence early neural network research takes continuous time systems as a starting point.
Artificial Neural Networks today model brain with discretization.
Proposal
We can modify ResNet to make it more ODE alike with a continuous time system!
deff(z,t,theta):returnnnet(z, theta[t])defresnet(z,theta):for t in [1:T]: z = z +f(z, t, theta)return z
⇓
deff(z,t,theta):returnnnet([z, t], theta)defresnet(z,theta):for t in [1:T]: z = z +f(z, t, theta)return z
We can replace resnet with an ODE solver, which approximates ∫t0t1f(z(t),t,θ)dt.
BP through ODE Solvers
We want to optimize
L(ODESolve(z(t0),f,t0,t1,θ))
by changing its parameters z(t0),t0,t1,θ.
Direct BP is slow, with extra numerical error and memory cost.
The authors compute gradients using the adjoint method, which directly approximates the gradient, rather than differentiating the approximation. Proof can be found here.
This approach computes gradients by solving a second, augmented ODE backwards in time, and is applicable to all ODE solvers.
The key results from the proof are as follows:
Let a(t)=∂z(t)∂L
dtda(t)=−a(t)T∂z∂f(z(t),t,θ)
dθdL=∫t1t0a(t)T∂θ∂f(z(t),t,θ)dt
When the loss depends on intermediate states, the reverse-mode derivative must be broken into a sequence of separate solves.
Experiments in Supervised Learning
The authors compared the performance of a small residual network which downsamples the input twice then applies 6 residual blocks, which are replaced by an ODESolve module in the ODE-Net variant. They also tested a network which directly BP through a Runge-Kutta integrator (a numerical method for iteratively solving an nonlinear ODE), referred to as RK-Net.
In the figure above, L denotes the number of layers in the ResNet and L~ denotes the number of function evaluations that the ODE solver requests in a single forward pass, which can be interpreted as an implicit number of layers.
For reference, a neural net with a single hidden layer of 300 units has around the same number of parameters as the ODE-Net and RK-Net architecture.
One advantage of using ODE solvers is that many of them approximately ensure that the output is within a given tolerance of the true solution.
The authors verify that error can indeed be controlled in figure 3a. They also show that tunning the tolerance gives us a trade-off between accuracy and computational cost with figure 3b.
Figure 3c shows that the number of evaluations in the backward time is roughly half of the forward pass. This suggests that the adjoint sensitivity method is not only more memory efficient, but also more computationally efficient than direct BP.
Note that it is not clear how to define the "depth" of an ODE solution. A related quantity is the number of evaluations of the hidden state dynamics required, a detail delegated to the ODE solver and dependent on the initial state or input. In figure 3d, the authors show that the number of function evaluations increases throughout training, presumably adapting to increasing complexity of the model. Note that Duvenaud mentioned that their model in the end can be 2-4x slower than ResNet.
Continuous Normalizing Flows
The discretized equation
ht+1=ht+f(ht,θt)
also appears in normalizing flows and the NICE framework. These methods use the change of variables theorem to compute exact changes in probability if samples are transformed through a bijective function f:
z1=f(z0)⟹logp(z1)=logp(z0)−logdet∂z0∂f
Generally, the main bottleneck to using the change of variables formula is computing of the determinant of the Jacobian ∂z∂f, which has a cubic cost in either the dimension of z, or the number of hidden units.
Surprisingly, moving from a discrete set of layers to a continuous transformation simplifies the computation of the change in normalizing constant.
Theorem (Instantaneous Change of Variables). Let z(t) be a finite continuous random variable with probability p(z(t)) dependent on time. Let dtdz=f(z(t),t) be a differential equation describing a continuous-in-time transformation of z(t). Assuming that f is uniformly Lipschitz continuous (∃ an universal Lipschitz constant) in z and continuous in t, then the change in log probability also follows a differential equation,
∂t∂logp(z(t))=−tr(dz(t)df).
Proof. Let z(t+ϵ)=Tϵ(z(t)). We assume that f is Lipschitz continuous in z(t) and continuous in t, so every initial value problem has a unique solution by Picard's existence theorem. We also assume z(t) is bounded. These conditions imply that f,Tϵ and ∂z∂Tϵ are all bounded.
By expanding the Taylor series around z(t), the equation above equals to
−tr(∂z∂f(z(t),t))
The proof has been finished now.
Trace is a much cheaper operation, which has a square cost in either the dimension of z, or the number of hidden units.
Generative latent function time-series model
The authors present a continuous-time generative approach to modeling time series. Their model represents each time series by a latent trajectory. Each trajectory is determined from a local initial state zt0, and a global set of latent dynamics shared across all time series. Given observation times t0,t1,⋯,tN and an initial state zt0, an ODE solver produces zt1,⋯,ztN, which describe the latent state at each observation. They define this generative model fromally through a sampling procedure:
Function f is a time-invariant function that takes the value z at the current time step and outputs the gradient: ∂t∂z(t)=f(z(t),θf). The function is parametrized using an NN. Because f is time-invariant, given any latent space z(t), the entire trajectory is uniquely defined. Extrapolating this latent trajectory lets us make predictions arbitrarily far forwards or backwards in time.
Benefits
Benefits of differentiable ODE solvers include: 1. Memory efficiency: no intermediate variabels stored for chain rule, hence constant memory cost 2. Adaptive computation: the choice of ODE solver is orthogonal, different ODE solvers can be used for a same model. There are theories for ODE solvers with more than 120 years of development. Modern ODE solvers allow a controllable trade off between time cost and precission during test time. 3. Parameter sharing across layers 4. Easier computation for changing variables 5. Continuous time-series models