Neural ODE
Last updated
Last updated
With motivations from deep learning, neuroscience and mathematics, neural ODE is an attempt to replace layers of neural networks with a continuous-depth model enabled by ODE solvers. The solver parameters are updated using a second ODE solver along with the adjoint method, making it more efficient for both space and time. The model also has a controllable time cost during test and has potential applications in continuous time-series models.
Definition (ODE): Equations taking the form of
In ODE, we are usually interested in solving an initial value problem:
Given
Yes!
Real brains are continuous time systems, hence early neural network research takes continuous time systems as a starting point.
Artificial Neural Networks today model brain with discretization.
We can modify ResNet to make it more ODE alike with a continuous time system!
We want to optimize
Direct BP is slow, with extra numerical error and memory cost.
The authors compute gradients using the adjoint method, which directly approximates the gradient, rather than differentiating the approximation. Proof can be found here.
This approach computes gradients by solving a second, augmented ODE backwards in time, and is applicable to all ODE solvers.
The key results from the proof are as follows:
When the loss depends on intermediate states, the reverse-mode derivative must be broken into a sequence of separate solves.
The authors compared the performance of a small residual network which downsamples the input twice then applies 6 residual blocks, which are replaced by an ODESolve module in the ODE-Net variant. They also tested a network which directly BP through a Runge-Kutta integrator (a numerical method for iteratively solving an nonlinear ODE), referred to as RK-Net.
One advantage of using ODE solvers is that many of them approximately ensure that the output is within a given tolerance of the true solution.
The authors verify that error can indeed be controlled in figure 3a. They also show that tunning the tolerance gives us a trade-off between accuracy and computational cost with figure 3b.
Figure 3c shows that the number of evaluations in the backward time is roughly half of the forward pass. This suggests that the adjoint sensitivity method is not only more memory efficient, but also more computationally efficient than direct BP.
Note that it is not clear how to define the "depth" of an ODE solution. A related quantity is the number of evaluations of the hidden state dynamics required, a detail delegated to the ODE solver and dependent on the initial state or input. In figure 3d, the authors show that the number of function evaluations increases throughout training, presumably adapting to increasing complexity of the model. Note that Duvenaud mentioned that their model in the end can be 2-4x slower than ResNet.
The discretized equation
Surprisingly, moving from a discrete set of layers to a continuous transformation simplifies the computation of the change in normalizing constant.
The derivative of the determinant can be expressed using Jacobi's formula, which gives
The proof has been finished now.
Benefits of differentiable ODE solvers include: 1. Memory efficiency: no intermediate variabels stored for chain rule, hence constant memory cost 2. Adaptive computation: the choice of ODE solver is orthogonal, different ODE solvers can be used for a same model. There are theories for ODE solvers with more than 120 years of development. Modern ODE solvers allow a controllable trade off between time cost and precission during test time. 3. Parameter sharing across layers 4. Easier computation for changing variables 5. Continuous time-series models
Given a hidden layer that maps to , a ResNet block enforces The motivation is that if you can not do better with more layers, just make .
ODEs can be helpful for describing a continuous time system. Consider
we want to know
When it is difficult or impossible to solve , we can numerically approximate it at points of interest. This is what people working in numerical differential equations/computational mathematics are doing.
The most basic method is Euler's method:
Comparing with ResNet: , we can view each Residual layer as performing , the final output is just .
But it seems that can take any real value, then can we have infinite number of layers?
We can replace resnet with an ODE solver, which approximates .
by changing its parameters .
Let
In the figure above, denotes the number of layers in the ResNet and denotes the number of function evaluations that the ODE solver requests in a single forward pass, which can be interpreted as an implicit number of layers.
For reference, a neural net with a single hidden layer of units has around the same number of parameters as the ODE-Net and RK-Net architecture.
also appears in normalizing flows and the NICE framework. These methods use the change of variables theorem to compute exact changes in probability if samples are transformed through a bijective function :
Generally, the main bottleneck to using the change of variables formula is computing of the determinant of the Jacobian , which has a cubic cost in either the dimension of , or the number of hidden units.
Theorem (Instantaneous Change of Variables). Let be a finite continuous random variable with probability dependent on time. Let be a differential equation describing a continuous-in-time transformation of . Assuming that is uniformly Lipschitz continuous ( an universal Lipschitz constant) in and continuous in , then the change in log probability also follows a differential equation,
Proof. Let . We assume that is Lipschitz continuous in and continuous in , so every initial value problem has a unique solution by Picard's existence theorem. We also assume is bounded. These conditions imply that and are all bounded.
By expanding the Taylor series around , the equation above equals to
Trace is a much cheaper operation, which has a square cost in either the dimension of , or the number of hidden units.
The authors present a continuous-time generative approach to modeling time series. Their model represents each time series by a latent trajectory. Each trajectory is determined from a local initial state , and a global set of latent dynamics shared across all time series. Given observation times and an initial state , an ODE solver produces , which describe the latent state at each observation. They define this generative model fromally through a sampling procedure:
Function is a time-invariant function that takes the value at the current time step and outputs the gradient: . The function is parametrized using an NN. Because is time-invariant, given any latent space , the entire trajectory is uniquely defined. Extrapolating this latent trajectory lets us make predictions arbitrarily far forwards or backwards in time.