class TrainDataset(Dataset):
    def __init__(self, idx, transform):
        self.img_labels = label
        path = ''
        img_dir = ''
       
        self.img_dir = img_dir
        print(len(self.img_dir))
       
        self.idx = idx
        self.transform = transform
       
       
    def __len__(self):
       
        return len(self.img_dir)
       
    def __getitem__(self, index):
        indexlist = [[0,1],[1,0]]
       
        image = torch.FloatTensor(np.load(self.img_dir[index]))
        ix = self.img_labels.loc[index][0]
        label1 = torch.Tensor(indexlist[ix])
        #label1 = torch.tensor(list(self.img_labels.index)[index])

        if self.transform:
            image = self.transform(image)
       
        return image, label1
 
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 100, collate_fn=collate_fn,shuffle=True )

'Pytorch Study > Machine Learning' 카테고리의 다른 글

batch_sampler  (0) 2023.04.21

+ Recent posts