The Beautiful Future
TensorFlow CIFAR10 Example 분석2, Read TFRecord 본문
TFRecord 을 읽는방법에 대해 살펴보자!
cifar10 예제에서는 class Cifar10DataSet(object):는 클래스에서
저장해 두었던 TFRecord 를 알맞게 읽어드리도록 설계되어있다.
CIFAR10을 이미 train.tfrecords, validation.tfrecords, eval.tfrecords로 이미 저장해놨고
필요에 따라 위 세계중 하나를 택해서 사용한다.
1. tfrecords 읽는 방법
dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()
위는 tfrecords와의 인터페이스를 마련했다고 보면 될것 같다.
이제 데이터를 어떻게 읽을지 방법을 제공하는 parser와 몇개씩 읽을지 쓰레드 개수 , 버퍼크기를 map을 이용하여 지정.
dataset = dataset.map( self.parser, num_threads=batch_size, output_buffer_size=2 * batch_size)
parser는 클래스 함수, 쓰래드개수는 배치크기로 지정, 아웃 버퍼는 배치의 2배 크기로 지정
- 랜덤 셔플
현재 선택된 학습데이터셋의 개수에 40% + 배치크기 x 3 = 버퍼크기로 제공해서 셔플 적용
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
실제 배치싸이즈 할당 이터레이터 생성 읽어오기
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
2. parse 함수
parser를 하나의 함수로 만드는데 serialized_example이라는 인자를 받게 만듬
텐서플로우의 parse_single_example 매서드를 이용하여 parsing하면 디셔너리로 나오고
이것을 잘 디코드해서 사용
features = tf.parse_single_example(serialized_example,features={'image': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64),})
image = tf.decode_raw(features['image'], tf.uint8)
image.set_shape([DEPTH * HEIGHT * WIDTH])
image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]), tf.float32)
label = tf.cast(features['label'], tf.int32)
3. Preprocessing
현재 cifar10 데이터는 32 by 32 로 줄여져있는 상태인데 이것을 40 by 40으로 패딩한다.
그후에 랜덤하게 다시 32 by 32로 크랍, 그 후 좌우 플립
4. 실제 학습시 인풋
위에서 만든 배치를 tf.unstack으로 풀고 tf.parallel_stack으로 변경해서 넘겨줌
전체 배치크기를 shards(파편)의 개수 만큼 나눠서 할당
'DNN' 카테고리의 다른 글
Initialization of DNN (0) | 2019.10.01 |
---|---|
tensorflow obj detection (0) | 2018.06.04 |
TensorFlow CIFAR10 Example 분석1, Write TFRecord (0) | 2018.04.13 |
transpose convolution visualization (0) | 2018.02.05 |
caffe conv layer (0) | 2017.11.24 |