supermachine--tomato-passio.../20240529RGBtest3/brix_model_train&test/predict.py
TG 158cf4d9a9 refactor:修改代码结构,使用config文件进行参数设置;添加必要的注释;
feat:增加糖度模型训练&预测代码;重新训练了新的糖度模型;增加图像处理预热原文件,置于models文件夹下;
fix:修改胜哥糖度模型训练代码的数据切片部分逻辑
2024-06-26 16:29:10 +08:00

87 lines
3.1 KiB
Python

import joblib
import numpy as np
import os
from model_train import prepare_data
def read_spectral_data(hdr_path, raw_path):
# Read HDR file for image dimensions information
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
lines = hdr_file.readlines()
height = width = bands = 0
for line in lines:
if line.startswith('lines'):
height = int(line.split()[-1])
elif line.startswith('samples'):
width = int(line.split()[-1])
elif line.startswith('bands'):
bands = int(line.split()[-1])
# Read spectral data from RAW file
raw_image = np.fromfile(raw_path, dtype='uint16')
# Initialize the image with the actual read dimensions
formatImage = np.zeros((height, width, bands))
for row in range(height):
for dim in range(bands):
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
# Ensure the image is 30x30x224 by cropping or padding
target_height, target_width, target_bands = 30, 30, 224
# Crop or pad height
if height > target_height:
formatImage = formatImage[:target_height, :, :]
elif height < target_height:
pad_height = target_height - height
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
# Crop or pad width
if width > target_width:
formatImage = formatImage[:, :target_width, :]
elif width < target_width:
pad_width = target_width - width
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
# Crop or pad bands if necessary (usually bands should not change)
if bands > target_bands:
formatImage = formatImage[:, :, :target_bands]
elif bands < target_bands:
pad_bands = target_bands - bands
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
return formatImage
def load_model(model_path):
"""加载模型"""
return joblib.load(model_path)
def predict(model, data):
"""预测数据"""
return model.predict(data)
def main():
# 加载模型
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
# 读取数据
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
all_spectral_data = []
for i in range(1, 101):
hdr_path = os.path.join(directory, f'{i}.HDR')
raw_path = os.path.join(directory, f'{i}')
spectral_data = read_spectral_data(hdr_path, raw_path)
all_spectral_data.append(spectral_data)
all_spectral_data = np.stack(all_spectral_data)
print(all_spectral_data.shape)
# 预处理数据
data_prepared = prepare_data(all_spectral_data)
print(data_prepared.shape)
# 预测数据
predictions = predict(model, data_prepared)
# 打印预测结果
print(predictions)
if __name__ == "__main__":
main()