파이썬 이것저것/파이썬 딥러닝 관련

[Python] Tensorflow 커스텀 데이터셋 불러오기

agingcurve 2022. 7. 17. 13:55
반응형

Tensorflow 커스텀 데이터셋 불러오기

 

 

 

 

대용량의 데이터셋은 한번에 메모리에 불러오는 것이 불가능

대용량의 데이터셋을 학습에 사용할 경우에는 해당 데이터를 사용할 때만 메모리에 불러오는 방법을 사용

tensorflow에서는 이 과정을 수행하는 함수를 제공

tensorflow에서 제공하는 ImageDataGenerator는 데이터셋을 불러오는 기능과 데이터 증강을 적용하는 기능을 제공

 

사전에 dataset폴더에 train과 val을 구분하여 데이터를 만들어준다.

train 과 val에 각각 폴더를 나눔

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
data_path = "dataset"

batch_size = 2
img_height = 180
img_width = 180

def get_dataset(path,datagen):    # path의 데이터를 datagen(=ImageDataGenerator)로 불러와주는 함수
    data_set = datagen.flow_from_directory(path,
                                             target_size = (img_width, img_height),
                                             batch_size = batch_size,
                                             class_mode = 'categorical')
    return data_set


first_gen = ImageDataGenerator() # 1-1: 매개변수를 사용하지 않고 ImageDataGenerator를 first_gen에 저장합니다.
first_set = get_dataset(os.path.join(data_path, "val"), first_gen) #1-2 검증 데이터의 경로를 입력합니다.
x,y = first_set.__next__()

print("\n1. 데이터 제너레이터 만들기")
print("first_set")
print("x:",x.shape, "y:",y.shape)
print(x[0][0][0]) # 픽셀이 0~255의 값을 가짐

print("\n2. 데이터 제너레이터에 전처리 추가하기")
second_gen = ImageDataGenerator(rescale = 1/255) # 2-1 픽셀값을 0~1의 값으로 만들도록 매개변수를 추가하세요
second_set = get_dataset(os.path.join(data_path, "val"), second_gen) # 2-2 검증 데이터의 경로를 입력합니다.
x,y = second_set.__next__()    
print("second_set")
print("x:",x.shape, "y:",y.shape)
print(x[0][0][0]) # 픽셀이 0~1의 값을 가지는 것을 확인하세요



# 실제 학습을 위한 제너레이터 작성
print("\n3. 실제 학습을 위한 제너레이터 작성")
# 3-1 학습 데이터를 불러오는 validation_set을 완성하세요
train_gen = ImageDataGenerator(rescale=1/255)
training_set = get_dataset(os.path.join(data_path, "train"), train_gen)

# 3-2 검증 데이터를 불러오는 validation_set을 완성하세요
val_gen = ImageDataGenerator(rescale=1/255)
validation_set = get_dataset(os.path.join(data_path, "val"), train_gen)

print ("학습 데이터의 길이: ",len(training_set))
print ("검증 데이터의 길이: ",len(validation_set))