Source code for pyvar.ml.utils.retriever

# Copyright 2021 Variscite LTD
# SPDX-License-Identifier: BSD-3-Clause

"""
:platform: Unix/Yocto
:synopsis: Class to retrieve packages from Variscite FTP server.

.. moduleauthor:: Diego Dorta <diego.d@variscite.com>
"""

import ftplib
import glob
import os
import shutil
import socket
import sys

from pyvar.config import CACHEDIR
from pyvar.ml.config import CLASSIFICATION
from pyvar.ml.config import CLASSIFICATION_93
from pyvar.ml.config import DETECTION
from pyvar.ml.config import SEGMENTATION
from pyvar.ml.utils.config import DEFAULT_PACKAGES
from pyvar.ml.utils.config import FTP_HOST, FTP_PASS, FTP_USER
from pyvar.ml.utils.config import JPG, MP4, PNG, TFLITE, TXT, ZIP


[docs]class FTP: """ **This class can be used as reference only. It is not for production-ready.** """ def __init__(self, host=None, user=None, passwd=None): """ Constructor method for the FTP class. """ self.host = FTP_HOST if host is None else host self.user = FTP_USER if user is None else user self.passwd = FTP_PASS if passwd is None else passwd self.cachedir = CACHEDIR self.retrieved_package = None self.model = None self.label = None self.image = None self.video = None try: self.ftp = ftplib.FTP(self.host, self.user, self.passwd) try: os.mkdir(self.cachedir) except FileExistsError: pass except ftplib.all_errors as error: sys.exit(f"Error: {error}")
[docs] def retrieve_package(self, package_dir=None, package_filename=None, category=None): """ Retrieve package from the FTP server. Args: package_dir (str): package directory; package_filename (str): model package file name; category (str): model category (classification or detection). Returns: True if the package file was downloaded successfully. False if not. """ host_name = socket.gethostname() host_93 = False if host_name.startswith("imx93"): host_93 = True if category is not None: if category is CLASSIFICATION: package_dir = DEFAULT_PACKAGES[CLASSIFICATION][0] package_filename = DEFAULT_PACKAGES[CLASSIFICATION][1] if host_93 is True: package_filename = DEFAULT_PACKAGES[CLASSIFICATION_93][1] elif category is DETECTION: package_dir = DEFAULT_PACKAGES[DETECTION][0] package_filename = DEFAULT_PACKAGES[DETECTION][1] elif category is SEGMENTATION: package_dir = DEFAULT_PACKAGES[SEGMENTATION][0] package_filename = DEFAULT_PACKAGES[SEGMENTATION][1] package_file = os.path.join(self.cachedir, package_filename) try: self.ftp.cwd(package_dir) with open(package_file, "wb") as f: r = self.ftp.retrbinary(f"RETR {package_filename}", f.write) if not r.startswith("226 Transfer complete"): os.remove(package_file) return False else: self.retrieved_package = package_file self.ftp.cwd("/") except Exception as ex: print(f"Exc: {ex}") return False if self.retrieved_package.endswith(ZIP): package_name_path = self.retrieved_package[:-4] try: shutil.unpack_archive(self.retrieved_package, self.cachedir) self._get_package_names(package_name_path, category) os.remove(self.retrieved_package) except Exception as ex: print(f"Exc: {ex}") return False self._disconnect() return True
def _disconnect(self): """ Send a quit command to the server and close the connection. """ self.ftp.quit() def _get_package_names(self, package_name_path, category): """ Get the model and label names from the downloaded package. """ model_list = glob.glob(os.path.join(package_name_path, TFLITE)) self.model = model_list[0] label_list = glob.glob(os.path.join(package_name_path, TXT)) self.label = label_list[0] if category is CLASSIFICATION: image_list = glob.glob(os.path.join(package_name_path, JPG)) self.image = image_list[0] if category is DETECTION: image_list = glob.glob(os.path.join(package_name_path, PNG)) self.image = image_list[0] video_list = glob.glob(os.path.join(package_name_path, MP4)) self.video = video_list[0]