PyTorch performs one-hot encoding of category tensors

This article has been authorized to ji Shi platform, and first published in ji Shi platform public account. Shall not be reproduced twice without permission.

  • The original document: www.yuque.com/lart/ugkv9f…
  • Code repository: github.com/lartpang/Co…

preface

One-hot coding is very common in deep learning tasks, but it is not a natural way to store data. So most of the time we have to do this manually. Although the idea is straightforward, that is, to divide categories into one-to-one corresponding 0-1 vectors, but the implementation does require some thinking. In fact pyTorch itself already provides the one_hot method in nn.functional for quick use. But this does not affect our thinking and practice :>! Therefore, this article will try to implement one-Hot coding based on the common methods in PyTorch, hopefully useful.

The main ways are as follows:

  • forcycle
  • scatter
  • index_select

forcycle

This method is very straightforward, in plain English, a blank (all zero) tensor in the specified position (assign 1) operation. The key is how to set the index. Here are two scenarios that are essentially the same but slightly different because they specify different dimensions.

def bhw_to_onehot_by_for(bhw_tensor: torch.Tensor, num_classes: int) :
    """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = bhw_tensor.new_zeros(size=(num_classes, *bhw_tensor.shape))
    for i in range(num_classes):
        one_hot[i, bhw_tensor == i] = 1
    one_hot = one_hot.permute(1.2.3.0)
    return one_hot


def bhw_to_onehot_by_for_V1(bhw_tensor: torch.Tensor, num_classes: int) :
    """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = bhw_tensor.new_zeros(size=(*bhw_tensor.shape, num_classes))
    for i in range(num_classes):
        one_hot[..., i][bhw_tensor == i] = 1
    return one_hot
Copy the code

scatter

This method should be the common form for most of the neat one_hot notation on the web. And the main thing that it does is actually assign values to the places that you specify in the tensor.

Because it can use a specially constructed index matrix as an index, it is more flexible. Flexibility, of course, brings with it difficulties of understanding. The explanation provided in the official documentation is straightforward:

''' https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html * (int dim, Tensor index, Tensor src) * (int dim, Tensor index, Tensor src, *, str reduce) * (int dim, Tensor index, Number value) * (int dim, Tensor index, Number value, *, str reduce) '''

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
Copy the code

The documentation uses the in-place version, and it’s explained on the basis that the substitution value is SRC, which is tensor. In fact, our application is based on the in-place substitution version with the substitution value as a scalar floating-point value.

In the form above, we can see that by specifying the tensor index, we can put the values of SRC (I,j,k) into the position of the method caller (in this case self). The tensor is at (I,j,k) instead of dim at (I,j,k). The tensor needs to be at home with self and SRC. The scalar value 1 (SRC) is used instead of value (SRC). This fits well with the concept of one-hot. Because the formal meaning of one-hot itself is that for the i-th data, the i-th position is 1 and the rest of the positions are 0. So it’s very easy to construct one-hottensor by using scatter_ for total zero tensor, by putting a 1 at the position corresponding to the category number.

For our problem, index works well with the input tensor that has a category number in the form B,H,W. Based on this thinking, two different strategies can be conceived:

def bhw_to_onehot_by_scatter(bhw_tensor: torch.Tensor, num_classes: int) :
    """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1.1), value=1)
    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
    return one_hot


def bhw_to_onehot_by_scatter_V1(bhw_tensor: torch.Tensor, num_classes: int) :
    """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)
    return one_hot
Copy the code

The root of the difference between the two forms is the treatment of shape. The result is different use of Scatter.

For the first form, combining the B,H, and W dimensions has the benefit of making the indexing of channels (categories) intuitive.

    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1.1), value=1)
Copy the code

Here the category dimension is separated directly from the other dimensions and moved to the bottom. Specify the dimension with dim, so there is a correspondence like this:

zero_tensor[abc, index[abc][d]] = value  # d=0
Copy the code

In the second case, the first three dimensions are still retained and the category dimension is still moved to the last place.

    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)
Copy the code

The corresponding relationship is as follows:

zero_tensor[a,b,c, index[a][b][c][d]] = value # d=0
Copy the code

A similar approach is used in The PyTorch classification model library timm:

# https://github.com/rwightman/pytorch-image-models/blob/2c33ca6d8ce5d9257edf8cab5ab7ece81780aaf7/timm/data/mixup.py#L17-L 19
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda') :
    x = x.long().view(-1.1)
    return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
Copy the code

index_select

torch.index_select(input, dim, index, *, Out =None) → tensor-input (Tensor) -- the input Tensor. -Dim (int) -- the dimension in which we index - index (IntTensor Or LongTensor) -- the 1-D tensor containing the indices to indexCopy the code

This function, as its name suggests, uses indexes to select the tensor’s children of a given dimension.

To understand the motivation for this approach, you actually need to look at one-hot encoding in reverse, from the perspective of category tags.

The matrix encoded by the one-Hot sequence number corresponding to the original category sequence number arranged from small to large is an identity matrix. So each category corresponds to a specific column (or row) of the identity matrix. This requirement fits neatly into the functionality of Index_SELECT. So we can use it to implement one_HOT encoding by indexing specific columns or rows with category numbers. Here’s an example:

def bhw_to_onehot_by_index_select(bhw_tensor: torch.Tensor, num_classes: int) :
    """ Args: bhw_tensor: b,h,w num_classes: Returns: b,h,w,num_classes """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.eye(num_classes).index_select(dim=0, index=bhw_tensor.reshape(-1))
    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
    return one_hot
Copy the code

The performance comparison

The entire code is visible on GitHub.

The following shows the relative performance of the different methods (because the background is running, it may not be very accurate, I recommend you to test). As you can see, PyTorch’s built-in functions are not very efficient on the CPU, but they work well on the GPU. The interesting thing is that the index_select-based representation is very bright.

1.10.0 GeForce RTX 2080 Ti CPU (' bhw_to_oneHOT_BY_for ', 0.5411529541015625) ('bhw_to_onehot_by_for_V1', 0.4515676498413086) (0.0686192512512207) 'bhw_to_onehot_by_scatter' (' bhw_to_onehot_by_scatter_V1 ', 0.08529376983642578) (0.05156970024108887) 'bhw_to_onehot_by_index_select' (' F.o ne_hot ', 0.07366824150085449) GPU (' bhw_to_oneHOT_BY_for ', 0.005235433578491211) ('bhw_to_onehot_by_for_V1', 0.045584678649902344) (0.0025513172149658203) 'bhw_to_onehot_by_scatter' (' bhw_to_onehot_by_scatter_V1 ', (' bhw_to_oneHOT_by_index_SELECT ', 0.002012014389038086) ('F.o ', 0.0024051666259765625)Copy the code