Pages

Tuesday, December 13, 2022

Blue Bees, Twin Neural Networks, and more!

For the past few months I’ve been working on an algorithm to automate pollen species detection. This is part of a field of research called palynology, the study of pollen grains and other spores. It usually takes a lot of time and expertise for someone to identify what plant species pollen grains came from while scanning through a microscope slide. The question I’m trying to answer is whether we can automate this process (using machine learning) for samples of pollen collected by bees at the Bernard Field Station, a restored habitat across the street from Harvey Mudd's campus.

One approach is to use Convolutional Neural Networks (CNNs), a machine learning approach that has revolutionized the field of image recognition and classification. CNN based algorithms such as ResNet and YOLO can identify hundreds of objects in images, videos, and classify them with human-like accuracy. However, CNNs are not all powerful, their one fatal flaw is how much data they require. CNNs require enormous amounts of data to learn from (for example ResNet was trained on the 14 million+ images in ImageNet dataset). If you want to learn more about CNNs I would recommend this article. But at a high level they are just really really good at pattern recognition, and they need so much data because the pattern of “dog” or “cat” or “house” is a really complicated pattern (think about how many species/colors/styles there are for each of those).

Humans on the other hand only need a few examples to learn. Imagine I told you about a new species of bee you’d never heard about before: Xylocopa caerulea (a blue carpenter bee). After showing you only a few sample images, you would be able to identify it without trouble.

[1] Image of a Xylocopa caerulea (blue carpenter bee)


A CNN on the other hand would require tens of thousands of images of this bee before it could accurately identify it. It’s never seen a blue carpenter bee before, so it might think that the bee is a flower, or a blueberry, etc.

You might wonder, is there any way to create a neural network that can learn to recognize something without a lot of training data? The answer is yes, and the solution is Twin Neural Networks (also called Siamese Neural Networks). Twin Neural Networks (TNNs) are part of a field of machine learning called “Meta Learning.” In other words, training a network to learn how to learn. While a CNN learns to predict which class an image belongs to, a TNN learns whether two images are in the same class. For example, you could train a ResNet CNN with many images of bees and ants and give it a single picture, and it could tell you if there is a bee or an ant in that picture. The problem is that if you give the CNN an image of a type that it has never seen before (for example a wasp), it will still try to guess whether that picture of a wasp is more similar to a bee or an ant. Whereas you could give a ResNet TNN two pictures (one bee and one ant), and it would tell you that the two pictures don’t contain the same object.

In order to do this, the TNN has to learn which features are useful for differentiating classes (wings, hairiness, colored stripes, body shape, etc.). Because a TNN learns what makes a class unique, it can even tell whether two images it’s never seen before are in the same class! If you gave it a picture of a wasp, it wouldn’t be able to tell you that it's a wasp, but it would be able to tell you that it isn’t a bee or an ant!


[2] Basic structure of a TNN


A TNN has three main steps.
  1. It takes in a pair of images (either in the same or different class)
  2. It extracts the most important pieces of both images as a list of numbers (known as a feature vector)
  3. It compares distance (for now this is euclidean distance) between the images to determine if they are in the same class or not
For the rest of the post I’m going to focus on the big question here: how can a TNN determine what the most important aspects of an image are? In other words, how do we extract a feature vector from an image? Think of both the original image, and its feature vector as different ways of representing the same information (the content of the image). We basically want a machine that can convert from the image to the feature vector.

To do this we need to step back and think about how we represent images on computers. The most common way of representing an image is with a list of every pixel in the image and its corresponding color. This image format is really easy for computers to process and for humans to look at, but it takes up a lot of space (both physical storage space, and there are a ton of numbers, ex a 256px x 256px RGB image has 256*256*3 = 196,608 numbers). It's very easy for a human to interpret the content of a photo and describe it with words, but much harder for a computer because there are so many numbers. (As a side note: We can think of the numbers as similar to the signals of the millions of cells in a human retina. We humans can't understand an image in that format either. That's why we need our neural networks—our brains—to extract the useful features.)


Another way of storing an image is in a compressed format like JPG. This format takes up less physical space for storage, but is still hard for a computer to understand the contents of, and a bit of detail is lost due to compression. In order for a human to understand a JPG, it needs to first be decompressed back into a big list of pixels, and then displayed on screen.


A much simpler (and probably less useful way of representing images) would be just to take the average color in the image and ignore everything else. This is very easy for both human and computers to understand (and it’s easy to turn other image formats into this one). But you lose a lot of important information in this image format, and it’s very hard to go back to the original image.


The final way of representing images is by extracting the important features. This is something that CNNs are excellent at. Take this photo of a face:


[3] Face generated using a VAE


Using CNNs, we can create an algorithm that can extract the person’s skin color, hair color, eye color, how much they are smiling, which direction they are looking, etc. and represent all of these featues as numbers. This is not a particularly useful way of representing images for people, unless we know exactly what each number represents and have examples, but it is great for computers. It's even possible to take a feature vector and turn it back into an image!


In all the image formats I mentioned above, notice how there is a way to take a list of pixel colors and convert it to that format? Let's call this an “encoder”, since it encodes a raw image into the new format. And there is also a way to go from that image format back to a list of pixels. Let's call this the “decoder” since it decodes that more compressed image format back into a raw image.

Think of this as storing an image in a compressed format like JPG. This takes up less physical space for storage, but is still hard for a computer to understand the contents of, and a bit of detail is lost due to compression. In order for a human to understand a JPG, it needs to first be decompressed back into a big list of pixels, and then displayed on screen.


Now imagine that both the encoder and decoder are a neural network. The cool thing about neural networks is that they can basically learn anything (how well they learn it depends on how much data you have, how complicated it is, and a bunch of other things, but they will try their best to learn whatever you ask them to). When you have an encoder and a decoder both as neural networks, it's not actually a TNN, but rather a Variational Autoencoder (VAE), which is a whole other field of research, but they are useful to think about here.


[4] Diagram of Variational Autoencoder network


In this image, the “latent space representation” is kind of like the compressed image format. Remember what I said before about how CNNs can extract features from images and how neural networks will try to learn whatever we want them to? Well, we can tell a CNN to extract exactly 10, or 20, or 128 features from an image (this is our latent space representation), and then train another network (our decoder) to take these features and turn them back into an image!


The really cool thing here is that because our latent space representation only has a few numbers (as opposed to the hundreds of thousands in a normal RGB image), each of those numbers needs to be really important. In the example of faces, those numbers end up representing physical things like skin/hair color, facial expression, etc. Since the neural network is learning all this stuff on its own, we don’t get to control which features it decides are important. The network decides on its own which features it wants to use to make it as easy as possible to go from the feature vector back to a full image.

[6] Change in generated faces as "smile" component is modulated

0.0 smile on left transitions to 1.0 smile on right


This is a real world example where one of the items in the feature vector represented how much the person was smiling. By leaving every other feature the same, and just changing that one, we can create a bunch of images where the only difference is how much the person is smiling.


To reinforce the example of a feature vector, let's think about pollen which is what my research is trying to classify, and go back to TNNs. Remember that a TNN is just the “encoder” section of a VAE.

[7] Various pollen species with hand labeled attributes

This is an example of what a TNN might decide important features are. Remember again that it gets to learn or “pick” which ones are important (and it naturally picks good ones because that's the only way for it to accurately tell the difference between images).


In order to figure whether two images contain the same object or not, a TNN takes two images, turns them into feature vectors, and then compares the euclidean distance between the two feature vectors to see how similar the images are.


[8]


That isn’t quite enough to tell what class an image is though. In order to do that we need something to compare to. Imagine that we have a picture of pollen from an unknown species, and a bunch of pictures of pollen from known species. Then we can generate feature vectors for all our images, known and unknown, and see which one the unknown image is closest to, then we can assume that the unknown image is in that class, or from that species of plant.


There are a bunch of other really cool algorithms worth checking out, as well as variations of TNNs. You might wonder how we pick which images to compare our unknown image to, or how we calculate the distance between two images (Hint: you can do a lot better than just euclidean distance!) All of these questions have super interesting answers and I would encourage you to read more using the links below.


Further Reading:


Media Credits:
[2] Diagram created by Kavi Dey
[3] White, T. "Sampling generative networks: notes on a few effective techniques CoRR (2016)." arXiv preprint arXiv:1609.04468.
[5] White, T. "Sampling generative networks: notes on a few effective techniques CoRR (2016)." arXiv preprint arXiv:1609.04468.
[6] Diagram created by Kavi Dey
[7] Diagram created by Kavi Dey

No comments:

Post a Comment