TV Tokyo

Who’s That Pokémon?! Building a Pokémon Identifier In Keras

It’s Clefairy! Wait-

Lincoln Wentz
9 min readNov 9, 2021

--

Project Overview

This project is an attempt to build a image classifier that can identify any first generation pokémon.

In the pokémon anime the main character — Ash — has a pokédex: a device capable of identifying and relaying exposition about any pokémon it is pointed at. The show first aired in 1997, which made just about any functional consumer pokédex at the time merely a prop that could say a few phrases. But in the year 2021, with all of our technology, it might actually be possible to make the pokédex a reality.

To do that, we need to build a model capable of identifying pokémon, and that is the focus of this article. We’ll be working with a dataset from Kaggle, containing 7,000 total images of almost every pokémon from the first generation. The model built will take in these images and output a classification which corresponds to the name of the pokémon in question.

Problem Statement

The model will take a single image of size (224 x 224) and output a single-class classification for that image. That classification should correspond to the name of a pokémon.The model will be trained on limited resources and will have only 30 epochs to train, so considerations for those limitations will have to be made as well.

Metrics

Model performance will be measured in two ways. Firstly, model performance will be measured by accuracy. Accuracy makes sense because we’re not trying to minimize or maximize any specific behaiviour (a situation where precision or recall would be more appropriate); we just want to measure the general error rate.

The second way model performance will be measured is more novel than practical, but I felt it was worth it. Once completed, the model will play “Who’s That Pokémon?!”. We will show the model the silhouettes of 9 pokémon and see how many it can correctly guess. We will also show the model the full color versions of those same pokémon for a more reasonable test of the model’s abilities.

Data Exploration & Visualization

Example from training data for Oddish

The data consists of 7,000 images of 149 different pokémon from the first generation of pokémon. These images are drawn from various sources including the anime itself, games, fanart, and even plushies, as shown above. This wide array of sources should help the model generalize better on unseen data.

The dataset does have a couple of complications that will need to be dealt with. First of all, there is no standard size for images in this dataset, so any preprocessing will also need to involve resizing in order to make the images compatible with the model. What’s more, the data is unbalanced. Some classes have more images than others, so this will also need to be accounted for.

Lastly, Nidoran and Nidoran are missing, and Alolan Sandlash is present. The former is not a typo, there are in fact two pokémon with the same name. They are the male and female versions of the same pokémon, and unfortunately I did not collect the data for them to add them to the dataset. The latter, Alolan Sandslash is a pokémon that was not in the first generation, so I just removed that class altogether.

Data Preprocessing

Firstly, before loading the data, the images were split into train and test partitions. As the data was unbalanced, the amount of images in each class in the train partition was limited by the class with the lowest amount of images. This left around half of the data in the train partition, while the remainder was donated to the test partition.

Data post-augmentation

To construct the train and test datasets, images were sampled with replacement from the train and test partitions. These images were then augmented such that their vertical and horizontal shift, zoom, and brightness were randomly altered. Some were also flipped horizontally. This added variance allows us to artificially extend our dataset, as well as allowing our model to generalize better on unseen data.

# Define ImageDataGenerator with custom augmentation parameters.
data_gen = keras.preprocessing.image.ImageDataGenerator(
horizontal_flip = True,
height_shift_range = 0.05,
width_shift_range = 0.05,
zoom_range = 0.2,
brightness_range = [0.9, 1.5],
rescale = 1./255.)
# Initialize flow from training data directory.
flow = image_data_generator.flow_from_directory('directory',
target_size = (224, 224))
# Initialize loop variables.
X = []
Y = []
size = 5000
while len(X) < size:
next_images, next_labels = flow.next()
X.extend(next_images)
Y.extend(next_labels)
# Cut off any extra unwanted images.
X = X[:size]
Y = Y[:size]
# Convert to numpy arrays for model compatibility.
X = np.array(X)
Y = np.array(Y)

The images were also resized to a standard size to ensure compatibility with the model. In total, there were 5000 images in the train set and 2000 images in the test set.

Implementation & Refinement

The first model I built was a simple convolutional neural network. The model consisted, in its entirety, of two convolutional layers, and two dense layers including the output layer. Once I verified that the model was functional, I began adding layers and tuning parameters, until the model reached the structure outlined below.

# Initialize model.
model = keras.models.Sequential()
# Convolutional Layers
model.add(keras.layers.Conv2D(32, (3, 3), activation = 'relu',
input_shape = (224, 224, 3)))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(32, (3, 3), activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(64, (3, 3), activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(64, (3, 3), activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(128, (3, 3), activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(128, (3, 3), activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
# Dense Layers
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(128, activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.Dense(128, activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
model.add(keras.layers.Dense(128, activation = 'relu'))
model.add(keras.layers.Dropout(0.1))
# Output Layer
model.add(keras.layers.Dense(149, activation = 'softmax'))
# Compile model.
model.compile(optimizer = keras.optimizers.Adam(0.0001),
loss = keras.losses.CategoricalCrossentropy(),
metrics = [keras.metrics.CategoricalAccuracy()])

This model achieved performance that was somewhat promising, but resource constraints meant that I couldn’t increase the complexity of the model by much more. This meant I had to use an alternative solution: transfer learning.

To build the transfer learning model, I simply loaded the pretrained version of VGG16 from the Keras applications module, removed the top layers, and replaced them with two dense layers which could produce the desired classifications. After a bit more fine tuning, including some alterations to the layers I added, and changes to the initial learning rate, this model achieved much better results.

# Model property shotrcuts.
l2 = keras.regularizers.l2(0.00001)
adam = keras.optimizers.Adam(0.0001)
# Initalize pretrained model.
base_model = keras.applications.VGG16(
weights = 'imagenet',
include_top = False,
input_shape = (224, 224, 3))
base_model.trainable = False

# Initialize model.
model = keras.models.Sequential()

# Base
model.add(base_model)
# Dense Layers
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(256, activation = 'relu'))
model.add(keras.layers.Dropout(dropout_rate))

# Output Layer
model.add(keras.layers.Dense(output_count, activation = 'softmax'))
# Compile model.
model.compile(optimizer = adam,
loss = 'categorical_crossentropy',
metrics = [keras.metrics.CategoricalAccuracy()])

Additionally, the models were each provided two callbacks. Callbacks allow you to change the functionality of a model’s fit method. In this case, the two callbacks I used allowed me to save the best model found during training, and reduce the learning rate whenever validation accuracy hadn’t improved for a certain number of epochs, or “plateaued”.

# Define callbacks# Save the model with the best validation accuracy to the provided 
# path
checkpoint = keras.callbacks.ModelCheckpoint(
'save_path',
monitor = 'val_categorical_accuracy',
mode = 'max',
save_best_only = True)
# Multiply the learning rate by 0.2 once the loss has stopped
# improving for 10 epochs.
learning_rate_reducer = keras.callbacks.ReduceLROnPlateau(
monitor = 'val_loss',
factor = 0.2,
patience = 10)
model.fit(X, Y, callbacks = [checkpoint, learning_rate_reducer])

This meant that not only could I avoid overfitting ruining a training session by saving the model with the best validation accuracy instead of the latest model, but the models could achieve better performance by being able to avoid “plateaus”, or points where there is more performance to be had, but the learning rate causes the optimizer to “overshoot” the minimum loss.

Model Evaluation and Validation

Scratch Built Model

The made-from-scratch model performed somewhat well given the constraints. Since it was only able to train for 30 epochs and model complexity was rather limited, the fact that it was able to achieve ~38% accuracy on unseen data is a rather good result. It also showed no signs of plateau-ing so this model could certainly achieve better performance with more time.

VGG16 Based Model

However, the transfer learning model performed much better. Not only did it achieve a much higher validation accuracy than the scratch model did, ~73%, but it managed to do it in a much shorter time frame, making it a viable solution for even more limited hardware than my own.

Now we get to the fun part: “Who’s That Pokémon?!”. Below are 9 silhouettes of 9 randomly selected pokémon, and it will be the models’ job to correctly identify all of them. We will be testing both the transfer learning model and the scratch model, just to make it more interesting. Remember that the models were saved at their peak validation accuracy, so these models should be the best availible from our training session. Feel free to play along at home! Here are the silhouettes:

And here we see the results. Our VGG16 based model managed to get 5 out of the 9 pokémon correct! Additionally, while it failed to identify Dugtrio, it did get close, outputting Diglett which is a former evolution of Dugtrio. However things did not pan out so smoothly for our scratch model, which did not manage to get even a single prediction correct. Now let’s look at how the models perform when they can see the pokémon in question:

Here we see that VGG16 performs even better than it did before, managing to identify 8 out of the 9 pokémon present. Our scratch model also improved, correctly identifying both Kingler and Rhydon. Not bad at all I’d say. I know this isn’t the most scientific test, but it is a good example of how changing the amount of visual information availible affects a model’s performance.

Justification

So, as we’ve seen that transfer learning is much better in this instance, but the question remains: why? Well the answer comes down to limited resources. A pretrained model is exactly what it says, meaning that much of the training has already been done for us.

If you break down image recognition into two steps: identifying patterns, (analagous to the convolutional layers) and understanding the meaning of those patterns, (analagous to the dense and output layers) then a model trained entirely from scratch would have to learn to identify patterns and learn their meaning.

Transfer learning allows us to have an understanding of patterns baked into our model from the start, and usually a much more comprehensive understanding than the one we could create on our own. That means that the model only has to understand what those patterns mean, reducing the amount of things our model needs to learn quite substantially. This reduction in training labor makes it easier to train these models on limited computing resources, and means that they train in much less time than they would otherwise be able to.

Reflection & Improvement

So, to recap: we imported and augmented image data depicting most of the first generation pokémon, used that data to train two models for 30 epochs each, one made from scratch, and one made using transfer learning with a pretrained version of VGG16, and we found that the transfer learning model performed much better than the scratch model, achieving an accuracy of ~73%. We also had some fun along the way!

Of course, models can always be improved, and there is certainly room for improvement here. Improvements include using class weights to make use of the full dataset for training, tuning the data augmentation parameters to introduce even more variance, increasing the complexity of the dense layers in the transfer learning model and improving the complexity of the scratch model. We could even explore other pretrained architectures to see if any of those perform better. We could also add more pokémon to the data from other generations, and implement the final classifier in a pokédex like app.

However, for now we’ve reached the end of our journey. Thank you for reading and be sure to check the sources to see the GitHub for this project and to get the dataset for yourself!

Sources

--

--

Lincoln Wentz
Lincoln Wentz

Written by Lincoln Wentz

0 Followers

Data Scientist in training and professional Star Wars historian.