Image Classification is a fundamental task in vision recognition that aims to understand and categorize an image as a whole under a specific label. Unlike object detection, which involves classification and location of multiple objects within an image, image classification typically pertains to single-object images.
In this section, we'll go through the process of exporting a trained image classification model and export to an ONNX file that conforms to the requirements mentioned before.
For the sake of example, we'll be using a model available in the PyTorch image models (aka. timm) library, specifically a ResNet-18 model.
This guide shows how a trained image classification model (implemented using PyTorch) can be exported to ONNX the right way. We'll be using image classification models implemented in the timm library.
Requirements
Make sure to install the required packages:
pipinstall-rrequirements.txt
requirements.txt
timm
sclblonnx
torch
onnxsim
Exporting the model to ONNX
The following command exports a trained image classification model to ONNX and saves it locally:
pythonexport_to_onnx.py
export_to_onnx.py
import jsonfrom os.path import join, dirname, abspathimport onnximport sclblonnx as soimport timmimport torchfrom torch import nnfrom onnxsim import simplifyPATH =dirname(abspath(__file__))classes_path = join(PATH, 'imagenet-classes.json') # https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json
model_name ='resnet18'classes = json.load(open(classes_path))# Replace with the classes that the model was trained on.concatenated_classes =';'.join([f'{i}:{c}'for i, c inenumerate(classes)])onnx_opset_version =12output_onnx_path =join(PATH, f'{model_name}.onnx')input_width =224# Replace with your input widthinput_height =224# Replace with your input heightmodel_means = [0.485,0.456,0.406] # Replace with your model meansmodel_means = [255* m for m in model_means] # Convert to 0-255 rangemodel_stds = [0.229,0.224,0.225] # Replace with your model standard deviationsmodel_stds = [255* s for s in model_stds] # Convert to 0-255 range# Define the model since the model's output needs to be softmaxedclassModel(nn.Module):def__init__(self):super(Model, self).__init__() self.model = timm.create_model(model_name, pretrained=True) self.sotfmax = nn.Softmax(dim=1)defforward(self,x): x = self.model(x) x = self.sotfmax(x)return x# Load the modelmodel =Model()# Set the model to evaluation modemodel.eval()# Define onnx IO namesinput_names = ['image-']output_names = [f'scores-{concatenated_classes}']dummy_input = torch.rand(1, 3, input_width, input_height)# Export model to ONNXtorch.onnx.export(model, dummy_input, output_onnx_path, input_names=input_names, output_names=output_names, opset_version=onnx_opset_version)# Update the ONNX descriptiongraph = so.graph_from_file(output_onnx_path)# Add the model means and standard deviations to the ONNX graph description,# because that's used by the toolchain to populate some settings.graph.doc_string = json.dumps({'means': model_means, 'vars': model_stds})so.graph_to_file(graph, output_onnx_path, onnx_opset_version=onnx_opset_version)# Simplify the ONNX model# This step is optional, but it is recommended to reduce the size of the model# optimize the model for inferencetry: model = onnx.load(output_onnx_path) model, check =simplify(model, check_n=10)assert check,"Couldn't simplify the ONNX model" onnx.save_model(model, output_onnx_path)exceptExceptionas e:print(f'Simplification failed: {e}')exit(1)
The exported ONNX can be uploaded on the platform and used for inference.
In-depth inspection of the export script
The ONNX model is expected to output probabilities (values in range [0, 1]). However, most image classification models don't have the Softmax activation by default. Therefore, we need to append Softmax to the model manually.
# Define the model since the model's output needs to be softmaxedclassModel(nn.Module):def__init__(self):super(Model, self).__init__() self.model = timm.create_model(model_name, pretrained=True) self.sotfmax = nn.Softmax(dim=1)defforward(self,x): x = self.model(x) x = self.sotfmax(x)return x
The ONNX IO names must adhere to the convention that was detailed in the requirements section. For image classification models, the input and output names must be "image-" and "scores-...", as shown in the Python code below:
We export the model to ONNX using PyTorch's onnx exporter.
Please note that the Operator Set version is set to 12 (which less than 18, the upper bound).
Setting the model normalization values can be done in two different ways. The first option is specifying those values in Nx cloud when uploading the ONNX model.
The second option is to carve the ONNX with those values by setting them as JSON in the ONNX doc_string attribute as shown below:
# Update the ONNX descriptiongraph = so.graph_from_file(output_onnx_path)# Add the model means and standard deviations to the ONNX graph description,# because that's used by the toolchain to populate some settings.graph.doc_string = json.dumps({'means': model_means, 'vars': model_stds})so.graph_to_file(graph, output_onnx_path, onnx_opset_version=onnx_opset_version)
Optimizing the model is an optional step as it's always done in the Nx cloud during the conversion. But, it's useful to run locally as a validation step to check whether the ONNX model conforms to the ONNX specification:
# Simplify the ONNX model# This step is optional, but it is recommended to reduce the size of the model# optimize the model for inferencetry: model = onnx.load(output_onnx_path) model, check =simplify(model, check_n=10)assert check,"Couldn't simplify the ONNX model" onnx.save_model(model, output_onnx_path)exceptExceptionas e:print(f'Simplification failed: {e}')exit(1)
Beyond this example
This example is a starting point for exporting image classification models to ONNX. Yet, this approach is valid for any image classification model implemented in PyTorch, given that the model has one input (image input) and one output (representing the probabilities).
To adapt this example to your own model, you need to:
update these variables in the Python script: classes, input_width, input_height, model_means, and model_stds.