Growing neural cellular automata in PyTorch

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments

Love this, Ive been obsessed with CA for months now lol. Thanks for posting!

👍︎︎ 1 👤︎︎ u/Lifeisagarden_Digit 📅︎︎ Apr 10 2021 🗫︎ replies
Captions
Hey there and welcome to this video. Today we will  implement a really cool article called "Growing   neural cellular automata". It was published on  Distill and you can see it on the screen. First   of all the article itself is absolutely amazing  both because of the presented methods but also   because it contains interactive visualizations.  Before I start coding I will give you a brief   summary of what the article is about. However,  I would definitely encourage you to pause the   video and try to go through it yourself because  it is a pleasure to read. Also note that I won't   cover all the topics so I don't want you to miss  out on anything. Ayway, the authors published   their code as a collab notebook which you can  see here. It is written in TensorFlow and I   used it heavily in my PyTorch implementation.  Additionally, I used two other resources that   I will link in the description box below. The  first one being this PyTorch implementation   that I found on GitHub and the second resource  was a YouTube tutorial from the author himself:   Alexander Mordvinsev. I would definitely recommend  you to check these resources out for yourself and   if you never come back to my video it's fine  I'm not going to take it personally. Anyway,   all the credit goes to the authors and the before  mentioned resources. Let me now try to give you   a very rough and high level explanation of what  this article is trying to do. Personally, I try   to relate it to the Conway's game of life. As you  probably know, it takes place on a two-dimensional   grid. Each of the grid cells can either be dead  or alive. First of all, we are supposed to design   an initial state of the entire grid. Let's say  it will look like this. Second of all, there   are very simple rules that determine what happens  to each of the grid cells in the next iteration.   So for example if a given cell is dead and it  happens to have exactly three neighbors that are   alive the cell itself will become alive. Note that  the next state of a given cell is fully determined   by the cell's current state and the state of its  eight immediate neighbors. So in other words, the   cell cannot really look any further to determine  its future state. We have the initial state and   we have the rules and now we can just see what  happens. As you saw there was a lot of movement   going on and at one point the pattern became  stable. Let us now try a different initial state. We have a completely new pattern and as you  would imagine you can just play around with   this forever. Let me just stress that in this game  of life we are given the rules and we are supposed   to choose the initial state. However, what if  someone fixed the initial state and we were   supposed to design the rules. More specifically.  imagine somebody told us that the initial state   looks like this. So we have only one  cell that is alive and it's exactly   in the middle of our grid and now our goal  would be to write this explanation section   but instead of using these rules we have the  freedom to choose anything we want. And of   course you can play around with this and you can  come up with infinitely many rules, however, let's   make this a little bit more interesting. Let's  say somebody gives us the target state of the   grid that we want to achieve and he or she asks  us: Can you find the rules that given this simple   initial state will get me to the final state?  (In a fixed number of iterations let's say). For example let's say I want the final grid  state to be this beautiful smiley and I'm   asking you to find the rules that will get me  there let's say in 100 iterations starting from   the simple initial state. And if I simplify that  is exactly what the article is about. It proposes   a way how to use deep learning (namely the  convolution operation) to learn rules that will   result in some specific grid configuration.  Anyway, this was just to give you a very   intuitive explanation ,however, our setup is  going to be a little bit more complicated.   So first of all, instead of having a grid of  booleans we'll be actually working with an   image where each cell would equal one pixel and  therefore each of the cells can have multiple   floats associated with it. And additionally  not only we will have the RGB channels   to store cell information but we will  also have multiple other channels (let's   say hidden states) that will encode all the  information about a given cell. First of all,   let me explain how we can use the convolution  operation to update our grid and define rules. Here I define a two-dimensional grid and one  way you can think about this is just a simple   grayscale image and instead of the values being  just True or False we can actually have any   number between 0 and 1 as the value. Let me now  define a new tensor that will represent the rules. So these rules are nothing else than a filter we  will convolve the input grid with. Note that this   rules filter is three times three and that means  that we are only allowed to look at the immediate   neighbors and the cell itself which is exactly the  same constraint that we saw with the game of life.   And what is this rule actually doing? Well, for  a given cell we just look at the neighbor that is   right above and the neighbor that is right below  and we define the new value to be their average.   To actually perform the convolution we will use  the following function. So this is the docstring.   Feel free to read it and I am just going to  use this function to perform the convolution. So as you can see the two main inputs are the  initial grid and the rules. I played around with   the dimensions of those tensors so that things  work out but don't worry about it. This is how   the grid looks after one iteration and let's just  verify whether it did what was expected. So if we   look at this cell for example we want it to be  an average of this number and this number. Yeah!   And now nothing actually prevents us from just  repeating this procedure as many times as we want.   Lastly, let me point out that the rule tensor is  something I created, however, what we actually   want is to turn this tensor into a learnable  parameter and just learn it. We are now ready   to jump into the article. What you see here is  one iteration of the pipeline and it is actually   slightly more complicated than sliding a single  3x3 filter over our input image. First of all,   our input is going to be an RGBA image where the  A stands for the alpha channel. Plus 12 additional   channels to store any additional information. The  first step is to convolve the image with three   different 3x3 filters. First of them is going  to be the identity filter that just results in   copying the input image, the second and the third  filter are Sobel filters in the x and y directions   respectively. The idea behind using the Sobel  filter is to approximate the gradient and thus   to give ourselves some information on what the  intensities of the neighboring pixels are. The   authors claim that this is actually inspired  by real biological cells and chemical gradients   which I cannot really comment on, however,  from the machine learning point of view   this is an interesting design choice because the  other and maybe more natural approach would be   to learn these 3x3 filters from scratch. I  guess the main benefit of hard coding these   filters is having fewer learnable parameters  and also that we introduce a very reasonable   prior into the neural network. Anyway, in our  implementation we'll follow what the paper did   and hardcode these filters. However, note that in  the other video that I mentioned at the beginning   the author actually just learns the 3x3 filters  from scratch. After applying our three filters to   the 16 channel image we end up with a 48 channel  image. We then apply two 1x1 convolutions which   is nothing else than applying a linear model  for each pixel over all channels. I guess   I will describe this in detail later The last  operation is called the stochastic update and   it is more or less a pixel wise dropout. Finally,  we take this image and we add it to the input one   and that is nothing else than a residual block.  Finally, we will check the alpha channel of the   image. If it's below 0.1, we will just consider  that given cell dead and manually set the   channels to 0 and this process is called a live  masking. What we saw before was a single iteration   of applying the rule, however, we actually  want to take our input image and run it through   the same pipeline multiple times. In this diagram  you can clearly see that once we take multiple   steps we simply take our predicted image namely  the first four channels and the target image   and compute the L2 loss and there you go.  We have our deep learning pipeline. Anyway,   I guess that's it for the explanations  and in my opinion the implementation   is pretty straightforward. So let's just get  started. First of all, we implement the model. First parameter determines the number of  channels of the input image. In the article   this is actually equal to 16. Since we are going  to run the 1x1 convolution twice we can decide on   any number of hidden channels. So the  fire rate determines how probable it is   that a given cell is going to  go through the update process   and finally we provide the device so that we  can easily switch between the CPU and the GPU. Internally, we will create this updated module  which is nothing else than two 1x1 convolutions.   What's really important is that this will be  the only block of our pipeline that is going to   have learnable parameters. Internally, we also  store this filter tensor that will represent   the identity filter and the Sobel  filter in the x and y direction.   As always you call the constructor of the parent.   If the user doesn't provide a device  specifically we will default to a CPU.   So first of all we need to prepare the  filters for the so-called perceive step.   This step is nothing else than a 3x3 convolution. We define manually the so-called Sobel  filter and what it does is approximation   of the gradient and again the idea behind it  is to tell our current cell what is happening   around it and in what direction it would need  to go to maximize or minimize the intensity. Here we define an identity filter  and if we slide this filter over   any image we will actually get exactly the same  image. So here we take the three filters that   we define and we just stack them together  along the zero-th dimension. Our ultimate   goal is to take these three filters and apply  them to each of the channels of the input image   and therefore we will end up with a new image  that will have three times as many channels. So here we just repeated the filters over all  channels and we send them to the right device and   finally we store them internally as an attribute  because we will use them in the forward pass.   Let me just stress again one very important  thing. These filters are not learnable.   We manually hardcoded them. Now we want to prepare   the so-called update step this is the only  place where we will have trainable parameters. We use the sequential module to define three  consecutive steps. We apply the 1x1 convolution,   then the ReLU activation and finally  again another 1x1 convolution. Let me   quickly explain the relation between the  linear model and the 1x1 convolution. So we defined a couple of constants and here I  define a random tensor that represents a patch of   different images. Let me now  instantiate two different torch modules.   Here I created the 1x1 convolution layer and here  I created a simple linear model. First of all,   let me just check the number of parameters each  of them has. So as you can see they have the   same number of parameters. These parameters are  actually stored in the following two attributes. So as we can see the bias and the weight of the  linear and the convolution layer are more or   less matching except for some extra dimensions  and I guess at this point you realize that what   I want to say or what I want to show is that  these two modules are more or less doing the   same thing. When we do 1x1 convolution it's  nothing else than iterating over all pixels   and applying the same linear model across all  the channels. Let me just prove it to you.   So what I did here was to make sure that the  bias and the weight of the convolutional layer   is exactly the same as the weight and the  bias of the linear layer. Note that when we   constructed them these parameters were just  initialized randomly. And now the idea is   to run the forward pass with our random tensor  and see whether we would get the same result. Note that I actually had to permute the dimensions  of the input sensor in order for it to be usable   with the linear module. However, then I actually  undid it after the forward pass. First of all let   us check the shapes. They seem to be the same  and also these two tensors seem to be the same   element wise. If you disregard tiny differences.  To summarize 1x1 convolution is nothing else   than a linear model that is applied to  all pixels across the channels. We are   back in the implementation. So since we're  using the 1x1 convolution we will be never   looking at the neighbors and we are hoping that  by now all the information is already encoded   in the channels and that is actually a reasonable  assumption because as we saw in the previous step   we already included a lot of information  about the neighbors via the Sobel filters.   To understand what I'm doing here let me just  remind you that our seed starting image is going   to be a single bright pixel in the middle of  the image. All the other pixels or you can call   them cells will be non-active and by adjusting the  weight and the bias of this second 1x1 convolution   we're making sure it will actually take a couple  of iterations of this rule to populate the pixels   that are further away from the center. I guess  the main motivation behind this is to make the   training simpler and just make sure we don't  end up with some crazy complicated pattern just   after the first iteration. Finally, we recursively  send all the parameters of this module to   our desired device. All right so now we're  done with our constructor and we can write   a couple of helper methods that will finally  be put together to create the forward pass. So here we implement the perceive step. Its goal  is to look at the surrounding pixels or cells and   understand how the intensity changes. There are  no learnable parameters here. When it comes to the   input and output shapes as you can see they're  the same except for the number of channels. We   actually multiply the number of channels by  3 because we apply 3 filters to each of them. So we take the filters we prepared in  the constructor and we just perform   a so-called depth-wise convolution and we  achieve this by setting groups equal to   the number of input channels. Let  us now implement the update step.   Again the update step is the only place where  we have trainable parameters and it's exactly   those parameters inside of the two 1x1 convolution  layers and it's just a one-liner because we   prepared everything in the constructor.  Next step to implement is the stochastic update. The stochastic update is nothing else  than a pixelwise dropout, however, note that   we're not actually scaling the remaining values by  any scalar. Let me just point out that this step   as well as the others has a biological rationale.  We don't want all the cells to be updated with   each iteration which would kind of imply that  there's this global clock and with each iteration   everybody updates. What we want is for this  process to be more or less random. Let's say   I focus on a given cell. I want it to update only  80% of the time independently of its neighbors. First of all we create a boolean  mask for each pixel and then we just   element wise multiply the original  tensor with the mask. Now that this   mask is going to be broadcasted over all  the channels it cannot happen that some   channels of a given pixels are active  and the remaining ones are inactive.   So this utility function will actually  take the alpha channel of our image   which will be the fourth one and it will use it  to determine whether a given cell is alive or not. And the criterion here is  that if the cell itself or   any cell in the neighborhood has an alpha  channel higher than 0.1 this cell will be   considered as alive. And now we have all  we need to implement the forward method. Let me just remind you that calling the forward  method once in our case will mean nothing else   than one iteration of the rule. What we will  actually do while training is to call the   forward method multiple times to simulate multiple  iterations. First of all we will create a pre-live   mask which will be a tensor of booleans. We  take our input tensor and run the perceived   step which applies the identity and the two Sobel  filters. Then we run the update step that contains   learnable parameters. We run the stochastic update  and the goal of it is to make sure that some cells   don't get updated during this forward pass and  thus making it more biologically plausible.   Here we actually use a residual block and it's  really important because the new image is nothing   else than the previous image plus some delta image  and I guess here one can make the same argument   as with the ResNet. We will run this forward  method multiple times and one way to think about   this is that you're just creating a very deep  architecture. We compute the post life mask. The final life mask is going to be an element-wise  and operator between the pre-life mask and the   post-life mask. That is it for the forward pass!  Right now we want to write a training script. Here we load an RGBA image and we pre-multiply  the RGB channels with the alpha channel and   finally we turn it into a torch tensor and make  sure that the channels are actually the second dimension. We take an RGBA image and we turn it into an RGB  image. Note that we use the torch clamp to make   sure we are not falling outside of the range  0, 1 and we want the background to be white. Here we create our initial grid state. It is  nothing else than a blank image. What we will   do is to take its center pixel and we will  set all the channels except for RGB equal to   one. Right now we would like to create a  command line interface because there are   multiple parameters that one can play around with. I'm going to explain some of these arguments  when we actually use them in the code.   We parse the arguments and we just  print them out to see the configuration.   We instantiate the device based on the CLI  option. Here we prepare the TensorBoard writer. Here we load the target image and we pad it on all  four of the borders. We do this to kind of prevent   overfitting since we don't want the network to  rely on the fact that there are borders nearby.   Finally, we just take the same image and repeat  it. That is because we want to do batch training.   We also add this target image to TensorBoard We instantiate the model that we  wrote and we also create an optimizer. Here I need to provide more explanation. Instead  of always starting from the seed image and then   trying to get to the target image we will create  a pool of images that should ideally contain all   in-between states together with the target one  and also the seed image. The main idea of this   pool is to make sure that once we reach the final  pattern more iterations are not going to degrade   this pattern. You will see how the pool is being  updated in a couple of lines. Now we're trying   to write our training loop. Most importantly  we will take number of batches gradient steps,   we will randomly select a couple of samples from  our pool and that way we'll create a batch. This   part is really important because we will take our  batch and we will just run it through our forward   pass multiple times and the number of iterations  is actually not going to be deterministic. It's   going to be just randomly sampled from the  interval 64 to 96. We are hoping that around   70 iterations should be enough to go from the  seed image all the way to the target image. Here we compute per sample mean squared error.  Note that we are only extracting the first   four channels out of our predicted image  and that is because the target image itself   only contains the RGBA channels. We compute  an average loss over all samples in the batch   and then we just pick the gradient  step and log the loss with TensorBoard. Here we're trying to update our pool. First of  all, we find a sample for which the loss was   the highest and we make an assumption that this  sample was terrible and that we actually do not   want to keep it in that pool. So what we do is  that we just replace this bad sample with the   initial seed sample. When it comes to the other  samples in a batch we actually throw them back   into the pool but what's important  it's the updated version of them not   the initial one. This way we are hoping to  create a pool that will contain all kinds   of different images that represent different  stages of the development of our final pattern. This is just a logging block that will create a  video for TensorBoard and the idea is that each   frame will represent different iteration. However,  we would like to run it for way more iterations.   In our case it will be 300 and the number  of iterations we trained it for which was   in the range of 60 and 90. This way we'll be  able to assess whether the pattern once it   reaches its final form the target image stays  stable. Right and that's it. Now we just need   to train it. First of all, let us verify whether  the CLI was created correctly. It seems to be   the case. Note that to train we need to have a  target image which I have here. What's important   is that it's an RGBA image so that it has the  alpha channel. So now i will just launch the   training. In order to get decent results in  a matter of minutes one needs to use a GPU. I'll just let it train and once it's done I'll  just show you the results in the TensorBoard.   First of all, you can see that the  loss is consistently going down.   When it comes to the videos we can see that  after 500 gradient steps the rule is kind of   able to reproduce the general shape of the  rabbit's face. However, it is far from perfect.   Additionally, it seems like artifacts appear after  a certain number of steps and therefore it is not   stable. If we look at the rule towards the end of  the training we can see that it is pretty good and   stable. Let me just point out that I did not cover  regeneration. One can actually perturb the image   during the training and that way we can force the  rule to be able to deal with degenerate images.   Also I did not cover the rotating of the Sobel  filters. Once the model is trained what happens   is that the actual image rotates too which is  really impressive. Anyway, that's it for the   video. All the credit goes to the authors. I hope  I managed to interpret their research correctly.   Additionally I made a lot of modifications and  simplifications in the code and I hope that i   did not introduce too many mistakes. Anyway, I  hope you managed to learn new things and that you   found this video interesting. I will continue  creating similar content in the future so do   not hesitate to subscribe. I wish you a nice  rest of the day and I will see you next time
Info
Channel: mildlyoverfitted
Views: 1,291
Rating: 4.9402986 out of 5
Keywords: neural cellular automata, growing neural cellular automata, cellular automata, deep learning, machine learning, pytorch, implementation, github, python, state-of-the-art, from scratch, computer vision, data science, research, artificial intelligence, Differentiable Model of Morphogenesis
Id: 21ACbWoF2Oo
Channel Id: undefined
Length: 26min 27sec (1587 seconds)
Published: Sat Apr 10 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.