The Beautiful Future

TensorFlow CIFAR10 Example 분석2, Read TFRecord 본문

DNN

TensorFlow CIFAR10 Example 분석2, Read TFRecord

Small Octopus 2018. 4. 16. 11:09

TFRecord 을 읽는방법에 대해 살펴보자!

cifar10 예제에서는 class Cifar10DataSet(object):는 클래스에서 

저장해 두었던 TFRecord 를 알맞게 읽어드리도록 설계되어있다.

CIFAR10을 이미 train.tfrecords, validation.tfrecords, eval.tfrecords로 이미 저장해놨고

필요에 따라 위 세계중 하나를 택해서 사용한다.


1. tfrecords 읽는 방법

dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()

위는 tfrecords와의 인터페이스를 마련했다고 보면 될것 같다.

이제 데이터를 어떻게 읽을지 방법을 제공하는 parser와 몇개씩 읽을지 쓰레드 개수 , 버퍼크기를 map을 이용하여 지정.

 dataset = dataset.map( self.parser, num_threads=batch_size, output_buffer_size=2 * batch_size)

parser는 클래스 함수, 쓰래드개수는 배치크기로 지정, 아웃 버퍼는 배치의 2배 크기로 지정

- 랜덤 셔플

현재 선택된 학습데이터셋의 개수에 40% + 배치크기 x 3 = 버퍼크기로 제공해서 셔플 적용

dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)

실제 배치싸이즈 할당 이터레이터 생성 읽어오기

dataset = dataset.batch(batch_size)

iterator = dataset.make_one_shot_iterator()

image_batch, label_batch = iterator.get_next()


2. parse 함수

parser를 하나의 함수로 만드는데 serialized_example이라는 인자를 받게 만듬

텐서플로우의 parse_single_example 매서드를 이용하여 parsing하면 디셔너리로 나오고 

이것을 잘 디코드해서 사용 

features = tf.parse_single_example(serialized_example,features={'image': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64),})

image = tf.decode_raw(features['image'], tf.uint8)

image.set_shape([DEPTH * HEIGHT * WIDTH])

image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]), tf.float32)

label = tf.cast(features['label'], tf.int32)


3. Preprocessing

현재 cifar10 데이터는 32 by 32 로 줄여져있는 상태인데 이것을 40 by 40으로 패딩한다.

그후에 랜덤하게 다시 32 by 32로 크랍, 그 후 좌우 플립


4. 실제 학습시 인풋

위에서 만든 배치를 tf.unstack으로 풀고 tf.parallel_stack으로 변경해서 넘겨줌

전체 배치크기를 shards(파편)의 개수 만큼 나눠서 할당

'DNN' 카테고리의 다른 글

Initialization of DNN  (0) 2019.10.01
tensorflow obj detection  (0) 2018.06.04
TensorFlow CIFAR10 Example 분석1, Write TFRecord  (0) 2018.04.13
transpose convolution visualization  (0) 2018.02.05
caffe conv layer  (0) 2017.11.24
Comments