mirror of
https://github.com/NanjingForestryUniversity/supermachine--tomato-passion_fruit.git
synced 2025-11-09 06:44:02 +00:00
87 lines
3.1 KiB
Python
87 lines
3.1 KiB
Python
import joblib
|
|
import numpy as np
|
|
import os
|
|
from model_train import prepare_data
|
|
|
|
def read_spectral_data(hdr_path, raw_path):
|
|
# Read HDR file for image dimensions information
|
|
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
|
|
lines = hdr_file.readlines()
|
|
height = width = bands = 0
|
|
for line in lines:
|
|
if line.startswith('lines'):
|
|
height = int(line.split()[-1])
|
|
elif line.startswith('samples'):
|
|
width = int(line.split()[-1])
|
|
elif line.startswith('bands'):
|
|
bands = int(line.split()[-1])
|
|
|
|
# Read spectral data from RAW file
|
|
raw_image = np.fromfile(raw_path, dtype='uint16')
|
|
# Initialize the image with the actual read dimensions
|
|
formatImage = np.zeros((height, width, bands))
|
|
|
|
for row in range(height):
|
|
for dim in range(bands):
|
|
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
|
|
|
|
# Ensure the image is 30x30x224 by cropping or padding
|
|
target_height, target_width, target_bands = 30, 30, 224
|
|
# Crop or pad height
|
|
if height > target_height:
|
|
formatImage = formatImage[:target_height, :, :]
|
|
elif height < target_height:
|
|
pad_height = target_height - height
|
|
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
|
|
|
|
# Crop or pad width
|
|
if width > target_width:
|
|
formatImage = formatImage[:, :target_width, :]
|
|
elif width < target_width:
|
|
pad_width = target_width - width
|
|
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
|
|
|
|
# Crop or pad bands if necessary (usually bands should not change)
|
|
if bands > target_bands:
|
|
formatImage = formatImage[:, :, :target_bands]
|
|
elif bands < target_bands:
|
|
pad_bands = target_bands - bands
|
|
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
|
|
|
|
return formatImage
|
|
|
|
def load_model(model_path):
|
|
"""加载模型"""
|
|
return joblib.load(model_path)
|
|
|
|
def predict(model, data):
|
|
"""预测数据"""
|
|
return model.predict(data)
|
|
|
|
def main():
|
|
# 加载模型
|
|
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
|
|
|
|
# 读取数据
|
|
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
|
|
all_spectral_data = []
|
|
for i in range(1, 101):
|
|
hdr_path = os.path.join(directory, f'{i}.HDR')
|
|
raw_path = os.path.join(directory, f'{i}')
|
|
spectral_data = read_spectral_data(hdr_path, raw_path)
|
|
all_spectral_data.append(spectral_data)
|
|
all_spectral_data = np.stack(all_spectral_data)
|
|
print(all_spectral_data.shape)
|
|
|
|
# 预处理数据
|
|
data_prepared = prepare_data(all_spectral_data)
|
|
print(data_prepared.shape)
|
|
|
|
# 预测数据
|
|
predictions = predict(model, data_prepared)
|
|
|
|
# 打印预测结果
|
|
print(predictions)
|
|
|
|
if __name__ == "__main__":
|
|
main() |