Hey there and welcome to this video. Today we will
implement a really cool article called "Growing neural cellular automata". It was published on
Distill and you can see it on the screen. First of all the article itself is absolutely amazing
both because of the presented methods but also because it contains interactive visualizations.
Before I start coding I will give you a brief summary of what the article is about. However,
I would definitely encourage you to pause the video and try to go through it yourself because
it is a pleasure to read. Also note that I won't cover all the topics so I don't want you to miss
out on anything. Ayway, the authors published their code as a collab notebook which you can
see here. It is written in TensorFlow and I used it heavily in my PyTorch implementation.
Additionally, I used two other resources that I will link in the description box below. The
first one being this PyTorch implementation that I found on GitHub and the second resource
was a YouTube tutorial from the author himself: Alexander Mordvinsev. I would definitely recommend
you to check these resources out for yourself and if you never come back to my video it's fine
I'm not going to take it personally. Anyway, all the credit goes to the authors and the before
mentioned resources. Let me now try to give you a very rough and high level explanation of what
this article is trying to do. Personally, I try to relate it to the Conway's game of life. As you
probably know, it takes place on a two-dimensional grid. Each of the grid cells can either be dead
or alive. First of all, we are supposed to design an initial state of the entire grid. Let's say
it will look like this. Second of all, there are very simple rules that determine what happens
to each of the grid cells in the next iteration. So for example if a given cell is dead and it
happens to have exactly three neighbors that are alive the cell itself will become alive. Note that
the next state of a given cell is fully determined by the cell's current state and the state of its
eight immediate neighbors. So in other words, the cell cannot really look any further to determine
its future state. We have the initial state and we have the rules and now we can just see what
happens. As you saw there was a lot of movement going on and at one point the pattern became
stable. Let us now try a different initial state. We have a completely new pattern and as you
would imagine you can just play around with this forever. Let me just stress that in this game
of life we are given the rules and we are supposed to choose the initial state. However, what if
someone fixed the initial state and we were supposed to design the rules. More specifically.
imagine somebody told us that the initial state looks like this. So we have only one
cell that is alive and it's exactly in the middle of our grid and now our goal
would be to write this explanation section but instead of using these rules we have the
freedom to choose anything we want. And of course you can play around with this and you can
come up with infinitely many rules, however, let's make this a little bit more interesting. Let's
say somebody gives us the target state of the grid that we want to achieve and he or she asks
us: Can you find the rules that given this simple initial state will get me to the final state?
(In a fixed number of iterations let's say). For example let's say I want the final grid
state to be this beautiful smiley and I'm asking you to find the rules that will get me
there let's say in 100 iterations starting from the simple initial state. And if I simplify that
is exactly what the article is about. It proposes a way how to use deep learning (namely the
convolution operation) to learn rules that will result in some specific grid configuration.
Anyway, this was just to give you a very intuitive explanation ,however, our setup is
going to be a little bit more complicated. So first of all, instead of having a grid of
booleans we'll be actually working with an image where each cell would equal one pixel and
therefore each of the cells can have multiple floats associated with it. And additionally
not only we will have the RGB channels to store cell information but we will
also have multiple other channels (let's say hidden states) that will encode all the
information about a given cell. First of all, let me explain how we can use the convolution
operation to update our grid and define rules. Here I define a two-dimensional grid and one
way you can think about this is just a simple grayscale image and instead of the values being
just True or False we can actually have any number between 0 and 1 as the value. Let me now
define a new tensor that will represent the rules. So these rules are nothing else than a filter we
will convolve the input grid with. Note that this rules filter is three times three and that means
that we are only allowed to look at the immediate neighbors and the cell itself which is exactly the
same constraint that we saw with the game of life. And what is this rule actually doing? Well, for
a given cell we just look at the neighbor that is right above and the neighbor that is right below
and we define the new value to be their average. To actually perform the convolution we will use
the following function. So this is the docstring. Feel free to read it and I am just going to
use this function to perform the convolution. So as you can see the two main inputs are the
initial grid and the rules. I played around with the dimensions of those tensors so that things
work out but don't worry about it. This is how the grid looks after one iteration and let's just
verify whether it did what was expected. So if we look at this cell for example we want it to be
an average of this number and this number. Yeah! And now nothing actually prevents us from just
repeating this procedure as many times as we want. Lastly, let me point out that the rule tensor is
something I created, however, what we actually want is to turn this tensor into a learnable
parameter and just learn it. We are now ready to jump into the article. What you see here is
one iteration of the pipeline and it is actually slightly more complicated than sliding a single
3x3 filter over our input image. First of all, our input is going to be an RGBA image where the
A stands for the alpha channel. Plus 12 additional channels to store any additional information. The
first step is to convolve the image with three different 3x3 filters. First of them is going
to be the identity filter that just results in copying the input image, the second and the third
filter are Sobel filters in the x and y directions respectively. The idea behind using the Sobel
filter is to approximate the gradient and thus to give ourselves some information on what the
intensities of the neighboring pixels are. The authors claim that this is actually inspired
by real biological cells and chemical gradients which I cannot really comment on, however,
from the machine learning point of view this is an interesting design choice because the
other and maybe more natural approach would be to learn these 3x3 filters from scratch. I
guess the main benefit of hard coding these filters is having fewer learnable parameters
and also that we introduce a very reasonable prior into the neural network. Anyway, in our
implementation we'll follow what the paper did and hardcode these filters. However, note that in
the other video that I mentioned at the beginning the author actually just learns the 3x3 filters
from scratch. After applying our three filters to the 16 channel image we end up with a 48 channel
image. We then apply two 1x1 convolutions which is nothing else than applying a linear model
for each pixel over all channels. I guess I will describe this in detail later The last
operation is called the stochastic update and it is more or less a pixel wise dropout. Finally,
we take this image and we add it to the input one and that is nothing else than a residual block.
Finally, we will check the alpha channel of the image. If it's below 0.1, we will just consider
that given cell dead and manually set the channels to 0 and this process is called a live
masking. What we saw before was a single iteration of applying the rule, however, we actually
want to take our input image and run it through the same pipeline multiple times. In this diagram
you can clearly see that once we take multiple steps we simply take our predicted image namely
the first four channels and the target image and compute the L2 loss and there you go.
We have our deep learning pipeline. Anyway, I guess that's it for the explanations
and in my opinion the implementation is pretty straightforward. So let's just get
started. First of all, we implement the model. First parameter determines the number of
channels of the input image. In the article this is actually equal to 16. Since we are going
to run the 1x1 convolution twice we can decide on any number of hidden channels. So the
fire rate determines how probable it is that a given cell is going to
go through the update process and finally we provide the device so that we
can easily switch between the CPU and the GPU. Internally, we will create this updated module
which is nothing else than two 1x1 convolutions. What's really important is that this will be
the only block of our pipeline that is going to have learnable parameters. Internally, we also
store this filter tensor that will represent the identity filter and the Sobel
filter in the x and y direction. As always you call the constructor of the parent. If the user doesn't provide a device
specifically we will default to a CPU. So first of all we need to prepare the
filters for the so-called perceive step. This step is nothing else than a 3x3 convolution. We define manually the so-called Sobel
filter and what it does is approximation of the gradient and again the idea behind it
is to tell our current cell what is happening around it and in what direction it would need
to go to maximize or minimize the intensity. Here we define an identity filter
and if we slide this filter over any image we will actually get exactly the same
image. So here we take the three filters that we define and we just stack them together
along the zero-th dimension. Our ultimate goal is to take these three filters and apply
them to each of the channels of the input image and therefore we will end up with a new image
that will have three times as many channels. So here we just repeated the filters over all
channels and we send them to the right device and finally we store them internally as an attribute
because we will use them in the forward pass. Let me just stress again one very important
thing. These filters are not learnable. We manually hardcoded them. Now we want to prepare the so-called update step this is the only
place where we will have trainable parameters. We use the sequential module to define three
consecutive steps. We apply the 1x1 convolution, then the ReLU activation and finally
again another 1x1 convolution. Let me quickly explain the relation between the
linear model and the 1x1 convolution. So we defined a couple of constants and here I
define a random tensor that represents a patch of different images. Let me now
instantiate two different torch modules. Here I created the 1x1 convolution layer and here
I created a simple linear model. First of all, let me just check the number of parameters each
of them has. So as you can see they have the same number of parameters. These parameters are
actually stored in the following two attributes. So as we can see the bias and the weight of the
linear and the convolution layer are more or less matching except for some extra dimensions
and I guess at this point you realize that what I want to say or what I want to show is that
these two modules are more or less doing the same thing. When we do 1x1 convolution it's
nothing else than iterating over all pixels and applying the same linear model across all
the channels. Let me just prove it to you. So what I did here was to make sure that the
bias and the weight of the convolutional layer is exactly the same as the weight and the
bias of the linear layer. Note that when we constructed them these parameters were just
initialized randomly. And now the idea is to run the forward pass with our random tensor
and see whether we would get the same result. Note that I actually had to permute the dimensions
of the input sensor in order for it to be usable with the linear module. However, then I actually
undid it after the forward pass. First of all let us check the shapes. They seem to be the same
and also these two tensors seem to be the same element wise. If you disregard tiny differences.
To summarize 1x1 convolution is nothing else than a linear model that is applied to
all pixels across the channels. We are back in the implementation. So since we're
using the 1x1 convolution we will be never looking at the neighbors and we are hoping that
by now all the information is already encoded in the channels and that is actually a reasonable
assumption because as we saw in the previous step we already included a lot of information
about the neighbors via the Sobel filters. To understand what I'm doing here let me just
remind you that our seed starting image is going to be a single bright pixel in the middle of
the image. All the other pixels or you can call them cells will be non-active and by adjusting the
weight and the bias of this second 1x1 convolution we're making sure it will actually take a couple
of iterations of this rule to populate the pixels that are further away from the center. I guess
the main motivation behind this is to make the training simpler and just make sure we don't
end up with some crazy complicated pattern just after the first iteration. Finally, we recursively
send all the parameters of this module to our desired device. All right so now we're
done with our constructor and we can write a couple of helper methods that will finally
be put together to create the forward pass. So here we implement the perceive step. Its goal
is to look at the surrounding pixels or cells and understand how the intensity changes. There are
no learnable parameters here. When it comes to the input and output shapes as you can see they're
the same except for the number of channels. We actually multiply the number of channels by
3 because we apply 3 filters to each of them. So we take the filters we prepared in
the constructor and we just perform a so-called depth-wise convolution and we
achieve this by setting groups equal to the number of input channels. Let
us now implement the update step. Again the update step is the only place where
we have trainable parameters and it's exactly those parameters inside of the two 1x1 convolution
layers and it's just a one-liner because we prepared everything in the constructor.
Next step to implement is the stochastic update. The stochastic update is nothing else
than a pixelwise dropout, however, note that we're not actually scaling the remaining values by
any scalar. Let me just point out that this step as well as the others has a biological rationale.
We don't want all the cells to be updated with each iteration which would kind of imply that
there's this global clock and with each iteration everybody updates. What we want is for this
process to be more or less random. Let's say I focus on a given cell. I want it to update only
80% of the time independently of its neighbors. First of all we create a boolean
mask for each pixel and then we just element wise multiply the original
tensor with the mask. Now that this mask is going to be broadcasted over all
the channels it cannot happen that some channels of a given pixels are active
and the remaining ones are inactive. So this utility function will actually
take the alpha channel of our image which will be the fourth one and it will use it
to determine whether a given cell is alive or not. And the criterion here is
that if the cell itself or any cell in the neighborhood has an alpha
channel higher than 0.1 this cell will be considered as alive. And now we have all
we need to implement the forward method. Let me just remind you that calling the forward
method once in our case will mean nothing else than one iteration of the rule. What we will
actually do while training is to call the forward method multiple times to simulate multiple
iterations. First of all we will create a pre-live mask which will be a tensor of booleans. We
take our input tensor and run the perceived step which applies the identity and the two Sobel
filters. Then we run the update step that contains learnable parameters. We run the stochastic update
and the goal of it is to make sure that some cells don't get updated during this forward pass and
thus making it more biologically plausible. Here we actually use a residual block and it's
really important because the new image is nothing else than the previous image plus some delta image
and I guess here one can make the same argument as with the ResNet. We will run this forward
method multiple times and one way to think about this is that you're just creating a very deep
architecture. We compute the post life mask. The final life mask is going to be an element-wise
and operator between the pre-life mask and the post-life mask. That is it for the forward pass!
Right now we want to write a training script. Here we load an RGBA image and we pre-multiply
the RGB channels with the alpha channel and finally we turn it into a torch tensor and make
sure that the channels are actually the second dimension. We take an RGBA image and we turn it into an RGB
image. Note that we use the torch clamp to make sure we are not falling outside of the range
0, 1 and we want the background to be white. Here we create our initial grid state. It is
nothing else than a blank image. What we will do is to take its center pixel and we will
set all the channels except for RGB equal to one. Right now we would like to create a
command line interface because there are multiple parameters that one can play around with. I'm going to explain some of these arguments
when we actually use them in the code. We parse the arguments and we just
print them out to see the configuration. We instantiate the device based on the CLI
option. Here we prepare the TensorBoard writer. Here we load the target image and we pad it on all
four of the borders. We do this to kind of prevent overfitting since we don't want the network to
rely on the fact that there are borders nearby. Finally, we just take the same image and repeat
it. That is because we want to do batch training. We also add this target image to TensorBoard We instantiate the model that we
wrote and we also create an optimizer. Here I need to provide more explanation. Instead
of always starting from the seed image and then trying to get to the target image we will create
a pool of images that should ideally contain all in-between states together with the target one
and also the seed image. The main idea of this pool is to make sure that once we reach the final
pattern more iterations are not going to degrade this pattern. You will see how the pool is being
updated in a couple of lines. Now we're trying to write our training loop. Most importantly
we will take number of batches gradient steps, we will randomly select a couple of samples from
our pool and that way we'll create a batch. This part is really important because we will take our
batch and we will just run it through our forward pass multiple times and the number of iterations
is actually not going to be deterministic. It's going to be just randomly sampled from the
interval 64 to 96. We are hoping that around 70 iterations should be enough to go from the
seed image all the way to the target image. Here we compute per sample mean squared error.
Note that we are only extracting the first four channels out of our predicted image
and that is because the target image itself only contains the RGBA channels. We compute
an average loss over all samples in the batch and then we just pick the gradient
step and log the loss with TensorBoard. Here we're trying to update our pool. First of
all, we find a sample for which the loss was the highest and we make an assumption that this
sample was terrible and that we actually do not want to keep it in that pool. So what we do is
that we just replace this bad sample with the initial seed sample. When it comes to the other
samples in a batch we actually throw them back into the pool but what's important
it's the updated version of them not the initial one. This way we are hoping to
create a pool that will contain all kinds of different images that represent different
stages of the development of our final pattern. This is just a logging block that will create a
video for TensorBoard and the idea is that each frame will represent different iteration. However,
we would like to run it for way more iterations. In our case it will be 300 and the number
of iterations we trained it for which was in the range of 60 and 90. This way we'll be
able to assess whether the pattern once it reaches its final form the target image stays
stable. Right and that's it. Now we just need to train it. First of all, let us verify whether
the CLI was created correctly. It seems to be the case. Note that to train we need to have a
target image which I have here. What's important is that it's an RGBA image so that it has the
alpha channel. So now i will just launch the training. In order to get decent results in
a matter of minutes one needs to use a GPU. I'll just let it train and once it's done I'll
just show you the results in the TensorBoard. First of all, you can see that the
loss is consistently going down. When it comes to the videos we can see that
after 500 gradient steps the rule is kind of able to reproduce the general shape of the
rabbit's face. However, it is far from perfect. Additionally, it seems like artifacts appear after
a certain number of steps and therefore it is not stable. If we look at the rule towards the end of
the training we can see that it is pretty good and stable. Let me just point out that I did not cover
regeneration. One can actually perturb the image during the training and that way we can force the
rule to be able to deal with degenerate images. Also I did not cover the rotating of the Sobel
filters. Once the model is trained what happens is that the actual image rotates too which is
really impressive. Anyway, that's it for the video. All the credit goes to the authors. I hope
I managed to interpret their research correctly. Additionally I made a lot of modifications and
simplifications in the code and I hope that i did not introduce too many mistakes. Anyway, I
hope you managed to learn new things and that you found this video interesting. I will continue
creating similar content in the future so do not hesitate to subscribe. I wish you a nice
rest of the day and I will see you next time
Love this, Ive been obsessed with CA for months now lol. Thanks for posting!