class TrainDataset(Dataset):
def __init__(self, idx, transform):
self.img_labels = label
path = ''
img_dir = ''
self.img_dir = img_dir
print(len(self.img_dir))
self.idx = idx
self.transform = transform
def __len__(self):
return len(self.img_dir)
def __getitem__(self, index):
indexlist = [[0,1],[1,0]]
image = torch.FloatTensor(np.load(self.img_dir[index]))
ix = self.img_labels.loc[index][0]
label1 = torch.Tensor(indexlist[ix])
#label1 = torch.tensor(list(self.img_labels.index)[index])
if self.transform:
image = self.transform(image)
return image, label1
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 100, collate_fn=collate_fn,shuffle=True )
'Pytorch Study > Machine Learning' 카테고리의 다른 글
batch_sampler (0) | 2023.04.21 |
---|