Learning to Ponder: Memory in Deep Neural Networks - #528

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
[Music] all right everyone i am here with andrea benino andrea is a research scientist at deepmind andrea welcome to the twiml ai podcast thanks sam thanks for having me so you are working on artificial general intelligence at deepmind tell us a little bit about where your interest in agi comes from my interest in agi comes from my background i'm actually a neuroscientist who decided to study how the brain works and in particular memory how we create our memory why we create our memory and in particular i'm interested in a subset of the memory field which is called episodic memory so episodic memory uh those memories that relates to some episodes that you experience yourself in the past in a specific location sometimes or with the specific people so they they they are also referred as autobiographical memories and episodic is in contrast to what other kinds of memory another example could be semantic memory so semantic is like you know the meaning of something for example you know what is a bicycle that would be a semantic memory but if i ask you to remind you uh to remember uh like specifically when when you have when you cycle with someone uh then maybe now you're thinking about a specific episode in your life that would be an episodic memory so you relate something to a specific moment in time it's also described as the what when and where of of a specific event like episodic memory to some degree the relationship between memory and intelligence is kind of an obvious one in the sense that you know we use our memory and our prior experiences in interacting with the world making decisions and all that but is there a kind of broader significance to uh memory in the development of agi yeah i think uh as you say it's kind of uh of an obvious one right so we live in a world that is consistent so if we gain some experience in the world then we want to reuse that experience if we don't want to relearn every time from scratch so you can already see how that's viable but something that i actually study over my phd is that episodic memory it's also a way to so it's something that enables generalization in particular we study how let's let's say if you have if you experience two events a and b that are related together and later in time you experience other two events like b and c that are again related together your brain without you doing any effort directly relate a to c together such that you relate uh you you do this kind of inference through your episodic memory so if you want to give like if i i can tell you like a more uh more precise example would be like if you see someone uh going out with a dog in the morning and then you see the same dog with a different person in the afternoon immediately your brain is going to try to connect the two people that were going out with the dog and that's the kind of of influence that episodic memory we know support that's that's quite important right right right and what is the the the history of memory you know in this effort to kind of get us closer to agi how have how have we used memory in the field to facilitate uh intelligence okay so this is this i think nice question because i think in some sense the kind of memory i'm talking about is still an open question how we do it properly in in ai we have different forms of memory currently which we know how to to play with in particular the current networks both rnn and the lsdm are so called working memory so it's kind that gives you like imagine a white white board where you can write something and reason about that and then erase because it won't stay there forever it's like kind of a canvas where you can you can make uh predictions um that's that's what we know then we have also something called memory augmented neural networks uh in that case we we basically give neural networks an external memory where they can write previous computation that they performed and then read back from there and reuse previous computation and this gets a little bit closer to the kind of memory uh i hinted before but we are not there yet and then we have retrieval augmented models which are those models that basically go back in a if you like in a table where we store almost everything we have seen in the past like a dictionary if you like and then they try to look up things but most of the time they look up and they they they use that lookup to answer but they don't consolidate back the knowledge into the into the weights if you like of the system so they need to do the lookup all the time which in some sense is a waste we don't do as i said before we don't do that we kind of make sense of what we retrieve and use it later right so in that latter case you can think of it as kind of this fixed boundary between the memory and the the computation in a sense and the the computation isn't ever updated the way we think about or the way the model thinks about its inputs is never updated uh with regard to the things that it learns in the memory it's just checking the memory constantly with each each input yeah sometimes there are models for i i i that comes now to my mind a model where that computation is updated but it's very difficult to scale up those models because they require a lot of computation so we don't know yet how to make that model big to a scale where they're actually useful to target a very complex problem and also we have finally the last one would be a very recent uh advancement in in ai which is this something that everyone knows i'm sure transformer these models have shown some of the properties of this kind of memory in language there are what the the sometimes they saw a few short learning uh which is a prerogative of like the memory models i'm talking about but that happens in language only not in other domains so i'm not sure if it's be that then that raised the question it's because of the model capable of doing that or it's because of this specific uh domain uh where you train the model that that allows that so no one has done that experiment uh yet but i think it could be interesting and can you drill down on that how should we think about transformers from a memory perspective and exhibiting this these properties that you mentioned first of all i think most of the transformer that we use now in language uh transformer excel and it's a specific instantiation of transformer whereby you add as a memory and a sort of external memory at each layer of the transformer so we already know that that's necessary in some sense to overcome some of the limits like not having recurrence for instance um so i would say they are limited in terms of memory also because given that they do this all to all uh comparison over the sequence we cannot process very long context there are now papers trying to deal with that and reduce the complexity to to linear size but but uh still i don't think we are at the point where let's say we can process uh several books and ask inference questions about books without like doing like this sort of external retrieval all the time you know oftentimes there are ideas that we might want to apply in the context of neural networks and a big challenge is yeah are they differentiable you know so that we can apply techniques like gradient descent uh is is memory you know typically differentiable is it typically undifferentiable and like how does how does that play into you you touch the the perfect the perfect downside of like uh this kind of external memory where we do what is called uh kenya's neighbor retrieval and that's a non-differentiable operation so that obviously has some limitations on then what you can get and again i don't think with the current system we know how to back prop over to do back propagation of a very large uh memory uh which again it's an open problem i think it's a very nice problem to tackle and we should probably start now i think we might be able to do it where why does this how does the size of the memory play into whether we can backprop over it or not because most of the time you use a soft max operation over the dimension of of the memory and you know it becomes quickly impracticable in practical to send gradients over a very large large memory and that's because uh it's computationally infeasible so it doesn't fit in memory and also when you get with very large um soft maxes also you have problem with gradients there's another aspect of memory that comes to mind you know we've talked about specific kind of features and uh you know architectures that are used to emulate memory uh but oftentimes one of the critiques of deep learning is that the networks themselves like remember stuff and you know that becomes a problem causes problems with generalization are there ways that that can be harnessed more directly to achieve some of your goals for episodic memory yeah okay i think yes memorization is it's definitely an issue like overfitting it's it's uh it's an issue although most of what we do as well is like uh kind of memorized so we enlarge our data set of experience as we grow right so in some sense we tend to overfit as well uh in most of what we do so i don't think that's that's a huge problem that's also uh there was actually a nice paper last year called direct fit to nature uh which i recommend the your listener to look up from uriassum which basically uh poses these uh the same question you are posing to me uh and it basically answer in this way so we we as we grow as we have grown as a species through evolution and also during our life we kept enlarging the the kind of the the amount of experience that we used to grow our our our brain our network so in some sense we kept memorizing it's just that we don't always start from scratch and also we have this ability of generalizes that i think it's still a little bit missing from from deep learning although we shouldn't we should recognize that this generalization is not that we can generalize to everything right so we have in some sense limited as well in generalization uh maybe neural networks are slightly more limited than us but we still we already see some example of generalization which uh are starting to to emerge in neural networks and i think that's a good thing uh and something we cannot neglect so for instance i had a paper a couple of uh in 2018 so three years ago in nature where we actually use uh we we imbued a neural network with a representation like the one like the one we have in the brain in the hippocampus and in that way the agent was able to travel that was a navigation task the agent was able to take shortcut and traverse part of the environment that was were previously blocked and the agent was able to do that with the right representation so i don't think the problem is the back propagation or the models themselves it might be a problem how we train that in terms of both data we use and the representation that they we force to emerge you mentioned this paper that implemented something akin to the representation that's used in the hippocampus what is that representation how does how does that work and how the paper we had these uh we studied these this thing called grit cell so in the hippocampus grid cells yes so in the hippocampus we have uh so it's kind it's a memory machine of the brain but it's also the spatial machine right you can see also spatial as memory but that not going to that that's a whole field that i don't want to go into but basically we have the two cells that are probably the free cells that are more more uh known uh like head direction cell so those are cells that fires every time you look in a in a certain direction but i'm talking about all the centered directions so every time you face north there will be a cell firing with a certain probability distribution over north and the same for the other the other allocated direction and the north is not the cardinal north the north is the let's say is the one with which relate to a certain reference point in the environment so let's say okay then we have place cell those are neurons that fires every time you are in a particular location uh independently of where you're looking at and then we have grid cells which are these visually mathematically beautiful cells that basically fires uh following an hexagonal lattice uh they have a 60 degree offset and they are very beautiful and and there have been several tiers that try to uh motivate the reason why we have that and one of these was because we can basically we can use that to calculate shortcut to calculate the shortest vector between two points and we manage uh to do two things in that paper first of all we managed to make the representation emerge in our neural network trained to do path integration so to do a navigation task and secondly we use these those representation in our enforcement learning agent and we prove through through several ablation as well that there was that was the only agent able to take shortcut and if we lesion let's say some grid cell the ability of the agent to take shortcut just went down so it was kind of an empirical paper to prove what the grid cell are for is it is the idea in terms of the implementation that i'm imagining you're like adding sine and cosine uh elements to your loss function or something like that no no no no no no data driven it's data-driven that our goal was really to be super data-driven and we achieved we achieved that by actually a specific arc it was our current architecture uh two things were really important one was to introduce uh a drop out uh such that no not all neurons were able to fire at the same time and the second one was introduced noise in the gradients and that basically helped if but those are the two things we always want to do yeah but i don't think particularly particularly uh uh specific to this problem that's that's that's probably why people like so much that kind of paper because it was kind of a general approach to to make that one of the although i have to say it was kind of difficult to analyze but you know one of the reason why noise helps because it helps you moving away from a certain solution in the landscape of the loss and our our our our way to work was to help the network uh going down the solution we were we liked so great cell but that wasn't too difficult actually to achieve okay and that's why i think it was a nice piece of work yeah interesting interesting um so we were i'm trying to remember speaking of memory i'm trying to remember how we got here actually the ponder stuff i i guess you invited me to talk about the pond yeah and it's inspired by memory exactly from the study that i mentioned at the beginning this thing about relate doing inference associative inference that was part of my phd and one of the paper that came out during that time was this ability of the hippocampus to basically do this recursive uh pondering before actually being able to do this sort of inference so you see people when you ask people people when you ask i remember doing this experiment practically with people and when you ask them to relate a and c to tell you a story about ac they think more they really spend more time in thinking compared to a b and b c and that's and that's that's how i got inspired because you know our brain works so the same mechanism because when we did that fmri study right the same mechanism was involved it's just that the the answer was going out of the hippocampus and then back into the hippocampus few times and but then was processed by the same system if you like and i think that's a nice property to have in an algorithm does that mean i i'm probably you know taking this too far but you know when i hear you say that i i i make associations like the you know the brain knows how to do scanning and it doesn't have something like an inner join uh okay no uh the the the i mean the problem there is that what does it mean the brain i guess from a computational perspective then i could ask you what's the loss that the brain is minimizing than to do that which i don't know i don't know the answer maybe i i don't know maybe it could be uncertainty reduction because at the end of the day we want to get better prediction we want to be able to i think uh better predict the model uh the world because it's then it would be less uncertain and so less risky but i think there are people like much better than me that could explain that yeah yeah i think that was going in maybe a slightly different direction i was inferring from the what i heard you say was that when we asked people to do these kind of associative types of uh tasks or inferences you know where they need to get from a to b and b to c you know that takes longer um which kind of suggests that there's not some built-in associative thing in the brain no because in the in the fmri study that we did we saw that also if we do longer inference we are still able to do that it just takes more time even more time so i think that the algorithm we apply is the same which is the ability of making associations is just that the the longer the jump that i ask you to do in doing this sort of associative mechanism the more you need to to do you do it a bit hierarchically right you put i don't know the things that are just two steps separated you calculate that then three step and then you might be able to put together two and three and do five so we went down this particular rat hole in trying to provide some context for pondernet which um you know i'll have you explain but i still don't really see the connection with memory when i when i read the abstract for pondernet i think about um things like you know hyper parameters tuning and like early stopping like the way we train networks and like pulling that into the network as opposed to anything having to do with memory and the stuff we were just talking about yeah so tell me more about the connection two-way answer about the connection the first one is the more high level uh maybe n-wavy if you like it's generalization so we believe that mechanism like pondernet can get you a little bit far in terms of generalization compared to not doing wandering in deep neural networks and that's the kind of things that i we discussed before right so one of the benefit of memory is you're giving you the ability to to have the right input to generalize i guess let me put it this way the second is that a previous work to ponder net was another work i did during my phd which was called memo and was a memory augmented neural network where we did exactly we studied this mechanism of recirculation so you you you the network speed out an answer you compare it with previous answer and you only give the final answer to a certain problem when you are satisfied so you were already like already implementing this sort of boundary mechanism but there it was really really uh early stage because we use musical reinforcement learning to train uh to train a bernoulli variable that was basically saying go stop and we saw that this was really how to train and very noisy in terms of variance so we decided to do something more principled uh and that's the that's what led us to to ponder it okay so taking a step back what's the problem that you're trying to solve with pondernet yeah i think that's the first line of the abstract which basically says that the normally neural networks so the amount of computation that we spend in neural network uh grows with the size of the input but not with the complexity of the problem but as we just mentioned uh like few minutes ago that's not how we reason right the more complex the problem the the more we spend time on it and that's that's what we wanted to get essentially uh with this this work and also we wanted to make it uh fairly general such that it could be applied to any architecture that was yeah to be architectural agnostic got it and so the size of the input we know what that means you know we're talking about like feature dimensionality yeah we know the problems that come along with that complexity of the problem what exactly does that mean and how do we measure that okay uh so i think empirically again we have like uh the expert the the example i like in the paper is that it takes more to divide than to sum so that's exactly the same problem mathematically speaking but for some reason we spend more time dividing than doing submission i mean it's fair to say that we're we're talking about the [Music] computational complexity of a given problem as opposed to some conceptual type of accident yeah yeah yeah so if you have to apply the same algorithm to so the algorithm is the same it's just that the specific instantiation where you have to apply it it requires more more compute more compute time so it's like a computer right so it's the same the same you apply the same function but in some in some circumstances it takes more time so the computer thinks more that's not how neural networks work now it's they imply the same amount of thinking for for each input i can give you another example in a sentence for instance when we read we don't spend the same amount of time gazing each single word in a sentence so that's been proved in psychology so we tend to focus our attention on few words in the sentence that are more important to process the whole thing so that's another that's another practical example okay okay um and so you want to create a neural network that um [Music] you know would you describe it as kind of budgets it's you know computational investment in solving a problem according to the inherent complexity of the problem is that a way to think of it yeah i would say it's a fair description and so how did you do that uh in pondernet okay so important first of all this is based on a previous word called adaptive computation time and the problem there is that there they they directly minimize minimize the number of pondering step so the number of step the the the the the the network took whereas in our case what we did was to make these probabilistic so to be to be a slightly more practical on these for each uh time step in the sequence uh we calculate the prediction the probability of alting and the next step so the probability of alting is just a bernoulli random variable which tells you the probability of alting at this particular step given that you have not altered in previous step and then from there what we can do then is to calculate a probability distribution by basically multiplying the probability at each time step in order to form a proper distribution for geometric distribution once we have that what we can do is to basically weight each single so we calculated the loss for each prediction in the sequence that we made and then the loss uh is then weighted by the probability of having altered at that particular step and that's a critical difference between us and act because in act they instead output the weighted average of the prediction so they don't output a specific prediction but awaited the average of the whole prediction and that creates a bias integrated whereas in our case we can basically just really take the the exact loss at each time step uh given given the decision of lt and then we take the particular at training time we take just a particular step uh that has a threshold that basically surpassed the threshold that we decide for ulti whereas at test time we just sample from the probability that from the probability distribution that we learned okay i mentioned earlier that it calls to mind you're talking about halting here calls to my early stopping is the idea that you know early stopping is kind of like uh you know we're trying to conserve training time is the idea here that we're trying to conserve inference time like yeah exactly we're trying to make a prediction instead of going through the whole thing let's keep going until we're sure what the answer is and then stop early and this is an approach to getting there yeah i think our hypothesis was that you know you can let's say you want to implement an algorithm on your phone you can train it in a in the cloud no problem a training time like the amount of compe a few people no problem but then at influence time you want to be quick and that's and that's uh and that's important but now i'm i'm laughing a little bit but you know if done properly i think this could help also reducing the amount of resources that you you spend the training time because we actually have an experiment in the paper where we see that the total number of gradient updates for pondernet are smaller than other methods given the same final performance so this could also help reducing the amount of resources that for for certain people i guess okay uh and then circling back the the connection to memory here is it in storing these probable probabilities and the bernoulli variables that kind of stuff or is there a different connection no yeah i guess the we left the connection with memory maybe a little bit behind because okay however you might think i remember one tweet that actually also inspired this work a little bit from under a car party that was basically say one of the limitations of transformer is actually they spend the amount the same amount of compute for each token in the sequence so if we treat broadly speaking again a little bit like transformer as a form of preliminal memory right you can you can think of applying these on top of transformer and and see if that could help right maybe spending different amount of compute per point in the sequence got it got it got it so um i understand it's a bit stretchy but yeah yeah yeah yeah it's it's kind of uh it's an analogy of some sort it's not necessarily an implementation of a memory system that we're talking about here is pondernet is a specific network architecture as opposed to a technique that you could apply to different architectures or is it the the latter it's a it's a technique so it's really if indeed if you see in the paper the step function what we call s in the paper could be anything could be another nano c and a transformer okay on a rail agent whatever as long as you return so as long as you add this this extra unit that calculates the probability of altit you can apply this through everything that that was important for us and indeed in the paper we do that so we applied it three different architectures okay and how did you evaluate the results so we use uh one task from the act paper called the parity task so you have a string of one and zero or one and minus one whatever and you need to calculate the parity of that string could so could be either uh uh posit odd or even right and the good thing that we could that that's a nice task because you can also uh train on parity up let's say in our case up to 48 uh integers but then test up to twice as long the length to test this a bit of extrapolate and indeed we see that our network extrapolate much better than than baselines the other methods then we applied it to a reasoning task which is called uh babi um uh and basically you have 20 tasks which you train all in parallel it's a language task where you get asked uh uh questions actually can i can i pause you and and go back to parody i'm trying yeah yeah i'm trying to work this through my head like okay so i i guess the first thought that occurred to me is with some number of bits that you're trying to calculate the parity for it's not like uh you're trying to end the bits together and as soon as you hit a zero or something like that you know the answer right you still you need to look at all of the bits yes and the input for parity yeah yeah but yours the premise is that independent of that particular fact inside the network you could still stop early relative to going through some number of computations and still answering the writing and still answer the question correctly we have a baseline we pick these tasks exactly because we know this is a task that like a normal rnn is having trouble uh okay like doing got that's that's a word that's a kind of a well-known uh issue in this sort of literature okay got it yeah on the topic of transformers which have come up a couple times you have you did a workshop at icml talking about transformers and reinforcement learning can you talk about taking two topics that people are really excited about and combining them together tell us a little bit about the goals of that work so uh i i think i really got inspired by the birth work where they have they have these um known causal masking and you know in in rl we we keep using um lstm essentially for doing most of the tasks but we know that they suffer from what it's called uh regency bias so they tend to pay attention only to the last bit of the sequence that they are trained on which to be honest for most of the environment that we we use nowadays it's fine but if you grow in size and the context you want to pay attention to is quite long then they start to suffer so one option uh would be to use transformer because we know that they can handle long term long context longer context than lstm however the problem in rl is that the reward are sparser most of the time and the gradient has been shown to be noisier so it's difficult to train so many weights so what we did in that work was basically to to generalize the birth training to which you know is done on token so those are like uh categorical numbers uh so that you can apply soft marks on on the other side we generalize the the birth masking to um real value uh numbers input so to basically features so we send the features from the cnn into the into the transformer we mask some of them and then on the opposite side we basically use a contrastive approach to uh reconstruct the the mask the input so we give like some negative taken from the same sequence and and and positive and the network has to discriminate and also what we did in that work was to combine lstm and transformer in the same architecture and and the good thing about that is that it helps you reducing the size of the the transformer to gain in in speed so because the and we let the agent actually learn when to use transformer sorry when to use stm alone or when to combine transformer and the stm such that in some tasks it can basically avoid the extra complexity of the transformer and just use the stm whereas in other more complex tasks it will focus on using both got it and so in a sense you there are echoes of the pondernet paper in that you're trying to manage the computation or let the agent manage the computational investment uh based on an assessment of complexity in some sense yes through the rl gradients there so that was really the agent by playing with the environment that was deciding what to use because yeah it's although i would love to actually have agents that stop and ponder that i think would be nice what types of problems uh did you experiment with for this paper three three different domains the whole of atari suite deep mind control suite uh which i'm particularly proud because that's something that normally with of policy methods like the one that we had there we didn't do so much work and then um deep deep mind laboratory which is a suit of thirty task 3d complex tasks that you play all at the same time so it's also flavor multitasking at scale that was the the task really at scale and from a performance perspective what kind of results did you see was this you know promising enough to keep poking around at or was it uh you know really good performance that um you know just kind of challenges state of the art first of all we always improve on that efficiency massively compared to baseline and in many domains especially in dm lab so in the the deepmind laboratory we uh actually also got state of the art performers okay which was so and that's good because that was the more complex domain where we played so i think and and i think that was kind of uh not an issue but you know something we could have done slightly better to focus more and more on complex domain because i think that's where these kind of methods will shine so like complex architecture probably will benefit more from like complex method although i still think it's something i'm quite passionate about i think there's lots of stuff we can do to improve transformer and memory in general in reinforcement learning especially in relation to the length of the context that we can process that's something i think important and kind of a bottleneck in my opinion meaning the approach you took in this work of you know coupling the lstm and the transformer and allowing the agent to choose sounds like you're saying you know that that's kind of a beginning place but there's a lot more a lot more to be done there i think so i think so as as for us we have different sort of memory as i said before we have very long-term memory we have short-term memory i think our i'm not the first one to say this there are few papers out there yet already and they argue and i argue that agents should be equipped with this sort of different time scale memory awesome awesome well andrea thanks so much for taking the time to share a bit about what you're working on yeah it's been a pleasure simon again thanks a lot for going writing actually absolutely thank you so much you
Info
Channel: The TWIML AI Podcast with Sam Charrington
Views: 231
Rating: undefined out of 5
Keywords: TWiML & AI, Podcast, Tech, Technology, ML, AI, Machine Learning, Artificial Intelligence, Sam Charrington, data, science, computer science, deep learning, neuroscience, deepmind, andrea banino, pondernet, memory, episodic memory, semantic memory, andrej karpathy
Id: Na0lm-y81lE
Channel Id: undefined
Length: 41min 37sec (2497 seconds)
Published: Mon Oct 18 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.