- toc {:toc}
Sequence ์ด์ฉํด Custom Dataset ๋ง๋ค๊ธฐ
Pytorch์ Dataset์ ์์๋ฐ์ Custom Dataset์ ๋ง๋๋ ๋ฐฉ์๊ณผ ์ ์ฌํ๋ค.
Tensorflow 2.x ๋ฒ์ ๋ถํฐ custom dataset loader๋ฅผ ๋ง๋๋ ๋ฐฉ๋ฒ์ด ์๊ฒผ๋ค.
tensorflow.keras.utils.Sequence๋ฅผ ์ฌ์ฉํ๋ค.
init ํจ์ ์ ์
class CustomDataset(Sequence):
def __init__(self, img, labels, batch_size=BATCH_SIZE, augmentor=None, shuffle=False):
self.img = img
self.labels = labels
self.batch_size = BATCH_SIZE
self.augmentor = augmentor
self.shuffle = shuffle
if self.shuffle:
self.on_epoch_end()- img : ์ด๋ฏธ์ง ํ์ผ์ด ์๋ directory ๊ฒฝ๋ก, ์ด์ธ์ ๊ฐ ํฝ์ ๊ฐ์ ๋ด๋ numpy array์ ๊ฒฝ์ฐ์๋ NonImplmentedError๊ฐ ๋ฐ์ํ๋ค.
- labels : ์ด๋ฏธ์ง label์ ๋ด๋๋ค.
len ํจ์ ์ ์
def __len__(self):
return int(np.ceil(len(self.labels)/self.batch_size))- step์ด ๋ช ๋ฒ ๋ฐ์ํ๋์ง๋ฅผ ์๋ฏธํ๋ค.
- ์ฆ, ์ ์ฒด๋ฐ์ดํฐ๊ฐ 60000์ด๊ณ batch_size๊ฐ 600์ด๋ผ๋ฉด 100๋ฒ ๋์ step์ ์งํํจ์ ์๋ฏธํ๋ค.
- np.ceil์ ๋ง์ฝ batch_size๊ฐ 599๋ผ๋ฉด 100.xxxxx๋ฒ ํ๋ ๊ฒ์ด ์๋ 101๋ฒ์ ์งํํด์ผ ํ๊ธฐ ๋๋ฌธ์ ์ฌ๋ฆผ ์ฒ๋ฆฌ๋ก ์ฌ์ฉํ๋ค.
getitem ํจ์ ์ ์
def __getitem__(self, index):
img_batch = self.img[index*self.batch_size:(index+1)*self.batch_size]
if self.labels is not None:
label_batch = self.labels[index*self.batch_size:(index+1)*self.batch_size]
image_batch = np.zeros((img_batch.shape[0], IMAGE_SIZE, IMAGE_SIZE, 3))
for image_index in range(img_batch.shape[0]):
image = cv2.cvtColor(cv2.imread(img_batch[image_index]), cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
if self.augmentor is not None:
image = self.augmentor(image=image)['image']
img_batch[image_index] = image
return img_batch, label_batch- index์ ๋ฐ๋ผ์ ๋ฐ์ดํฐ์์ batch_size๋งํผ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ ธ์ค๋ ํจ์์ด๋ค.
- ํ ์คํธ ์ธํธ์ ๊ฒฝ์ฐ label์ด ์๊ธฐ ๋๋ฌธ์ ๋ฐ๋ก label์ ์ฒ๋ฆฌํด์ค๋ค.
- img_batch๊ฐ ๊ฐ์ง๊ณ ์๋ ๊ฐ์ด directory path์ ๊ฐ์ด๊ธฐ ๋๋ฌธ์ cv2๋ฅผ ํตํด numpy array๋ก ๋ณ๊ฒฝํด resizeํ๋ค.
- augmentor๊ฐ ์กด์ฌํ๋ ๊ฒฝ์ฐ ์ด๋ฏธ์ง ๊ฐ๊ฐ์ ์ ์ฉํ๊ณ img_batch์ ์ ์ฅํ๋ค.
- img_batch์ label_batch๋ฅผ ๋ฐํํด iteration๋ง๋ค batch๋ฅผ ๊ฐ์ ธ์ค๊ฒ ํ๋ค.
On_epoch_end ํจ์ ์ ์
def on_epoch_end(self):
if(self.shuffle):
self.image_filenames, self.labels = sklearn.utils.shuffle(self.image_filenames, self.labels)
else:
pass- On_epoch_end ํจ์๋ ์ ํ์ฌํญ์ด๋ค.
- shuffle์ ์ํด ์ฌ์ฉํ๋ค. sklearn.utils.shuffle ์ ์ฌ์ฉํด์ ์์์ ๋ฐ๋ผ shuffleํ๋ค.
- sklearn.utils.shuffle() : ๋ฐ์ดํฐ๋ฅผ ๋์ผํ ์์๋ก ์์ด์ค๋ค.