Yolov7

This guide shows how to export a trained Yolov7 model to ONNX that can be directly uploaded and deployed on servers.

Unlike most object detection models, Yolov7 can be exported with the necessary post-processing (Non-Max-Suppression) directly using the exporter in the original repository.

Requirements

Make sure to install the required packages:

pip install -r requirements.txt
requirements.txt
onnx
onnxsim
scblblonnx
onnxruntime
opencv-python

Exporting the model to ONNX

  1. First, we need to clone the Yolov7 as it has the Python script that we need to convert the model form PyTorch to an ONNX that includes NMS.

git clone https://github.com/WongKinYiu/yolov7
  1. The next step is export the model to ONNX, and slightly adjust it IO:

bash model-to-onnx.sh
model-to-onnx.sh
cd "$(dirname "$0")" || exit

width=640
height=640
model_name=yolov7-tiny

cd yolov7

python export.py --weights $model_name.pt --grid --end2end --simplify --include-nms \
        --topk-all 100 --iou-thres 0.5 --conf-thres 0.35 --img-size $height $width --max-wh $height

cp $model_name.onnx ../$model_name.onnx
cd ..

python complete-onnx.py "$model_name.onnx"
complete_onnx.py
from os.path import splitext

from utils import add_pre_post_processing_to_onnx, rename_io, update_onnx_doc_string

if __name__ == '__main__':
    from sys import argv

    if len(argv) != 2:
        print("Usage: python complete-onnx.py model.onnx")
        exit(1)
    onnx_path = argv[1]
    output_onnx_path = splitext(onnx_path)[0] + "-complete.onnx"
    classes = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
               'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
               'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
               'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa',
               'potted plant', 'bed', 'dining table', 'toilet', 'tv monitor', 'laptop', 'mouse', 'remote', 'keyboard',
               'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
               'scissors', 'teddy bear', 'hair drier', 'toothbrush']
    classes_str = ';'.join([f'{i}:{c}' for i, c in enumerate(classes)])

    add_pre_post_processing_to_onnx(onnx_path, output_onnx_path)

    rename_io(output_onnx_path, output_onnx_path, **{'image': 'image-',
                                                     'unmasked_bboxes': f'bboxes-format:xyxysc;{classes_str}',
                                                     })

    update_onnx_doc_string(output_onnx_path, [0, 0, 0], [1, 1, 1])
utils.py
import json
import re

import numpy as np
import onnx
import sclblonnx as so


def add_pre_post_processing_to_onnx(onnx_path: str, output_onnx_path: str):
    base_graph = so.graph_from_file(onnx_path)

    output_name = base_graph.output[0].name
    input_name = base_graph.input[0].name

    # Change the NMS node to use the sensitivity input instead of the constant
    for node in base_graph.node:
        if node.op_type == 'NonMaxSuppression':
            node.input[-1] = 'nms_sensitivity-'
            break

    # get input shape
    input_shape = base_graph.input[0].type.tensor_type.shape.dim
    input_shape = [d.dim_value for d in input_shape]
    if input_shape[1] <= 3:  # NCHW
        width, height = input_shape[2], input_shape[3]
    else:  # NHWC
        width, height = input_shape[1], input_shape[2]

    # cleanup useless IO
    so.delete_output(base_graph, output_name)
    so.delete_input(base_graph, input_name)

    # Normalize the input by dividing by 255
    so.add_constant(base_graph, 'c_255', np.array([255], dtype=np.float32), 'FLOAT')
    div = so.node('Div', inputs=['image-', 'c_255'], outputs=[input_name])
    base_graph.node.insert(0, div)
    so.add_input(base_graph, name='image-', dimensions=input_shape, data_type='FLOAT')

    # move constant nodes to the beginning of the graph
    constant_nodes = [n for n in base_graph.node if n.op_type == 'Constant']
    for n in constant_nodes:
        base_graph.node.remove(n)
        base_graph.node.insert(0, n)

    # Add NMS to the model
    make_yolov7_complementary_graph(base_graph, output_name)

    # Add mask to the model
    so.delete_output(base_graph, 'bboxes-')
    mask_bboxes(base_graph, 'bboxes-', 'mask-', width, height)
    so.add_output(base_graph, 'unmasked_bboxes', 'FLOAT', dimensions=[20, 6])

    # Save the model
    so.graph_to_file(base_graph, output_onnx_path, onnx_opset_version=get_onnx_opset_version(onnx_path))


def mask_bboxes(graph, bboxes_name, mask_name, w, h):
    so.add_input(graph, name=mask_name, dimensions=[h, w], data_type='BOOL')

    so.add_constant(graph, 'index_one_three', np.array([0, 2]), 'INT64')
    so.add_constant(graph, 'index_four', np.array([3, 3]), 'INT64')
    so.add_constant(graph, 'hw_clip_min', np.array(0), 'FLOAT')
    so.add_constant(graph, 'w_clip_max', np.array(w - 1), 'FLOAT')
    so.add_constant(graph, 'h_clip_max', np.array(h - 1), 'FLOAT')

    x_coordinates = so.node('Gather', inputs=[bboxes_name, 'index_one_three'], outputs=['x_coordinates'], axis=1)
    y_coordinates = so.node('Gather', inputs=[bboxes_name, 'index_four'], outputs=['y_coordinates'], axis=1)
    x_reducemean = so.node('ReduceMean', inputs=['x_coordinates'], outputs=['x_reducemean'], axes=(1,), keepdims=1)
    y_coordinate = so.node('ReduceMean', inputs=['y_coordinates'], outputs=['y_coordinate'], axes=(1,), keepdims=1)
    x_clipped = so.node('Clip', inputs=['x_reducemean', 'hw_clip_min', 'w_clip_max'], outputs=['x_clipped'])
    y_clipped = so.node('Clip', inputs=['y_coordinate', 'hw_clip_min', 'h_clip_max'], outputs=['y_clipped'])

    bottom_center_corner = so.node('Concat', inputs=['y_clipped', 'x_clipped'], outputs=['bottom_center_corner'],
                                   axis=1)
    bottom_center_corner_int = so.node('Cast', inputs=['bottom_center_corner'], outputs=['bottom_center_corner_int'],
                                       to=7)
    bboxes_mask1 = so.node('GatherND', inputs=[mask_name, 'bottom_center_corner_int'], outputs=['bboxes_mask1'])
    bboxes_indices1 = so.node('NonZero', inputs=['bboxes_mask1'], outputs=['bboxes_indices1'])
    bboxes_indices1_squeezed = so.node('Squeeze', inputs=['bboxes_indices1'], outputs=['bboxes_indices1_squeezed'],
                                       axes=(0,))
    new_bboxes = so.node('Gather', inputs=[bboxes_name, 'bboxes_indices1_squeezed'], outputs=['unmasked_bboxes'],
                         axis=0)

    so.add_nodes(graph, [x_coordinates, y_coordinates, x_reducemean, y_coordinate, y_clipped, x_clipped,
                         bottom_center_corner,
                         bottom_center_corner_int,
                         bboxes_mask1, bboxes_indices1,
                         bboxes_indices1_squeezed, new_bboxes])
    return graph


def make_yolov7_complementary_graph(g, output_name):
    so.add_input(g, name='nms_sensitivity-', dimensions=[1], data_type='FLOAT')

    # constants
    so.add_constant(g, name='c_123456', value=np.array([1, 2, 3, 4, 5, 6]), data_type='INT64')

    # nodes
    gather = so.node('Gather', inputs=[output_name, 'c_123456'], outputs=['bboxes-'], axis=1)

    so.add_node(g, gather)

    so.add_output(g, 'bboxes-', 'FLOAT', dimensions=[20, 6])

    return g


def rename_io(model_path, new_model_path=None, **io_names):
    if new_model_path is None:
        new_model_path = model_path

    g = so.graph_from_file(model_path)

    def log(old: bool = True):
        s = 'Old' if old else 'New'
        assert so.list_inputs(g)
        assert so.list_outputs(g)

    if io_names == {}:
        return

    inputs = [i.name for i in g.input]
    outputs = [i.name for i in g.output]

    for k, v in io_names.items():
        pattern = re.compile(k)
        renamed = False

        for i in inputs:
            if pattern.match(i):
                renamed = True
                so.rename_input(g, i, v)
                break

        if not renamed:
            for o in outputs:
                if pattern.match(o):
                    renamed = True
                    so.rename_output(g, o, v)
                    break

        if not renamed:
            continue

    log(False)

    so.graph_to_file(g, new_model_path, onnx_opset_version=get_onnx_opset_version(model_path))


def get_onnx_opset_version(onnx_path):
    model = onnx.load(onnx_path)
    opset_version = model.opset_import[0].version if len(model.opset_import) > 0 else 0
    return opset_version


def update_onnx_doc_string(onnx_path: str, model_means: list[float], model_stds: list[float]):
    # Update the ONNX description
    graph = so.graph_from_file(onnx_path)
    # Add the model means and standard deviations to the ONNX graph description,
    # because that's used by the toolchain to populate some settings.
    graph.doc_string = json.dumps({'means': model_means, 'vars': model_stds})
    so.graph_to_file(graph, onnx_path, onnx_opset_version=get_onnx_opset_version(onnx_path))
  1. Finally, to test the ONNX model, we can use the following command:

python test_onnx.py
test_onnx.py
import cv2
import onnxruntime as rt
import numpy as np
from os.path import join, dirname, abspath

PATH = dirname(abspath(__file__))


def test_model(model_path, img_path):
    sess = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'])

    # get input name
    input_name1 = sess.get_inputs()[0].name
    input_name2 = sess.get_inputs()[1].name
    input_name3 = sess.get_inputs()[2].name

    # get input dimensions
    input_shape = sess.get_inputs()[0].shape
    if input_shape[1] <= 3:  # nchw
        width, height = input_shape[2], input_shape[3]
    else:  # nhwc
        width, height = input_shape[1], input_shape[2]

    img = cv2.imread(img_path)
    img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # nwhc nchw
    img = np.transpose(img, (2, 0, 1)).astype('float32')
    img = np.expand_dims(img, axis=0)

    mask_area = np.repeat(1, width * height).astype('bool')
    mask_area = mask_area.reshape((height, width))
    mask_area[:, :width // 2] = 0  # mask the left half of the image

    bboxes = sess.run(None, {
        input_name1: img,
        input_name2: np.array([0.15]).astype('float32'),
        input_name3: mask_area
    })[0]

    print(bboxes)

    return bboxes, (width, height)


def visualize_bboxes(bboxes, img_path, width, height):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)

    for bbox in bboxes:
        x1, y1, x2, y2, score, class_id = bbox
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(img, f'{int(class_id)}', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
    cv2.imshow('img', img)
    cv2.waitKey(0)


if __name__ == '__main__':
    from glob import glob

    model_path = glob(join(PATH, '*-complete.onnx'))[0]

    img_path = join(PATH, 'pedestrians.jpg')
    bboxes, (width, height) = test_model(model_path, img_path)
    visualize_bboxes(bboxes, img_path, width, height)

The exported ONNX can be uploaded on the platform and used for inference directly.

Beyond this example

This example is just a starting point for exporting Yolov7 models to ONNX. Therefore any model based on Yolov7 can be easily deployed by slightly adapting the ONNX model.

To adapt this example to your own model, you need to:

  • Change the parameter values in bash export-to-onnx.sh

  • Update the complete_onnx.py script to add the classes your model is trained on.

Last updated