The Beautiful Future

TensorFlow CIFAR10 Example 분석1, Write TFRecord 본문

DNN

TensorFlow CIFAR10 Example 분석1, Write TFRecord

Small Octopus 2018. 4. 13. 22:26

cifar 10 estimator 예제를 살펴보자!

https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator

우선 READ.md부터 읽어보면  TF 1.2.1 이상버전을 설치하라고 한다.

그리고 cifar10 db를 받아서 TFRecord file을 만들라고 한다. TFRecord은 학습시 대량의 이미지를 

빠르게 읽을 수 있게해주는 포멧이다.

$python generate_cifar10_tfrecords.py --data-dir=${PWD}/cifar-10-data

을 실행하면 이미지를 다운로드하고 TFRecord도 만들어 준다는데 일단 이것부터 살펴보자


1. CIFAE 이미지 다운로드

 텐서플로우에서 다운로드는 function으로 제공한다.

import tensorflow as tf

tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir, CIFAR_DOWNLOAD_URL)

CIFAR_FILENAME = 'cifar-10-python.tar.gz'

CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME

data_dir = 저장되는 경로


2.  압축풀기

이렇게 다운 받아 보면 cifar-10-python.tar.gz 파일이 다운 되어 있고 압축을 풀려면 

import tarfile

import os

tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),'r:gz').extractall(data_dir)

압축을 풀면 cifar-10-batches-py란 폴더안에 data_batch_1~5, batches.meta, test_batch 이렇게 7개의 데이터가 있다.(readme.html 빼고)

data_batch_1~4는 Train용, data_batch_5는 Validation용, test_batch는 Evaluation용.


3. pickle 풀기

현재 데이터들이 이상한 바이러니 파일 같이 보인다.. 그런데 이게 'pickle'을 이용해서 만들어진 것이다.

'pickle'을 이용해서 풀어보자, 일단 바이러니로 로드후 pickle을 이용 풀수 있다.

import pickle

f = tf.gfile.Open( './cifar-10-batches-py/data_batch_1', 'rb')

data_dict = pickle.load(f)

print data_dict['data'], data_dict['data'].shape

print data_dict['labels']data_dict['labels'].shape

이렇게 해보면 10000개의 이미지가 10000 x 3072의 형태로  정장되어있음을 알 수 있다.


4. TFRecord로 저장

이제 for loop을 10000번 돌면서 TFRecord로 저장하면 된다.

TFRecord로 파일 열기

tfr = tf.python_io.TFRecordWriter('./xxx.tfrecords')

TFRecord로 파일에 쓰기는 방법은 파이선 딕셔너리 형태를 사용하면된다.

feature = {

'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data[0].tobytes()]))

'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[0]]))

}

원래 data[0]는 numpy.ndarray, 넘파이 uint8 n dimensional array.. 였는데 이거를 바이트로 바꾸고 

또 텐서플로우의 바이트리스트로 바꾸고 그리고 또 다시 텐서의 Feature로 바꾼다... 으.... 

labels는 int 리스트 , labels[0]은 그냥 int 이고 이거를 다시 [labels[0]]로 바꾸고 텐서 Int64List로 바꾸고 

그리고 텐서 Feature로 바꾼다...

위에서 만든 딕셔너리를 tf.train.Example로 바꿔야한다.

example = tf.train.Example(features=tf.train.Features( feature=feature ))

위에서 만들었던 feature를 또 텐서의 tf.train.Features로 변환해서 드디어 tf.train.Example(로 변환

이제 드디어 TF Record에 쓴다.

tfr.write( example.SerializeToString())

다썻으면 닫아줘야지...

tfr.close()





'DNN' 카테고리의 다른 글

tensorflow obj detection  (0) 2018.06.04
TensorFlow CIFAR10 Example 분석2, Read TFRecord  (0) 2018.04.16
transpose convolution visualization  (0) 2018.02.05
caffe conv layer  (0) 2017.11.24
minimal filtering algorithm, Shmuel Winograd  (0) 2017.10.12
Comments