سلام . در پست قبلی دوست بزرگواری فرمودند اگر دیتا آموزش ما کم باشه یکی از تکنیک ها استفاده از oversampling هست آیا کسی تو کلاسیفیکشن اینکار رو کرده؟ من از کد زیر جهت آموزش استفاده می کنم
def train(self,data_dir,model_name ="squeezenet",num_classes = 4,batch_size=8, num_epochs = 15,
feature_extract = False, pre_checkpoint='',input_size_=(224,224),checkpoints_path = './checkpoints'):
self.data_dir = data_dir
self.model_name = model_name
self.num_classes = num_classes
self.batch_size = batch_size
self.num_epochs = num_epochs
self.feature_extract = feature_extract
if pre_checkpoint == '':
model_ft, input_size = self._initialize_model(use_pretrained=True)
else:
model_ft = torch.load(pre_checkpoint)
input_size = input_size_
#print(model_ft)
data_transforms = {
'train': transforms.Compose([
transforms.RandomRotation(degrees=(-2, 2),expand=True,fill=(255,255,255)),
transforms.RandomResizedCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(input_size),
# transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
print("Initializing Datasets and Dataloaders...")
image_datasets = {x: datasets.ImageFolder(os.path.join(self.data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders_dict = {
dataloaders_dict = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=self.batch_size, shuffle=True, num_workers=4) for x in
['train', 'val']}
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(self.device)
params_to_update = model_ft.parameters()
print("Params to learn:")
if self.feature_extract:
params_to_update = []
for name, param in model_ft.named_parameters():
if param.requires_grad == True:
params_to_update.append(param)
print("\t", name)
else:
for name, param in model_ft.named_parameters():
if param.requires_grad == True:
print("\t", name)
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum = 0.9)
criterion = nn.CrossEntropyLoss()
model_ft, hist = self._train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=self.num_epochs,
is_inception=(self.model_name == "inception"),checkpoints_path=checkpoints_path)
scratch_flag = False
if scratch_flag:
# Initialize the non-pretrained version of the model used for this run
self.feature_extract = False
scratch_model, _ = self._initialize_model(use_pretrained=False)
scratch_model = scratch_model.to(self.device)
scratch_optimizer = optim.SGD(scratch_model.parameters(), lr=0.001, momentum=0.9)
scratch_criterion = nn.CrossEntropyLoss()
_, scratch_hist = self._train_model(scratch_model, dataloaders_dict, scratch_criterion, scratch_optimizer,
num_epochs=self.num_epochs, is_inception=(self.model_name == "inception"),checkpoints_path=checkpoints_path)
# Plot the training curves of validation accuracy vs. number
# of training epochs for the transfer learning method and
# the model trained from scratch
ohist = []
# shist = []
ohist = [h.cpu().numpy() for h in hist]
if scratch_flag:
shist = [h.cpu().numpy() for h in scratch_hist]
plt.title("Validation Accuracy vs. Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Accuracy")
plt.plot(range(1, self.num_epochs + 1), ohist, label="Pretrained")
if scratch_flag:
plt.plot(range(1, self.num_epochs + 1), shist, label="Scratch")
plt.ylim((0, 1.))
plt.xticks(np.arange(1, self.num_epochs + 1, 1.0))
plt.legend()
plt.show()