TensorFlow Lite — Flutter
TensorFlow Lite Flutter plugin provides a flexible and fast solution for accessing TensorFlow Lite interpreter and performing inference. The API is similar to the TFLite Java and Swift APIs. It directly binds to TFLite C API making it efficient (low-latency). Offers acceleration support using NNAPI, GPU delegates on Android, Metal and CoreML delegates on iOS, and XNNPack delegate on Desktop platforms.
Examples
Current tensorflow lite flutter official example covers various models.
Lets see some pieces of the code how we can implement the plugin
Installation
add dependencies in pubspec.yaml
file
tflite_flutter: ^0.10.1
Dart Code
The way we implement is first we loadModel
and loadLabels
from assets
/// Load tflite model from assets
Future<void> _loadModel() async {
print('Loading interpreter options...');
final interpreterOptions = InterpreterOptions();
// Use XNNPACK Delegate
if (Platform.isAndroid) {
interpreterOptions.addDelegate(XNNPackDelegate());
}
// Use Metal Delegate
if (Platform.isIOS) {
interpreterOptions.addDelegate(GpuDelegate());
}
print('Loading interpreter...');
_interpreter =
await Interpreter.fromAsset(_modelPath, options: interpreterOptions);
}
/// Load Labels from assets
Future<void> _loadLabels() async {
print('Loading labels...');
final labelsRaw = await rootBundle.loadString(_labelPath);
_labels = labelsRaw.split('\n');
}
Next all we need to do is obtain the image as Uint8List
and pass it to the
Once we obtained the Uint8List
we need to resize the image to fixed size like 300x300
then convert the image into Matrix representation [300, 300, 3]
We need to run the matrix trough the model to obtain the result
// Reading image bytes from file
final imageData = File(imagePath).readAsBytesSync();
// Decoding image
final image = img.decodeImage(imageData);
// Resizing image fpr model, [300, 300]
final imageInput = img.copyResize(
image!,
width: 300,
height: 300,
);
// Creating matrix representation, [300, 300, 3]
final imageMatrix = List.generate(
imageInput.height,
(y) => List.generate(
imageInput.width,
(x) {
final pixel = imageInput.getPixel(x, y);
return [pixel.r, pixel.g, pixel.b];
},
),
);
// pass the imageMatrix to run on model
final output = _runInference(imageMatrix);
Running the ImageMatrix via Model Interpreter
List<List<Object>> _runInference(
List<List<List<num>>> imageMatrix,
) {
print('Running inference...');
// Set input tensor [1, 300, 300, 3]
final input = [imageMatrix];
// Set output tensor
// Locations: [1, 10, 4]
// Classes: [1, 10],
// Scores: [1, 10],
// Number of detections: [1]
final output = {
0: [List<List<num>>.filled(10, List<num>.filled(4, 0))],
1: [List<num>.filled(10, 0)],
2: [List<num>.filled(10, 0)],
3: [0.0],
};
_interpreter!.runForMultipleInputs([input], output);
return output.values.toList();
}
Next we need to process the output to map the detection to a label and draw the boundaries for the detected object
print('Processing outputs...');
// Location
final locationsRaw = output.first.first as List<List<double>>;
final locations = locationsRaw.map((list) {
return list.map((value) => (value * 300).toInt()).toList();
}).toList();
print('Locations: $locations');
// Classes
final classesRaw = output.elementAt(1).first as List<double>;
final classes = classesRaw.map((value) => value.toInt()).toList();
print('Classes: $classes');
// Scores
final scores = output.elementAt(2).first as List<double>;
print('Scores: $scores');
// Number of detections
final numberOfDetectionsRaw = output.last.first as double;
final numberOfDetections = numberOfDetectionsRaw.toInt();
print('Number of detections: $numberOfDetections');
print('Classifying detected objects...');
final List<String> classication = [];
for (var i = 0; i < numberOfDetections; i++) {
classication.add(_labels![classes[i]]);
}
print('Outlining objects...');
for (var i = 0; i < numberOfDetections; i++) {
if (scores[i] > 0.6) {
// Rectangle drawing
img.drawRect(
imageInput,
x1: locations[i][1],
y1: locations[i][0],
x2: locations[i][3],
y2: locations[i][2],
color: img.ColorRgb8(255, 0, 0),
thickness: 3,
);
// Label drawing
img.drawString(
imageInput,
'${classication[i]} ${scores[i]}',
font: img.arial14,
x: locations[i][1] + 1,
y: locations[i][0] + 1,
color: img.ColorRgb8(255, 0, 0),
);
}
}
print('Done.');
final outputImage = img.encodeJpg(imageInput);
now the outputImage
contains the image with mapped objects and their score.
Run using the live camera feed
In order to run the same detection on live camera feed. If we do same as above we might end up doing more work on the main thread we need to use isolates
to create another thread so we can run all the detection related code inside the newly created isolate
.
First Lets see how we can get the camera feed
Camera Package
camera: anyCameraController? _cameraController;
late List<CameraDescription> cameras;
void _initializeCamera() async {
cameras = await availableCameras();
// cameras[0] for back-camera
_cameraController = CameraController(
cameras[0],
ResolutionPreset.medium,
enableAudio: false,
)..initialize().then((_) async {
await _controller.startImageStream(onLatestImageAvailable);
setState(() {});
/// previewSize is size of each image frame captured by controller
///
/// 352x288 on iOS, 240p (320x240) on Android with ResolutionPreset.low
ScreenParams.previewSize = _controller.value.previewSize!;
});
}
void onLatestImageAvailable(CameraImage cameraImage) async {
// We can pass the cameraImage to the Isolate to get it processed
}
In this current article we will see how to get the Image
from the CameraImage
so we can be able to process it we can use this image_utils
from the example repo of the package. added link below to the file.
Once we have the image we can be able to process the image as per out previous workflow and return the result back from the isolate
.
Here we reached the end of the code
Hope you liked it, Thanks for reading :)