Explaining the Segment Anything Model - Network architecture, Dataset, Training

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hello guys meet Sam meta's new segment anything model that can literally segment anything Sam is the world's first massive scaled promptable interactive Foundation image segmentation model you can upload a new image and prompt Sam with clicks and bounding boxes and the model is able to segment an arbitrary number of objects at multiple levels of granularity the whole object a part of it or even a sub part of it acquiring a data set that is large enough to train a foundation image segmentation model is a challenging task because even though we have a lot of images in the internet they are not labeled with segmentation masks this paper not only comes up with a robust Network architecture to create a truly interactive AI experience they open source the largest image segmentation data set in the history of time with 11 million images and 1 billion annotations by the way if you're finding me for the first time I am a new YouTuber and I am making videos on deep learning research and unity devlogs about using reinforcement learning to train competitive adversarial AI Behavior if you like the content on this channel please consider subscribing leave a like and share it within your community it helps my channel to grow way more than you can imagine you're magnificent so there's a lot to cover in this paper let's jump into the second episode of neural breakdown foreign first we have an image encoder that inputs an image at 1024x1024 and embeds it into a multi-channel feature map of 64 by 64 with 256 channels a prompt encoder then embeds all the prompts which can come in the form of multiple Point clicks bounding boxes and text into a list of vector embeddings the max decoder layer then combines the image embeddings and the prompt embeddings and outputs three segmentation masks these three masks capture segments at three nested levels of depth the whole the part of the subpart allowing to Output multiple masks basically lets the model output valid segments even when the input prompt is ambiguous like in case of the scissor the three masks the model can return can correspond to the whole scissor the two handles are just one handle for calculating the loss the authors first calculate the IOU or intersection over Union score for each of the predicted mask with respect to the ground through the mass mask that has the highest IOU score is then used to compute the loss they use a combination of two losses for training The Mask output the focal loss and the dice loss the dice loss is similar to the IOU loss in the sense that it also measures how well the prediction segment overlaps with the target segment the focal loss is a modification of the Cross entropy loss such that it assigns higher weights to examples that the model is performing poorly on and lower the weights of easier examples that the model is already doing well on Sam additionally also predicts the IOU score for each of the three masks this is simply trained using the mean Square errored loss between the prediction and the actual IOU calculate between the mask and the ground truth this is super useful because this allows Sam to Output its own confidence score which can be used to rank all the masks it predicts and also to determine if Sam is unconfident about all of its predicted masks okay now here is something really interesting about this architecture the image embeddings and the prompt embeddings are not conditional on each other and they come from separate and independent heads of the network all of their interactions basically only happen inside the mask decoder part of the network this means two things the image embedding Network here only gets to look at the image without any context of the prompt and so it has to generate an embedding that is generally rich and expressive enough to generate valid masks from any arbitrary prompts that may appear from the user input second since the image embeddings are going to be the same for each prompt we can pre-compute them at the start when a new image is loaded and cache it and just keep reusing it for all types of future user prompts the ability to cache in image embeddings is a game changer the authors use a very heavy Vision Transformer based image encoder containing hundreds of millions of parameters that they run on a cloud-based GPU server every time a new image is loaded while the rest of the architecture are all relatively smaller networks that need to run at interactive latency for every user input and they can just run in the browser and keep reusing the cached image embeddings smarter so imagine you already have a data set of images and their annotations how are you then going to train a prompt to build model that can output valid masks given user prompt the authors adopt an interactive iterative strategy given an image and segmentation Target the first randomly created prompt this could be a point prom which they simulate by randomly choosing a point close to the center of the ground truth mask or a bounding box prompt that they create by adding random Jitters to the corners of the ground truth masks bounding box next the feed forward the image and the random prompt through the network and output the three masks for the mask that has the highest IOU they extract its error region a new point is now sampled from this error region randomly and re-input it into the network as a new Point prompt if the new sample point is from a false positive region The Prompt is inputted as a background point and if it's from a false negative region it is inputted as a foreground point in the second round along with passing in this newly sampled Point prompt they also crucially pass the best mask outputted by the network in round one as a dense prompt and then predict a new mask for round two this Loop of sampling new points generating a new mask and feeding it back in in the next round as a dense prompt continues for a number of rounds as mentioned here each iteration guide the model towards better segmentations by constantly feeding in the previous outputs as a dense guiding signal to improve the next predicted mask in the interest of time I'm not going to dive deep into how the data set is generated but on a higher level this data generation process is divided into three stages where a team of hired professional annotators label segments with early versions of Sam in the final stage everything is fully automatic and Sam trains on like 300 000 images with 10.2 million segments and then proceeds to annotate a new data set of 11 million images and 1.1 billion masks to segment an entire image Sam was prompted to Output masks for a 32 by 32 regular grid of points and then only the stable and high confident masks were considered and the duplicate masks were removed using non-maximal suppression this auto-generated data set is now open sourced as the sa-1b data set and is now available for future research that is really really exciting now let me dive deeper into the model architecture and explain how each of these blocks work together first up the image encoder the input to the image encoder is a 1024x1024 image and it gets mapped to a 64x64 with 256 channels they used a pre-trained mask Auto encoder with vision Transformers here which were introduced back in the this paper at a really high level these networks are pre-trained such that the input image gets divided into multiple patches and then a lot of the patches say 70 percent are masked randomly the rest 30 percent of the unmasked patches are then passed through and encoder Vision Transformer to output embeddings for each unmasked batch the decoder then inputs the encoder's embeddings of the unmasked patches as well as the positional encodings of the masked patches and reconstructs the original image back marked Auto encoders are pretty crazy that they can reconstruct good images from just 30 percent of the original and it is one of the most robust and state-of-the-art techniques for deriving generalized image embeddings so next up let's talk about this mask encoder input this input mask also gets mapped to a 64 by 64 Dimensions with 256 Channel same as the image embedding through two striated convolutional layers and this output embedding is now added to the imagenberry before it gets passed into a decoder mask the model then encodes each of the sparse prompts for encoding points they train positional encodings for the clicked point and added with special embeddings for foreground and background points to encode if the prompted point was from a false negative error region or false positive error region as I have discussed earlier for encoding bounding boxes they use two prompt tokens the first is the sum of the positional encoding of the top left corner and a special embedding they train to represent the top left input and the second token is the same thing but for the bottom right corner for encoding text they use the text encoder of a pre-trained clip model the clip model was introduced by openai that are trained to input an image and one or more text captions and then output the confidence values for each image caption pair using a technique known as contrastive learning now it's a turn for the mask decoder first all prompt token embeddings are changed in a 2d Matrix and three new output token embeddings are added to them they are there to store the final representation of the entire prompt with respect to each of the three masks inside we use three types of attention first a self-attention layer is applied to the prompt embeddings which basically enriches each prompt embedding to be more context aware about all the other prompted buildings in the list second these self-aware prompt embeddings are then updated by doing cross attention with the image embeddings and third the image embedding is updated with the cross attention with the latest prompt embeddings the whole layer is repeated twice with the second layers input embeddings being the final context of our embeddings of the first layer every time the prompt embedding is changed whether through self-attention or through cross attention the original input prompt embedding is also added back to the updated embedding to introduce more inductive bias to the embeddings and make it never forget about the geometrical properties of the user's original prompt in Sam the authors use cross attention between the image and the prompt embeddings twice once treating the image embeddings as the query and once treating The Prompt embedding as the query what this basically achieves is that it manages the prompt embeddings and the image embeddings together by making them contextually aware about each other and compensates for the lack of communication between the two at the outer parts of the network in the image encode on the prompt encoder so after the two layers of all sorts of attention grabbing is done we arrive at a context of where image embedding and a context aware prompt embedding the image embedding is then upsampled from a 64 by 64 feature space to 512 by 512 feature space with 256 channels from the final prompt embedding we extract the embeddings for each mask from the indices corresponding to the three output tokens we had placed in the input these three output embeddings are passed through an MLP layer to map each of them to 256 Dimensions the same as the number of channels in the final image embedding finally we take a channel wise product between the image and each of these three prompt outputs and output the three final masks we also train a separate MLP layers for output tokens to predict the IOU score for each predicted Mask The Masks are trained as discussed earlier with a combination of focal and dice loss with respect to the best mask of the tree while the IOU output is trained with the mean square and a loss the authors also showed how good this model is on zero shot tasks such as single point mask evaluation on other publicly available data sets Edge detection object proposals instant segmentation and Sam seems to do a great job at each of them without needing to retrain the weights on these Downstream tasks that is amazing with that I'm going to end this video If you like this video don't forget to share it with others and Don don't forget to subscribe for my next one and thanks for watching till the end bye
Info
Channel: Neural Breakdown with AVB
Views: 14,642
Rating: undefined out of 5
Keywords:
Id: OhxJkqD1vuE
Channel Id: undefined
Length: 13min 2sec (782 seconds)
Published: Tue May 02 2023
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.