Build SearchTransfer Pytorch

SearchTransfer is derived from the paper Learning Texture Transformer Network for Image super-resolution

[paper] [code]

The main idea is like self-attention, but it is a batch-wise matrix x to calculate (B, HW, C) and (B, C, HW) using a multi sign. In this article, instead of viewing the input directly into (B, HW, C), we expand the input into (B, number of pixels per block, num_blocks) in the way of convolution sliding window, and then do hard-attention.

This article has documented recreating some of the usages encountered in the Transformer Module

The key function

  • torch.nn.functional.unfold
  • torch.nn.functional.fold
  • torch.expand
  • torch.gather

Unfold makes it easy to do attention between blocks, and then use the resulting similarity graph to calculate the index to extract information from ref_UNFOLD. Finally, restore the ref_UNFOLD using folds

1. unfold

Unfold divides the input into blocks using the same sliding window as Nn.conv2D

import torch
import torch.nn.functional as F

x = torch.rand((1.3.5.5))
x_unfold = F.unfold(x, kernel_size=3, padding=1, stride=1)
print(x.shape)	# torch.Size([1, 3, 5, 5])
print(x_unfold.shape)	# torch.Size([1, 27, 25])
Copy the code

The x shape is (Batch, channel, H, W), and the X_UNFOLD shape is (Batch, K x K x channel, number_blocks).

K is kernel_size, and k x k x channel represents the number of pixels in a block

Number_blocks is how many blocks you can slide out of given kernel_size, padding, stride

2. fold

As opposed to unfold, a fold is a reversion of blocks (batch, channel, H, W)

k = 6
s = 2
p = (k - s) // 2
H, W = 100.100

x = torch.rand((1.3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
print(x_unfold.shape)	# torch.Size([1, 108, 2500])
print(x_fold.shape)		# torch.Size([1, 3, 10, 10])
print(x.mean())			# tensor (0.5012)
print(x_fold.mean())	# tensor (4.3924)
Copy the code

As you can see, although the shape is restored, the value range of X and X_fold has changed. This is because a single position (1x1xchannel) can appear in multiple blocks during unfold. Therefore, these overlapping positions are summed up during fold, resulting in inconsistent data. So once you get x_fold you have to divide by the overlap number to get the original range. When k=6, s=2, a position will appear in 3*3=9 blocks (the window slides up, down, left, and right).

x = torch.rand((1.3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p) / (3.*3.)
print(x_unfold.shape)
print(x_fold.shape)
print(x.mean())			# tensor (0.4998)
print(x_fold.mean())	# tensor (0.4866)
print((x[:, :, 30:40.30:40] == x_fold[:, :, 30:40.30:40]).sum()) # tensor(189)
Copy the code

It can be seen from sum() that only part of the data is restored. Another way to accurately calculate the Divisor (e.g. 3.x 3.) is to use torch. Ones as input.

k = 5
s = 3
p = (k - s) // 2
H, W = 100.100

x = torch.rand((1.3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

ones = torch.ones((1.3, H, W))
ones_unfold = F.unfold(ones, kernel_size=k, stride=s, padding=p)
ones_fold = F.fold(ones_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

x_fold = x_fold / ones_fold
print(x.mean())			# tensor (0.5001)
print(x_fold.mean())	# tensor (0.5001)
print((x == x_fold).sum())	# tensor(30,000) every point has been reduced
Copy the code

3. expand

Tensor. Expand (*size), you can use -1 for the dimension that stays the same

x = torch.rand((1.4))	# x = torch. Rand (4) can get the same result
x_expand1 = x.expand((3.4))
x_expand2 = x.expand((3, -1))

print(x)
# tensor([[0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand1)
# tensor ([[0.1745, 0.2331, 0.5449, 0.1914].
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand2)
# tensor ([[0.1745, 0.2331, 0.5449, 0.1914].
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])
Copy the code

4. gather

Torch. Gather (input, dim, index, *, sparse_grad=False, out=None

for i in range(dim0):
    for j in range(dim1):
        for k in range(dim2):
            out[i, j, k] = input[index[i][j][k], j, k]  # if dim == 0
			out[i, j, k] = input[i, index[i][j][k], k]  # if dim == 1
			out[i, j, k] = input[i, j, index[i][j][k]]  # if dim == 2
Copy the code

To use Gather, first set the size of the index to equal the size of the input with expand.

Shape == [B, blocks], expand shape to [B, c x c x k, blocks], then index[I, :, k] is a 1D tensor, And each element is equal to the index before expand [I, j].

In this way, index[I][j][k] does not change when j changes, Out [I, j, k] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k] = input[I]

5. Build Features Transfer

import torch
import torch.nn as nn
import torch.nn.functional as F


class Transfer(nn.Module) :
    def __init__(self) :
        super(Transfer, self).__init__()

    def bis(self, unfold, dim, index) :
        "" Block index select ARgs: Unfold: [B, K * K *C, Hr*Wr] DIM: Which dimension is blocks index: [B, H*W], value range is [0, Hr*Wr-1] return: [B, k*k*C, H*W] """
        views = [unfold.size(0)] + [...1 if i == dim else 1 for i in range(1.len(unfold.size()))]  # [B, 1, -1(H*W)]
        expanse = list(unfold.size())
        expanse[0] = -1
        expanse[dim] = -1   # [-1, k*k*C, -1]
        index = index.view(views).expand(expanse)   # [B, H*W] -> [B, 1, H*W] -> [B, k*k*C, H*W]
        return torch.gather(unfold, dim, index)    # return[i][j][k] = unfold[i][j][index[i][j][k]]

    def forward(self, lrsr_lv3, refsr_lv3, ref_lv1, ref_lv2, ref_lv3) :
        """ args: lrsr_lv3: [B, C, H, W] refsr_lv3: [B, C, Hr, Wr] ref_lv1: [B, C, Hr*4, Wr*4] ref_lv2: [B, C, Hr*2, Wr*2] ref_lv3: [B, C, Hr, Wr] """
        H, W = lrsr_lv3.size()[-2:]

        lrsr_lv3_unfold = F.unfold(lrsr_lv3, kernel_size=3, padding=1, stride=1)    # [B, k*k*C, H*W]
        refsr_lv3_unfold = F.unfold(refsr_lv3, kernel_size=3, padding=1, stride=1).transpose(1.2)  # [B, Hr*Wr, k*k*C]

        lrsr_lv3_unfold = F.normalize(lrsr_lv3_unfold, dim=1)
        refsr_lv3_unfold = F.normalize(refsr_lv3_unfold, dim=2)

        R = torch.bmm(refsr_lv3_unfold, lrsr_lv3_unfold)  # [B, Hr*Wr, H*W]
        score, index = torch.max(R, dim=1)  # [B, H*W]

        ref_lv3_unfold = F.unfold(ref_lv3, kernel_size=3, padding=1, stride=1)      # vgg19
        ref_lv2_unfold = F.unfold(ref_lv2, kernel_size=6, padding=2, stride=2)      # lv1-> Lv2, lv2-> Lv3 have a Max pooling
        ref_lv1_unfold = F.unfold(ref_lv1, kernel_size=12, padding=4, stride=4)     # kernel_size is not calculated according to the actual receptive field

        # Dividend, record fold(Unfold) overlap
        divisor_lv3 = F.unfold(torch.ones_like(ref_lv3), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.unfold(torch.ones_like(ref_lv2), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.unfold(torch.ones_like(ref_lv1), kernel_size=12, padding=4, stride=4)

        T_lv3_unfold = self.bis(ref_lv3_unfold, 2, index)   # [B, k*k*C, H*W]
        T_lv2_unfold = self.bis(ref_lv2_unfold, 2, index)
        T_lv1_unfold = self.bis(ref_lv1_unfold, 2, index)

        divisor_lv3 = self.bis(divisor_lv3, 2, index)  # [B, k*k*C, H*W]
        divisor_lv2 = self.bis(divisor_lv2, 2, index)
        divisor_lv1 = self.bis(divisor_lv1, 2, index)

        divisor_lv3 = F.fold(divisor_lv3, (H, W), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.fold(divisor_lv2, (2*H, 2*W), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.fold(divisor_lv1, (4*H, 4*W), kernel_size=12, padding=4, stride=4)

        T_lv3 = F.fold(T_lv3_unfold, (H, W), kernel_size=3, padding=1, stride=1) / divisor_lv3
        T_lv2 = F.fold(T_lv2_unfold, (2*H, 2*W), kernel_size=6, padding=2, stride=2) / divisor_lv2
        T_lv1 = F.fold(T_lv1_unfold, (4*H, 4*W), kernel_size=12, padding=4, stride=4) / divisor_lv1

        score = score.view(lrsr_lv3.size(0), 1, H, W)   # [B, 1, H, W]

        return score, T_lv1, T_lv2, T_lv3
Copy the code

** To use gather, first set the size of the index to equal the size of the input.

Shape == [B, blocks], expand shape to [B, c x c x k, blocks], then index[I, :, k] is a 1D tensor, And each element is equal to the index before expand [I, j].

In this way, index[I][j][k] does not change when j changes, Out [I, j, k] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k] = input[I]

reference

Pytorch.org/docs/stable…

Github.com/researchmm/…