tflite模型PC端与嵌入式交叉验证

TFLM(Tensorflow lite micro)验证嵌入式端模型运行,直截了当做法是:对比PC端和嵌入式端运行的tflite模型的输入输出。笔者就TinyML的HelloWorld例程,实践了PC端tflite模型运行情况和MCU端RT1062部署后运行情况。

  • mcu: nxp rt1062
  • ide: MCUXpresso IDE
  • TinyML framework: TensorflowLiteMicro

1. 嵌入式端:输入输出打印(float print)

1) IDE支持浮点数打印

NXP IDE MCUXpresso支持float打印比较简单,如下设置:

设置浮点打印支持

2) 打印输出代码

void model_run()
{
	 float position = static_cast<float>(inference_count) /
		                       static_cast<float>(kInferencesPerCycle);
	float x = position * kXrange;
	、
	int8_t x_quantized = x / input->params.scale + input->params.zero_point;
	input->data.int8[0] = x_quantized;

	MODEL_RunInference();
	int8_t y_quantized = output->data.int8[0];

	float y = (y_quantized - output->params.zero_point) * output->params.scale;
	printf("inference: %d > x: %f > x_q: %d, y_q: %d > y: %f \r\n",
			inference_count,x,x_quantized,y_quantized,			y);
	inference_count += 1;
	if (inference_count >= kInferencesPerCycle) inference_count = 0;
}

运行后串口输出日志信息:

input  zero_porint: -128 , scale: 0.024574
output zero_porint: 4 , scale: 0.008472
inference: 0 > x: 0.000000 > x_q: -128, y_q: 4 > y: 0.000000
inference: 1 > x: 0.314159 > x_q: -115, y_q: 48 > y: 0.372770
inference: 2 > x: 0.628319 > x_q: -102, y_q: 70 > y: 0.559154
inference: 3 > x: 0.942478 > x_q: -89, y_q: 103 > y: 0.838731
inference: 4 > x: 1.256637 > x_q: -76, y_q: 118 > y: 0.965812
inference: 5 > x: 1.570796 > x_q: -64, y_q: 127 > y: 1.042060
inference: 6 > x: 1.884956 > x_q: -51, y_q: 117 > y: 0.957340
inference: 7 > x: 2.199115 > x_q: -38, y_q: 101 > y: 0.821787
inference: 8 > x: 2.513274 > x_q: -25, y_q: 67 > y: 0.533738
inference: 9 > x: 2.827433 > x_q: -12, y_q: 32 > y: 0.237217
inference: 10 > x: 3.141593 > x_q: 0, y_q: 5 > y: 0.008472
inference: 11 > x: 3.455752 > x_q: 12, y_q: -32 > y: -0.304993
inference: 12 > x: 3.769912 > x_q: 25, y_q: -59 > y: -0.533738
inference: 13 > x: 4.084070 > x_q: 38, y_q: -88 > y: -0.779427
inference: 14 > x: 4.398230 > x_q: 50, y_q: -110 > y: -0.965812
inference: 15 > x: 4.712389 > x_q: 63, y_q: -127 > y: -1.109837
inference: 16 > x: 5.026548 > x_q: 76, y_q: -112 > y: -0.982756
inference: 17 > x: 5.340708 > x_q: 89, y_q: -84 > y: -0.745539
inference: 18 > x: 5.654867 > x_q: 102, y_q: -59 > y: -0.533738
inference: 19 > x: 5.969026 > x_q: 114, y_q: -38 > y: -0.355825
inference: 0 > x: 0.000000 > x_q: -128, y_q: 4 > y: 0.000000
inference: 1 > x: 0.314159 > x_q: -115, y_q: 48 > y: 0.372770

2. PC端:串口日志解析

import numpy as np
import pandas as pd
import re
def log_to_df(log_file, pat, x_match_idx, y_match_idx):
    x = []
    y = []

    with open(log_file,'r') as fs:
        lines = fs.readlines()
        for line in lines:
            rlt = re.search(pat,line)
            if rlt:
                x.append(rlt.group(x_match_idx))
                y.append(rlt.group(y_match_idx))
    
    df = pd.DataFrame({
            'x':x,
            'y':y
        })
    df['x'] = df['x'].astype(float)
    df['y'] = df['y'].astype(float)
    return df

log_file = '../data/sine_model_out.txt'
pat ='inference: (\d+) > x: ([-]?\d+\.\d+) > x_q: ([-]?\d+), y_q: ([-]?\d+) > y: ([-]?\d\.\d+)'
x_match_idx = 2
y_match_idx = 5

df = log_to_df(log_file, pat, x_match_idx, y_match_idx)
df

log_to_df使用re解析得到输入输出

df

3. PC端:输入输出对比验证

import tensorflow as tf
import numpy as np
def predict_tflite(tflite_model, x_test):
    # Prepare the test data
    x_test_ = x_test.copy()
    x_test_ = x_test_.reshape((x_test.size, 1))
    x_test_ = x_test_.astype(np.float32)

    # Initialize the TFLite interpreter
    if isinstance(tflite_model, str):
        interpreter = tf.lite.Interpreter(model_path=tflite_model)
    else:
        interpreter = tf.lite.Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]

    # If required, quantize the input layer (from float to integer)
    input_scale, input_zero_point = input_details["quantization"]
    if (input_scale, input_zero_point) != (0.0, 0):
        x_test_ = x_test_ / input_scale + input_zero_point
        x_test_ = x_test_.astype(input_details["dtype"])

    # Invoke the interpreter
    y_pred = np.empty(x_test_.size, dtype=output_details["dtype"])
    for i in range(len(x_test_)):
        interpreter.set_tensor(input_details["index"], [x_test_[i]])
        interpreter.invoke()
        y_pred[i] = interpreter.get_tensor(output_details["index"])[0]

    # If required, dequantized the output layer (from integer to float)
    output_scale, output_zero_point = output_details["quantization"]
    if (output_scale, output_zero_point) != (0.0, 0):
        y_pred = y_pred.astype(np.float32)
        y_pred = (y_pred - output_zero_point) * output_scale

    return y_pred

def cross_validate_tflite(df, tflite_model_path):
    df['pred_pc'] = predict_tflite(tflite_model_path,df['x'].to_numpy().astype(np.float32))
    df[:20].set_index('x',drop=True).plot()
    return df
    
tflite_model_path = 'hello_world/train/models/model.tflite'    
cross_validate_tflite(df, tflite_model_path)  

输出完美匹配

其他

helloworld例程是简单回归模型,输入输出都是一维,相对简单,直接。对于图像分类问题,输入图像矩阵,输出各个类别概率,只能比较输出。另外对比基于相同的输入才有意义,这就要把样例保存到flash或者sd卡,通过fatfs相同在运行时进行识别,这样也只是覆盖小部分样例测试,效率较低。完善的解决方案,应该是通过网络进行pc端和嵌入式端的实时通讯和调试,实现硬件在环的测试。NXP eIQ AI平台工具链实现了这一过程,其本质应该是嵌入式端实现了tcp server之类的。

合智互联客户成功服务热线:400-1565-661

admin
admin管理员

上一篇:客快物流大数据项目(六十五):仓库主题
下一篇:【AIoT应用创新大赛-基于TencentOS Tiny 的智能取暖器】

留言评论

暂无留言