onnx转换为K230支持模型的问题

Viewed 48

问题描述


我尝试用timm库中的预训练分类模型mobilenetv2将其转换为K230支持的格式,请问以下这个报错是什么问题呢?image.png
我的python,torch,以及onnx,onnxruntime,onnxsim的版本如下所示:image.png
python=3.10;
torch==2.9.1;
onnx = 1.16.0;
onnxruntime = 1.19.2;
onnxscript == 0.5.6;

2 Answers

模型内部有些计算不符合标准,这个可能是模型的问题,你是怎么转的

感谢您的回复,我已经附上我的运行代码和运行结果截图。

以下是我的运行代码和运行的截图image.png

import torch
import timm
import onnx
from onnxsim import simplify

# 1. Define the Model Architecture (Same as your training code)
def build_stairnet_timm():
    # Ensure num_classes and drop_rate match your training
    model = timm.create_model(
        'mobilenetv2_100', 
        pretrained=False, 
        num_classes=4
    )
    return model

# 2. Configuration
WEIGHTS_PATH = 'stairnet_best.pth'
ONNX_PATH = 'stairnet.onnx'
INPUT_SHAPE = (1, 3, 224, 224) # Batch, Channels, Height, Width

def main():
    device = torch.device('cpu')
    
    # Load Model
    print(f"Loading weights from {WEIGHTS_PATH}...")
    model = build_stairnet_timm()
    
    # Load state dict
    # We map to CPU to avoid CUDA requirements during export
    state_dict = torch.load(WEIGHTS_PATH, map_location=device)
    model.load_state_dict(state_dict)
    
    model.eval() # CRITICAL: Switch to eval mode (fixes BatchNorm/Dropout)

    # Create dummy input for tracing
    dummy_input = torch.randn(INPUT_SHAPE, device=device)

    # Export to ONNX
    print("Exporting to ONNX...")
    torch.onnx.export(
        model,
        dummy_input,
        ONNX_PATH,
        opset_version=18, # Opset 11 or 12 is generally best for embedded
        input_names=['input'],
        output_names=['output'],
        dynamic_axes=None # K230 prefers static shapes
    )

    # Simplify ONNX (Recommended for NPU compilation)
    print("Simplifying ONNX...")
    onnx_model = onnx.load(ONNX_PATH)
    model_simp, check = simplify(onnx_model)
    
    if check:
        onnx.save(model_simp, ONNX_PATH)
        print(f"Success! Model saved to {ONNX_PATH}")
    else:
        print("Simplification failed!")

if __name__ == "__main__":
    main()