Tutorial Pembelajaran Transfer PyTorch dengan Contoh

Apa itu Pembelajaran Transfer?

Transfer Belajar adalah teknik menggunakan model terlatih untuk menyelesaikan tugas terkait lainnya. Ini adalah metode penelitian Pembelajaran Mesin yang menyimpan pengetahuan yang diperoleh saat memecahkan masalah tertentu dan menggunakan pengetahuan yang sama untuk memecahkan masalah lain yang berbeda namun terkait. Hal ini meningkatkan efisiensi dengan menggunakan kembali informasi yang dikumpulkan dari tugas yang dipelajari sebelumnya.

Penggunaan bobot model jaringan lain sangat populer untuk mengurangi waktu pelatihan karena Anda memerlukan banyak data untuk melatih model jaringan. Untuk mengurangi waktu pelatihan, Anda menggunakan jaringan lain dan bobotnya serta memodifikasi lapisan terakhir untuk menyelesaikan masalah kita. Keuntungannya adalah Anda bisa menggunakan kumpulan data kecil untuk melatih lapisan terakhir.

Selanjutnya dalam tutorial pembelajaran PyTorch Transfer ini, kita akan mempelajari cara menggunakan Transfer Learning dengan PyTorch.

Memuat Kumpulan Data

Memuat Kumpulan Data

Sumber: Alien vs. Predator Kaggle

Sebelum Anda mulai menggunakan Transfer Learning PyTorch, Anda perlu memahami dataset yang akan Anda gunakan. Dalam contoh Transfer Learning PyTorch ini, Anda akan mengklasifikasikan Alien dan Predator dari hampir 700 gambar. Untuk teknik ini, Anda tidak terlalu membutuhkan data dalam jumlah besar untuk melatihnya. Anda dapat mengunduh kumpulan data dari Kaggle: Alien vs. Predator.

Bagaimana Cara Menggunakan Pembelajaran Transfer?

Berikut adalah proses langkah demi langkah tentang cara menggunakan Transfer Learning untuk Deep Learning dengan PyTorch:

Langkah 1) Muat Data

Langkah pertama adalah memuat data kita dan melakukan beberapa transformasi pada gambar agar sesuai dengan kebutuhan jaringan.

Anda akan memuat data dari folder dengan torchvision.dataset. Modul akan melakukan iterasi dalam folder untuk membagi data untuk pelatihan dan validasi. Proses transformasi akan memotong gambar dari tengah, melakukan pembalikan horizontal, normalisasi, dan terakhir mengubahnya menjadi tensor menggunakan Deep Learning.

from __future__ import print_function, division
import os
import time
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

data_dir = "alien_pred"
input_shape = 224
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

#data transformation
data_transforms = {
   'train': transforms.Compose([
       transforms.CenterCrop(input_shape),
       transforms.ToTensor(),
       transforms.Normalize(mean, std)
   ]),
   'validation': transforms.Compose([
       transforms.CenterCrop(input_shape),
       transforms.ToTensor(),
       transforms.Normalize(mean, std)
   ]),
}

image_datasets = {
   x: datasets.ImageFolder(
       os.path.join(data_dir, x),
       transform=data_transforms[x]
   )
   for x in ['train', 'validation']
}

dataloaders = {
   x: torch.utils.data.DataLoader(
       image_datasets[x], batch_size=32,
       shuffle=True, num_workers=4
   )
   for x in ['train', 'validation']
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'validation']}

print(dataset_sizes)
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Mari visualisasikan kumpulan data kita untuk PyTorch Transfer Learning. Proses visualisasi akan mendapatkan kumpulan gambar berikutnya dari pemuat dan label data kereta dan menampilkannya dengan matplot.

images, labels = next(iter(dataloaders['train']))

rows = 4
columns = 4
fig=plt.figure()
for i in range(16):
   fig.add_subplot(rows, columns, i+1)
   plt.title(class_names[labels[i]])
   img = images[i].numpy().transpose((1, 2, 0))
   img = std * img + mean
   plt.imshow(img)
plt.show()
Kumpulan Gambar
Kumpulan Gambar

Langkah 2) Tentukan Model

Dalam Belajar mendalam prosesnya, Anda akan menggunakan ResNet18 dari modul torchvision.

Anda akan menggunakan torchvision.models untuk memuat resnet18 dengan bobot yang telah dilatih sebelumnya yang ditetapkan menjadi Benar. Setelah itu, Anda akan membekukan lapisan-lapisan tersebut sehingga lapisan-lapisan ini tidak dapat dilatih. Anda juga memodifikasi lapisan terakhir dengan lapisan Linier agar sesuai dengan kebutuhan kita, yaitu 2 kelas. Anda juga menggunakan CrossEntropyLoss untuk fungsi kerugian multikelas dan untuk pengoptimal, Anda akan menggunakan SGD dengan laju pembelajaran 0.0001 dan momentum 0.9 seperti yang ditunjukkan dalam contoh Pembelajaran Transfer PyTorch di bawah ini.

## Load the model based on VGG19
vgg_based = torchvision.models.vgg19(pretrained=True)

## freeze the layers
for param in vgg_based.parameters():
   param.requires_grad = False

# Modify the last layer
number_features = vgg_based.classifier[6].in_features
features = list(vgg_based.classifier.children())[:-1] # Remove last layer
features.extend([torch.nn.Linear(number_features, len(class_names))])
vgg_based.classifier = torch.nn.Sequential(*features)

vgg_based = vgg_based.to(device)

print(vgg_based)

criterion = torch.nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(vgg_based.parameters(), lr=0.001, momentum=0.9)

Struktur model keluaran

VGG(
  (features): Sequential(
	(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(1): ReLU(inplace)
	(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(3): ReLU(inplace)
	(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(6): ReLU(inplace)
	(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(8): ReLU(inplace)
	(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(11): ReLU(inplace)
	(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(13): ReLU(inplace)
	(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(15): ReLU(inplace)
	(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(17): ReLU(inplace)
	(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(20): ReLU(inplace)
	(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(22): ReLU(inplace)
	(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(24): ReLU(inplace)
	(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(26): ReLU(inplace)
	(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(29): ReLU(inplace)
	(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(31): ReLU(inplace)
	(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(33): ReLU(inplace)
	(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(35): ReLU(inplace)
	(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
	(0): Linear(in_features=25088, out_features=4096, bias=True)
	(1): ReLU(inplace)
	(2): Dropout(p=0.5)
	(3): Linear(in_features=4096, out_features=4096, bias=True)
	(4): ReLU(inplace)
	(5): Dropout(p=0.5)
	(6): Linear(in_features=4096, out_features=2, bias=True)
  )
)

Langkah 3) Latih dan Uji Model

Kami akan menggunakan beberapa fungsi dari Transfer Learning Tutorial PyTorch untuk membantu kami melatih dan mengevaluasi model kami.

def train_model(model, criterion, optimizer, num_epochs=25):
   since = time.time()

   for epoch in range(num_epochs):
       print('Epoch {}/{}'.format(epoch, num_epochs - 1))
       print('-' * 10)

       #set model to trainable
       # model.train()

       train_loss = 0

       # Iterate over data.
       for i, data in enumerate(dataloaders['train']):
           inputs , labels = data
           inputs = inputs.to(device)
           labels = labels.to(device)

           optimizer.zero_grad()
          
           with torch.set_grad_enabled(True):
               outputs  = model(inputs)
               loss = criterion(outputs, labels)

           loss.backward()
           optimizer.step()

           train_loss += loss.item() * inputs.size(0)

           print('{} Loss: {:.4f}'.format(
               'train', train_loss / dataset_sizes['train']))
          
   time_elapsed = time.time() - since
   print('Training complete in {:.0f}m {:.0f}s'.format(
       time_elapsed // 60, time_elapsed % 60))

   return model

def visualize_model(model, num_images=6):
   was_training = model.training
   model.eval()
   images_so_far = 0
   fig = plt.figure()

   with torch.no_grad():
       for i, (inputs, labels) in enumerate(dataloaders['validation']):
           inputs = inputs.to(device)
           labels = labels.to(device)

           outputs = model(inputs)
           _, preds = torch.max(outputs, 1)

           for j in range(inputs.size()[0]):
               images_so_far += 1
               ax = plt.subplot(num_images//2, 2, images_so_far)
               ax.axis('off')
               ax.set_title('predicted: {} truth: {}'.format(class_names[preds[j]], class_names[labels[j]]))
               img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
               img = std * img + mean
               ax.imshow(img)

               if images_so_far == num_images:
                   model.train(mode=was_training)
                   return
       model.train(mode=was_training)

Terakhir, dalam contoh Pembelajaran Transfer di PyTorch ini, mari kita mulai proses pelatihan dengan jumlah epoch yang disetel ke 25 dan evaluasi setelah proses pelatihan. Pada setiap langkah pelatihan, model akan mengambil masukan dan memprediksi keluaran. Setelah itu, keluaran prediksi akan diteruskan ke kriteria untuk menghitung kerugian. Kemudian kerugian akan melakukan perhitungan backprop untuk menghitung gradien dan terakhir menghitung bobot dan mengoptimalkan parameter dengan autograd.

Pada model visualisasi, jaringan terlatih akan diuji dengan sekumpulan gambar untuk memprediksi label. Kemudian akan divisualisasikan dengan bantuan matplotlib.

vgg_based = train_model(vgg_based, criterion, optimizer_ft, num_epochs=25)

visualize_model(vgg_based)

plt.show()

Langkah 4) Hasil

Hasil akhirnya adalah Anda mencapai akurasi 92%.

Epoch 23/24
----------
train Loss: 0.0044
train Loss: 0.0078
train Loss: 0.0141
train Loss: 0.0221
train Loss: 0.0306
train Loss: 0.0336
train Loss: 0.0442
train Loss: 0.0482
train Loss: 0.0557
train Loss: 0.0643
train Loss: 0.0763
train Loss: 0.0779
train Loss: 0.0843
train Loss: 0.0910
train Loss: 0.0990
train Loss: 0.1063
train Loss: 0.1133
train Loss: 0.1220
train Loss: 0.1344
train Loss: 0.1382
train Loss: 0.1429
train Loss: 0.1500
Epoch 24/24
----------
train Loss: 0.0076
train Loss: 0.0115
train Loss: 0.0185
train Loss: 0.0277
train Loss: 0.0345
train Loss: 0.0420
train Loss: 0.0450
train Loss: 0.0490
train Loss: 0.0644
train Loss: 0.0755
train Loss: 0.0813
train Loss: 0.0868
train Loss: 0.0916
train Loss: 0.0980
train Loss: 0.1008
train Loss: 0.1101
train Loss: 0.1176
train Loss: 0.1282
train Loss: 0.1323
train Loss: 0.1397
train Loss: 0.1436
train Loss: 0.1467
Training complete in 2m 47s

Selesai maka keluaran model kita akan divisualisasikan dengan matplot di bawah ini:

Divisualisasikan dengan Matplot
Divisualisasikan dengan Matplot

Ringkasan

Jadi, mari kita rangkum semuanya! Faktor pertama adalah PyTorch adalah framework deep learning yang sedang berkembang untuk pemula atau untuk tujuan penelitian. Ia menawarkan waktu komputasi yang tinggi, Dynamic Graph, dukungan GPU dan sepenuhnya ditulis dalam Python. Anda dapat menentukan modul jaringan Anda sendiri dengan mudah dan melakukan proses pelatihan dengan iterasi yang mudah. Jelas bahwa PyTorch sangat ideal bagi pemula untuk mengetahui deep learning dan bagi peneliti profesional sangat berguna dengan waktu komputasi yang lebih cepat dan juga fungsi autograd yang sangat membantu untuk membantu grafik dinamis.

Ringkaslah postingan ini dengan: