Austin's_Lab

[Tensorflow] Training data를 load하는 방법과 model을 save하다 발생한 문제 본문

Machine Learning

[Tensorflow] Training data를 load하는 방법과 model을 save하다 발생한 문제

Ausome_(Austin_is_Awesome) 2017. 3. 16. 22:09

-Tensorflow 0.12.0 사용-


Tensorflow에서 training image data를 load하기 위해서는 tf.train.string_input_producer()를 이용해 queue에 넣어줘야 한다. 외에도 TFRecord를 쓴다던지, binary file을 만들어 쓰거나 csv file format을 쓰는 등 다양한 방법이 있지만, 처음 Tensorflow를 접했을 땐 tf.train.string_input_producer를 찾은 것만으로도 굉장한 수확이었다.


Data가 많아질수록 TFRecord를 사용하는게 훨씬 빠르다고는 하던데 학습 모델을 만들고 hyper parameter 조절이나 데이터 전처리 작업 연구에 집중하다보니 TFRecord를 만들고 사용하는 법에 대해 찾아볼 시간이 없었다. 여태까지는 주어진 데이터가 워낙 적어서 tf.train.string_input_producer로도 빠른 학습이 가능했었지만,  Data augmentation과 전처리 작업을 진행하고 보니 Training data가 너무 많아져서 학습속도가 미친듯이 느려졌기 때문에 이제는 TFRecord를 써야할 것 같다. 조만간 TFRecord에 대해 공부해보고, 속도 차이를 비교해봐도 재밌을 것 같다.

(열심히 Data augmentation을 하고보니 tf.image에 data augmentation에 대한 함수들이 있었다. insert TODO.)


MNIST나 Cifar같은 오픈된 training set이 아닌 자연의 training data를 load해서 쓰면서 발생했던 실수에 대해 말하고자 한다.


Training data들은 각 class별로 서로 다른 directory에 있다. 지금은 모든 image data의 절대경로와 label을 한 줄에 한 쌍씩 짝지어서 텍스트 파일로 만들어 data를 load하고 있지만(이것도 csv 할줄 몰라서 어영부영 흉내낸..), 처음 data load 방법을 알았을 때는 정말 무식하게 각 directory를 통째로 넘겨줌으로써 data를 queue에 넣었다. 

이런식으로 image를 load해서 batch로 묶은 뒤 학습을 진행했다.(함수로 만들어 사용)

학습을 마치고 tf.saver로 모델과 변수들을 저장한뒤, 다시 restore해서 training data와 전혀 다른 distribution의 wild image로 테스트를 진행했는데 놀랍게도 1-2%정도의 에러율을 보였다. 이후 에러가 나는 이미지가 어떤 특징을 갖고있는 지 확인해보기위해 pyplot으로 해당 이미지를 띄워보니 화면의 나타난 이미지는 test에 쓰인 wild image가 아니라 cross validation을 위해 load했던 training data중의 하나였다. 이상함을 느끼고 test data를 다른 이미지로 바꿔서 load해봤지만 결과는 마찬가지로 training data가 화면에 띄워졌다. 결론적으로 wild image가 아니라 training data를 그대로 가져다 테스트를 진행한 셈이다. 뭐 overfitting은 확실하다는 거니까 그걸로라도 작은 위안을 삼았다.

Training phase에서 load한 데이터가 tf.saver에 함께 저장돼버려서 다시 restore할 때 그대로 queue에 들어가는 것 같았다. 정말 그런지 확인해보기위해 training phase에서 validation data를 wild image로 load해서 학습을 진행하고 test phase에서 training data를 load해보니 역시 training phase에서 load되었던 데이터가 그대로 queue에 들어갔다. wild image를 통해 확인한 학습 결과는 말할 것도 없이 참담했다.


여기저기 물어보고 찾아본 결과 tf.saver나 queue 둘 중 하나의 문제인 것 같았다. 구글에 널려있는 model save & restore 예제들에서는 tf.saver에 대한 다른 특수한 옵션을 주지 않았었기 때문에 아마도 queue의 문제일 것이라고 생각했다. 그래서 training phase에서 validation data를 쓰지 않고 학습을 시킨 뒤 test phase에서 test data를 load하려고 하니 matched file이 없다는 에러가 났다. 에러 메시지가 묘하게 익숙한 느낌이어서 코드를 천천히 살펴보니 tf.train.string_input_producer에 데이터를 전달해주는 일을 tf.train.match_filenames_once()라는 함수가 하고있었다. 혹시나 해서 해당 함수를 이렇게 바꿨더니 문제가 해결되었다.


역시나 tf.train.match_filenames_once가 문제였다. tf.saver가 저 함수도 같이 저장해버리나보다. 어쩐지 once라는 단어가 찝찝하긴했다. Tensorflow 공식 홈페이지에서는 tf.train.match_filenames_once를 쓰라고 설명해줬었는데. 그냥 본인들이 만든 TFRecord 형식 쓰라는 말인가보다.

다음엔 데이터들을 텍스트 파일로 저장해서 관리하고, load하는 방법에 대해서 올려야겠다.


Comments