Classification Examples

The classification examples use a quantized starter model from TensorFlow Lite:

  • i.MX 8M Plus:
    • mobilenet_v1_1.0_224_quant.tflite

    • mobilenet_v2_1.0_224_quant.tflite

Image Classification

Run the Image Classification Example on MPlus

  1. Retrieve the example, and execute it on the SoM:

curl -LJO https://github.com/varigit/pyvar/raw/master/examples/ml/classification/image_classification_tflite.py
python3 image_classification_tflite.py
  1. The output should be similar as the one below:

Image Example

Image Example Classified

car-plus

car-converted-plus

Image Classification Example Source Code for MPlus: image_classification_tflite.py
 1# Copyright 2021 Variscite LTD
 2# SPDX-License-Identifier: BSD-3-Clause
 3
 4"""
 5This script performs image classification using the TFLiteInterpreter engine.
 6
 7It performs the following steps:
 8
 91. Retrieves the classification package using a HTTPS retriever instance.
102. Loads the labels from the label file.
113. Creates an TFLiteInterpreter engine instance and a resizer instance.
124. Resizes the input image to the engine's input size.
135. Runs inference and gets the result.
146. Creates an overlay instance and draws the output image with the
15   classification result and other information.
167. Shows the output image using the Multimedia helper.
17
18Example:
19
20To run this script:
21    $ python3 image_classification_tflite.py
22
23Args:
24None.
25
26Returns:
27None.
28"""
29
30from argparse import ArgumentParser
31
32from pyvar.ml.engines.tflite import TFLiteInterpreter
33from pyvar.ml.utils.label import Label
34from pyvar.ml.utils.overlay import Overlay
35from pyvar.ml.utils.resizer import Resizer
36from pyvar.ml.utils.retriever_https import HTTPS
37from pyvar.multimedia.helper import Multimedia
38
39https = HTTPS()
40parser = ArgumentParser()
41parser.add_argument('--num_threads', type=int)
42args = parser.parse_args()
43args.num_threads = 2
44
45if https.retrieve_package(category="classification"):
46    model_file_path = https.model
47    label_file_path = https.label
48    image_file_path = https.image
49
50labels = Label(label_file_path)
51labels.read_labels("classification")
52
53engine = TFLiteInterpreter(model_file_path=model_file_path,
54                           num_threads=args.num_threads)
55
56resizer = Resizer()
57resizer.set_sizes(engine_input_details=engine.input_details)
58
59image = Multimedia(image_file_path)
60resizer.resize_image(image.video_src)
61
62engine.set_input(resizer.image_resized)
63engine.run_inference()
64engine.get_result("classification")
65
66draw = Overlay()
67
68output_image = draw.info(category="classification",
69                         image=resizer.image,
70                         top_result=engine.result,
71                         labels=labels.list,
72                         inference_time=engine.inference_time,
73                         model_name=model_file_path,
74                         source_file=resizer.image_path)
75
76image.show_image("TFLite: Image Classification", output_image)



Video Classification

Run the Video Classification Example on MPlus

  1. Retrieve the example, and execute it on the SoM:

curl -LJO https://github.com/varigit/pyvar/raw/master/examples/ml/classification/video_classification_tflite.py
python3 video_classification_tflite.py
  1. The output should be similar as the one below:

Video Example

Video Example Classified

street-plus

street-classified-plus

Video Classification Example Source code for MPlus: video_classification_tflite.py
 1# Copyright 2021 Variscite LTD
 2# SPDX-License-Identifier: BSD-3-Clause
 3
 4"""
 5This script performs video classification using the TFLiteInterpreter engine.
 6
 7It performs the following steps:
 8
 91. Retrieves the classification package using a HTTPS retriever instance.
102. Loads the labels from the label file.
113. Creates an TFLiteInterpreter engine instance and a resizer instance.
124. Resizes each frame of the input video to the engine's input size.
135. Runs inference and gets the result for each frame.
146. Creates an overlay instance and draws the output image with the
15   classification result and other information for each frame.
167. Shows the output video using the Multimedia helper.
17
18Example:
19
20To run this script:
21    $ python3 video_classification_tflite.py
22
23Args:
24None.
25
26Returns:
27None.
28"""
29
30from argparse import ArgumentParser
31
32from pyvar.ml.engines.tflite import TFLiteInterpreter
33from pyvar.ml.utils.label import Label
34from pyvar.ml.utils.overlay import Overlay
35from pyvar.ml.utils.resizer import Resizer
36from pyvar.ml.utils.retriever_https import HTTPS
37from pyvar.multimedia.helper import Multimedia
38
39https = HTTPS()
40parser = ArgumentParser()
41parser.add_argument('--num_threads', type=int)
42args = parser.parse_args()
43args.num_threads = 2
44
45if https.retrieve_package(category="classification"):
46    model_file_path = https.model
47    label_file_path = https.label
48    video_file_path = https.video
49
50labels = Label(label_file_path)
51labels.read_labels("classification")
52
53engine = TFLiteInterpreter(model_file_path=model_file_path,
54                           num_threads=args.num_threads)
55
56resizer = Resizer()
57resizer.set_sizes(engine_input_details=engine.input_details)
58
59video = Multimedia(video_file_path)
60video.set_v4l2_config()
61
62draw = Overlay()
63
64while video.loop:
65    frame = video.get_frame()
66    resizer.resize_frame(frame)
67
68    engine.set_input(resizer.frame_resized)
69    engine.run_inference()
70    engine.get_result("classification")
71
72    output_frame = draw.info(category="classification",
73                             image=resizer.frame,
74                             top_result=engine.result,
75                             labels=labels.list,
76                             inference_time=engine.inference_time,
77                             model_name=model_file_path,
78                             source_file=video.video_src)
79
80    video.show("TFLite: Video Classification", output_frame)
81
82video.destroy()



Real Time Classification

Run the Real Time Classification Example

  1. Retrieve the example, and execute it on the SoM:

curl -LJO https://github.com/varigit/pyvar/raw/master/examples/ml/classification/realtime_classification_tflite.py
python3 realtime_classification_tflite.py
Real Time Classification Example Source code: realtime_classification_tflite.py
 1# Copyright 2021 Variscite LTD
 2# SPDX-License-Identifier: BSD-3-Clause
 3
 4"""
 5This script performs real-time video classification using the TFLiteInterpreter
 6engine.
 7
 8It performs the following steps:
 9
101. Retrieves the classification package using a HTTPS retriever instance.
112. Loads the labels from the label file.
123. Creates a Multimedia instance with the video device path and resolution, and
13   sets the video for Linux (V4L2) configuration.
144. Creates a Framerate instance to calculate the frames per second (FPS) of the
15   camera stream.
165. Creates an TFLiteInterpreter engine instance and a resizer instance.
176. Resizes each frame of the input video to the engine's input size.
187. Runs inference and gets the result for each frame.
198. Creates an overlay instance and draws the output image with the
20   classification result and other information for each frame.
219. Shows the output video using the Multimedia helper.
22
23Example:
24
25To run this script:
26    $ python3 video_classification_tflite.py
27
28Args:
29None.
30
31Returns:
32None.
33"""
34
35from argparse import ArgumentParser
36
37from pyvar.ml.engines.tflite import TFLiteInterpreter
38from pyvar.ml.utils.framerate import Framerate
39from pyvar.ml.utils.label import Label
40from pyvar.ml.utils.overlay import Overlay
41from pyvar.ml.utils.resizer import Resizer
42from pyvar.ml.utils.retriever_https import HTTPS
43from pyvar.multimedia.helper import Multimedia
44
45https = HTTPS()
46parser = ArgumentParser()
47parser.add_argument('--num_threads', type=int)
48args = parser.parse_args()
49args.num_threads = 2
50
51if https.retrieve_package(category="classification"):
52    model_file_path = https.model
53    label_file_path = https.label
54
55labels = Label(label_file_path)
56labels.read_labels("classification")
57
58engine = TFLiteInterpreter(model_file_path=model_file_path,
59                           num_threads=args.num_threads)
60
61resizer = Resizer()
62resizer.set_sizes(engine_input_details=engine.input_details)
63
64camera = Multimedia("/dev/video1", resolution="vga")
65camera.set_v4l2_config()
66
67framerate = Framerate()
68
69draw = Overlay()
70draw.framerate_info = True
71
72while camera.loop:
73    with framerate.fpsit():
74        frame = camera.get_frame()
75        resizer.resize_frame(frame)
76
77        engine.set_input(resizer.frame_resized)
78        engine.run_inference()
79        engine.get_result("classification")
80
81        output_frame = draw.info(category="classification",
82                                 image=resizer.frame,
83                                 top_result=engine.result,
84                                 labels=labels.list,
85                                 inference_time=engine.inference_time,
86                                 model_name=model_file_path,
87                                 source_file=camera.dev.name,
88                                 fps=framerate.fps)
89
90        camera.show("TFLite: Real Time Classification", output_frame)
91
92camera.destroy()