Build SearchTransfer Pytorch
SearchTransfer is derived from the paper Learning Texture Transformer Network for Image super-resolution
[paper] [code]
The main idea is like self-attention, but it is a batch-wise matrix x to calculate (B, HW, C) and (B, C, HW) using a multi sign. In this article, instead of viewing the input directly into (B, HW, C), we expand the input into (B, number of pixels per block, num_blocks) in the way of convolution sliding window, and then do hard-attention.
This article has documented recreating some of the usages encountered in the Transformer Module
The key function
torch.nn.functional.unfold
torch.nn.functional.fold
torch.expand
torch.gather
Unfold makes it easy to do attention between blocks, and then use the resulting similarity graph to calculate the index to extract information from ref_UNFOLD. Finally, restore the ref_UNFOLD using folds
1. unfold
Unfold divides the input into blocks using the same sliding window as Nn.conv2D
import torch
import torch.nn.functional as F
x = torch.rand((1.3.5.5))
x_unfold = F.unfold(x, kernel_size=3, padding=1, stride=1)
print(x.shape) # torch.Size([1, 3, 5, 5])
print(x_unfold.shape) # torch.Size([1, 27, 25])
Copy the code
The x shape is (Batch, channel, H, W), and the X_UNFOLD shape is (Batch, K x K x channel, number_blocks).
K is kernel_size, and k x k x channel represents the number of pixels in a block
Number_blocks is how many blocks you can slide out of given kernel_size, padding, stride
2. fold
As opposed to unfold, a fold is a reversion of blocks (batch, channel, H, W)
k = 6
s = 2
p = (k - s) // 2
H, W = 100.100
x = torch.rand((1.3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
print(x_unfold.shape) # torch.Size([1, 108, 2500])
print(x_fold.shape) # torch.Size([1, 3, 10, 10])
print(x.mean()) # tensor (0.5012)
print(x_fold.mean()) # tensor (4.3924)
Copy the code
As you can see, although the shape is restored, the value range of X and X_fold has changed. This is because a single position (1x1xchannel) can appear in multiple blocks during unfold. Therefore, these overlapping positions are summed up during fold, resulting in inconsistent data. So once you get x_fold you have to divide by the overlap number to get the original range. When k=6, s=2, a position will appear in 3*3=9 blocks (the window slides up, down, left, and right).
x = torch.rand((1.3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p) / (3.*3.)
print(x_unfold.shape)
print(x_fold.shape)
print(x.mean()) # tensor (0.4998)
print(x_fold.mean()) # tensor (0.4866)
print((x[:, :, 30:40.30:40] == x_fold[:, :, 30:40.30:40]).sum()) # tensor(189)
Copy the code
It can be seen from sum() that only part of the data is restored. Another way to accurately calculate the Divisor (e.g. 3.x 3.) is to use torch. Ones as input.
k = 5
s = 3
p = (k - s) // 2
H, W = 100.100
x = torch.rand((1.3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
ones = torch.ones((1.3, H, W))
ones_unfold = F.unfold(ones, kernel_size=k, stride=s, padding=p)
ones_fold = F.fold(ones_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
x_fold = x_fold / ones_fold
print(x.mean()) # tensor (0.5001)
print(x_fold.mean()) # tensor (0.5001)
print((x == x_fold).sum()) # tensor(30,000) every point has been reduced
Copy the code
3. expand
Tensor. Expand (*size), you can use -1 for the dimension that stays the same
x = torch.rand((1.4)) # x = torch. Rand (4) can get the same result
x_expand1 = x.expand((3.4))
x_expand2 = x.expand((3, -1))
print(x)
# tensor([[0.1745, 0.2331, 0.5449, 0.1914]])
print(x_expand1)
# tensor ([[0.1745, 0.2331, 0.5449, 0.1914].
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])
print(x_expand2)
# tensor ([[0.1745, 0.2331, 0.5449, 0.1914].
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])
Copy the code
4. gather
Torch. Gather (input, dim, index, *, sparse_grad=False, out=None
for i in range(dim0):
for j in range(dim1):
for k in range(dim2):
out[i, j, k] = input[index[i][j][k], j, k] # if dim == 0
out[i, j, k] = input[i, index[i][j][k], k] # if dim == 1
out[i, j, k] = input[i, j, index[i][j][k]] # if dim == 2
Copy the code
To use Gather, first set the size of the index to equal the size of the input with expand.
Shape == [B, blocks], expand shape to [B, c x c x k, blocks], then index[I, :, k] is a 1D tensor, And each element is equal to the index before expand [I, j].
In this way, index[I][j][k] does not change when j changes, Out [I, j, k] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k] = input[I]
5. Build Features Transfer
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transfer(nn.Module) :
def __init__(self) :
super(Transfer, self).__init__()
def bis(self, unfold, dim, index) :
"" Block index select ARgs: Unfold: [B, K * K *C, Hr*Wr] DIM: Which dimension is blocks index: [B, H*W], value range is [0, Hr*Wr-1] return: [B, k*k*C, H*W] """
views = [unfold.size(0)] + [...1 if i == dim else 1 for i in range(1.len(unfold.size()))] # [B, 1, -1(H*W)]
expanse = list(unfold.size())
expanse[0] = -1
expanse[dim] = -1 # [-1, k*k*C, -1]
index = index.view(views).expand(expanse) # [B, H*W] -> [B, 1, H*W] -> [B, k*k*C, H*W]
return torch.gather(unfold, dim, index) # return[i][j][k] = unfold[i][j][index[i][j][k]]
def forward(self, lrsr_lv3, refsr_lv3, ref_lv1, ref_lv2, ref_lv3) :
""" args: lrsr_lv3: [B, C, H, W] refsr_lv3: [B, C, Hr, Wr] ref_lv1: [B, C, Hr*4, Wr*4] ref_lv2: [B, C, Hr*2, Wr*2] ref_lv3: [B, C, Hr, Wr] """
H, W = lrsr_lv3.size()[-2:]
lrsr_lv3_unfold = F.unfold(lrsr_lv3, kernel_size=3, padding=1, stride=1) # [B, k*k*C, H*W]
refsr_lv3_unfold = F.unfold(refsr_lv3, kernel_size=3, padding=1, stride=1).transpose(1.2) # [B, Hr*Wr, k*k*C]
lrsr_lv3_unfold = F.normalize(lrsr_lv3_unfold, dim=1)
refsr_lv3_unfold = F.normalize(refsr_lv3_unfold, dim=2)
R = torch.bmm(refsr_lv3_unfold, lrsr_lv3_unfold) # [B, Hr*Wr, H*W]
score, index = torch.max(R, dim=1) # [B, H*W]
ref_lv3_unfold = F.unfold(ref_lv3, kernel_size=3, padding=1, stride=1) # vgg19
ref_lv2_unfold = F.unfold(ref_lv2, kernel_size=6, padding=2, stride=2) # lv1-> Lv2, lv2-> Lv3 have a Max pooling
ref_lv1_unfold = F.unfold(ref_lv1, kernel_size=12, padding=4, stride=4) # kernel_size is not calculated according to the actual receptive field
# Dividend, record fold(Unfold) overlap
divisor_lv3 = F.unfold(torch.ones_like(ref_lv3), kernel_size=3, padding=1, stride=1)
divisor_lv2 = F.unfold(torch.ones_like(ref_lv2), kernel_size=6, padding=2, stride=2)
divisor_lv1 = F.unfold(torch.ones_like(ref_lv1), kernel_size=12, padding=4, stride=4)
T_lv3_unfold = self.bis(ref_lv3_unfold, 2, index) # [B, k*k*C, H*W]
T_lv2_unfold = self.bis(ref_lv2_unfold, 2, index)
T_lv1_unfold = self.bis(ref_lv1_unfold, 2, index)
divisor_lv3 = self.bis(divisor_lv3, 2, index) # [B, k*k*C, H*W]
divisor_lv2 = self.bis(divisor_lv2, 2, index)
divisor_lv1 = self.bis(divisor_lv1, 2, index)
divisor_lv3 = F.fold(divisor_lv3, (H, W), kernel_size=3, padding=1, stride=1)
divisor_lv2 = F.fold(divisor_lv2, (2*H, 2*W), kernel_size=6, padding=2, stride=2)
divisor_lv1 = F.fold(divisor_lv1, (4*H, 4*W), kernel_size=12, padding=4, stride=4)
T_lv3 = F.fold(T_lv3_unfold, (H, W), kernel_size=3, padding=1, stride=1) / divisor_lv3
T_lv2 = F.fold(T_lv2_unfold, (2*H, 2*W), kernel_size=6, padding=2, stride=2) / divisor_lv2
T_lv1 = F.fold(T_lv1_unfold, (4*H, 4*W), kernel_size=12, padding=4, stride=4) / divisor_lv1
score = score.view(lrsr_lv3.size(0), 1, H, W) # [B, 1, H, W]
return score, T_lv1, T_lv2, T_lv3
Copy the code
** To use gather, first set the size of the index to equal the size of the input.
Shape == [B, blocks], expand shape to [B, c x c x k, blocks], then index[I, :, k] is a 1D tensor, And each element is equal to the index before expand [I, j].
In this way, index[I][j][k] does not change when j changes, Out [I, j, k] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k]] = input[I, j, index[I][j][k] = input[I]
reference
Pytorch.org/docs/stable…
Github.com/researchmm/…