DNN

TensorFlow CIFAR10 Example 분석1, Write TFRecord

Small Octopus 2018. 4. 13. 22:26

cifar 10 estimator 예제를 살펴보자!

https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator

우선 READ.md부터 읽어보면  TF 1.2.1 이상버전을 설치하라고 한다.

그리고 cifar10 db를 받아서 TFRecord file을 만들라고 한다. TFRecord은 학습시 대량의 이미지를 

빠르게 읽을 수 있게해주는 포멧이다.

$python generate_cifar10_tfrecords.py --data-dir=${PWD}/cifar-10-data

을 실행하면 이미지를 다운로드하고 TFRecord도 만들어 준다는데 일단 이것부터 살펴보자


1. CIFAE 이미지 다운로드

 텐서플로우에서 다운로드는 function으로 제공한다.

import tensorflow as tf

tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir, CIFAR_DOWNLOAD_URL)

CIFAR_FILENAME = 'cifar-10-python.tar.gz'

CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME

data_dir = 저장되는 경로


2.  압축풀기

이렇게 다운 받아 보면 cifar-10-python.tar.gz 파일이 다운 되어 있고 압축을 풀려면 

import tarfile

import os

tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),'r:gz').extractall(data_dir)

압축을 풀면 cifar-10-batches-py란 폴더안에 data_batch_1~5, batches.meta, test_batch 이렇게 7개의 데이터가 있다.(readme.html 빼고)

data_batch_1~4는 Train용, data_batch_5는 Validation용, test_batch는 Evaluation용.


3. pickle 풀기

현재 데이터들이 이상한 바이러니 파일 같이 보인다.. 그런데 이게 'pickle'을 이용해서 만들어진 것이다.

'pickle'을 이용해서 풀어보자, 일단 바이러니로 로드후 pickle을 이용 풀수 있다.

import pickle

f = tf.gfile.Open( './cifar-10-batches-py/data_batch_1', 'rb')

data_dict = pickle.load(f)

print data_dict['data'], data_dict['data'].shape

print data_dict['labels']data_dict['labels'].shape

이렇게 해보면 10000개의 이미지가 10000 x 3072의 형태로  정장되어있음을 알 수 있다.


4. TFRecord로 저장

이제 for loop을 10000번 돌면서 TFRecord로 저장하면 된다.

TFRecord로 파일 열기

tfr = tf.python_io.TFRecordWriter('./xxx.tfrecords')

TFRecord로 파일에 쓰기는 방법은 파이선 딕셔너리 형태를 사용하면된다.

feature = {

'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data[0].tobytes()]))

'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[0]]))

}

원래 data[0]는 numpy.ndarray, 넘파이 uint8 n dimensional array.. 였는데 이거를 바이트로 바꾸고 

또 텐서플로우의 바이트리스트로 바꾸고 그리고 또 다시 텐서의 Feature로 바꾼다... 으.... 

labels는 int 리스트 , labels[0]은 그냥 int 이고 이거를 다시 [labels[0]]로 바꾸고 텐서 Int64List로 바꾸고 

그리고 텐서 Feature로 바꾼다...

위에서 만든 딕셔너리를 tf.train.Example로 바꿔야한다.

example = tf.train.Example(features=tf.train.Features( feature=feature ))

위에서 만들었던 feature를 또 텐서의 tf.train.Features로 변환해서 드디어 tf.train.Example(로 변환

이제 드디어 TF Record에 쓴다.

tfr.write( example.SerializeToString())

다썻으면 닫아줘야지...

tfr.close()