Creating a One-Hot Encoding in PyTorch

This article explains how to create a one-hot encoding of categorical values using PyTorch library. The idea of this post is inspired by “Deep Learning with PyTorch” by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Sooner or later every data scientist does meet categorical values in one’s dataset. For example, the size of a t-shirt (small (S), medium (M), large (L), and extra large (XL)) has four categorical values. Therefore, the problem formulation of this post will be

How do we encode these categorical values before we feed them into Machine Learning algorithms?

Suppose that we have installed PyTorch in our machine and as an example, we use the train set of Titanic dataset.

img1

The RMS Titanic sank in the North Atlantic Ocean in the early morning hours of 15 April 1912. Image taken from Greenscene, some rights reserved.

Let’s get started by reading the train set.

from pathlib import Path
import pandas as pd
import numpy as np
import torch
torch.set_printoptions(edgeitems=2, precision=2, linewidth=75)

# This is the path to our train set. You might modify it accordingly.
TITANIC_DATASET = Path( '/home/hbunyamin/Perkuliahan/Pembelajaran-Mesin-Maranatha/projects/UTS/Titanic' ) 
titanic_df = pd.read_csv( TITANIC_DATASET / 'train.csv' ) 


Next, we show several rows from the dataframe.

titanic_df.head()

Out: img1
We also show the statistics of the Titanic train as follows:

titanic_df.describe()

Out: img1
We are interested in making the passenger classes (Pclass) column into a one-hot encoding. Let’s show each value and its frequency inside Pclass column.

titanic_df['Pclass'].value_counts()

Out:

3    491
1    216
2    184
Name: Pclass, dtype: int64

The number of occurrences in the dataset for value 3, 1, and 2 are 491, 216, and 184 respectively.

Next, we convert 1, 2 , and 3 into a one-hot encoding. Since indices in PyTorch starts from 0and the values of Pclass column start from 1, we need to make an adjustment. Let’s subtract 1 from each value in Pclass column and view the values.

pclass = titanic_df['Pclass'] - 1
pclass.unique()

Out:

array([2, 0, 1])

Now the values in pclass consist of 0, 1, and 2.
Subsequently, we convert the pclass into a tensor.

pclass = torch.tensor(pclass)
pclass.shape

Out:

torch.Size([891])


Now we are ready to convert

\[\begin{equation} \texttt{0} \Rightarrow \begin{bmatrix} 1 \\ 0 \\ 0 \\ \end{bmatrix}, \; \texttt{1} \Rightarrow \begin{bmatrix} 0 \\ 1 \\ 0 \\ \end{bmatrix}, \text{ and } \texttt{2} \Rightarrow \begin{bmatrix} 0 \\ 0 \\ 1 \\ \end{bmatrix}. \end{equation}\]

We initialize the one-hot encoding with a zero matrix with dimension: $891 \times 3$.

pclass_onehot = torch.zeros(pclass.shape[0], 3)
pclass_onehot.shape

Out:

torch.Size([891, 3])


Next, we call scatter_ method. The underscore after the method name means that the method will not return a new tensor; instead, it will modify the tensor in place.

pclass_onehot.scatter_(1, pclass.unsqueeze(1), 1.0)

Out:

tensor([[0., 0., 1.],
        [1., 0., 0.],
        ...,
        [1., 0., 0.],
        [0., 0., 1.]])

The second argument (pclass.unsqueeze(1)) adds a new dimension to tensor pclass. Therefore, the dimension of pclass changes from torch.Size([891]) to torch.Size([891, 1]).
The first argument (1) states that the axis 1 (column) of pclass that will be expanded from 1 into 3. We need to make sure that the column size of pclass_onehot is the same as the number of unique values in pclass with each value represents a column index. Please write a comment if there is anything unclear about this explanation.

We conclude the post by showing that indeed our conversion works well.
This is the pclass.

pclass[:10]

Out:

tensor([2, 0, 2, 0, 2, 2, 0, 2, 2, 1])

Next, this is the one-hot encoding of pclass.

pclass_onehot[:10]

Out:

tensor([[0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.]])

Written on August 22, 2020