问题描述
原onnx模型和转换出来的kmodel余弦相似度只有0.3几正常吗?更换校正集几乎没什么作用,无论是真实图片还是随机生成的数据作校正集,转换出来的kmodel余弦相似度都只有0.3几,没有什么变化。不量化的话转换kmodel余弦相似度有0.9几。使用的是paddle平台转换出来的一个CRNN的文字识别onnx模型。
复现步骤
python3 script.py --target k230 --model ./inference.onnx --dataset_path ./rec_imgs4 --input_width 100 --input_height 32 --ptq_option 0
转换脚本代码:
import os
import argparse
import numpy as np
from PIL import Image
import onnxsim
import onnx
import nncase
import shutil
import math
def parse_model_input_output(model_file,input_shape):
onnx_model = onnx.load(model_file)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
input_names = list(set(input_all) - set(input_initializer))
input_tensors = [
node for node in onnx_model.graph.input if node.name in input_names]
# input
inputs = []
for _, e in enumerate(input_tensors):
onnx_type = e.type.tensor_type
input_dict = {}
input_dict['name'] = e.name
input_dict['dtype'] = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_type.elem_type]
input_dict['shape'] = [(i.dim_value if i.dim_value != 0 else d) for i, d in zip(
onnx_type.shape.dim, input_shape)]
inputs.append(input_dict)
return onnx_model, inputs
def onnx_simplify(model_file, dump_dir,input_shape):
onnx_model, inputs = parse_model_input_output(model_file,input_shape)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
# 安装的 onnx/onnxsim 版本太老,校验器(checker)看不懂模型的 IR 版本
# ↓↓↓ 尝试:下调 IR 版本到当前 onnx 可识别的版本 ↓↓↓
try:
if onnx_model.ir_version > onnx.IR_VERSION:
print(f"[warn] downgrade ir_version {onnx_model.ir_version} -> {onnx.IR_VERSION}")
onnx_model.ir_version = onnx.IR_VERSION
except Exception as e:
print("[warn] fail to adjust ir_version:", e)
# ↑↑↑
input_shapes = {}
for input in inputs:
input_shapes[input['name']] = input['shape']
onnx_model, check = onnxsim.simplify(onnx_model, input_shapes=input_shapes)
assert check, "Simplified ONNX model could not be validated"
model_file = os.path.join(dump_dir, 'simplified.onnx')
onnx.save_model(onnx_model, model_file)
return model_file
def read_model_file(model_file):
with open(model_file, 'rb') as f:
model_content = f.read()
return model_content
def generate_data(shape, batch, calib_dir):
img_paths = [os.path.join(calib_dir, p) for p in os.listdir(calib_dir)]
data = []
for i in range(batch):
assert i < len(img_paths), "calibration images not enough."
img_data = Image.open(img_paths[i]).convert('RGB')
img_data = img_data.resize((shape[3], shape[2]), Image.BILINEAR)
img_data = np.asarray(img_data, dtype=np.uint8)
img_data = np.transpose(img_data, (2, 0, 1))
data.append([img_data[np.newaxis, ...]])
return np.array(data)
def main():
parser = argparse.ArgumentParser(prog="nncase")
parser.add_argument("--target", default="k230",type=str, help='target to run,k230/cpu')
parser.add_argument("--model",type=str, help='model file')
parser.add_argument("--dataset_path", type=str, help='calibration_dataset')
parser.add_argument("--input_width", type=int, default=320, help='model input_width')
parser.add_argument("--input_height", type=int, default=320, help='model input_height')
parser.add_argument("--ptq_option", type=int, default=0, help='ptq_option:0,1,2,3,4,5')
args = parser.parse_args()
# # 更新参数为32倍数
# input_width = int(math.ceil(args.input_width / 32.0)) * 32
# input_height = int(math.ceil(args.input_height / 32.0)) * 32
# CRNN:高度固定 32,宽度对齐到横向 stride 的倍数
# CRNN 识别模型在横向大概只下采样了 2 次(stride=2, stride=2),所以总 stride=4;输入宽度最好是 4 的倍数。
stride_x = 4 # 你的模型最后输出 25 步,对应 W=100/4 → stride_x=4
input_height = 32
input_width = int(math.ceil(args.input_width / float(stride_x))) * stride_x
# 模型的输入shape,维度要跟input_layout一致
input_shape=[1,3,input_height,input_width]
dump_dir = 'tmp'
if not os.path.exists(dump_dir):
os.makedirs(dump_dir)
# onnx simplify
model_file = onnx_simplify(args.model, dump_dir,input_shape)
# 设置CompileOptions
compile_options = nncase.CompileOptions()
compile_options.target = args.target
# 是否采用kmodel模型做预处理
compile_options.preprocess = True
# onnx模型需要RGB的,k230上的摄像头给出的数据也是RGB格式的,因此不需要开启交换RB
compile_options.swapRB = False
# 输入图像的shape
compile_options.input_shape = input_shape
# 模型输入格式‘uint8’或者‘float32’
compile_options.input_type = 'uint8'
# 如果输入是‘uint8’格式,输入反量化之后的范围
compile_options.input_range = [0, 1]
# # 预处理的mean/std值,每个channel一个,该数据由YOLOv8源码获取
# compile_options.mean = [0, 0, 0]
# compile_options.std = [1, 1, 1]
# 预处理的mean/std值,每个channel一个
compile_options.mean = [0.5, 0.5, 0.5] # 尝试
compile_options.std = [0.5, 0.5, 0.5] # 尝试
# 设置输入的layout,onnx默认‘NCHW’即可
compile_options.input_layout = "NCHW"
# 创建Compiler实例
compiler = nncase.Compiler(compile_options)
# 导入onnx模型
model_content = read_model_file(model_file)
import_options = nncase.ImportOptions()
compiler.import_onnx(model_content, import_options)
# 配置量化方式
ptq_options = nncase.PTQTensorOptions()
ptq_options.samples_count = 10
if args.ptq_option == 0:
ptq_options.calibrate_method = 'NoClip'
ptq_options.quant_type = 'uint8'
ptq_options.w_quant_type = 'uint8'
elif args.ptq_option == 1:
ptq_options.calibrate_method = 'NoClip'
ptq_options.quant_type = 'uint8'
ptq_options.w_quant_type = 'int16'
elif args.ptq_option == 2:
ptq_options.calibrate_method = 'NoClip'
ptq_options.quant_type = 'int16'
ptq_options.w_quant_type = 'uint8'
elif args.ptq_option == 3:
ptq_options.calibrate_method = 'Kld'
ptq_options.quant_type = 'uint8'
ptq_options.w_quant_type = 'uint8'
elif args.ptq_option == 4:
ptq_options.calibrate_method = 'Kld'
ptq_options.quant_type = 'uint8'
ptq_options.w_quant_type = 'int16'
elif args.ptq_option == 5:
ptq_options.calibrate_method = 'Kld'
ptq_options.quant_type = 'int16'
ptq_options.w_quant_type = 'uint8'
else:
pass
# 设置校正数据
ptq_options.set_tensor_data(generate_data(input_shape, ptq_options.samples_count, args.dataset_path))
compiler.use_ptq(ptq_options)
# 启动编译
compiler.compile()
# 写入kmodel文件
kmodel = compiler.gencode_tobytes()
base,ext=os.path.splitext(args.model)
kmodel_name=base+".kmodel"
with open(kmodel_name, 'wb') as f:
f.write(kmodel)
if __name__ == '__main__':
main()
如果尝试用int16量化,--ptq_option 1,会报错:
warn: Nncase.Hosting.PluginLoader[0]
NNCASE_PLUGIN_PATH is not set.
[warn] downgrade ir_version 10 -> 7
Unhandled exception. System.AggregateException: One or more errors occurred. (assert(allocation.IsOk) error!
File "/home/gitlab-runner/builds/zaC7hZ1H/1/maix2-ai-sw/k510-gnne-compiler/modules/Nncase.Modules.K230/Transform/Rules/Tile/TileLSTM.cs", line 201 .)
---> System.InvalidOperationException: assert(allocation.IsOk) error!
File "/home/gitlab-runner/builds/zaC7hZ1H/1/maix2-ai-sw/k510-gnne-compiler/modules/Nncase.Modules.K230/Transform/Rules/Tile/TileLSTM.cs", line 201 .
at Nncase.Passes.Rules.K230.TileUtilities.Assert(Boolean v, String vStr, String path, Int32 line)
at Nncase.Passes.Rules.K230.TileLSTM.SearchGlbParameters()
at Nncase.Passes.Rules.K230.TileLSTM.GetReplace(Expr output, Call midCall, IReadOnlyList`1 midCallParams)
at Nncase.Passes.Rules.K230.TileLSTM.GetReplace(IMatchResult __result, RunPassContext __context)
at Nncase.Passes.Rules.Tile.K230FusionConvertVisitor.Process(Fusion fusion)
at Nncase.Passes.Rules.Tile.K230FusionConvertVisitor.RewriteLeafFusion(Fusion expr)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.Rewrite(Expr expr, TContext context)
at Nncase.IR.ExprRewriter.Rewrite(Expr expr)
at Nncase.Passes.Rules.Tile.CheckedConvertMutator.RewriteLeafFusion(Fusion expr)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitTuple(Tuple expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitCall(Call expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.VisitOperands(Expr expr, TContext context)
at Nncase.IR.ExprVisitor`3.VisitFunction(Function expr, TContext context)
at Nncase.IR.ExprVisitor`3.DispatchVisit(Expr expr, TContext context)
at Nncase.IR.ExprRewriter`1.Rewrite(Expr expr, TContext context)
at Nncase.IR.ExprRewriter.Rewrite(Expr expr)
at Nncase.Passes.Rules.Tile.K230FusionToTirPass.RunCoreAsync(IRModule module, RunPassContext options)
at Nncase.Passes.Pass`2.RunAsync(TInput input, RunPassContext context)
at Nncase.Passes.PassManager.ModulePassGroup.RunAsync(IRModule module)
at Nncase.Passes.PassManager.RunAsync(IRModule module)
at Nncase.Compiler.Compiler.RunPassAsync(Action`1 register, String name, IProgress`1 progress, CancellationToken token)
at Nncase.Compiler.Compiler.CompileAsync(IProgress`1 progress, CancellationToken token)
--- End of inner exception stack trace ---
at System.Threading.Tasks.Task.ThrowIfExceptional(Boolean includeTaskCanceledExceptions)
at System.Threading.Tasks.Task.Wait(Int32 millisecondsTimeout, CancellationToken cancellationToken)
at Nncase.Compiler.Interop.CApi.CompilerCompile(IntPtr compilerHandle)
硬件板卡
创乐博k230
软件版本
CanMV-K230-V3_sdcard__nncase_v2.9.0.img.gz