refactor:修改代码结构,使用config文件进行参数设置;添加必要的注释;

feat:增加糖度模型训练&预测代码;重新训练了新的糖度模型;增加图像处理预热原文件,置于models文件夹下;
fix:修改胜哥糖度模型训练代码的数据切片部分逻辑
This commit is contained in:
TG 2024-06-26 16:29:10 +08:00
parent 6a304ceb7c
commit 158cf4d9a9
15 changed files with 405 additions and 57 deletions

View File

@ -0,0 +1,111 @@
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from spec_read import all_spectral_data
import joblib
def prepare_data(data):
"""Reshape data and select specified spectral bands."""
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
# 筛选特定的波段
data_selected = data[:, :, :, selected_bands]
# 将筛选后的数据重塑为二维数组,每行代表一个样本
reshaped_data = data_selected.reshape(-1, 30 * 30 * len(selected_bands))
return reshaped_data
def split_data(X, y, test_size=0.20, random_state=12):
"""Split data into training and test sets."""
return train_test_split(X, y, test_size=test_size, random_state=random_state)
def evaluate_model(model, X_test, y_test):
"""Evaluate the model and return multiple metrics and predictions."""
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
return mse, mae, r2, y_pred
def print_predictions(y_test, y_pred, model_name):
"""Print actual and predicted values."""
print(f"Test Set Predictions for {model_name}:")
for i, (real, pred) in enumerate(zip(y_test, y_pred)):
print(f"Sample {i + 1}: True Value = {real:.2f}, Predicted Value = {pred:.2f}")
def main():
sweetness_acidity = np.array([
16.2, 16.1, 17, 16.9, 16.8, 17.8, 18.1, 17.2, 17, 17.2, 17.1, 17.2,
17.2, 17.2, 18.1, 17, 17.6, 17.4, 17.1, 17.1, 16.9, 17.6, 17.3, 16.3,
16.5, 18.7, 17.6, 16.2, 16.8, 17.2, 16.8, 17.3, 16, 16.6, 16.7, 16.7,
17.3, 16.3, 16.8, 17.4, 17.3, 16.3, 16.1, 17.2, 18.6, 16.8, 16.1, 17.2,
18.3, 16.5, 16.6, 17, 17, 17.8, 16.4, 18, 17.7, 17, 18.3, 16.8, 17.5,
17.7, 18.5, 18, 17.7, 17, 18.3, 18.1, 17.4, 17.7, 17.8, 16.3, 17.1, 16.8,
17.2, 17.5, 16.6, 17.7, 17.1, 17.7, 19.4, 20.3, 17.3, 15.8, 18, 17.7,
17.2, 15.2, 18, 18.4, 18.3, 15.7, 17.2, 18.6, 15.6, 17, 16.9, 17.4, 17.8,
16.5
])
X = prepare_data(all_spectral_data)
print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}')
X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity)
models_params = {
"RandomForest": {
'model': RandomForestRegressor(),
'params': {
'n_estimators': [100, 200, 300],
'max_depth': [None, 10, 20],
'min_samples_split': [2, 5],
'min_samples_leaf': [1, 2],
'random_state': [42]
}
},
"GradientBoosting": {
'model': GradientBoostingRegressor(),
'params': {
'n_estimators': [100, 200, 300],
'learning_rate': [0.01, 0.1, 0.2],
'max_depth': [3, 5, 7],
'min_samples_split': [2, 5],
'min_samples_leaf': [1, 2],
'random_state': [42]
}
},
"SVR": {
'model': SVR(),
'params': {
'C': [0.1, 1, 10, 100],
'gamma': ['scale', 'auto', 0.01, 0.1],
'epsilon': [0.01, 0.1, 0.5]
}
}
}
best_models = {}
for model_name, mp in models_params.items():
grid_search = GridSearchCV(mp['model'], mp['params'], cv=5, scoring='r2', verbose=2)
grid_search.fit(X_train, y_train)
best_models[model_name] = grid_search.best_estimator_
mse, mae, r2, y_pred = evaluate_model(grid_search.best_estimator_, X_test, y_test)
print(f"Best {model_name} parameters: {grid_search.best_params_}")
print(f"Model: {model_name}")
print(f"MSE on the test set: {mse}")
print(f"MAE on the test set: {mae}")
print(f"R² score on the test set: {r2}")
print_predictions(y_test, y_pred, model_name)
print("\n" + "-" * 50 + "\n")
# Optionally save the best model for each type
joblib.dump(grid_search.best_estimator_, f'{model_name}_best_model.joblib')
if __name__ == "__main__":
main()

View File

@ -0,0 +1,87 @@
import joblib
import numpy as np
import os
from model_train import prepare_data
def read_spectral_data(hdr_path, raw_path):
# Read HDR file for image dimensions information
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
lines = hdr_file.readlines()
height = width = bands = 0
for line in lines:
if line.startswith('lines'):
height = int(line.split()[-1])
elif line.startswith('samples'):
width = int(line.split()[-1])
elif line.startswith('bands'):
bands = int(line.split()[-1])
# Read spectral data from RAW file
raw_image = np.fromfile(raw_path, dtype='uint16')
# Initialize the image with the actual read dimensions
formatImage = np.zeros((height, width, bands))
for row in range(height):
for dim in range(bands):
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
# Ensure the image is 30x30x224 by cropping or padding
target_height, target_width, target_bands = 30, 30, 224
# Crop or pad height
if height > target_height:
formatImage = formatImage[:target_height, :, :]
elif height < target_height:
pad_height = target_height - height
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
# Crop or pad width
if width > target_width:
formatImage = formatImage[:, :target_width, :]
elif width < target_width:
pad_width = target_width - width
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
# Crop or pad bands if necessary (usually bands should not change)
if bands > target_bands:
formatImage = formatImage[:, :, :target_bands]
elif bands < target_bands:
pad_bands = target_bands - bands
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
return formatImage
def load_model(model_path):
"""加载模型"""
return joblib.load(model_path)
def predict(model, data):
"""预测数据"""
return model.predict(data)
def main():
# 加载模型
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
# 读取数据
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
all_spectral_data = []
for i in range(1, 101):
hdr_path = os.path.join(directory, f'{i}.HDR')
raw_path = os.path.join(directory, f'{i}')
spectral_data = read_spectral_data(hdr_path, raw_path)
all_spectral_data.append(spectral_data)
all_spectral_data = np.stack(all_spectral_data)
print(all_spectral_data.shape)
# 预处理数据
data_prepared = prepare_data(all_spectral_data)
print(data_prepared.shape)
# 预测数据
predictions = predict(model, data_prepared)
# 打印预测结果
print(predictions)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,70 @@
import numpy as np
import os
def read_spectral_data(hdr_path, raw_path):
# Read HDR file for image dimensions information
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
lines = hdr_file.readlines()
height = width = bands = 0
for line in lines:
if line.startswith('lines'):
height = int(line.split()[-1])
elif line.startswith('samples'):
width = int(line.split()[-1])
elif line.startswith('bands'):
bands = int(line.split()[-1])
# Read spectral data from RAW file
raw_image = np.fromfile(raw_path, dtype='uint16')
# Initialize the image with the actual read dimensions
formatImage = np.zeros((height, width, bands))
for row in range(height):
for dim in range(bands):
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
# Ensure the image is 30x30x224 by cropping or padding
target_height, target_width, target_bands = 30, 30, 224
# Crop or pad height
if height > target_height:
formatImage = formatImage[:target_height, :, :]
elif height < target_height:
pad_height = target_height - height
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
# Crop or pad width
if width > target_width:
formatImage = formatImage[:, :target_width, :]
elif width < target_width:
pad_width = target_width - width
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
# Crop or pad bands if necessary (usually bands should not change)
if bands > target_bands:
formatImage = formatImage[:, :, :target_bands]
elif bands < target_bands:
pad_bands = target_bands - bands
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
return formatImage
# Specify the directory containing the HDR and RAW files
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
# Initialize a list to hold all the spectral data arrays
all_spectral_data = []
# Loop through each data set (assuming there are 40 datasets)
for i in range(1, 101):
hdr_path = os.path.join(directory, f'{i}.HDR')
raw_path = os.path.join(directory, f'{i}')
# Read data
spectral_data = read_spectral_data(hdr_path, raw_path)
all_spectral_data.append(spectral_data)
# Stack all data into a single numpy array
all_spectral_data = np.stack(all_spectral_data)
print(all_spectral_data.shape) # This should print (40, 30, 30, 224)

View File

@ -14,6 +14,7 @@ import random
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from utils import Pipe from utils import Pipe
from config import Config as setting
from sklearn.ensemble import RandomForestRegressor from sklearn.ensemble import RandomForestRegressor
#图像分类网络所需库,实际并未使用分类网络 #图像分类网络所需库,实际并未使用分类网络
# import torch # import torch
@ -23,8 +24,10 @@ from sklearn.ensemble import RandomForestRegressor
#番茄RGB处理模型 #番茄RGB处理模型
class Tomato: class Tomato:
def __init__(self): def __init__(self, find_reflection_threshold=setting.find_reflection_threshold, extract_g_r_factor=setting.extract_g_r_factor):
''' 初始化 Tomato 类。''' ''' 初始化 Tomato 类。'''
self.find_reflection_threshold = find_reflection_threshold
self.extract_g_r_factor = extract_g_r_factor
pass pass
def extract_s_l(self, image): def extract_s_l(self, image):
@ -40,14 +43,14 @@ class Tomato:
result = cv2.add(s_channel, l_channel) result = cv2.add(s_channel, l_channel)
return result return result
def find_reflection(self, image, threshold=190): def find_reflection(self, image):
''' '''
通过阈值处理识别图像中的反射区域 通过阈值处理识别图像中的反射区域
:param image: 输入的单通道图像 :param image: 输入的单通道图像
:param threshold: 用于二值化的阈值 :param threshold: 用于二值化的阈值
:return: 二值化后的图像高于阈值的部分为白色其余为黑色 :return: 二值化后的图像高于阈值的部分为白色其余为黑色
''' '''
_, reflection = cv2.threshold(image, threshold, 255, cv2.THRESH_BINARY) _, reflection = cv2.threshold(image, self.find_reflection_threshold, 255, cv2.THRESH_BINARY)
return reflection return reflection
def otsu_threshold(self, image): def otsu_threshold(self, image):
@ -67,7 +70,7 @@ class Tomato:
''' '''
g_channel = image[:, :, 1] g_channel = image[:, :, 1]
r_channel = image[:, :, 2] r_channel = image[:, :, 2]
result = cv2.subtract(cv2.multiply(g_channel, 1.5), r_channel) result = cv2.subtract(cv2.multiply(g_channel, self.extract_g_r_factor), r_channel)
return result return result
def extract_r_b(self, image): def extract_r_b(self, image):
@ -226,7 +229,8 @@ class Tomato:
#百香果RGB处理模型 #百香果RGB处理模型
class Passion_fruit: class Passion_fruit:
def __init__(self, hue_value=37, hue_delta=10, value_target=25, value_delta=10): def __init__(self, hue_value=setting.hue_value, hue_delta=setting.hue_delta,
value_target=setting.value_target, value_delta=setting.value_delta):
# 初始化常用参数 # 初始化常用参数
self.hue_value = hue_value self.hue_value = hue_value
self.hue_delta = hue_delta self.hue_delta = hue_delta
@ -259,7 +263,8 @@ class Passion_fruit:
def find_largest_component(self, mask): def find_largest_component(self, mask):
if mask is None or mask.size == 0 or np.all(mask == 0): if mask is None or mask.size == 0 or np.all(mask == 0):
logging.warning("RGB 图像为空或全黑返回一个全黑RGB图像。") logging.warning("RGB 图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if mask is None else np.zeros_like(mask) return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
if mask is None else np.zeros_like(mask)
# 寻找最大连通组件 # 寻找最大连通组件
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S) num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S)
if num_labels < 2: if num_labels < 2:
@ -291,11 +296,13 @@ class Passion_fruit:
# 检查 RGB 图像是否为空或全黑 # 检查 RGB 图像是否为空或全黑
if rgb_img is None or rgb_img.size == 0 or np.all(rgb_img == 0): if rgb_img is None or rgb_img.size == 0 or np.all(rgb_img == 0):
logging.warning("RGB 图像为空或全黑返回一个全黑RGB图像。") logging.warning("RGB 图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img) return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
if rgb_img is None else np.zeros_like(rgb_img)
# 检查二值图像是否为空或全黑 # 检查二值图像是否为空或全黑
if bin_img is None or bin_img.size == 0 or np.all(bin_img == 0): if bin_img is None or bin_img.size == 0 or np.all(bin_img == 0):
logging.warning("二值图像为空或全黑返回一个全黑RGB图像。") logging.warning("二值图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img) return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
if bin_img is None else np.zeros_like(bin_img)
# 转换二值图像为三通道 # 转换二值图像为三通道
try: try:
bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR) bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR)
@ -336,13 +343,15 @@ class Spec_predict(object):
def predict(self, data_x): def predict(self, data_x):
''' '''
对数据进行预测 预测数据
:param data_x: 波段选择后的数据 :param data_x: 重塑为二维数组的数据
:return: 预测结果二值化后的数据0为背景1为黄芪,2为杂质23为杂质14为甘草片5为红芪 :return: 预测结果糖度
''' '''
data_x = data_x.reshape(1, -1) # 对数据进行切片,筛选谱段
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145] #qt_test进行测试时如果读取的是3030224需要解开注释进行数据切片筛选谱段
data_x = data_x[:, selected_bands] # data_x = data_x[ :25, :, setting.selected_bands ]
# 将筛选后的数据重塑为二维数组,每行代表一个样本
data_x = data_x.reshape(-1, setting.n_spec_rows * setting.n_spec_cols * setting.n_spec_bands)
data_y = self.model.predict(data_x) data_y = self.model.predict(data_x)
return data_y[0] return data_y[0]
@ -508,8 +517,8 @@ class Data_processing:
返回: 返回:
float: 估算的西红柿体积 float: 估算的西红柿体积
""" """
a = ((long_axis / 425) * 6.3) / 2 a = (long_axis * setting.pixel_length_ratio) / 2
b = ((short_axis / 425) * 6.3) / 2 b = (short_axis * setting.pixel_length_ratio) / 2
volume = 4 / 3 * np.pi * a * b * b volume = 4 / 3 * np.pi * a * b * b
weight = round(volume * self.density) weight = round(volume * self.density)
#重量单位为g #重量单位为g
@ -524,19 +533,16 @@ class Data_processing:
tuple: (长径, 短径, 缺陷区域个数, 缺陷区域总像素, 处理后的图像) tuple: (长径, 短径, 缺陷区域个数, 缺陷区域总像素, 处理后的图像)
""" """
tomato = Tomato() # 创建 Tomato 类的实例 tomato = Tomato() # 创建 Tomato 类的实例
# 设置 S-L 通道阈值并处理图像
threshold_s_l = 180
threshold_fore_g_r_t = 20
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR) img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
s_l = tomato.extract_s_l(img) s_l = tomato.extract_s_l(img)
thresholded_s_l = tomato.threshold_segmentation(s_l, threshold_s_l) thresholded_s_l = tomato.threshold_segmentation(s_l, setting.threshold_s_l)
new_bin_img = tomato.largest_connected_component(thresholded_s_l) new_bin_img = tomato.largest_connected_component(thresholded_s_l)
filled_img, defect = self.fill_holes(new_bin_img) filled_img, defect = self.fill_holes(new_bin_img)
# 绘制西红柿边缘并获取缺陷信息 # 绘制西红柿边缘并获取缺陷信息
edge, mask = tomato.draw_tomato_edge(img, new_bin_img) edge, mask = tomato.draw_tomato_edge(img, new_bin_img)
org_defect = tomato.bitwise_and_rgb_with_binary(edge, new_bin_img) org_defect = tomato.bitwise_and_rgb_with_binary(edge, new_bin_img)
fore = tomato.bitwise_and_rgb_with_binary(img, mask) fore = tomato.bitwise_and_rgb_with_binary(img, mask)
fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=threshold_fore_g_r_t) fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=setting.threshold_fore_g_r_t)
res = cv2.bitwise_or(new_bin_img, fore_g_r_t) res = cv2.bitwise_or(new_bin_img, fore_g_r_t)
nogreen = tomato.bitwise_and_rgb_with_binary(edge, res) nogreen = tomato.bitwise_and_rgb_with_binary(edge, res)
# 统计白色像素点个数 # 统计白色像素点个数
@ -554,8 +560,8 @@ class Data_processing:
# cv2.imwrite('filled_img.jpg',filled_img) # cv2.imwrite('filled_img.jpg',filled_img)
# 将处理后的图像转换为 RGB 格式 # 将处理后的图像转换为 RGB 格式
rp = cv2.cvtColor(nogreen, cv2.COLOR_BGR2RGB) rp = cv2.cvtColor(nogreen, cv2.COLOR_BGR2RGB)
#直径单位为cm所以需要除以10 #直径单位为cm
diameter = (long_axis + short_axis) /425 * 63 / 2 / 10 diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2
# print(f'直径:{diameter}') # print(f'直径:{diameter}')
# 如果直径小于3判断为空果拖异常图则将所有值重置为0 # 如果直径小于3判断为空果拖异常图则将所有值重置为0
if diameter < 2.5: if diameter < 2.5:
@ -563,16 +569,17 @@ class Data_processing:
green_percentage = 0 green_percentage = 0
number_defects = 0 number_defects = 0
total_pixels = 0 total_pixels = 0
rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB) rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
dtype=np.uint8), cv2.COLOR_BGR2RGB)
return diameter, green_percentage, number_defects, total_pixels, rp return diameter, green_percentage, number_defects, total_pixels, rp
def analyze_passion_fruit(self, img, hue_value=37, hue_delta=10, value_target=25, value_delta=10): def analyze_passion_fruit(self, img):
if img is None: if img is None:
logging.error("Error: 无图像数据.") logging.error("Error: 无图像数据.")
return None return None
# 创建PassionFruit类的实例 # 创建PassionFruit类的实例
pf = Passion_fruit(hue_value=hue_value, hue_delta=hue_delta, value_target=value_target, value_delta=value_delta) pf = Passion_fruit()
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR) img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
@ -596,15 +603,16 @@ class Data_processing:
edge = pf.draw_contours_on_image(img, contour_mask) edge = pf.draw_contours_on_image(img, contour_mask)
org_defect = pf.bitwise_and_rgb_with_binary(edge, max_mask) org_defect = pf.bitwise_and_rgb_with_binary(edge, max_mask)
rp = cv2.cvtColor(org_defect, cv2.COLOR_BGR2RGB) rp = cv2.cvtColor(org_defect, cv2.COLOR_BGR2RGB)
#直径单位为cm所以需要除以10 #直径单位为cm
diameter = (long_axis + short_axis) /425 * 63 / 2 / 10 diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2
# print(f'直径:{diameter}') # print(f'直径:{diameter}')
if diameter < 2.5: if diameter < 2.5:
diameter = 0 diameter = 0
green_percentage = 0 weight = 0
number_defects = 0 number_defects = 0
total_pixels = 0 total_pixels = 0
rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB) rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
dtype=np.uint8), cv2.COLOR_BGR2RGB)
return diameter, weight, number_defects, total_pixels, rp return diameter, weight, number_defects, total_pixels, rp
def process_data(seif, cmd: str, images: list, spec: any, pipe: Pipe, detector: Spec_predict) -> bool: def process_data(seif, cmd: str, images: list, spec: any, pipe: Pipe, detector: Spec_predict) -> bool:

View File

@ -4,9 +4,50 @@
# @File : config.py # @File : config.py
# @Software: PyCharm # @Software: PyCharm
from root_dir import ROOT_DIR
class Config: class Config:
#文件相关参数 #文件相关参数
#预热参数 #预热参数
n_rows, n_cols, n_channels = 256, 256, 3 n_spec_rows, n_spec_cols, n_spec_bands = 25, 30, 13
n_rgb_rows, n_rgb_cols, n_rgb_bands = 613, 800, 3
tomato_img_dir = ROOT_DIR / 'models' / 'TO.bmp'
passion_fruit_img_dir = ROOT_DIR / 'models' / 'PF.bmp'
#模型路径
#糖度模型
brix_model_path = ROOT_DIR / 'models' / 'passion_fruit.joblib'
#图像分类模型
imgclassifier_model_path = ROOT_DIR / 'models' / 'imgclassifier.joblib'
imgclassifier_class_indices_path = ROOT_DIR / 'models' / 'class_indices.json'
#classifer.py参数
#tomato
find_reflection_threshold = 190
extract_g_r_factor = 1.5
#passion_fruit
hue_value = 37
hue_delta = 10
value_target = 25
value_delta = 10
#spec_predict
#筛选谱段并未使用在qt取数据时已经筛选
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
#data_processing
#根据标定数据计算的参数,实际长度/像素长度单位cm
pixel_length_ratio = 6.3/425
#绿叶面积阈值,高于此阈值认为连通域是绿叶
area_threshold = 20000
#百香果密度g/cm^3
density = 0.652228972
#百香果面积比例每个像素代表的实际面积cm^2
area_ratio = 0.00021973702422145334
#def analyze_tomato
#s_l通道阈值
threshold_s_l = 180
threshold_fore_g_r_t = 20

View File

@ -14,11 +14,11 @@ import logging
from utils import Pipe from utils import Pipe
import numpy as np import numpy as np
import time import time
from config import Config
def main(is_debug=False): def main(is_debug=False):
setting = Config()
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'), encoding='utf-8') file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'), encoding='utf-8')
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
@ -27,15 +27,18 @@ def main(is_debug=False):
handlers=[file_handler, console_handler], handlers=[file_handler, console_handler],
level=logging.DEBUG) level=logging.DEBUG)
#模型加载 #模型加载
detector = Spec_predict(ROOT_DIR/'models'/'passion_fruit_2.joblib') detector = Spec_predict()
# classifier = ImageClassifier(ROOT_DIR/'models'/'resnet34_0619.pth', ROOT_DIR/'models'/'class_indices.json') detector.load(path=setting.brix_model_path)
# classifier = ImageClassifier(model_path=setting.imgclassifier_model_path,
# class_indices_path=setting.imgclassifier_class_indices_path)
dp = Data_processing() dp = Data_processing()
print('系统初始化中...') print('系统初始化中...')
#模型预热 #模型预热
_ = detector.predict(np.ones((30, 30, 224), dtype=np.uint16)) #与qt_test测试时需要注释掉预热模型接收尺寸为253013qt_test发送的数据为3030224需要对数据进行切片classifer.py第352行
# _ = classifier.predict(np.ones((224, 224, 3), dtype=np.uint8)) _ = detector.predict(np.ones((setting.n_spec_rows, setting.n_spec_cols, setting.n_spec_bands), dtype=np.uint16))
# _, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp')) # _ = classifier.predict(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8))
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp')) # _, _, _, _, _ =dp.analyze_tomato(cv2.imread(str(setting.tomato_img_dir)))
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(str(setting.passion_fruit_img_dir))
print('系统初始化完成') print('系统初始化完成')
rgb_receive_name = r'\\.\pipe\rgb_receive' rgb_receive_name = r'\\.\pipe\rgb_receive'

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@ -0,0 +1,4 @@
{
"0": "exist",
"1": "no_exist"
}

Binary file not shown.

View File

@ -1,4 +1,3 @@
import pathlib import pathlib
file_path = pathlib.Path(__file__) file_path = pathlib.Path(__file__)

View File

@ -12,6 +12,7 @@ import win32pipe
import time import time
import logging import logging
import numpy as np import numpy as np
from config import Config as setting
from PIL import Image from PIL import Image
import io import io
@ -204,13 +205,15 @@ class Pipe:
gp = int(green_percentage * 100).to_bytes(1, byteorder='big') gp = int(green_percentage * 100).to_bytes(1, byteorder='big')
weight = 0 weight = 0
weight = weight.to_bytes(1, byteorder='big') weight = weight.to_bytes(1, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes send_message = (length + cmd_re + brix + gp + diameter + weight +
defect_num + total_defect_area + height + width + img_bytes)
elif cmd == 'PF': elif cmd == 'PF':
brix = int(brix * 1000).to_bytes(2, byteorder='big') brix = int(brix * 1000).to_bytes(2, byteorder='big')
gp = 0 gp = 0
gp = gp.to_bytes(1, byteorder='big') gp = gp.to_bytes(1, byteorder='big')
weight = weight.to_bytes(1, byteorder='big') weight = weight.to_bytes(1, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes send_message = (length + cmd_re + brix + gp + diameter + weight +
defect_num + total_defect_area + height + width + img_bytes)
elif cmd == 'KO': elif cmd == 'KO':
brix = 0 brix = 0
brix = brix.to_bytes(2, byteorder='big') brix = brix.to_bytes(2, byteorder='big')
@ -222,13 +225,15 @@ class Pipe:
defect_num = defect_num.to_bytes(2, byteorder='big') defect_num = defect_num.to_bytes(2, byteorder='big')
total_defect_area = 0 total_defect_area = 0
total_defect_area = total_defect_area.to_bytes(4, byteorder='big') total_defect_area = total_defect_area.to_bytes(4, byteorder='big')
height = 100 height = setting.n_rgb_rows
height = height.to_bytes(2, byteorder='big') height = height.to_bytes(2, byteorder='big')
width = 100 width = setting.n_rgb_cols
width = width.to_bytes(2, byteorder='big') width = width.to_bytes(2, byteorder='big')
img_bytes = np.zeros((100, 100, 3), dtype=np.uint8).tobytes() img_bytes = np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
dtype=np.uint8).tobytes()
length = (18).to_bytes(4, byteorder='big') length = (18).to_bytes(4, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes send_message = (length + cmd_re + brix + gp + diameter + weight +
defect_num + total_defect_area + height + width + img_bytes)
try: try:
win32file.WriteFile(self.rgb_send, send_message) win32file.WriteFile(self.rgb_send, send_message)
# time.sleep(0.01) # time.sleep(0.01)

View File

@ -4,25 +4,40 @@ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from spec_read import all_spectral_data from spec_read import all_spectral_data
import joblib import joblib
# def prepare_data(data):
# """Reshape data and select specified spectral bands."""
# reshaped_data = data.reshape(100, -1)
# selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
# return reshaped_data[:, selected_bands]
def prepare_data(data): def prepare_data(data):
"""Reshape data and select specified spectral bands.""" """Reshape data and select specified spectral bands."""
reshaped_data = data.reshape(100, -1)
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145] selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
return reshaped_data[:, selected_bands] # 筛选特定的波段
data_selected = data[:, :25, :, selected_bands]
print(f'筛选后的数据尺寸:{data_selected.shape}')
# 将筛选后的数据重塑为二维数组,每行代表一个样本
reshaped_data = data_selected.reshape(-1, 25 * 30 * len(selected_bands))
return reshaped_data
def split_data(X, y, test_size=0.20, random_state=12): def split_data(X, y, test_size=0.20, random_state=12):
"""Split data into training and test sets.""" """Split data into training and test sets."""
return train_test_split(X, y, test_size=test_size, random_state=random_state) return train_test_split(X, y, test_size=test_size, random_state=random_state)
def evaluate_model(model, X_test, y_test): def evaluate_model(model, X_test, y_test):
"""Evaluate the model and return MSE and predictions.""" """Evaluate the model and return multiple metrics and predictions."""
y_pred = model.predict(X_test) y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred) mse = mean_squared_error(y_test, y_pred)
return mse, y_pred mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
return mse, mae, r2, y_pred
def print_predictions(y_test, y_pred, model_name): def print_predictions(y_test, y_pred, model_name):
"""Print actual and predicted values.""" """Print actual and predicted values."""
@ -44,6 +59,7 @@ def main():
]) ])
X = prepare_data(all_spectral_data) X = prepare_data(all_spectral_data)
print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}')
X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity) X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity)
models = { models = {
@ -55,13 +71,17 @@ def main():
for model_name, model in models.items(): for model_name, model in models.items():
model.fit(X_train, y_train) model.fit(X_train, y_train)
if model_name == "RandomForest": if model_name == "RandomForest":
joblib.dump(model, r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib') joblib.dump(model,
r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit.joblib')
mse, y_pred = evaluate_model(model, X_test, y_test) mse, mae, r2, y_pred = evaluate_model(model, X_test, y_test)
print(f"Model: {model_name}") print(f"Model: {model_name}")
print(f"Mean Squared Error on the test set: {mse}") print(f"MSE on the test set: {mse}")
print(f"MAE on the test set: {mae}")
print(f"R² score on the test set: {r2}")
print_predictions(y_test, y_pred, model_name) print_predictions(y_test, y_pred, model_name)
print("\n" + "-"*50 + "\n") print("\n" + "-" * 50 + "\n")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -60,7 +60,7 @@ def predict(model, data):
def main(): def main():
# 加载模型 # 加载模型
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib') model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
# 读取数据 # 读取数据
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030' directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'