• toc {:toc}

Sequence 이용해 Custom Dataset 만들기

Pytorch의 Dataset을 상속받아 Custom Dataset을 만드는 방식과 유사하다.

Tensorflow 2.x 버전부터 custom dataset loader를 만드는 방법이 생겼다.

tensorflow.keras.utils.Sequence를 사용한다.

init 함수 정의

class CustomDataset(Sequence):
	def __init__(self, img, labels, batch_size=BATCH_SIZE, augmentor=None, shuffle=False):
	    self.img = img
	    self.labels = labels
	    self.batch_size = BATCH_SIZE
	    self.augmentor = augmentor
	    self.shuffle = shuffle
			
			if self.shuffle:
				self.on_epoch_end()
  • img : 이미지 파일이 있는 directory 경로, 이외의 각 픽셀값을 담는 numpy array의 경우에는 NonImplmentedError가 발생했다.
  • labels : 이미지 label을 담는다.

len 함수 정의

def __len__(self):
    return int(np.ceil(len(self.labels)/self.batch_size))
  • step이 몇 번 발생하는지를 의미한다.
  • 즉, 전체데이터가 60000이고 batch_size가 600이라면 100번 동안 step을 진행함을 의미한다.
  • np.ceil은 만약 batch_size가 599라면 100.xxxxx번 하는 것이 아닌 101번을 진행해야 하기 때문에 올림 처리로 사용한다.

getitem 함수 정의

def __getitem__(self, index):
    img_batch = self.img[index*self.batch_size:(index+1)*self.batch_size]
    if self.labels is not None:
        label_batch = self.labels[index*self.batch_size:(index+1)*self.batch_size]
 
    image_batch = np.zeros((img_batch.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3))
 
    for image_index in range(img_batch.shape[0]):
        image = cv2.cvtColor(cv2.imread(img_batch[image_index]), cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
        if self.augmentor is not None:
            image = self.augmentor(image=image)['image']
    
        img_batch[image_index] = image
   
    return img_batch, label_batch
  • index에 따라서 데이터에서 batch_size만큼 데이터를 가져오는 함수이다.
  • 테스트 세트의 경우 label이 없기 때문에 따로 label을 처리해준다.
  • img_batch가 가지고 있는 값이 directory path의 값이기 때문에 cv2를 통해 numpy array로 변경해 resize한다.
  • augmentor가 존재하는 경우 이미지 각각에 적용하고 img_batch에 저장한다.
  • img_batch와 label_batch를 반환해 iteration마다 batch를 가져오게 한다.

On_epoch_end 함수 정의

def on_epoch_end(self):
        if(self.shuffle):
            self.image_filenames, self.labels = sklearn.utils.shuffle(self.image_filenames, self.labels)
        else:
            pass
  • On_epoch_end 함수는 선택사항이다.
  • shuffle을 위해 사용한다. sklearn.utils.shuffle 을 사용해서 순서에 따라 shuffle한다.
  • sklearn.utils.shuffle() : 데이터를 동일한 순서로 섞어준다.