일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |
- Homomorphic Filter
- tf.train.match_filenames_once()
- porting
- Raspberry Pi
- Python
- Facial expression recognition
- ARM Processor
- deep-learning
- tf.train.string_input_producer()
- preprocessing
- I.MX6Q
- IOT
- tf.saver()
- TensorFlow
- cross compile
- Embedded System
- OpenCV
- Data Load
- VGGnet
- CNN
- Machine Vision
- Machine learning
- Today
- Total
Austin's_Lab
[Tensorflow] Training data를 load하는 방법과 model을 save하다 발생한 문제 본문
[Tensorflow] Training data를 load하는 방법과 model을 save하다 발생한 문제
Ausome_(Austin_is_Awesome) 2017. 3. 16. 22:09-Tensorflow 0.12.0 사용-
Tensorflow에서 training image data를 load하기 위해서는 tf.train.string_input_producer()를 이용해 queue에 넣어줘야 한다. 외에도 TFRecord를 쓴다던지, binary file을 만들어 쓰거나 csv file format을 쓰는 등 다양한 방법이 있지만, 처음 Tensorflow를 접했을 땐 tf.train.string_input_producer를 찾은 것만으로도 굉장한 수확이었다.
Data가 많아질수록 TFRecord를 사용하는게 훨씬 빠르다고는 하던데 학습 모델을 만들고 hyper parameter 조절이나 데이터 전처리 작업 연구에 집중하다보니 TFRecord를 만들고 사용하는 법에 대해 찾아볼 시간이 없었다. 여태까지는 주어진 데이터가 워낙 적어서 tf.train.string_input_producer로도 빠른 학습이 가능했었지만, Data augmentation과 전처리 작업을 진행하고 보니 Training data가 너무 많아져서 학습속도가 미친듯이 느려졌기 때문에 이제는 TFRecord를 써야할 것 같다. 조만간 TFRecord에 대해 공부해보고, 속도 차이를 비교해봐도 재밌을 것 같다.
(열심히 Data augmentation을 하고보니 tf.image에 data augmentation에 대한 함수들이 있었다. insert TODO.)
MNIST나 Cifar같은 오픈된 training set이 아닌 자연의 training data를 load해서 쓰면서 발생했던 실수에 대해 말하고자 한다.
Training data들은 각 class별로 서로 다른 directory에 있다. 지금은 모든 image data의 절대경로와 label을 한 줄에 한 쌍씩 짝지어서 텍스트 파일로 만들어 data를 load하고 있지만(이것도 csv 할줄 몰라서 어영부영 흉내낸..), 처음 data load 방법을 알았을 때는 정말 무식하게 각 directory를 통째로 넘겨줌으로써 data를 queue에 넣었다.
여기저기 물어보고 찾아본 결과 tf.saver나 queue 둘 중 하나의 문제인 것 같았다. 구글에 널려있는 model save & restore 예제들에서는 tf.saver에 대한 다른 특수한 옵션을 주지 않았었기 때문에 아마도 queue의 문제일 것이라고 생각했다. 그래서 training phase에서 validation data를 쓰지 않고 학습을 시킨 뒤 test phase에서 test data를 load하려고 하니 matched file이 없다는 에러가 났다. 에러 메시지가 묘하게 익숙한 느낌이어서 코드를 천천히 살펴보니 tf.train.string_input_producer에 데이터를 전달해주는 일을 tf.train.match_filenames_once()라는 함수가 하고있었다. 혹시나 해서 해당 함수를 이렇게 바꿨더니 문제가 해결되었다.