Differentiable Programming


If you haven’t already, read this very thought-provoking article on colah’s blog about the connection between functional programming and neural networks. I would bet that we’re going to see ideas like this cropping up more and more. I intend to write about similar ideas too. The basic idea is that you can express many common kinds of neural networks in a very simple way as functional programs. In this blog post I’m going to talk about ways to actually turn that idea into reality. The article talks about various kinds of (now classical) neural networks and how they fit in to that picture. To set the tone of this article, let’s first refresh our memory on some kinds of very common neural nets:

  • Multi-Layer Perceptrons (MLPs). Possibly the simplest useful neural net model. MLPs are just sequential stacks of densely connected layers.
  • Convolutional Neural Nets (CNNs). Used widely for image recognition. CNNs are, like MLPs, sequential stacks of layers, but in this case they are convolutional layers, and they often also include max-pooling layers as well.
  • Recurrent Neural Nets (RNNs). Used for time series, audio, and text processing (anything with variable-length sequences of data). They work by feeding the output of the net back into its input. It is possible to simply feed an MLP output back into the input to create a simple RNN, but in practice the most popular variant uses LSTM cells.

These types of neural nets are very easy to lay out as a kind of graph structure to see the relationships between layers. And, in the old days, they used to have a fairly small number of layers. The method used to fit or train the nets was (and continues to be) gradient descent. The motivation behind choosing gradient descent was that we knew that as long as the models were simple, gradient descent was an effective learning strategy. Also, gradient descent makes use of the gradient of the function, which has the advantage that it provides faster convergence than algorithms that don’t use the gradient, but it has the disadvantage that you have to be able to calculate the gradient. With MLPs, though, there is a simple algorithm (backpropagation) that lets you calculate the gradient efficiently.

About ten years ago, people started experimenting with very deep (many layers) graph structures. This includes nets like VGG-derived nets. With Google’s inception model and later models, we’re starting to see the emergence of networks with very complicated, non-sequential structures. Some of the models created in the past few years have such complicated network structures now that they make even the inception model seem simple and straightforward. The major realization/insight that has emerged from all of this work is that gradient descent still works, even when you scale up to thousands of layers. Of course, you have to be somewhat clever about how you do gradient descent, and you need to design your model in a way that avoids certain pitfalls, but if you follow some simple rules then you’re golden. This has sparked a lot of speculation and research into just how far you can push gradient descent and what the limits are. Especially since we now have automatic systems that can compute gradients from models without the researcher having to painstakingly compute the gradients by hand and hard-code them into the model.

In recent years another new class of models has emerged: Neural Turing machines (NTMs). These models attempt to mimic the fundamental functions of computers – central processor units, memory units, buses, etc. – entirely using neural functions. The important and significant thing about this is that you get general-purpose computation systems that are entirely differentiable. That is, you can input programs and they execute them, and you can take the derivative of the machine’s parameters with respect to your outputs. This means you can optimize the entire machine using gradient descent to obtain some desired output.

Some of these NTM-like architectures include more recent models like memory networks, memory networks with a soft attention mechanism, pointer networks, stack-augmented RNN, stack/queue-augmented LSTM nets, NTMs trained with reinforcement learning, and NPIs, as I discussed in my NPI post. People like Ilya Sutskever and Jurgen Schmidhuber have talked about how these sorts of methods are the future.

So now that we have neural architectures that act like general-purpose computers, we have almost come full circle. The blog article I linked at the beginning of this post offers a tentative way of expressing neural nets using functional language, and NTM-derived models show certain classes of program-like behaviours being carried out by custom neural models. It seems that what would be needed to close the loop would be a language where you could describe arbitrary program-like behaviors using a functional language, which would then be automatically translated or compiled into a neural system (perhaps an NTM or NPI, as appropriate), allowing learning some of the the parameters or functions of a program entirely from data. This sort of programming language, were someone to implement it, could be called a Differentiable Programming Language (DPL). The special case of this where all primitives (except for very simple ones like addition) are neural, that is, have the following form:

y = f(x) = σ(Wx + b)

could be called Neural Programming Languages (NPLs).

One could argue that TensorFlow is already a low-level DPL. It has many of the language constructs that we expect in traditional programming but don’t usually expect from neural nets, such as while loops, conditional statements, and data structures like queues. And all of this is completely differentiable! Yet I would argue that TensorFlow is not the ideal DPL we desire, because implementing NTMs, NPIs, and other models in TensorFlow doesn’t seem ‘natural’. Also, TensorFlow lacks an innate memory mechanism, which all of those architectures require. In addition, there is no type system in TensorFlow.

Instead, we would like to use a more high-level approach to designing a DPL, and let the ‘compiler’ take care of the details for us.

High-Level Differentiable Programming Languages

Let’s give an example to make the idea more concrete. Let’s say our program is, in pseudocode:

function h(x)
  return f(g(x))

Then this would be compiled to:

h(x) = σ(W1(σ(W2x + b1)) + b2)

That is, a standard two-layer MLP. Simple enough. What about a more complicated example? Consider a conditional:

function f(x, a)
  if x > 1.0
    return a + 1
    return a

It is possible to translate this to the following neural structure:

+(x, y) = σ(Wx + Wy + b)

f(x, a) = if(x, 1.0, +(a, 1.0), a)

Where we have used recursive neural nets aka TreeNets, and we have used a differentiable if function.

Ok, now let’s make things a bit more formal and actually define a usable language. We have to make the distinction between functions defined in a neural way (that is, those functions we want to find or optimize) and functions given as combinations of other functions. This we will do in the next section.

Lambda Calculus

It’s good to keep the language simple at first, so that we only need to implement the most basic functions required to get us up and going. Lisp is a language that is highly minimalistic in every way: Its interpreter is simple, its evaluation semantics are simple, and its ‘core’ set of functions are simple. You can describe the entire implementation of a Lisp interpreter/evaluatior in just half a page of code (page 13 of this book). Other functional languages like Haskell and so on are also very minimalistic and simple. Here, I’m going to use a subset of Lisp that doesn’t include quoted expressions. This actually greatly reduces the power of the language and it’s probably not a true Lisp, but for our purposes it suffices. We’ll also include a rudimentary type system.

Our language constructs are:

  • A set of primitive functions, which carry out the basic neural σ(Wx + b) building block, with W and b learned from the data. Each primitive function has a simple type: f:T->S. Functions have to be applied to correct types.
  • A set of composite functions, which we specify by composing together primitive functions. For example, we could define a two-layer MLP as mlp(x) = f(g(x)), assuming f and g are primitive functions.
  • And a set of higher-level functions, such as map and fold, which take other functions as input, and perform operations like mapping primitive functions across the data, and so on.

Together, a collection of these functions defines our programs. A simple proof-of-concept compiler can be written to take a program to a TreeNet, but more sophisticated compilers that figured out and utilized structure would be more efficient. I suspect that someone, somewhere, is probably working on something like this right now! In fact, people have already started to carry out research on differentiable languages based not on lambda calculus per se but on stack machines, for instance a Differentiable Forth interpreter.

How is memory handled in a functional setting? We maintain persistent memory by threading some data structure through our functions. In the most pure functional languages, this is expressed in a simple way via monads. I’m going to caution here that I haven’t worked out all the details (maybe I’ll flesh it out in later posts) but the experience with pure functional languages shows that a very simple paradigm of merely composing functions can encapsulate very complex ideas, and describe many otherwise complicated programs in a simple way. That is the goal of producing a DPL.

Statistical Models

There’s also been a parallel development in the field of statistical modeling, called probabilistic programming. The story of statistical modeling has been one of continuous unification among disparate ideas. Classical statistical models, like mixture models and regression, which were first proposed in the 19th century, became unified as Bayesian networks (BNs) in the 1980’s with the work of Judea Pearl, and these models were then unified with Markov Random Fields (MRFs) as the general theory of Probabilistic Graphical Models (PGMs) that started in the 1990’s and then was further expanded during the 2000’s. Each of these represents a more and more general class of models. Bayesian nets are those models where the dependencies between variables can be expressed as a directed acyclic graph (DAG). In this respect the analogous neural models would be non-recurrent neural networks. MRFs are models where the dependencies also take on a fixed graph structure, but the graph is not required to be a DAG. The latest iteration of unification has been probabilistic programming, where even the graph structure itself does not need to be fixed and can vary depending on the data. Probabilistic programs are written pretty much as ordinary computer programs, and can be ‘run’, generatively taking a set of input parameters to some output observations. The important distinction between probabilistic programs and classical programs, though, is that probabilistic programs can also be run in reverse in a way; taking output observations to (distributions over) input parameters. So they can be used for both generative purposes and for inference and model fitting.

Some of the better-developed probabilistic programming languages (PPLs) include STAN, JAGS, and Venture. Church is a very simple minimalistic PPL with Lisp-like syntax. This page has a substantial list of PPLs. One characteristic that most PPLs share is that they separate models from optimization algorithms. This is a very important key development and one that will also be needed for a practical DPL or NPL.





2 thoughts on “Differentiable Programming

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s