Advanced actual combat of propeller AI: industrial meter reading (record)

Preparation stage

Step 1: create a Notebook model task

step1: enter the BML home page and click use now

🔗: https://ai.baidu.com/bml/

Step 2: click the Notebook to create a "general task"

step3: fill in task information

Step 2: download the task operation template

Download link: https://aistudio.baidu.com/aistudio/datasetdetail/120387

Target detection model training

Step 1: configure Notebook

1. Find the Notebook task created yesterday and click Configure

2. Configuration selection

  • Development language: Python 3 seven
  • AI framework: paddlepaddle2 zero
  • Resource specification: GPU V100

3. Open the Notebook

4. Upload this Notebook operation model

If you don't have time to download, please click the link to download: https://aistudio.baidu.com/aistudio/datasetdetail/120387

Step 2: Environmental preparation

1. Install filelock

!pip install filelock

2. Install PaddleX

!pip install paddlex==2.0.0

Note: you need to make a version when installing paddlex.

3. Upgrade paddlepaddle GPU to 2.1 Version 3

!pip install paddlepaddle-gpu==2.1.3.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html

Step 3: target detection model training

Description of training process:

Define data preprocessing - > define data set path - > initialize model - > model training

1. Call PaddleX

import paddlex as pdxfrom paddlex import transforms as T

2. Define transforms during training and verification

API details: https://github.com/PaddlePaddle/PaddleX/blob/release/2.0-rc/paddlex/cv/transforms/operators.py

train_transforms = T.Compose([    T.MixupImage(mixup_epoch=250), T.RandomDistort(),    T.RandomExpand(im_padding_value=[123.675, 116.28, 103.53]), T.RandomCrop(),    T.RandomHorizontalFlip(), T.BatchRandomResize(        target_sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],        interp='RANDOM'), T.Normalize(            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
eval_transforms = T.Compose([    T.Resize(        608, interp='CUBIC'), T.Normalize(            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

3. Download the meter reading data set for target detection training

meter_det_dataset = 'https://bj.bcebos.com/paddlex/examples/meter_reader/datasets/meter_det.tar.gz'pdx.utils.download_and_decompress(meter_det_dataset, path='./')

You can view the dataset in the folder area on the left

4. Set training parameters

Detailed API Description: https://github.com/PaddlePaddle/PaddleX/blob/develop/paddlex/cv/datasets/coco.py#L26

train_dataset = pdx.datasets.CocoDetection(    data_dir='meter_det/train/',    ann_file='meter_det/annotations/instance_train.json',    transforms=train_transforms,    shuffle=True)eval_dataset = pdx.datasets.CocoDetection(    data_dir='meter_det/test/',    ann_file='meter_det/annotations/instance_test.json',    transforms=eval_transforms)

5. Check bestmodel after training

Step 4: save the Notebook, close it and stop running

Semantic segmentation model training

Step 1: reinstall the environment

1. Start the Notebook and open it

2. Re execute the three commands of the installation

Step 2: pointer and scale segmentation model training

1. Call paddlex

import paddlex as pdxfrom paddlex import transforms as T

2. Define transforms during training and verification

Detailed API description reference: https://github.com/PaddlePaddle/PaddleX/blob/release/2.0-rc/paddlex/cv/transforms/operators.py

train_transforms = T.Compose([    T.Resize(target_size=512),    T.RandomHorizontalFlip(),    T.Normalize(        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])
eval_transforms = T.Compose([    T.Resize(target_size=512),    T.Normalize(        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])

3. Download and decompress the pointer scale split dataset

meter_seg_dataset = 'https://bj.bcebos.com/paddlex/examples/meter_reader/datasets/meter_seg.tar.gz'pdx.utils.download_and_decompress(meter_seg_dataset, path='./')

4. Define data sets for training and verification and configure corresponding paths

Detailed API description reference: https://github.com/PaddlePaddle/PaddleX/blob/release/2.0-rc/paddlex/cv/datasets/seg_dataset.py#L22

train_dataset = pdx.datasets.SegDataset(    data_dir='meter_seg',    file_list='meter_seg/train.txt',    label_list='meter_seg/labels.txt',    transforms=train_transforms,    shuffle=True)
eval_dataset = pdx.datasets.SegDataset(    data_dir='meter_seg',    file_list='meter_seg/val.txt',    label_list='meter_seg/labels.txt',    transforms=eval_transforms,    shuffle=False)

5. Select the built-in deep labv3p model of PaddleX for training

API Description: https://github.com/PaddlePaddle/PaddleX/blob/release/2.0-rc/paddlex/cv/models/segmenter.py#L150

num_classes = len(train_dataset.labels)model = pdx.seg.DeepLabV3P(    num_classes=num_classes, backbone='ResNet50_vd', use_mixed_loss=True)

6. Set parameters during training

Introduction and adjustment description of each parameter: https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html

model.train(    num_epochs=2,    train_dataset=train_dataset,    train_batch_size=4,    eval_dataset=eval_dataset,    pretrain_weights='IMAGENET',    learning_rate=0.1,    save_dir='output/deeplabv3p_r50vd')

7. Check bestmodel after training

Step 3: save the Notebook, close it and stop running

Tip: once the Notebook runs, it will start billing. If not, please stop it in time! So as not to waste the free quota

model prediction

Step 1: reinstall the environment

1. Start the Notebook and open it

2. Re execute the three commands of the installation

Step 2: upload the predicted py file

1. Click the link below to download the py file locally

https://aistudio.baidu.com/aistudio/datasetdetail/120795

2. Upload to Notebook

Step 3: model prediction

1. Upload reader_infer.py file, execute the following command to predict the model

!python work/meter_reader/reader_infer.py --det_model_dir output/ppyolov2_r50vd_dcn/best_model --seg_model_dir output/deeplabv3p_r50vd/best_model/ --image meter_det/test/20190822_105.jpg

2. Open output/result to view the forecast results

Step 4: save the Notebook, close it and stop running

Inference Code:

# coding: utf8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import os.path as osp
import numpy as np
import math
import cv2
import argparse

from paddlex import transforms as T
import paddlex as pdx

# In the post-processing of reading, there is the operation of turning the circular dial into a rectangle. The width of the rectangle is the outer circumference of the circle
# Therefore, the dial image size is required to be fixed, which is set as [512, 512] here
METER_SHAPE = [512, 512]  # Height x width
# Center point of round dial
CIRCLE_CENTER = [256, 256]  # Height x width
# Radius of circular dial
CIRCLE_RADIUS = 250
# PI
PI = 3.1415926536
# The height of the rectangle after turning the round dial into a rectangle
# The current setting value is about half of the radius because the central area of the circular dial is the background except the root of the pointer
# We only need to save the peripheral scale and the tip of the pointer to locate the scale pointed by the pointer
RECTANGLE_HEIGHT = 120
# The width of a rectangular dial, that is, the outer circumference of a circular dial
RECTANGLE_WIDTH = 1570
# In the current case, only two types of dials are used. The number of scale elements of the first dial is 50
# The number of scale elements of the second dial is 32. Therefore, we judge the dial type by the predicted number of scale elements
# If the number of scale elements exceeds the threshold, it is the first, otherwise it is the second
TYPE_THRESHOLD = 40
# Configuration information of the two dials, including the value, range and unit of each scale
METER_CONFIG = [{
    'scale_interval_value': 25.0 / 50.0,
    'range': 25.0,
    'unit': "(MPa)"
}, {
    'scale_interval_value': 1.6 / 32.0,
    'range': 1.6,
    'unit': "(MPa)"
}]
# The segmentation model predicts the correspondence between class id and class alias
SEG_CNAME2CLSID = {'background': 0, 'pointer': 1, 'scale': 2}


def parse_args():
    parser = argparse.ArgumentParser(description='Meter Reader Infering')
    parser.add_argument(
        '--det_model_dir',
        dest='det_model_dir',
        help='The directory of the detection model',
        type=str)
    parser.add_argument(
        '--seg_model_dir',
        dest='seg_model_dir',
        help='The directory of the segmentation model',
        type=str)
    parser.add_argument(
        '--image_dir',
        dest='image_dir',
        help='The directory of images to be inferred',
        type=str,
        default=None)
    parser.add_argument(
        '--image',
        dest='image',
        help='The image to be inferred',
        type=str,
        default=None)
    parser.add_argument(
        '--use_erode',
        dest='use_erode',
        help='Whether erode the lable map predicted from a segmentation model',
        action='store_true')
    parser.add_argument(
        '--erode_kernel',
        dest='erode_kernel',
        help='Erode kernel size',
        type=int,
        default=4)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the predicted results',
        type=str,
        default='./output/result')
    parser.add_argument(
        '--score_threshold',
        dest='score_threshold',
        help="Predicted bounding boxes whose scores are lower than this threshlod are filtered",
        type=float,
        default=0.5)
    parser.add_argument(
        '--seg_batch_size',
        dest='seg_batch_size',
        help="The number of images fed into the segmentation model during one forward propagation",
        type=int,
        default=2)

    return parser.parse_args()


def is_pic(img_name):
    """Determine whether it is a picture

    Parameters:
        img_name (str): Picture path

    return:
        flag (bool): Judgment value.
    """
    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
    suffix = img_name.split('.')[-1]
    flag = True
    if suffix not in valid_suffix:
        flag = False
    return flag


class MeterReader:
    """Detect the position of the dial, divide the position of the scale and pointer in each dial, and calculate the reading of each dial according to the division results

    Parameters:
        det_model_dir (str): The path of the detection model used to locate the dial.
        seg_model_dir (str): The path of the split model used to split the scale and pointer.

    """

    def __init__(self, det_model_dir, seg_model_dir):
        if not osp.exists(det_model_dir):
            raise Exception("Model path {} does not exist".format(
                det_model_dir))
        if not osp.exists(seg_model_dir):
            raise Exception("Model path {} does not exist".format(
                seg_model_dir))
        self.detector = pdx.load_model(det_model_dir)
        self.segmenter = pdx.load_model(seg_model_dir)

    def decode(self, img_file):
        """Image decoding

        Parameters:
            img_file (str|np.array): Image path, or decoded BGR Image array.

        return:
            img (np.array): BGR Image array.
        """

        if isinstance(img_file, str):
            img = cv2.imread(img_file).astype('float32')
        else:
            img = img_file.copy()
        return img

    def filter_bboxes(self, det_results, score_threshold):
        """Filter detection boxes with confidence below the threshold

        Parameters:
            det_results (list[dict]): Detect the return value of the model prediction interface.
            score_threshold (float): Confidence threshold.

        return:
            filtered_results (list[dict]): Filtered detector.

        """
        filtered_results = list()
        for res in det_results:
            if res['score'] > score_threshold:
                filtered_results.append(res)
        return filtered_results

    def roi_crop(self, img, det_results):
        """Pick the image area of each detection frame on the image

        Parameters:
            img (np.array): BRG Image array.
            det_results (list[dict]): Detect the return value of the model prediction interface.

        return:
            sub_imgs (list[np.array]): The image area of each detection frame.

        """
        sub_imgs = []
        for res in det_results:
            # Crop the bbox area
            xmin, ymin, w, h = res['bbox']
            xmin = max(0, int(xmin))
            ymin = max(0, int(ymin))
            xmax = min(img.shape[1], int(xmin + w - 1))
            ymax = min(img.shape[0], int(ymin + h - 1))
            sub_img = img[ymin:(ymax + 1), xmin:(xmax + 1), :]
            sub_imgs.append(sub_img)
        return sub_imgs

    def resize(self, imgs, target_size, interp=cv2.INTER_LINEAR):
        """Zoom image to fixed size

        Parameters:
            imgs (list[np.array]): batch BGR Image array.
            target_size (list|tuple): The size of the scaled image in the format[high, wide]. 
            interp (int): Image difference method. The default value is cv2.INTER_LINEAR. 

        return:
            resized_imgs (list[np.array]): Scaled batch BGR Image array.

        """

        resized_imgs = list()
        for img in imgs:
            img_shape = img.shape
            scale_x = float(target_size[1]) / float(img_shape[1])
            scale_y = float(target_size[0]) / float(img_shape[0])
            resize_img = cv2.resize(
                img, None, None, fx=scale_x, fy=scale_y, interpolation=interp)
            resized_imgs.append(resize_img)
        return resized_imgs

    def seg_predict(self, segmenter, imgs, batch_size):
        """The segmentation model completes the prediction

        Parameters:
            segmenter (pdx.seg.model): Loaded segmentation model.
            imgs (list[np.array]): Input to be predicted BGR Image array.
            batch_size (int): The segmentation model predicts the batch size of the input image once.

        return:
            seg_results (list[dict]): Enter the prediction result of the image.

        """
        seg_results = list()
        num_imgs = len(imgs)
        for i in range(0, num_imgs, batch_size):
            batch = imgs[i:min(num_imgs, i + batch_size)]
            result = segmenter.predict(batch)
            seg_results.extend(result)
        return seg_results

    def erode(self, seg_results, erode_kernel):
        """In the prediction results of segmentation model label_map Image etching operation

        Parameters:
            seg_results (list[dict]): The prediction results of the segmentation model.
            erode_kernel (int): The size of convolution kernel of image corrosion.

        return:
            eroded_results (list[dict]): yes label_map The segmentation model prediction results after corrosion.

        """
        kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
        eroded_results = seg_results
        for i in range(len(seg_results)):
            test_resulte = seg_results[i]['label_map']
            # print('***********************************',type(test_resulte))
            # eroded_results[i]['label_map'] = cv2.erode(
            #     seg_results[i]['label_map'], kernel)
            eroded_results[i]['label_map'] = cv2.erode(
            test_resulte.astype('uint8'), kernel)
        return eroded_results

    def circle_to_rectangle(self, seg_results):
        """The prediction results of the circular dial label_map Convert to rectangle

        Calculation method of circle to rectangle:
            In this case, the starting value of the scale of the two dials is at the lower left, so the center point of the circle is taken as the coordinate origin,
            from-y The axis begins to calculate the polar coordinates counterclockwise to x-y Correspondence of coordinates:
              x = r + r * cos(theta)
              y = r - r * sin(theta)
            be careful:
                1. Because it's from-y The axis starts to calculate counterclockwise, so r * sin(theta)There is a minus sign before.
                2. Or because from-y The axis starts to calculate counterclockwise, so the rectangle corresponds to the circle from top to bottom, from outside to inside,
                   You can imagine turning a circle from-y When the shaft is cut and then pulled to the left and right, the periphery of the circle is above and the inside is below.

        Parameters:
            seg_results (list[dict]): The prediction results of the segmentation model.

        Return value:
            rectangle_meters (list[np.array]): Prediction results of rectangular dial label_map. 

        """
        rectangle_meters = list()
        for i, seg_result in enumerate(seg_results):
            label_map = seg_result['label_map']
            # rectangle_ The size of the meter has been determined by the preset global variable rectange_ HEIGHT, RECTANGLE_ Width decision
            rectangle_meter = np.zeros(
                (RECTANGLE_HEIGHT, RECTANGLE_WIDTH), dtype=np.uint8)
            for row in range(RECTANGLE_HEIGHT):
                for col in range(RECTANGLE_WIDTH):
                    theta = PI * 2 * (col + 1) / RECTANGLE_WIDTH
                    # The rectangle corresponds to the circle from top to bottom, from outside to inside
                    rho = CIRCLE_RADIUS - row - 1
                    y = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
                    x = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
                    rectangle_meter[row, col] = label_map[y, x]
            rectangle_meters.append(rectangle_meter)
        return rectangle_meters

    def rectangle_to_line(self, rectangle_meters):
        """The pointer and scale prediction results are extracted from the prediction results of the rectangular dial and compressed into a linear format along the height direction.

        Parameters:
            rectangle_meters (list[np.array]): Prediction results of rectangular dial label_map. 

        return:
            line_scales (list[np.array]): Linear prediction results of scale.
            line_pointers (list[np.array]): Linear prediction result of pointer.

        """
        line_scales = list()
        line_pointers = list()

        for rectangle_meter in rectangle_meters:
            height, width = rectangle_meter.shape[0:2]
            line_scale = np.zeros((width), dtype=np.uint8)
            line_pointer = np.zeros((width), dtype=np.uint8)
            for col in range(width):
                for row in range(height):
                    if rectangle_meter[row, col] == SEG_CNAME2CLSID['pointer']:
                        line_pointer[col] += 1
                    elif rectangle_meter[row, col] == SEG_CNAME2CLSID['scale']:
                        line_scale[col] += 1
            line_scales.append(line_scale)
            line_pointers.append(line_pointer)
        return line_scales, line_pointers

    def mean_binarization(self, data_list):
        """Perform mean binarization on the image

        Parameters:
            data_list (list[np.array]): Batch array to be binarized.

        return:
            binaried_data_list (list[np.array]): Binary batch array.

        """
        batch_size = len(data_list)
        binaried_data_list = data_list
        for i in range(batch_size):
            mean_data = np.mean(data_list[i])
            width = data_list[i].shape[0]
            for col in range(width):
                if data_list[i][col] < mean_data:
                    binaried_data_list[i][col] = 0
                else:
                    binaried_data_list[i][col] = 1
        return binaried_data_list

    def locate_scale(self, line_scales):
        """Find the center position of each scale in the linear prediction results

        Parameters:
            line_scales (list[np.array]): Scale linear prediction results after batch binarization.

        return:
            scale_locations (list[list]): The center position of each scale in each image.

        """
        batch_size = len(line_scales)
        scale_locations = list()
        for i in range(batch_size):
            line_scale = line_scales[i]
            width = line_scale.shape[0]
            find_start = False
            one_scale_start = 0
            one_scale_end = 0
            locations = list()
            for j in range(width - 1):
                if line_scale[j] > 0 and line_scale[j + 1] > 0:
                    if find_start == False:
                        one_scale_start = j
                        find_start = True
                if find_start:
                    if line_scale[j] == 0 and line_scale[j + 1] == 0:
                        one_scale_end = j - 1
                        one_scale_location = (
                            one_scale_start + one_scale_end) / 2
                        locations.append(one_scale_location)
                        one_scale_start = 0
                        one_scale_end = 0
                        find_start = False
            scale_locations.append(locations)
        return scale_locations

    def locate_pointer(self, line_pointers):
        """Find the center position of the pointer in the linear prediction result

        Parameters:
            line_scales (list[np.array]): Batch pointer linear prediction results.

        return:
            scale_locations (list[list]): The center position of the pointer in each image.

        """
        batch_size = len(line_pointers)
        pointer_locations = list()
        for i in range(batch_size):
            line_pointer = line_pointers[i]
            find_start = False
            pointer_start = 0
            pointer_end = 0
            location = 0
            width = line_pointer.shape[0]
            for j in range(width - 1):
                if line_pointer[j] > 0 and line_pointer[j + 1] > 0:
                    if find_start == False:
                        pointer_start = j
                        find_start = True
                if find_start:
                    if line_pointer[j] == 0 and line_pointer[j + 1] == 0:
                        pointer_end = j - 1
                        location = (pointer_start + pointer_end) / 2
                        find_start = False
                        break
            pointer_locations.append(location)
        return pointer_locations

    def get_relative_location(self, scale_locations, pointer_locations):
        """Find which scale the pointer points to

        Parameters:
            scale_locations (list[list]): The center point of each scale of the batch.
            pointer_locations (list[list]): The center point position of the batch pointer.

        return:
            pointed_scales (list[dict]): The results of each table are composed of list. The results of each table are represented by a dictionary,
                The dictionary has two keywords:'num_scales','pointed_scale',Respectively represents the predicted number of scales
                The pointer to the prediction points to the scale.

        """

        pointed_scales = list()
        for scale_location, pointer_location in zip(scale_locations,
                                                    pointer_locations):
            num_scales = len(scale_location)
            pointed_scale = -1
            if num_scales > 0:
                for i in range(num_scales - 1):
                    if scale_location[
                            i] <= pointer_location and pointer_location < scale_location[
                                i + 1]:
                        pointed_scale = i + (
                            pointer_location - scale_location[i]
                        ) / (scale_location[i + 1] - scale_location[i] + 1e-05
                             ) + 1
            result = {'num_scales': num_scales, 'pointed_scale': pointed_scale}
            pointed_scales.append(result)
        return pointed_scales

    def calculate_reading(self, pointed_scales):
        """Calculate the reading of the dial according to the interval value of the scale and the number of scales pointed by the pointer
        """
        readings = list()
        batch_size = len(pointed_scales)
        for i in range(batch_size):
            pointed_scale = pointed_scales[i]
            # If the number of scale elements is greater than the threshold, it is the first dial
            if pointed_scale['num_scales'] > TYPE_THRESHOLD:
                reading = pointed_scale['pointed_scale'] * METER_CONFIG[0][
                    'scale_interval_value']
            else:
                reading = pointed_scale['pointed_scale'] * METER_CONFIG[1][
                    'scale_interval_value']
            readings.append(reading)

        return readings

    def get_meter_reading(self, seg_results):
        """After the segmentation results are processed, the readings of each dial are obtained

        Parameters:
            seg_results (list[dict]): The prediction results of the segmentation model.

        return:
            meter_readings (list[dcit]): The reading of each dial.

        """

        rectangle_meters = self.circle_to_rectangle(seg_results)
        line_scales, line_pointers = self.rectangle_to_line(rectangle_meters)
        binaried_scales = self.mean_binarization(line_scales)
        binaried_pointers = self.mean_binarization(line_pointers)
        scale_locations = self.locate_scale(binaried_scales)
        pointer_locations = self.locate_pointer(binaried_pointers)
        pointed_scales = self.get_relative_location(scale_locations,
                                                    pointer_locations)
        meter_readings = self.calculate_reading(pointed_scales)
        return meter_readings

    def print_meter_readings(self, meter_readings):
        """Print the readings of each dial

        Parameters:
            meter_readings (list[dict]): Reading of each dial
        """
        for i in range(len(meter_readings)):
            print("Meter {}: {}".format(i + 1, meter_readings[i]))

    def visualize(self, img, det_results, meter_readings, save_dir="./"):
        """Visualize the position and reading of each dial in the image

        Parameters:
            img (str|np.array): Image path, or decoded BGR Image array.
            det_results (dict): The prediction results of the detection model.
            meter_readings (list): The reading of each dial.
            save_dir (str): The path to save the visualized picture.

        """
        vis_results = list()
        for i, res in enumerate(det_results):
            # Replace the keyword 'score' in the test result with a reading to call PDX Det.visualize drawing
            res['score'] = meter_readings[i]
            vis_results.append(res)
        # When visualizing the test results, the boxes with score lower than threshold will be filtered out. The readings here are > = - 1, so set threshold = - 1
        pdx.det.visualize(img, vis_results, threshold=-1, save_dir=save_dir)

    def predict(self,
                img_file,
                save_dir='./',
                use_erode=True,
                erode_kernel=4,
                score_threshold=0.5,
                seg_batch_size=2):
        """The dial in the image is detected, and then the pointer and scale in each dial are segmented. The reading of each dial is obtained after reading and post-processing the segmentation results.

        Parameters:
            img_file (str): The path of the picture to be predicted.
            save_dir (str): Save path of visualization results.
            use_erode (bool, optional): Whether image corrosion is performed on the segmentation prediction results. Default: True. 
            erode_kernel (int, optional): The convolution kernel size of image corrosion. Default value: 4. 
            score_threshold (float, optional): The confidence threshold used to filter out the detection box. Default: 0.5. 
            seg_batch_size (int, optional): The batch size of the input dial image when the segmentation model forward reasoning once. The default value is: 2.
        """

        img = self.decode(img_file)
        det_results = self.detector.predict(img)
        filtered_results = self.filter_bboxes(det_results, score_threshold)
        sub_imgs = self.roi_crop(img, filtered_results)
        sub_imgs = self.resize(sub_imgs, METER_SHAPE)
        seg_results = self.seg_predict(self.segmenter, sub_imgs,
                                       seg_batch_size)
        seg_results = self.erode(seg_results, erode_kernel)
        meter_readings = self.get_meter_reading(seg_results)
        self.print_meter_readings(meter_readings)
        self.visualize(img, filtered_results, meter_readings, save_dir)


def infer(args):
    image_lists = list()
    if args.image is not None:
        if not osp.exists(args.image):
            raise Exception("Image {} does not exist.".format(args.image))
        if not is_pic(args.image):
            raise Exception("{} is not a picture.".format(args.image))
        image_lists.append(args.image)
    elif args.image_dir is not None:
        if not osp.exists(args.image_dir):
            raise Exception("Directory {} does not exist.".format(
                args.image_dir))
        for im_file in os.listdir(args.image_dir):
            if not is_pic(im_file):
                continue
            im_file = osp.join(args.image_dir, im_file)
            image_lists.append(im_file)

    meter_reader = MeterReader(args.det_model_dir, args.seg_model_dir)
    if len(image_lists) > 0:
        for image in image_lists:
            meter_reader.predict(image, args.save_dir, args.use_erode,
                                 args.erode_kernel, args.score_threshold,
                                 args.seg_batch_size)


if __name__ == '__main__':
    args = parse_args()
    infer(args)

Keywords: AI Deep Learning paddlepaddle

Added by vtroubled on Tue, 14 Dec 2021 00:07:58 +0200