问题描述
我尝试用timm库中的预训练分类模型mobilenetv2将其转换为K230支持的格式,请问以下这个报错是什么问题呢?
我的python,torch,以及onnx,onnxruntime,onnxsim的版本如下所示:
python=3.10;
torch==2.9.1;
onnx = 1.16.0;
onnxruntime = 1.19.2;
onnxscript == 0.5.6;
我尝试用timm库中的预训练分类模型mobilenetv2将其转换为K230支持的格式,请问以下这个报错是什么问题呢?
我的python,torch,以及onnx,onnxruntime,onnxsim的版本如下所示:
python=3.10;
torch==2.9.1;
onnx = 1.16.0;
onnxruntime = 1.19.2;
onnxscript == 0.5.6;
以下是我的运行代码和运行的截图
import torch
import timm
import onnx
from onnxsim import simplify
# 1. Define the Model Architecture (Same as your training code)
def build_stairnet_timm():
# Ensure num_classes and drop_rate match your training
model = timm.create_model(
'mobilenetv2_100',
pretrained=False,
num_classes=4
)
return model
# 2. Configuration
WEIGHTS_PATH = 'stairnet_best.pth'
ONNX_PATH = 'stairnet.onnx'
INPUT_SHAPE = (1, 3, 224, 224) # Batch, Channels, Height, Width
def main():
device = torch.device('cpu')
# Load Model
print(f"Loading weights from {WEIGHTS_PATH}...")
model = build_stairnet_timm()
# Load state dict
# We map to CPU to avoid CUDA requirements during export
state_dict = torch.load(WEIGHTS_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval() # CRITICAL: Switch to eval mode (fixes BatchNorm/Dropout)
# Create dummy input for tracing
dummy_input = torch.randn(INPUT_SHAPE, device=device)
# Export to ONNX
print("Exporting to ONNX...")
torch.onnx.export(
model,
dummy_input,
ONNX_PATH,
opset_version=18, # Opset 11 or 12 is generally best for embedded
input_names=['input'],
output_names=['output'],
dynamic_axes=None # K230 prefers static shapes
)
# Simplify ONNX (Recommended for NPU compilation)
print("Simplifying ONNX...")
onnx_model = onnx.load(ONNX_PATH)
model_simp, check = simplify(onnx_model)
if check:
onnx.save(model_simp, ONNX_PATH)
print(f"Success! Model saved to {ONNX_PATH}")
else:
print("Simplification failed!")
if __name__ == "__main__":
main()