کلاس datasets.ImageFolder را باید به صورت زیر بازنویسی کنید exception مدیریت کنید و برای تصاویری که قابل خواندن نیست یه تصویر سیاه ایجاد کنید تا پروسه خواندن batch به مشکل بر نخوره.
from torchvision import datasets
import torch
from PIL import Image
class CustomImageFolder(datasets.ImageFolder):
def __init__(self, root, transform=None, target_transform=None,
loader=datasets.folder.default_loader, is_valid_file=None):
super().__init__(root, transform=transform,
target_transform=target_transform,
loader=loader, is_valid_file=is_valid_file)
def __getitem__(self, index):
path, target = self.samples[index]
try:
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
except (PIL.UnidentifiedImageError, OSError) as e:
print(f"Unable to open image: {path}")
# Return a black image of the same size as the other images in the dataset
sample = Image.new("RGB", (224, 224))
if self.transform is not None:
sample = self.transform(sample)
return sample, target
# Use your custom Dataset class
image_datasets = CustomImageFolder(src_path, transform=data_transforms)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=self.batch_size,
shuffle=False, num_workers=4)