How to Run PyTorch Models in the Browser With ONNX.js

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
what is up guys Eliot wait here and in today's video I want to show you how you can take a model you've trained in pi torch and use it directly in the browser using JavaScript like this so here's an example of something I built and it's basically the amnesty predicting what number you draw so you can see how quickly it's guessing each time the picture updates it will make a new prediction so that's one of the first advantages of having a model in the browser you avoid the latency of having to send the data to the server and back so he can get this really responsive app with really quick predictions so the second benefit is that it's really easy to set up your website you can do it with just static files and that makes it really easy to scale if you have a ton of users you only have to serve those static files and you don't have to start up new nodes of your back-end server and handle load balancing as you get more and more requests the third benefit of running your model in the browser is that it will work offline and this is because you're not communicating with a back-end server so all the code you need to calculate the predictions to run your input through the model is right there already loaded the first time you load the website or if you're using the JavaScript files in an app as long as they've already installed the app then they'll be able to use it offline and the fourth main benefit is privacy since you don't have to send the user data off to a server you can keep it right there on the local device and you can let the users know that if you're handling sensitive data you're not going to send it to a server no one could potentially intercept it and depending on what kind of data you're processing that could be a really nice feature so with all those benefits there are only a few reasons you probably wouldn't want to run a model in the browser and the first one is if you have a really big model if it takes a long time to initially download the first page or download your app that probably won't work and you'll just have to keep it on the back end and send data to and from that server to run your predictions so the second reason you wouldn't want to run a model in the browser is that the computation may take too much time so even though the model may be small enough to download in a reasonable time it might still take too long to run predictions on the user's device especially since devices can range in performance and you may want the app to work well across all devices so in that case you may get better overall performance by hosting the model on a back-end server and the third reason you may not want to run your model in the browser is that you want to keep the model private so if you're gonna run the model in the browser they're going to be able to potentially inspect the graph inspect the parameters and use that model themselves so if you want to keep it as your company's trade secret then you probably don't want to host your model in a browser so that's the other side of privacy in the browser you get user privacy don't have to send the user data but if you host it on a back end you get model privacy so now I'm going to show you how to actually convert your PI torch models into JavaScript and we're going to start with the PI torch example amnesty code and M nest is just a model for recognizing handwritten digits so I'm not going to go over all the code in the script you can check it out on github I'll link to it in the video description but we'll go over a high-level overview of what the model is doing so we'll start off with our image which is 28 pixels wide by 28 pixels tall and since it's in black and white we only have one channel depth of color and we'll do a convolution on this image to get 32 channels in the second tensor and we'll do a ray lu activation on this tensor and then we'll do another convolution to get 64 channels well then do a max pool on this tensor cutting the height and width by half will then perform drop out on this tensor of 25% and then flatten it down and do a fully connected matrix multiply to get this intermediate vector of 128 values and then again we'll do a ray Lueck tavae shun on it another drop out of 50% this time we'll then do another fully connected matrix multiply to get our final ten values that will run us off max on to get our 10 predictions for the number zero through nine for what we think the input image was so now to the code all right first thing we're gonna change is where the data is imported just it'll be in the same folder and then I'm gonna make it so it'll always save the model at the end and I'm just gonna call it PI torch model so now that we have our train model we have to convert it into something that can be used in JavaScript so there are two ways to do this we can go for using tensorflow j/s or onyx Jas now the onyx gs1 is simpler but it only allows you to do inference so if you want to do training in the browser I'd suggest going with tensorflow Jas but if you're only doing inference if you're only running data through the model then onyx Jas is simpler and I think it is more performant so onyx is the yellow bluest tensorflow jeaious orange is Karis Jes and the height is the time to run inference and the left side is on a CPU and the right side is on a GPU and their benchmark is the resonant 50 model so what is onyx it stands for open your own network exchange and the short description on their website is actually pretty good says onyx is an open format built to represent machine learning models onyx defines a common set of operators building blocks of machine learning and deep learning models in a common file format to enable AI developers to use models with a variety of frameworks tools runtimes and compilers so they basically define a format in which you can specify your model and the parameters used in your model so that you can transfer it around so if we save our PI torch model in this onyx file format we can then load that file into on XJS and onyx Jas is just a JavaScript library for loading these onyx models and running inference on them and it can be run anywhere where JavaScript can be used so in a browser using nodejs or in an app using a webview so here's the code for converting our PI torch model to an onyx model and it's actually pretty simple so first we're going to import our neural network which we defined in our training script which is here so this is just a PI torch module and then we're going to instantiate that model will then set the parameters of that model to the parameters that we got when we finished training the model using the load state DIC method will then put the model in evaluation mode which is what you want to do when you run inference on the model and in this case it will essentially bypass those dropout operations will then create a dummy input which is just a tensor that has the same shape as the input to your model would expect and it can have any values and in this case we're just going to use zeros will then call the torch Onix export method which expects a instance of the model a dummy input tensor the path to where you want to save the output model and an optional verbose argument which if true will just print out a string representation of the graph of your model as you'll see in a second so here we'll run this and it creates our Onix model and here's the output string which is a description of the graph of your model and it kind of shows you the format that Onix uses so I create this little demo script to test out using our model just to make sure everything works before using it in a bigger application and it's just an HTML document with some JavaScript inside and we start off by importing the Onix JS library and that's going to give us access to these Onix variables and then we're going to declare this test function and we'll call it right after we declare it but we want it to be an asynchronous function because we're going to be using these await operators so we'll start off by creating a session using this new Onix inference session and I'm not going to go over all the details of how the API works you can check it out online for all the specifics but I'll just walk you through this demo to give you an idea of how it works so first we create a session and next we're going to load the model into the session and that's going to be done by specifying the path to the model and this will return a promise that will resolve once the model has been loaded so that's why we use the await operator and then we're going to create the input that will pass through the model and we do that with the Onyx dot tensor class we're gonna specify a float32 array of this size of this type and of this shape which our model is expecting we're going to pass it through the model and that's going to return a promise that will resolve once the model has calculated the output so we use the await operator here and they'll give us an output map and then the reason this is an array is we could have multiple inputs here if you look back at our forward pass we're only passing in a single value for some reason we're passing in multiple values we could do that here oh well so this value will be a read-only map and again if we had multiple outputs we could access them through this map but since we only have one output we'll just access the first value and we'll do that by calling dot values next and got value to get the first output that'll be a tensor and then to access the underlying data we just call output tensor data and that will return a JavaScript float32 array and we'll log that to the council now if we run this in a browser we'll see we get an error cannot resolve operator log softmax what this error message is saying is that on XJS doesn't currently support the log softmax operator and if we go to this Onix j/s ducks operators page inside their github repo we can see the operators that are supported and if we go down to log softmax will see that it currently is not supported by any versions and if we look back at our model we can see that at the very end of our forward pass we call log softmax right here but to actually get the probabilities we want we don't need to call log softmax we actually only need to call softmax and as we can see on XJS does support the softmax operator so we can just update this model and create a new onyx model from it and try that out so I'm not gonna do that in this file I'm going to keep this file as is I'm gonna create a new Python file which I'll call inference emne smaadahl and here we'll update the model to be softmax and then in our convert to onyx script we can change this to our inference model and since changing this operator doesn't change any of the modules which require parameters we can still use our old train model and just reload it into this new model so if we rerun the script we'll get our new onyx model we can load it into here overwrite the old version come back to this page reload and now we see our output tensor and those are ten probabilities what we want just based off of an input of all zeros now to be able to test out actually sending handwritten digits through our model I've created this demo website and it's just this canvas where you can draw numbers and it'll print out the predictions down here right now it's not fully working because I'll show you something that we have in the code that we have to fix so I'm not going to go over all the code there'll be a link in the description to this code sandbox and you can look over it but at a high level we just have all these HTML elements and then we're gonna load in our onyx Jas library and call our script file and our script file will do some things to setup the canvas and set up the drawing but what we're going to look at is the same code we used before where we create a session and then we load the model into that session and this will return a promise that will just resolve once the models actually been loaded so lower down here in our code where we actually do our predictions we're going to make sure the model has loaded before trying to run any data through but up here is where we actually create the input tensor and you can look at all the code around it it's basically saying anytime the drawing is updated by adding a new line we're gonna update our predictions so we come into here we're gonna get our data from the canvas and basically we're gonna get this entire canvas which is two hundred and eighty pixels wide by two hundred and eighty pixels tall and then we're gonna try to run that through our model now you could try to use JavaScript to convert that into the input format that we're looking for which is a 28 by 28 pixel image but there's actually another way to do it that I think is simpler and that is to just change our model again our inference model so right here the input it expects is a shape of 1 1 28 28 but the actual shape we're gonna get from the image data is just going to be a long list of all the numbers that make up this image we have 280 pixels by 280 pixels and for each pixel we have 4 channels a red green blue and an alpha Channel so we're gonna have to update our code to handle that so if we look back at the model I've added in this code so this input X value is just going to be the list of numbers of the image data so we're first just going to reshape it into this shape 280 by 280 by 4 so that's the height the width and then the number of channels for red green blue and then the Alpha Channel and whenever you draw on a canvas wherever you draw will get an alpha value of 255 and wherever you haven't drawn will have an alpha value of 0 so we can just use that value to figure out where we've drawn so we're going to extract the fourth channel which has an index value of 3 and then in this next line we're going to reshape this tensor into the shape that PI torch expects for images and that's this first value is the batch size the number of channels and then the height and the width of the image so we're then going to call the average pool operator which will essentially resize this down by a factor of ten so we'll go from 280 by 280 to 28 by 28 which is the size of the input that our model expects and then since our max alpha channel value can be 255 we're going to divide that tensor by 255 and that will bring our image values into the range of 0 to 1 but there's actually one more thing we have to do so if we look back at the script that we use to train the model I didn't mention this earlier but when we create the data set we actually normalize it before passing it in through the model which means we take the initial values and then subtract the mean of the data set which is this value here and then divide it by the standard deviation of the data set which is this number here so we're actually going to have to do that with our data that we pass into the model as well so if we look back at the model code now I've copied over those values the mean and standard deviation and then just added in this last line where we subtract the mean and then divide by the standard deviation and now we just need to recreate the Onix file with this new model format come over here we run our script but we get an error says shape 280 by 280 by 4 is invalid for an input size of 784 so the reason we got this error message is because we need to update our dummy input it still is the shape of the original data that we trained on and we want to change it to the shape of the canvas data which is 280 by 280 by 4 and it's all in one long list so it's just going to be a single dimension and now it works so we got our new Onix model and we'll upload that into our demo code so I just tried out the model on the demo and it wasn't working and I was getting some airs with the onyx GS so I had to come back into the model and try a couple different things and what I found actually fixed it was changing these two lines originally I was taking a slice like this and when I changed it to taking a slice using narrow that solved an issue and then also on the average pool I had to specify this stride so if you look at the PI torch documentation if stratas left as none it should default to the kernel size but sometimes the conversion from PI torch to on XJS can be a little finicky and so sometimes you just got a tried debugging your model by creating a little small test like this and then incrementally changing things and seeing when it breaks and seeing if there's any alternative way of having the same operation be performed in your model using a different function and now with this updated model if we look back at our demo code it's now working we get numbers up that's one but it's thought it was a six so that was a four all right so what's going on so it's a little okay it's doing something but for some reason it's not getting the best results so nine but the last one thought it was a four so we're gonna do something to improve the model a little bit by using data augmentation and basically what that means is we're going to take the training data and shift it around a bunch before setting it through the model so that it'll be a bit more robust to the input looking a little different than what it's expecting so if we look at the data in the EM nest data set we can see that these numbers are pretty varied but they're all mostly in the center and the full height and we're going to start by augmenting it by adding in some rotations so we can see they got a bit more rotated and then we're gonna see what happens if we add in a translation so that translated it vertically could also translate it both horizontally and vertically which looks like this and then we could also experiment with changing the scale so that will give us some big numbers some small numbers and we can also apply some shear transformations which looks like this and then if we combine all those together and the one we end up getting some pretty tough samples for it to learn from but if the model can learn to recognize these digits it will be pretty robust and should work better on our demo so I've added those transformations to our training script right here in defining our training data and now we're going to retrain the model from scratch all right guys I've uploaded the new model and when we test it out let's see is it more robust well got that got the eight nine got the one zero see if we even put a little zero down here a little less confute a little more confused but it's still doing well two thinks it's a three to or help thinks that's a 7 ok so it could be better but we're just using the demo code from the PI torch examples and adding some data augmentation to make it a bit more robust so I'll put a link to the sandbox in the description so you can try it out yourself and I'll also put links to the other code in this video and that's it for this one I'll see you guys next time [Music]
Info
Channel: Elliot Waite
Views: 28,019
Rating: undefined out of 5
Keywords: run pytorch in the browser, convert pytorch to javascript, pytorch to onnx, pytorch to onnx.js, onnx.js, onnxjs, pytorch to javascript, convert pytorch to onnx, run pytorch in an app, machine learning in the browser, neural network in the browser, learn pytorch, run pytorch in chrome, run pytorch model in the browser, convert pytorch model to javascript model, convert pytorch model to onnx model, learn onnx.js, machine learning in javascript, Elliot Waite, elliotwaite
Id: Vs730jsRgO8
Channel Id: undefined
Length: 22min 21sec (1341 seconds)
Published: Thu Feb 13 2020
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.