import torch
import torch.nn as nn
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
class SiameseNet(nn.Module):
def __init__(self):
super(SiameseNet, self).__init__()
self.backbone = EfficientNet.from_pretrained('efficientnet-b0')
self.fc = nn.Linear(1000, 1)
def forward(self, x1, x2):
x1 = self.backbone(x1)
x2 = self.backbone(x2)
x = torch.abs(x1 - x2)
x = F.relu(x)
x = self.fc(x)
x = torch.sigmoid(x)
return x
در این مثال، ما از مدل EfficientNet-B0 به عنوانbackboneشبکه siamese مون استفاده می کنیم، و خروجی نهایی یک امتیاز شباهت بین دو ورودی است که به صورت مقداری بین 0 و 1 نمایش داده می شود. ورودی های x1 و x2 هر کدام از طریق همان بکبون EfficientNet، و سپس تفاوت بین ویژگی های آنها از یک لایه کاملاً متصل عبور داده می شود تا امتیاز شباهت نهایی به دست آید. توجه داشته باشید که این تنها یک پیاده سازی ممکن از یک شبکه siamese با EfficientNet است و ممکن است بهترین انتخاب برای مورد استفاده خاص شما نباشد.
import torch
import torchvision
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = []
for class_dir in os.listdir(root_dir):
class_dir = os.path.join(root_dir, class_dir)
for img_path in os.listdir(class_dir):
self.images.append((os.path.join(class_dir, img_path), class_dir))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path, class_dir = self.images[idx]
img = Image.open(img_path)
if self.transform:
img = self.transform(img)
label = os.path.basename(class_dir)
return img, label
def get_pair(dataset):
positive = torch.randint(0, len(dataset), (1,)).item()
negative = torch.randint(0, len(dataset), (1,)).item()
while dataset[positive][1] == dataset[negative][1]:
negative = torch.randint(0, len(dataset), (1,)).item()
return positive, negative
def contrastive_loss(output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(2 - euclidean_distance, min=0.0), 2))
return loss_contrastive
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = ImageDataset('./data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model = SiameseNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
for epoch in range(100):
for i, (img, label) in enumerate(dataloader):
positive, negative = get_pair(dataset)
img1, label1 = dataset[positive]
img2, label2 = dataset[negative]
if label1 == label2:
label = torch.tensor([1], dtype=torch.float32)
else:
label = torch.tensor([0],
img1, img2 = img1.unsqueeze(0), img2.unsqueeze(0)
output1, output2 = model(img1, img2)
loss = contrastive_loss(output1, output2, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch {}/{} Loss: {:.4f}'.format(epoch+1, 100, loss.item()))
print('Training finished.')
در این حلقه، برای هر دسته از داده ها، به طور تصادفی یک نمونه مثبت و یک نمونه منفی از مجموعه داده انتخاب می کنیم. سپس این دو نمونه را از طریق شبکه siamese عبور می دهیم تا ویژگی های آنها را بدست آوریم، که سپس برای محاسبه contrastive_loss استفاده می شود. سپس گرادیان ها محاسبه شده و پارامترهای مدل با استفاده از بهینه ساز Adam به روز می شوند. ضایعات پس از هر دوره برای epoch پیشرفت آموزش چاپ می شود. توجه داشته باشید که این تنها یک راه ممکن برای آموزش شبکه siamese است و پارامترهایی مانند نرخ یادگیری و اندازه دسته ای را می توان در صورت نیاز تنظیم کرد.