Classification Examples¶
The classification examples use a quantized starter model from TensorFlow Lite:
- i.MX 8M Plus:
mobilenet_v1_1.0_224_quant.tflite
mobilenet_v2_1.0_224_quant.tflite
Image Classification¶
Run the Image Classification Example on MPlus¶
Retrieve the example, and execute it on the SoM:
curl -LJO https://github.com/varigit/pyvar/raw/master/examples/ml/classification/image_classification_tflite.py
python3 image_classification_tflite.py
The output should be similar as the one below:
Image Example |
Image Example Classified |
---|---|
1# Copyright 2021 Variscite LTD
2# SPDX-License-Identifier: BSD-3-Clause
3
4"""
5This script performs image classification using the TFLiteInterpreter engine.
6
7It performs the following steps:
8
91. Retrieves the classification package using a HTTPS retriever instance.
102. Loads the labels from the label file.
113. Creates an TFLiteInterpreter engine instance and a resizer instance.
124. Resizes the input image to the engine's input size.
135. Runs inference and gets the result.
146. Creates an overlay instance and draws the output image with the
15 classification result and other information.
167. Shows the output image using the Multimedia helper.
17
18Example:
19
20To run this script:
21 $ python3 image_classification_tflite.py
22
23Args:
24None.
25
26Returns:
27None.
28"""
29
30from argparse import ArgumentParser
31
32from pyvar.ml.engines.tflite import TFLiteInterpreter
33from pyvar.ml.utils.label import Label
34from pyvar.ml.utils.overlay import Overlay
35from pyvar.ml.utils.resizer import Resizer
36from pyvar.ml.utils.retriever_https import HTTPS
37from pyvar.multimedia.helper import Multimedia
38
39https = HTTPS()
40parser = ArgumentParser()
41parser.add_argument('--num_threads', type=int)
42args = parser.parse_args()
43args.num_threads = 2
44
45if https.retrieve_package(category="classification"):
46 model_file_path = https.model
47 label_file_path = https.label
48 image_file_path = https.image
49
50labels = Label(label_file_path)
51labels.read_labels("classification")
52
53engine = TFLiteInterpreter(model_file_path=model_file_path,
54 num_threads=args.num_threads)
55
56resizer = Resizer()
57resizer.set_sizes(engine_input_details=engine.input_details)
58
59image = Multimedia(image_file_path)
60resizer.resize_image(image.video_src)
61
62engine.set_input(resizer.image_resized)
63engine.run_inference()
64engine.get_result("classification")
65
66draw = Overlay()
67
68output_image = draw.info(category="classification",
69 image=resizer.image,
70 top_result=engine.result,
71 labels=labels.list,
72 inference_time=engine.inference_time,
73 model_name=model_file_path,
74 source_file=resizer.image_path)
75
76image.show_image("TFLite: Image Classification", output_image)
Video Classification¶
Run the Video Classification Example on MPlus¶
Retrieve the example, and execute it on the SoM:
curl -LJO https://github.com/varigit/pyvar/raw/master/examples/ml/classification/video_classification_tflite.py
python3 video_classification_tflite.py
The output should be similar as the one below:
Video Example |
Video Example Classified |
---|---|
1# Copyright 2021 Variscite LTD
2# SPDX-License-Identifier: BSD-3-Clause
3
4"""
5This script performs video classification using the TFLiteInterpreter engine.
6
7It performs the following steps:
8
91. Retrieves the classification package using a HTTPS retriever instance.
102. Loads the labels from the label file.
113. Creates an TFLiteInterpreter engine instance and a resizer instance.
124. Resizes each frame of the input video to the engine's input size.
135. Runs inference and gets the result for each frame.
146. Creates an overlay instance and draws the output image with the
15 classification result and other information for each frame.
167. Shows the output video using the Multimedia helper.
17
18Example:
19
20To run this script:
21 $ python3 video_classification_tflite.py
22
23Args:
24None.
25
26Returns:
27None.
28"""
29
30from argparse import ArgumentParser
31
32from pyvar.ml.engines.tflite import TFLiteInterpreter
33from pyvar.ml.utils.label import Label
34from pyvar.ml.utils.overlay import Overlay
35from pyvar.ml.utils.resizer import Resizer
36from pyvar.ml.utils.retriever_https import HTTPS
37from pyvar.multimedia.helper import Multimedia
38
39https = HTTPS()
40parser = ArgumentParser()
41parser.add_argument('--num_threads', type=int)
42args = parser.parse_args()
43args.num_threads = 2
44
45if https.retrieve_package(category="classification"):
46 model_file_path = https.model
47 label_file_path = https.label
48 video_file_path = https.video
49
50labels = Label(label_file_path)
51labels.read_labels("classification")
52
53engine = TFLiteInterpreter(model_file_path=model_file_path,
54 num_threads=args.num_threads)
55
56resizer = Resizer()
57resizer.set_sizes(engine_input_details=engine.input_details)
58
59video = Multimedia(video_file_path)
60video.set_v4l2_config()
61
62draw = Overlay()
63
64while video.loop:
65 frame = video.get_frame()
66 resizer.resize_frame(frame)
67
68 engine.set_input(resizer.frame_resized)
69 engine.run_inference()
70 engine.get_result("classification")
71
72 output_frame = draw.info(category="classification",
73 image=resizer.frame,
74 top_result=engine.result,
75 labels=labels.list,
76 inference_time=engine.inference_time,
77 model_name=model_file_path,
78 source_file=video.video_src)
79
80 video.show("TFLite: Video Classification", output_frame)
81
82video.destroy()
Real Time Classification¶
Run the Real Time Classification Example¶
Retrieve the example, and execute it on the SoM:
curl -LJO https://github.com/varigit/pyvar/raw/master/examples/ml/classification/realtime_classification_tflite.py
python3 realtime_classification_tflite.py
1# Copyright 2021 Variscite LTD
2# SPDX-License-Identifier: BSD-3-Clause
3
4"""
5This script performs real-time video classification using the TFLiteInterpreter
6engine.
7
8It performs the following steps:
9
101. Retrieves the classification package using a HTTPS retriever instance.
112. Loads the labels from the label file.
123. Creates a Multimedia instance with the video device path and resolution, and
13 sets the video for Linux (V4L2) configuration.
144. Creates a Framerate instance to calculate the frames per second (FPS) of the
15 camera stream.
165. Creates an TFLiteInterpreter engine instance and a resizer instance.
176. Resizes each frame of the input video to the engine's input size.
187. Runs inference and gets the result for each frame.
198. Creates an overlay instance and draws the output image with the
20 classification result and other information for each frame.
219. Shows the output video using the Multimedia helper.
22
23Example:
24
25To run this script:
26 $ python3 video_classification_tflite.py
27
28Args:
29None.
30
31Returns:
32None.
33"""
34
35from argparse import ArgumentParser
36
37from pyvar.ml.engines.tflite import TFLiteInterpreter
38from pyvar.ml.utils.framerate import Framerate
39from pyvar.ml.utils.label import Label
40from pyvar.ml.utils.overlay import Overlay
41from pyvar.ml.utils.resizer import Resizer
42from pyvar.ml.utils.retriever_https import HTTPS
43from pyvar.multimedia.helper import Multimedia
44
45https = HTTPS()
46parser = ArgumentParser()
47parser.add_argument('--num_threads', type=int)
48args = parser.parse_args()
49args.num_threads = 2
50
51if https.retrieve_package(category="classification"):
52 model_file_path = https.model
53 label_file_path = https.label
54
55labels = Label(label_file_path)
56labels.read_labels("classification")
57
58engine = TFLiteInterpreter(model_file_path=model_file_path,
59 num_threads=args.num_threads)
60
61resizer = Resizer()
62resizer.set_sizes(engine_input_details=engine.input_details)
63
64camera = Multimedia("/dev/video1", resolution="vga")
65camera.set_v4l2_config()
66
67framerate = Framerate()
68
69draw = Overlay()
70draw.framerate_info = True
71
72while camera.loop:
73 with framerate.fpsit():
74 frame = camera.get_frame()
75 resizer.resize_frame(frame)
76
77 engine.set_input(resizer.frame_resized)
78 engine.run_inference()
79 engine.get_result("classification")
80
81 output_frame = draw.info(category="classification",
82 image=resizer.frame,
83 top_result=engine.result,
84 labels=labels.list,
85 inference_time=engine.inference_time,
86 model_name=model_file_path,
87 source_file=camera.dev.name,
88 fps=framerate.fps)
89
90 camera.show("TFLite: Real Time Classification", output_frame)
91
92camera.destroy()