Introduction to domain adaptation

Domain adaptation is one of the most common problems in transfer learning. Different domains have the same task, and the source domain data has labels, while the target domain data has no labels or very few data has labels. By projecting the features of source domain and target domain into similar feature space, the target domain can be classified by the classifier of source domain

The following dichotomies are illustrated as shown in the figure below:In the figure, the red circle is the source domain, the blue circle is the target domain, and the circle and cross are the data with different characteristics. The classifier of the source domain divides the data of the source domain into two categories, as shown in the dotted line. At this point, if the classifier of the source domain is used to classify the target domain, it can be seen from the figure that the effect is very poor. One way is to align the distribution of source domain and target domain. As shown on the right of the picture, the distribution of source domain and target domain is similar (that is, the data with similar features are distributed in similar locations). In this way, the target domain can be directly classified by the classifier of source domain.

GAN similar domain against the generated network training process Training at the same time two models: one is used to extract the target domain characteristics of MT, and the one to judge characteristic from the source domain and target domain domain discriminator D, MT training process is to maximize D errors, namely MT extracted feature is let D resolution not to come out from the source domain and target domain.

The target domain feature extractor MT and the domain discriminator D are rivals: D learns to distinguish whether features come from the source domain or the target domain, and MT learns to make its extracted features closer to those extracted from the source domain. The target domain feature extractor MT can be thought of as a forgery team trying to produce a fake and use it undetected, while the domain discriminator D is similar to the police trying to detect counterfeit money. Competition in this game drives both teams to refine their methods until the truth is too close to call.

Antagonistic domain adaptation

Data selection

In order to achieve good effect and easy training, I selected data of 0 and 1 in MNIST dataset as source domain and data of 2 and 3 as target domain. The number of data in the source domain and target domain is 10000 each. During training, the source domain can obtain data and labels, while the target domain can only obtain data without labels to simulate the background of the domain adaptation. The tag of the target field is only used to test accuracy.

network

1. Source domain feature extractor MS and target domain feature extractor MT. The so-called feature extractor actually removes the last classification layer from the network that identifies MNIST.

		(encoder): Sequential (
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (2): ReLU ()
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (5): ReLU ()
    )
    (fc1): Linear (64 * 4 * 4 -> 512)
Copy the code

Think of the output of this network as extracted features

2. Classifier C. In fact, it is the last classification layer of the network to identify MNIST, a simple fully connected network.

		Classifier (
    (fc2): Linear (512 -> 2)
    )
Copy the code

3. Domain identifier D. According to the output of the feature extractor, the data is judged to be from the source domain or the target domain. The output 0 represents the source domain, and the output 1 represents the target domain.

		Discriminator (
     (layer): Sequential (
    (0): Linear (512 -> 512)
    (1): Linear (512 -> 512)
    (2): Linear (512 -> 2)
    ))
Copy the code

 

The training process

Training MS, C

First, feature extractor MS and classifier C are trained on the source domainThe training process is similar to the general training process, but the whole network is divided into two parts to train and optimize.

def train_MS_C(loader_ms): Optim.sgd (Ms. Parameters (), lr=0.03) o_c = optim.sgd (c.parameters (), Lr =0.03) Criterion = nn.CrossEntropyLoss() # calculate loss for j in range(1): Print (j) # train for I, (images, labels) in enumerate(loader_ms): o_ms.zero_grad() o_c.zero_grad() outputs_mid = MS(images) outputs = C(outputs_mid) loss = criterion(outputs, Labels) loss. Backward () o_ms.step() # optimize parameter o_c.step() if I % 100 == 0: print(I) print('current loss: %.5f' % loss.data.item()) # Save model np.save(params.ms_save_dir, Ms. Get_w ()) np.save(params.c_save_dir, c.get_w ())Copy the code

After the training, the accuracy in the source domain is 0.9985. If the target domain is directly classified by the feature extractor and classifier of the source domain, the accuracy is only 0.5840

Fixed MS and C, trained MT and D

Then, MS and C are fixed, that is, their network weights are not changed, and target domain feature extractor MT and domain recognizer D 1 are antagonistically learned in source domain and target domain. When MT is initialized with MS, the target field will have a good accuracy of 0.5840 at the beginning, and then training on this basis, it will be easier to converge to the good direction, and the convergence process will be faster.

MT.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())
Copy the code

def train_MT_D(loader_ms, loader_mt): Discriminator() = Discriminator() = Discriminator() = Discriminator() allow_pickle=True).item()) if params.first_train: Update_w (np.load(params.ms_save_dir, encoding='bytes', allow_pickle=True).item()) else: MT.update_w(np.load(params.MT_save_dir, encoding='bytes', allow_pickle=True).item()) D.update_w(np.load(params.D_save_dir, encoding='bytes', Allow_pickle =True).item() # optimizer o_mt = optim.sgd (mt.parameters (), lr=0.00001) o_d = optim.sgd (d.parameters (), Lr =0.00001) criterion = nn.CrossEntropyLoss() # calculate loss # train for j in range(1): Print (j) data_zip = zip(loader_ms, loader_mt) for I, ((images_s, labels_s), (images_t, labels_t)) in enumerate(data_zip): # # # # # # # # # # # # # # # # # to the training of the domain discriminator D to extract the characteristics of f_s = MS (images_s) f_t = MT (images_t) f_cat = torch. (cat (f_s, f_t), Predicts_D = torch. Max (f_cat.detach()) [1] if I == 0: Len = len(labels_t) len_t = Len (labels_t) temp1 = Torch. Zeros (len_s) len_t = Len (labels_t) temp1 = Torch.  temp2 = torch.ones(len_t) lab_D = torch.cat((temp1, temp2), 0).long() # gradient = 0 o_d.zero_grad() # calculate loss loss_D = out_D Lab_D) # back propagation loss_D. Backward (#) to optimize the network o_d. Step () # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # to the training of the target domain feature extraction apparatus MT # to extract the features of f_t = Outputs, labels out_MT = d_t predicts_MT = torch. Max (out_MT. 1)[1] lab_MT = torch.zeros(len_t).long() # gradient 0 o_mt.zero_grad() # calculate loss_MT = criterion(out_MT, Step () if I % 100 == 0: print(I) print('current loss_D ': %.5f' % loss_D.data.item()) print('current loss_MT : Np.save (params.mt_save_dir, mt.get_w ()) np.save(params.d_save_dir, d.get_w ())Copy the code

Classify the target domain using MT and C

Finally, the trained feature extractor MT and classifier C are used to classify the target domain

def test_MT_C(loader_mt): MT = Encoder() C = Classifier() # load mt.update_w (np.load(params.mt_save_dir, encoding='bytes', encoding=' coder ') allow_pickle=True).item()) C.update_w(np.load(params.C_save_dir, encoding='bytes', allow_pickle=True).item()) correct = 0 for images, labels in loader_mt: outputs_mid = MT(images) outputs = C(outputs_mid) _, predicts = torch.max(outputs.data, 1) correct += (predicts == labels).sum() total = len(loader_mt.dataset) print('MT+C Accuracy: %.4f' % (1.0 * correct/total))Copy the code

The experimental results

With the feature extractor and classifier of the source domain, the accuracy of the target domain is only 0.5840The following figure shows the result of the domain discriminator D. The input of the first half is the feature of the source domain, and the input of the second half is the feature of the target domain. Now, D can mostly judge correctly.

After a few rounds of training, the accuracy went up a bitD’s resolution of the domain decreases, and most of the inputs of the target domain are judged to be those of the source domain.

After 40 rounds of training, the accuracy fluctuated around 0.9, much better than the initial 0.5840D cannot distinguish the source domain from the destination domain, and identifies all inputs as source domain.

The code address

Momodel. Cn/explore / 5 f1…

reference

Adversarial Discriminative Domain Adaptation blog.csdn.net/sinat_29381… Github.com/corenel/pyt…