The Beautiful Future
TensorFlow CIFAR10 Example 분석1, Write TFRecord 본문
cifar 10 estimator 예제를 살펴보자!
https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator
우선 READ.md부터 읽어보면 TF 1.2.1 이상버전을 설치하라고 한다.
그리고 cifar10 db를 받아서 TFRecord file을 만들라고 한다. TFRecord은 학습시 대량의 이미지를
빠르게 읽을 수 있게해주는 포멧이다.
$python generate_cifar10_tfrecords.py --data-dir=${PWD}/cifar-10-data
을 실행하면 이미지를 다운로드하고 TFRecord도 만들어 준다는데 일단 이것부터 살펴보자
1. CIFAE 이미지 다운로드
텐서플로우에서 다운로드는 function으로 제공한다.
import tensorflow as tf
tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir, CIFAR_DOWNLOAD_URL)
CIFAR_FILENAME = 'cifar-10-python.tar.gz'
CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
data_dir = 저장되는 경로
2. 압축풀기
이렇게 다운 받아 보면 cifar-10-python.tar.gz 파일이 다운 되어 있고 압축을 풀려면
import tarfile
import os
tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),'r:gz').extractall(data_dir)
압축을 풀면 cifar-10-batches-py란 폴더안에 data_batch_1~5, batches.meta, test_batch 이렇게 7개의 데이터가 있다.(readme.html 빼고)
data_batch_1~4는 Train용, data_batch_5는 Validation용, test_batch는 Evaluation용.
3. pickle 풀기
현재 데이터들이 이상한 바이러니 파일 같이 보인다.. 그런데 이게 'pickle'을 이용해서 만들어진 것이다.
'pickle'을 이용해서 풀어보자, 일단 바이러니로 로드후 pickle을 이용 풀수 있다.
import pickle
f = tf.gfile.Open( './cifar-10-batches-py/data_batch_1', 'rb')
data_dict = pickle.load(f)
print data_dict['data'], data_dict['data'].shape
print data_dict['labels'], data_dict['labels'].shape
이렇게 해보면 10000개의 이미지가 10000 x 3072의 형태로 정장되어있음을 알 수 있다.
4. TFRecord로 저장
이제 for loop을 10000번 돌면서 TFRecord로 저장하면 된다.
TFRecord로 파일 열기
tfr = tf.python_io.TFRecordWriter('./xxx.tfrecords')
TFRecord로 파일에 쓰기는 방법은 파이선 딕셔너리 형태를 사용하면된다.
feature = {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data[0].tobytes()]))
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[0]]))
}
원래 data[0]는 numpy.ndarray, 넘파이 uint8 n dimensional array.. 였는데 이거를 바이트로 바꾸고
또 텐서플로우의 바이트리스트로 바꾸고 그리고 또 다시 텐서의 Feature로 바꾼다... 으....
labels는 int 리스트 , labels[0]은 그냥 int 이고 이거를 다시 [labels[0]]로 바꾸고 텐서 Int64List로 바꾸고
그리고 텐서 Feature로 바꾼다...
위에서 만든 딕셔너리를 tf.train.Example로 바꿔야한다.
example = tf.train.Example(features=tf.train.Features( feature=feature ))
위에서 만들었던 feature를 또 텐서의 tf.train.Features로 변환해서 드디어 tf.train.Example(로 변환
이제 드디어 TF Record에 쓴다.
tfr.write( example.SerializeToString())
다썻으면 닫아줘야지...
tfr.close()
'DNN' 카테고리의 다른 글
tensorflow obj detection (0) | 2018.06.04 |
---|---|
TensorFlow CIFAR10 Example 분석2, Read TFRecord (0) | 2018.04.16 |
transpose convolution visualization (0) | 2018.02.05 |
caffe conv layer (0) | 2017.11.24 |
minimal filtering algorithm, Shmuel Winograd (0) | 2017.10.12 |