YOLOv8モデルをTensorFlowで動作させるには、PyTorchでトレーニング済みのモデルをTensorFlow Lite(tflite)形式に変換し、独自の入力前処理・推論後処理を実装する必要があります。
YOLOv8は、ultralyticsのライブラリを使うと簡単に変換できます。
import yaml
from ultralytics import YOLO
# yolo8n.pt to tensorflow lite
model = YOLO('yolov8n.pt')
model.export(format='tflite')
# metadata.yaml to labels.txt
metadata = yaml.load(open('yolov8n_saved_model/metadata.yaml', 'r'), Loader=yaml.SafeLoader)
labels = metadata['names']
with open('labels.txt', 'w') as f:
for label in labels:
f.write(labels[label] + '\n')
これで、yolov8n_saved_model/yolov8n_float[16or32].tflite というTensorFlow Liteモデルと、labels.txtが生成されます。
tfliteモデルにデータを入力する際には、正確な前処理が必要です。
特殊な入力形式
YOLOv8のtfliteモデルは、BRGの順序で、640×640の画像データを入力する必要があります。
前処理コード
以下は、画像をtfliteモデルに適した形式に変換する例です。
// 画像をリサイズ。スレッドで実行
final cmd = (Command()
..image(image)
..copyResize(
width: tensorWidth, // tfliteの入力幅(640など)
height: tensorHeight, // tfliteの入力の高さ(640など)
interpolation: Interpolation.nearest,
));
var resizedImage = await cmd.getImageThread();
// BRGの順に並び替え、正規化した配列を取得
var tensor = Float32List(resizedImage.width * resizedImage.height * 3);
var i = 0;
for (var y = 0; y [];
for (var e = 0; e maxConfidence) {
maxClass = c;
maxConfidence = confidence;
}
}
// 確信度が閾値を超えていない場合はスキップ
if (maxClass == -1) {
continue;
}
// 検出結果を追加
final rect = _getRectFromTensor(t, e);
if (rect.left 1 || rect.top 1) {
// 画像の範囲外の場合はスキップ
continue;
}
results.add(Detection(
label: _labels[maxClass],
labelId: maxClass,
confidence: maxConfidence,
rect: rect,
));
}
return results;
}
/// テンソルから指定したクラスの信頼度を取得
/// @param t テンソル
/// @param e 要素のインデックス
/// @param c クラスのインデックス
double _getConfidenceFromTensor(Float32List t, int e, int c) {
return t[e + (c + 4) * numElements];
}
/// テンソルから中心座標を取得
///
/// @param t テンソル
/// @param e 要素のインデックス
Rect _getRectFromTensor(Float32List t, int e) {
return Rect.fromCenter(
center: Offset(
t[e],
t[e + numElements],
),
width: t[e + 2 * numElements],
height: t[e + 3 * numElements]);
}
class Detection {
final String label;
final int labelId;
final double confidence;
final Rect rect;
Detection({
required this.label,
required this.labelId,
required this.confidence,
required this.rect,
});
Rect scaledRect(double width, double height) {
return Rect.fromLTRB(
rect.left * width,
rect.top * height,
rect.right * width,
rect.bottom * height,
);
}
}
全体のコードサンプルは、こちらにアップしています。
https://github.com/kobesoft-inc/yolov8_flutter_example
これで高精度なYOLOv8モデルをスマートフォン上で実行することができました。
スマホで各種物体検出を行うAIアプリの開発なら、ぜひ一度、ご相談ください。