IT実験のブログ

IT関連のツールの使い方など

freeze_graphツール、summarize_graphツール、label_imageツールを使ってみる

https://github.com/tensorflow/models/tree/r1.13.0/research/slim

こちらのサイトを参考にしながら、前回ビルドしたツールを使ってみます。

slimのディレクトリに移動します。

cd ~/AI/models/research/slim

download_and_convert_flowers.py を少し修正します。

gedit ./datasets/download_and_convert_flowers.py

f:id:itlab7:20200224140015p:plain

ダウンロードするファイルが消えないように、210行目をコメントアウトしました。

flowersのデータをダウンロードして、TFRecordフォーマットに変換します。

python3 download_and_convert_data.py \
    --dataset_name=flowers \
    --dataset_dir="${HOME}/AI/image_classifier/data/flowers"

f:id:itlab7:20200224140541p:plain

フォーマットが変換されました。ダウンロードしたファイルも残っています。 ダウンロードしたファイルの中には5種類の花の画像(jpg)が大量に入っているだけです。

inception v3 のチェックポイントをダウンロードしていきます。

mkdir ${HOME}/AI/image_classifier/checkpoints
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar -xvf inception_v3_2016_08_28.tar.gz
mv inception_v3.ckpt ${HOME}/AI/image_classifier/checkpoints
rm inception_v3_2016_08_28.tar.gz

f:id:itlab7:20200224141612p:plain

チェックポイントのダウンロードが完了しました。

GPUが使えない場合は、訓練用のプログラムを少し修正します。

gedit train_image_classifier.py

f:id:itlab7:20200224142447p:plain

43行目のclone_on_cpuをTrueにしました。 これでCPUでも訓練処理を実行できます。

訓練処理を実行します。ただし、今回はお試しなので、maxステップ数を20としています。

python3 train_image_classifier.py \
    --train_dir=${HOME}/AI/image_classifier/train_logs \
    --dataset_name=flowers \
    --dataset_split_name=train \
    --dataset_dir=${HOME}/AI/image_classifier/data/flowers \
    --model_name=inception_v3 \
    --max_number_of_steps=20

f:id:itlab7:20200224144328p:plain

ステップ数20なので、数分で完了しました。

f:id:itlab7:20200224144457p:plain

チェックポイントとGraphDefが生成されています。

Tensorboardで見てみます。

tensorboard --logdir=${HOME}/AI/image_classifier/train_logs

f:id:itlab7:20200224144810p:plain

ステップ数20なので、Lossが大きいです。ステップ数を増やして、訓練をしっかりするとLossが小さくなります。

f:id:itlab7:20200224144933p:plain

GraphDefも生成されていたので、グラフも表示できました。

f:id:itlab7:20200224145450p:plain

グラフを拡大して眺めることもできますが、この辺は難しくて良くわかりません。

生成したチェックポイントの評価をします。お試しなので、maxバッチ数は10としています。

python3 eval_image_classifier.py \
    --alsologtostderr \
    --checkpoint_path=${HOME}/AI/image_classifier/train_logs/model.ckpt-20 \
    --dataset_dir=${HOME}/AI/image_classifier/data/flowers \
    --dataset_name=flowers \
    --dataset_split_name=validation \
    --model_name=inception_v3 \
    --max_num_batches=10

f:id:itlab7:20200224151022p:plain

評価結果が表示されました。お試しなので、あてになりませんが、
eval/Accuracy[0.214]
eval/Recall_5[1]
と表示されています。

freeze_graphツールを使って、GraphDefとチェックポイントを合わせて、FrozenGraphDef形式に変換します。

cd ~/AI/tensorflow
bazel-bin/tensorflow/python/tools/freeze_graph \
  --input_graph=${HOME}/AI/image_classifier/train_logs/graph.pbtxt \
  --input_checkpoint=${HOME}/AI/image_classifier/train_logs/model.ckpt-20 \
  --input_binary=false \
  --output_graph=${HOME}/AI/image_classifier/frozen_flowers.pb \
  --output_node_names=InceptionV3/Predictions/Reshape_1

f:id:itlab7:20200224172003p:plain

python2でビルドしたせいか、エラーが出ました。

python2 にもTensorflow 1.13.1を入れます。

pip2 install tensorflow==1.13.1

もう一度freeze_graphを実行します。

bazel-bin/tensorflow/python/tools/freeze_graph \
  --input_graph=${HOME}/AI/image_classifier/train_logs/graph.pbtxt \
  --input_checkpoint=${HOME}/AI/image_classifier/train_logs/model.ckpt-20 \
  --input_binary=false \
  --output_graph=${HOME}/AI/image_classifier/frozen_flowers.pb \
  --output_node_names=InceptionV3/Predictions/Reshape_1

f:id:itlab7:20200224172425p:plain

FrozenGraphDefが生成されました。

そのFrozenGraphDefに対して、summarize_graphツールを使ってみます。

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
  --in_graph=${HOME}/AI/image_classifier/frozen_flowers.pb

f:id:itlab7:20200224173046p:plain

frozen_flowers.pb に関する情報が表示されました。 freeze_graphツール使用時に指定したノード名も表示されています。

label_imageツールを使ってみます。

bazel-bin/tensorflow/examples/label_image/label_image \
  --image=${HOME}/AI/imagenet/sunflower.jpg \
  --input_layer=input \
  --output_layer=InceptionV3/Predictions/Reshape_1 \
  --graph=${HOME}/AI/image_classifier/frozen_flowers.pb \
  --labels=${HOME}/AI/image_classifier/data/flowers/labels.txt \
  --input_mean=0 \
  --input_std=255

f:id:itlab7:20200224173930p:plain

inputがどうのこうのと、なんかエラーが出ています。freeze_graphする時にinputのノードも指定してやらないといけないのでしょうか。この辺がまだ良くわかりません。 label_imageツール自体は使えているようです。

今回はここまでです。