Understanding Mamba and State Space Models

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
at the end of last year a new architecture called Mamba was released Mamba allows for the same token generation speed irrespective of context length it works by compressing each token into storage so that it avoids having to pay attention to each of those previous tokens when it's generating the next one I'm going to describe some of the basics of Mamba give you some basic equations and some notebooks that you can run through we're still in the early days but the idea of compressing the context I think is a good one and that will allow us to move beyond the current Transformer design for agenda I'm going to start off and describe what a States based model is that's what Mamba is and I'm going to cast it by comparing with a Transformer which many of you will be familiar with I'll then go through some very simple Mamba maths just to show you the basics of how it works before moving to a notebook that I'll put a link to below so you can move move through and inference Mamba by yourself I'll then make a few comments on training Mamba the tools are really not mature so it's quite difficult to train at the moment for example there are no tools like Laura and only basic functionalities are supported by the huging face trainer nonetheless I'll quickly give you a look at that and I've put a copy of it into the advanced fine-tuning repo for those who have purchased lifetime access and I'll finish off with a few comments on practical issues today around the Mamba architecture and where I think it can improve over the current Transformers to understand Mamba let's start by reviewing how Transformers work so here I have a sequence of colors and my goal is to predict the next color so quite simply I have red and then blue and I'm going to predict the next color here and when we work with Transformers this works by paying attention so calculating the relationship between the previous token and each of the tokens before that so here given we have two non tokens we are going to run attention between these and then make a prediction of the third token and this process will continue on so when predicting the fourth token here we're going to have a tension between the third and the second and between the third and the first and so on again when we move to predict the fifth token we now are paying attention and Computing the relationship between 4 and 3 4 and 2 and 4 and one now we also still make use of the relationships between 3 and 2 and 3 and one and also 2 and one but those values have been stored from the previous steps so at every incremental step I'm showing you here the incremental relationships that need to be calculated through the mechanism of attention so this is how Transformer works and as you can see the further we get into a sequence or the longer the starting sequence here I've just started with two the more the number of calculations that need to be done in attention and this is what makes it become computationally expensive once you get to longer and longer sequences with Transformers so let's now move to a state space model so in this case I'm going to start off with one token and it's red and I'm going to predict the next token and the way this works is by having a little piece of storage here which is called State space you can think of it it is either a vector or it could be a 2X two Matrix sorry a matrix that has two Dimensions but either way it's got a certain amount of storage and you take the input token and through some kind of function that token will be compressed and stored within the state space so it's being stored inside this vector and then this state space here is going to be used to calculate the prediction so this is how Mambo would would take in a first token it would use that token to update the state space which is just storage and then use that storage to make a prediction of the next token and notice here when we move to the second token Mamba completely throws out the first token so we no longer even even pay any attention to the first token any information we need from that first token is stored within the state space in compressed form so again for the second token we take it in through a math itical function and update the state space so we basically compress the information that's in token 2 into this storage and then we use this storage which you can think of again as a vector we use it to calculate a prediction of the third token now when we come to the third token and predicting the fourth token you can see we've thrown out the first and the second tokens because they are being storage in a compressed form within this state space here which again is just a very big Val vector and now token 3 is going to be transformed so that we update the state space H and we're going to then use this state space to predict uh token 4 here so you can see in every case with the states base model and we can move again to step four where we compress token four and predict token 5 at every step we're throwing away all of the old tokens so unlike a transformer in the state state space model we have the the benefit that we don't need to do new calculations on every token generation we don't need to go back and make those calculations just to say that once more here when we're predicting token 5 with a Transformer we do have to calculate the attention between this previous token that has newly become available and all of the previous tokens whereas in the state space model because we're compressing each new token as we go we don't need to do these previous calculations we just need to update the state space and then calculate the prediction so if you look at it in terms of the vectors that we need to save in Transformer we have all of the input tokens say we have input token X1 represented by a vector X2 X3 and so on for each input in the sequence and we need to store each of those and when we generate a new token we need to run attention between the most recently available token and all of those previous factors now in a States based model all we consider is the most recent token that we've received we forget about everything in the past we just consider the most recent token and we consider H which is the state space now H contains information compressed information on all the previous tokens so here is the Crux of Transformers versus Mamba in Transformers as the sequence gets longer and longer you've got more calculations to make in Mamba you don't because you're updating the size of this state space however because you need to compress all of the information and store it you have this Vector H you have this storage H and it's going to need to be bigger because it's got to store the information content of many vectors it's going to need to be much bigger than one vector X or say X1 so if you had a really short sequence and this is looking at it in simplified terms if you a very short sequence it might be punitive because you have to consider the calculations for this storage but as your sequences become longer and longer this is where Mamba is really beneficial because you still only have to make this computation based on H you don't have to calculate all of the previous totens so with that let's look at in mathematical terms how we can achieve this Behavior here how can we uh develop some equations for the state space so that we can update it when we have a new token and then how can we calculate a prediction for a new token to help explain the maths I'm going to use chat GPT with a chat that I've developed earlier so we're going to talk about State space language models and we're going to talk about how to update the state the storage H based on the incoming Vector X1 and then how to calculate the new Vector X2 so we'll Define a state as H and we're going to talk about updating H and as I said given an input X the question is how to update H so when we say update H what I mean is what's the Delta H what's the change in h we want to bring about when there's a new X and what do we want to make that depend upon well there are two things that make sense to make it dependent on one is the input that's coming in but also the current state so if you you think of the memory as saying I don't want to know anything about dogs but the input is dog then that means you probably want to ignore dog and therefore you can see that it's important to consider not just the latest input but also the memory so we'll make Delta H dependent on X and we'll make it dependent on the current state H now if you want to do a very simple relationship where H where the update depends on H and X the simplest is always linear so here we've got a linear expression where we've got Delta H it says equals but I really mean is scaling with a which is just a constant time h plus b * X so this is the simplest way you could update the state by making it dependent on this on the storage and also on the input so we've got the linear relationship and there's actually one thing missing here from the right hand side of the equation and that's if this is a Delta h then we need to have an increment on the right hand side so here I've got this small Delta and what this means is we're going to update the state of the storage and we're going to do it in a way that depends linearly on the storage itself and the input and the size of the update we make which you can kind of think of as between zero and one just very Loosely the size of the update we're going to make is going to depend on how Delta is set and now we're going to make some of these parameters trainable so so the idea of a state space model is you train on a bunch of data with Delta trainable with a trainable and with B trainable and you figure out these matrices which will remain constant when you go to inference and that will provide you with a state St stat space model that allows you to update H now that we know how to calculate the update on H we can calculate H2 so that's going to be quite simple we'll have H2 is equal to H H 1 which is the previous state plus the update which we've already got the formula for so now that we have the new state we need to ask how can we calculate the new value of x so the output of X and the Very simplest way we can do that is by multiplying um X1 by the state so we take the new state H2 and we simply multiply it by the input in order to get the output and you'll see here there's again uh there's a trainable constant C that can be put in front so if we pull these equations together now for a States based model first we have an equation to get an updated state which is linearly dependent on the state itself and the input vector and then we have an equation to get the output vector and that's our states based model now it's not a selective States based model like Mamba that's actually why Mamba is called Mamba because it's reminiscent of snakes and snakes say his and so there's an extra s there that makes it selective so Mamba does not use constant values in all of the cases it makes these trainable matrices here it actually makes them linear functions of the input and what this means is when Mamba is compressing the input context the compression is further linearly dependent on the input so if you have a token coming in that says dog the way that it's going to be compressed is going to depend on that token dog now in principle all of these could be made trainable including Delta as well but in the Mamba paper a is remains as a matrix of constants and only the Delta the B and the C end up being trainable so I'll show you now those equations in full so it's very same equation but Delta now depends on the input X1 B depends on the input X1 and C depends on the input X1 a remains a matrix of constants and here you have it this is the set of equations that describe Mamba in a very high level loose sense and you can see that unlike Transformers there's not a very long array of calculations with all of the previous tokens because you're only ever concerned with the input token and the current state space now the drawback as I mentioned earlier is that the state space which needs to be large enough to store the compressed information is going to be larger and so that's going to increase to some degree the compute that you have to do and with that quick introduction we're going to take a look at some real models using a notebook there are two models we're going to take a look with there's a more original model here from the SP State spaces team this is a Mamba 2.8 billion parameter model so still about bit less than half half of the parameters of a Lama 7B model and it's been trained additionally on Slim pajamas data set so this is a base model here it's not chat fine-tuned so it will keep on blabbing on and you have to give a nice continuation for to get a meaningful prompt the other model we'll look at is fine tune of this which is um the instruct open Hermes model and this one is chat fine tuned and it actually responds quite nicely um on it's fine-tuned on a wide variety of data here including some uh onot data I believe and some coding examples too although as we'll see the coding performance is not quite great on this model yet so we'll head over to this Mamba notebook I'll put a link in the description and we're going to get started by doing some installations so there are a few packages we have to install including Mamba SSM which is the main package required we're then going to clone uh the Mamba repo so we can make use of some of those scripts and so I'll CD into that Mamba directory that is created here and by the way you can run this you can actually run it on a T4 if you like I'm running on a V100 you could run on an a100 either T4 of course being the free one now there uh lines of commands here just um for setup otherwise you run into some bugs later on with the inference now I'm first going to show you an example doing inference here with the Mamba slim pajama data set so we'll import some different modules including uh the Mamba LM head model and then we're going to set the model name um we're actually going to import a model from Tris repo the only reason for this is because the slim pajama model is saved in 32 bits so I've just pushed a 16-bit version that's a little bit quicker to download so we're going to load this and we load it onto the GPU so set the device as Cuda I've set the dtype to float 16 you can set it to B float 16 if you're using an a100 um but I've set it to float 16 for the T4 so we can run on any device and that model then should load it's got just under 2.8 billion parameters as you can see and it's about 5.6 gabyt in size because it's in 16 bit so 2 * 28 is uh around 56 now here we have the model config the configuration is different than a traditional Transformer you're not going to see um any you're not going to see for example a context length um because that's not a parameter that's of relevance because we're always just looking at the previous token here you can see a description of the modules this has got 64 layers which is quite a lot actually for a small module and you can see within the each of those blocks you have the Mamba block here which is defined by a series of linear projections also you can see at con 1D this relates to the Delta this is the sizing of the step so the adjustment we make to the state the size of that adjustment is actually trainable and that means depending on what what input it is it will make a smaller or a larger adjustment to the state then there are some aspects of the model that are quite similar to Transformers like the embedding module and uh also we have LM head and Norm modules as well so we'll do first a quick example of what planets are in our solar system notice that this is not a chat finetune model so I'm helping the model by not just saying uh new line new line answer but then saying here is the answer and then a colon and now we're going to run this through um this code here and see what we get for output um so I've gone through there's a series you can take your time going through it a series of um preparations of the tokens submitting them to the model and here we have the question what plants are in in the solar system answer here is the answer the plants in our solar system are Mercury Earth Venus Earth Mars Jupiter Saturn Uranus Neptune and Pluto so not exactly right but very close and you can see the model does blab on because it's not fine tune to finish um it actually yeah it just keeps on blabbing on and that's the nature of Base models but you can see that it's giving a somewhat reasonable answer not as good of course as stronger models even Lama 7B would probably get the exact eight planets now next um we'll move on and take a look at Pass Key retrieval so I've uploaded a long file here Burkshire hathway transcript and I've embedded a pass key halfway through so that's at the 05 Mark and I'm going to set in a chunk the chunk of that birkshire tra transcript is going to be almost 16,000 characters so almost 4,000 tokens it's a little bit less actually this is the largest I've been able to set it to somewhere around 3,600 tokens probably before you start getting issues now interestingly Mamba has been trained on a context length of 2,000 tokens um so it's interesting that it seems to get pass key retrieval a little bit longer than that because when we run the code um as I have done here and you can see um the first thing it responds with is actually the pass key um and just to show you the prompt is respond with the pass key it gives the text then respond with the pass key contained with the the above text the pass key is so we're really helping the model to give the very next next word which indeed is the pass key and that's correct now the model does blab on afterwards which is understandable because it's not a chat fine-tuned model so overall Pass Key retrieval is very good on this uh very small model next I'm going to do some inference on open Hermes which is this model here that's a little bit stronger because it's got the benefit of further fine-tuning so I've just uh uncommented this here so that we make sure to load that model you can see the weights are loading here and the shards have been downloaded so they're going to be loaded into the model now and here we have the result of the model which is the planets in our solar system are Mercury Venus Earth Mars Jupiter Saturn Uranus and Neptune so you can see that in this case the model is already a bit stronger because of that fine-tuning now just another question I'm going to put in here I'm not going to load the model again so let's just update the prompt to ask um write a piece of python code to add the first uh let's say first five Fibonacci numbers so let's try that and we'll see what we get back okay so here we have um piece of code and we can create a code cell and see if it runs and indeed this is the free Bonacci sequence but the code has not added the numbers together so it's missing a key piece and generally with my testing the coding performance is not great it's probably not as good as say five2 um which is a kind of similarly sized Transformer model in terms of parameters now that being said the amount of training that this model has undergone is not probably comparable to fi um nor is the total amount of tokens that have been used for training um so that's worth considering as well now I did say I would talk a little bit about training this model and so here I have a short notebook that I've created for training Mamba it's available for those who've purchased lifetime access to the advanced fin shuning repo and broadly speaking it's quite similar to how we might train a normal model except we don't have a lot of the tools that normally would be available it's not possible yet but I believe we be soon to do Laura fine-tuning that would reduce the memory requirements significantly but also Laura tends to be more stable when you're training than having to train all of the parameters I'll just show you here that we use the trainer the data setup is very similar to normal and indeed the trainer you'll recognize uh many of the arguments as been the same as what we would use for training a Transformer the same things around gr grading to accumulation batch size um logging steps saving steps whether we're using brain float 6 if we're running on an Amper type GPU or just fp16 if we're using a T4 and I was able to train the model uh reasonably stably I only decided to to remove some of the modules from training so I wasn't training all of the parameters I was trying to fine-tune the model using an open assist data set um you can check out the trellis repo and hugging face to find uh the fine-tuned model and you'll see that I was able to get the training loss to redu use um fairly consistently and able to run uh an evaluation now also because of the setup I was not able to do the evaluation set in parallel to the training set so I couldn't check my evaluation loss as the training was progressing that's because um the classes need to be updated to be compatible with how we would do things with Transformers so I would say not an easy model to train right now just because of the toolkit but I think if we wait even another month probably a lot of the tools will be available now before I wrap things up I just want to make a few last comments I expect there will be more videos on Mamba type architectures first off something that's particularly interesting about Mamba is because it has this state you can run a very long prompt or even a book and store the state of that book and then give that state to somebody else for inference so instead of having to keep running with very long pre PR prompts like we do in Transformers when you use a States based model you can just get the state from somebody else who has already run all of that context and start using that state so that's going to result in a lot of savings it's kind of like uh a form of fine-tuning but you don't have to do back propagation another key thing you can do with States based models is in principle unlimited context training if you have a batch you can grab the state at the very end of that batch batch and use it for the start of the next batch provided that your batches are chunks from a longer piece of text and so it seems like the current Mamba models are trained on separate 2000 context windows but in principle that could really be extended to very long context Windows just by grabbing that Save State and using it for the start of the next batch next there is currently only a dependence on the input when you look at the selectivity so when I showed you the equations the parameters like b c those matrices and also Delta they're trainable and they have a linear dependence on X the input but they don't have a dependence on H now dependences on H are computationally more expensive so that is a tradeoff but potentially I think introducing relationships between the input and the state base um that might allow for even improved performance somewhat analogously to how use attention in Transformers there wouldn't be ATT tension between uh each of the previous tokens but there would be a tension between the state space and between the input token as I men mentioned a lot of the techniques that work well in Transformers will probably work well in this uh State space mindset as well Laura is very easily applicable anywhere you have large matrices uh represent them with lower rank forms and fine chose fine- tune those Mamba Moe the benefits of mixture of experts that come with Transformers potentially also could be applied uh to the Mamba architectures you would segregate um just running in parallel train and have a router that will decide which um series of Mamba models to go to and that should also improve the inference Time by allowing each of the Mamba models to be smaller than just one large model and that's it for this first video on the mamba or state space model I think that it's an interesting Direction because of how it allows you to compress the context the way that we currently attend to every previous token with a Transformer is fundamentally inefficient and something that I think will improve as we move forward now the Mamba architecture probably has quite a bit of room for improvement from where it is so I expect things to evolve and some of the tricks that we've learned from Transformers to be applied over to Mamba the largest Mamba right model right now is about 2.8 billion parameters pretty soon I would expect somebody is going to come out with one that's at least 7 billion in size and it'll be very interesting when we get up to having similar sizes say to Lama 70b how the performance Compares and what kind of speed UPS we get particularly with longer contexts I think it can change the mindset as well for fine-tuning because you can just run inference and save the state that's going to change the mindset of how people feel fine-tuning is important and add an extra tool to the toolkit of optimizing models for deployment let your comments live below cheers folks
Info
Channel: Trelis Research
Views: 2,831
Rating: undefined out of 5
Keywords: mamba, understanding mamba, understanding SSMs, understanding state space models, mamba tutorial, mamba tutorial colab, training mamba ssm, training mamba
Id: iskuX3Ak9Uk
Channel Id: undefined
Length: 27min 40sec (1660 seconds)
Published: Fri Feb 02 2024
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.