This writeup introduces what I'm calling Bonsai Networks - extremely sparse computational graphs produced by training and pruning RNNs. They provide an interpretable view into the solutions learned by networks for simple logic problems, breaking out of the black box neural networks typically reside in.
I give an overview of the process I use to create these networks which includes several custom neural network components and a training pipeline implemented from scratch in Tinygrad. I also include many interactive visualizations of the generated graphs and use them to reverse engineer some interesting solutions they learned for a variety of logic problems.
If you aren't that interested in the process I used to develop and train these networks, feel free to skip down to the Results section to check out the reverse engineering part.
My original inspiration for this project came from a blog post called Differentiable Finite State Machines. It lays out a strategy for directly learning state machines to model binary sequences. The FSMs are learned using gradient descent in a similar way to how neural networks are trained, hence the "differentiable" part of the title.
I have a big interest in ML interpretability, and I've done some work on reverse engineering small neural networks in the past. The prospect of taking the big blob of floating point numbers + matrix multiplication and turning it into something that can be reasoned about and manipulated was very exciting to me.
This got me thinking about how I could expand the ideas from that post into something that could handle slightly more involved problems. At some point, I had the idea of trying to extend the ideas from the FSM post to full-fledged RNNs. That got me started on the path that eventually led to this project and writeup.
Here's the headline:
Since RNNs (and all other neural networks) are different kinds of computational graphs, it's also true that:
That sparsity is their first key feature. Most bonsai networks I trained ended up pruning out >90% of their weights. And once you get networks that sparse and convert them into graphs, some unique things become possible.
By doing that, the sparsity goes up even further and the most minimal graphs can be created.
At a high level, the whole process of growing them looks like this:
There's a good bit of nuance that's required for each of these points, so let's dive into that now.
The design goals of the RNNs I trained for the bonsai networks are significantly different from those of RNNs used for conventional purposes.
In order to get good results with creating these bonsai networks, I needed to design some custom components and architectures for their training.
The usual activation functions used when training neural networks - functions like sigmoid, ReLU, tanh, etc. - are usually quite simple. For large, deep networks like the vast majority of those trained today, they're a great choice. Neural networks excel at composing complex solutions out of these simple functions, building up more and more sophisticated internal representations over dozens of layers.
There's no rule saying that the activation function has to be simple, though. Technically, any differentiable non-linear function will work. For my purposes here, I wanted to pack as much computational power into the function as possible without making it impossible to train or numerically unstable.
I ended up using a custom activation function which looks like this:
I developed this activation function around a year ago as part of some earlier research into using neural networks to perform boolean logic. It has some unique properties which make it very well-suited for this particular use-case:
a && b, !a && (b || !c), a ? b : c, etc.) in a single neuron as well.All of these things contribute to making it very well-suited for modeling the kinds of binary logic programs I was targeting. However, there are some downsides as well:
These certainly created some headaches to deal with, but they were well worth it for the power that this function unlocks for the networks.
Another key modification I made was to the base architecture of the RNNs themselves.
I started out with a vanilla RNN modeled after TensorFlow's SimpleRNNCell. This is about as simple as it gets for a RNN, and I didn't use any of the fancy features like dropout, LSTM/GRU, etc.
Here's the architecture that TensorFlow uses for it:
Like all RNNs, each cell has its own internal state which is fed back for the next timestep of the sequence. In this vanilla architecture, the value passed on as output from the cell is the same as new state.
One important thing to note is how the inputs and current state are combined by adding them together before being passed to the activation function. In the diagram above, that's the + operator just before the activation function σ.
To solve this, I implemented a custom RNN cell using a modified architecture:
It's pretty similar to the vanilla architecture but with a few key differences.
This is important for multiple reasons:
+ operator that combines the outputs from both dot products, all the nodes can instead implement the same operation.This last point is important for keeping the converted graphs at the end of the process as simple to understand as possible. Every node can be thought of as a single neuron implementing this identical sum weighted inputs -> add bias -> activation operation, making the flow of data and overall operation much easier to follow.
One of the most important goals when training these RNNs is ensuring that they are maximally sparse - meaning that as many of their weights are almost exactly zero as possible.
A commonly used technique for encouraging sparsity while training neural networks is the use of regularizers during the training process.
Regularizers can be thought of as secondary training objectives. They are functions that take a tensor (weights, biases, or any other trainable parameter) and return a cost value. That regularization cost is then added on to the base loss value computed for the batch:
where represents the regularization function and is the regularization coefficient which controls the intensity of the regularization.