refactor:修改代码结构,使用config文件进行参数设置;添加必要的注释;

feat:增加糖度模型训练&预测代码;重新训练了新的糖度模型;增加图像处理预热原文件,置于models文件夹下;
fix:修改胜哥糖度模型训练代码的数据切片部分逻辑
This commit is contained in:
TG 2024-06-26 16:29:10 +08:00
parent 6a304ceb7c
commit 158cf4d9a9
15 changed files with 405 additions and 57 deletions

View File

@ -0,0 +1,111 @@
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from spec_read import all_spectral_data
import joblib
def prepare_data(data):
"""Reshape data and select specified spectral bands."""
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
# 筛选特定的波段
data_selected = data[:, :, :, selected_bands]
# 将筛选后的数据重塑为二维数组,每行代表一个样本
reshaped_data = data_selected.reshape(-1, 30 * 30 * len(selected_bands))
return reshaped_data
def split_data(X, y, test_size=0.20, random_state=12):
"""Split data into training and test sets."""
return train_test_split(X, y, test_size=test_size, random_state=random_state)
def evaluate_model(model, X_test, y_test):
"""Evaluate the model and return multiple metrics and predictions."""
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
return mse, mae, r2, y_pred
def print_predictions(y_test, y_pred, model_name):
"""Print actual and predicted values."""
print(f"Test Set Predictions for {model_name}:")
for i, (real, pred) in enumerate(zip(y_test, y_pred)):
print(f"Sample {i + 1}: True Value = {real:.2f}, Predicted Value = {pred:.2f}")
def main():
sweetness_acidity = np.array([
16.2, 16.1, 17, 16.9, 16.8, 17.8, 18.1, 17.2, 17, 17.2, 17.1, 17.2,
17.2, 17.2, 18.1, 17, 17.6, 17.4, 17.1, 17.1, 16.9, 17.6, 17.3, 16.3,
16.5, 18.7, 17.6, 16.2, 16.8, 17.2, 16.8, 17.3, 16, 16.6, 16.7, 16.7,
17.3, 16.3, 16.8, 17.4, 17.3, 16.3, 16.1, 17.2, 18.6, 16.8, 16.1, 17.2,
18.3, 16.5, 16.6, 17, 17, 17.8, 16.4, 18, 17.7, 17, 18.3, 16.8, 17.5,
17.7, 18.5, 18, 17.7, 17, 18.3, 18.1, 17.4, 17.7, 17.8, 16.3, 17.1, 16.8,
17.2, 17.5, 16.6, 17.7, 17.1, 17.7, 19.4, 20.3, 17.3, 15.8, 18, 17.7,
17.2, 15.2, 18, 18.4, 18.3, 15.7, 17.2, 18.6, 15.6, 17, 16.9, 17.4, 17.8,
16.5
])
X = prepare_data(all_spectral_data)
print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}')
X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity)
models_params = {
"RandomForest": {
'model': RandomForestRegressor(),
'params': {
'n_estimators': [100, 200, 300],
'max_depth': [None, 10, 20],
'min_samples_split': [2, 5],
'min_samples_leaf': [1, 2],
'random_state': [42]
}
},
"GradientBoosting": {
'model': GradientBoostingRegressor(),
'params': {
'n_estimators': [100, 200, 300],
'learning_rate': [0.01, 0.1, 0.2],
'max_depth': [3, 5, 7],
'min_samples_split': [2, 5],
'min_samples_leaf': [1, 2],
'random_state': [42]
}
},
"SVR": {
'model': SVR(),
'params': {
'C': [0.1, 1, 10, 100],
'gamma': ['scale', 'auto', 0.01, 0.1],
'epsilon': [0.01, 0.1, 0.5]
}
}
}
best_models = {}
for model_name, mp in models_params.items():
grid_search = GridSearchCV(mp['model'], mp['params'], cv=5, scoring='r2', verbose=2)
grid_search.fit(X_train, y_train)
best_models[model_name] = grid_search.best_estimator_
mse, mae, r2, y_pred = evaluate_model(grid_search.best_estimator_, X_test, y_test)
print(f"Best {model_name} parameters: {grid_search.best_params_}")
print(f"Model: {model_name}")
print(f"MSE on the test set: {mse}")
print(f"MAE on the test set: {mae}")
print(f"R² score on the test set: {r2}")
print_predictions(y_test, y_pred, model_name)
print("\n" + "-" * 50 + "\n")
# Optionally save the best model for each type
joblib.dump(grid_search.best_estimator_, f'{model_name}_best_model.joblib')
if __name__ == "__main__":
main()

View File

@ -0,0 +1,87 @@
import joblib
import numpy as np
import os
from model_train import prepare_data
def read_spectral_data(hdr_path, raw_path):
# Read HDR file for image dimensions information
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
lines = hdr_file.readlines()
height = width = bands = 0
for line in lines:
if line.startswith('lines'):
height = int(line.split()[-1])
elif line.startswith('samples'):
width = int(line.split()[-1])
elif line.startswith('bands'):
bands = int(line.split()[-1])
# Read spectral data from RAW file
raw_image = np.fromfile(raw_path, dtype='uint16')
# Initialize the image with the actual read dimensions
formatImage = np.zeros((height, width, bands))
for row in range(height):
for dim in range(bands):
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
# Ensure the image is 30x30x224 by cropping or padding
target_height, target_width, target_bands = 30, 30, 224
# Crop or pad height
if height > target_height:
formatImage = formatImage[:target_height, :, :]
elif height < target_height:
pad_height = target_height - height
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
# Crop or pad width
if width > target_width:
formatImage = formatImage[:, :target_width, :]
elif width < target_width:
pad_width = target_width - width
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
# Crop or pad bands if necessary (usually bands should not change)
if bands > target_bands:
formatImage = formatImage[:, :, :target_bands]
elif bands < target_bands:
pad_bands = target_bands - bands
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
return formatImage
def load_model(model_path):
"""加载模型"""
return joblib.load(model_path)
def predict(model, data):
"""预测数据"""
return model.predict(data)
def main():
# 加载模型
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
# 读取数据
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
all_spectral_data = []
for i in range(1, 101):
hdr_path = os.path.join(directory, f'{i}.HDR')
raw_path = os.path.join(directory, f'{i}')
spectral_data = read_spectral_data(hdr_path, raw_path)
all_spectral_data.append(spectral_data)
all_spectral_data = np.stack(all_spectral_data)
print(all_spectral_data.shape)
# 预处理数据
data_prepared = prepare_data(all_spectral_data)
print(data_prepared.shape)
# 预测数据
predictions = predict(model, data_prepared)
# 打印预测结果
print(predictions)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,70 @@
import numpy as np
import os
def read_spectral_data(hdr_path, raw_path):
# Read HDR file for image dimensions information
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
lines = hdr_file.readlines()
height = width = bands = 0
for line in lines:
if line.startswith('lines'):
height = int(line.split()[-1])
elif line.startswith('samples'):
width = int(line.split()[-1])
elif line.startswith('bands'):
bands = int(line.split()[-1])
# Read spectral data from RAW file
raw_image = np.fromfile(raw_path, dtype='uint16')
# Initialize the image with the actual read dimensions
formatImage = np.zeros((height, width, bands))
for row in range(height):
for dim in range(bands):
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
# Ensure the image is 30x30x224 by cropping or padding
target_height, target_width, target_bands = 30, 30, 224
# Crop or pad height
if height > target_height:
formatImage = formatImage[:target_height, :, :]
elif height < target_height:
pad_height = target_height - height
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
# Crop or pad width
if width > target_width:
formatImage = formatImage[:, :target_width, :]
elif width < target_width:
pad_width = target_width - width
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
# Crop or pad bands if necessary (usually bands should not change)
if bands > target_bands:
formatImage = formatImage[:, :, :target_bands]
elif bands < target_bands:
pad_bands = target_bands - bands
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
return formatImage
# Specify the directory containing the HDR and RAW files
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
# Initialize a list to hold all the spectral data arrays
all_spectral_data = []
# Loop through each data set (assuming there are 40 datasets)
for i in range(1, 101):
hdr_path = os.path.join(directory, f'{i}.HDR')
raw_path = os.path.join(directory, f'{i}')
# Read data
spectral_data = read_spectral_data(hdr_path, raw_path)
all_spectral_data.append(spectral_data)
# Stack all data into a single numpy array
all_spectral_data = np.stack(all_spectral_data)
print(all_spectral_data.shape) # This should print (40, 30, 30, 224)

View File

@ -14,6 +14,7 @@ import random
import numpy as np
from PIL import Image
from utils import Pipe
from config import Config as setting
from sklearn.ensemble import RandomForestRegressor
#图像分类网络所需库,实际并未使用分类网络
# import torch
@ -23,8 +24,10 @@ from sklearn.ensemble import RandomForestRegressor
#番茄RGB处理模型
class Tomato:
def __init__(self):
def __init__(self, find_reflection_threshold=setting.find_reflection_threshold, extract_g_r_factor=setting.extract_g_r_factor):
''' 初始化 Tomato 类。'''
self.find_reflection_threshold = find_reflection_threshold
self.extract_g_r_factor = extract_g_r_factor
pass
def extract_s_l(self, image):
@ -40,14 +43,14 @@ class Tomato:
result = cv2.add(s_channel, l_channel)
return result
def find_reflection(self, image, threshold=190):
def find_reflection(self, image):
'''
通过阈值处理识别图像中的反射区域
:param image: 输入的单通道图像
:param threshold: 用于二值化的阈值
:return: 二值化后的图像高于阈值的部分为白色其余为黑色
'''
_, reflection = cv2.threshold(image, threshold, 255, cv2.THRESH_BINARY)
_, reflection = cv2.threshold(image, self.find_reflection_threshold, 255, cv2.THRESH_BINARY)
return reflection
def otsu_threshold(self, image):
@ -67,7 +70,7 @@ class Tomato:
'''
g_channel = image[:, :, 1]
r_channel = image[:, :, 2]
result = cv2.subtract(cv2.multiply(g_channel, 1.5), r_channel)
result = cv2.subtract(cv2.multiply(g_channel, self.extract_g_r_factor), r_channel)
return result
def extract_r_b(self, image):
@ -226,7 +229,8 @@ class Tomato:
#百香果RGB处理模型
class Passion_fruit:
def __init__(self, hue_value=37, hue_delta=10, value_target=25, value_delta=10):
def __init__(self, hue_value=setting.hue_value, hue_delta=setting.hue_delta,
value_target=setting.value_target, value_delta=setting.value_delta):
# 初始化常用参数
self.hue_value = hue_value
self.hue_delta = hue_delta
@ -259,7 +263,8 @@ class Passion_fruit:
def find_largest_component(self, mask):
if mask is None or mask.size == 0 or np.all(mask == 0):
logging.warning("RGB 图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if mask is None else np.zeros_like(mask)
return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
if mask is None else np.zeros_like(mask)
# 寻找最大连通组件
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S)
if num_labels < 2:
@ -291,11 +296,13 @@ class Passion_fruit:
# 检查 RGB 图像是否为空或全黑
if rgb_img is None or rgb_img.size == 0 or np.all(rgb_img == 0):
logging.warning("RGB 图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
if rgb_img is None else np.zeros_like(rgb_img)
# 检查二值图像是否为空或全黑
if bin_img is None or bin_img.size == 0 or np.all(bin_img == 0):
logging.warning("二值图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
if bin_img is None else np.zeros_like(bin_img)
# 转换二值图像为三通道
try:
bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR)
@ -336,13 +343,15 @@ class Spec_predict(object):
def predict(self, data_x):
'''
对数据进行预测
:param data_x: 波段选择后的数据
:return: 预测结果二值化后的数据0为背景1为黄芪,2为杂质23为杂质14为甘草片5为红芪
预测数据
:param data_x: 重塑为二维数组的数据
:return: 预测结果糖度
'''
data_x = data_x.reshape(1, -1)
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
data_x = data_x[:, selected_bands]
# 对数据进行切片,筛选谱段
#qt_test进行测试时如果读取的是3030224需要解开注释进行数据切片筛选谱段
# data_x = data_x[ :25, :, setting.selected_bands ]
# 将筛选后的数据重塑为二维数组,每行代表一个样本
data_x = data_x.reshape(-1, setting.n_spec_rows * setting.n_spec_cols * setting.n_spec_bands)
data_y = self.model.predict(data_x)
return data_y[0]
@ -508,8 +517,8 @@ class Data_processing:
返回:
float: 估算的西红柿体积
"""
a = ((long_axis / 425) * 6.3) / 2
b = ((short_axis / 425) * 6.3) / 2
a = (long_axis * setting.pixel_length_ratio) / 2
b = (short_axis * setting.pixel_length_ratio) / 2
volume = 4 / 3 * np.pi * a * b * b
weight = round(volume * self.density)
#重量单位为g
@ -524,19 +533,16 @@ class Data_processing:
tuple: (长径, 短径, 缺陷区域个数, 缺陷区域总像素, 处理后的图像)
"""
tomato = Tomato() # 创建 Tomato 类的实例
# 设置 S-L 通道阈值并处理图像
threshold_s_l = 180
threshold_fore_g_r_t = 20
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
s_l = tomato.extract_s_l(img)
thresholded_s_l = tomato.threshold_segmentation(s_l, threshold_s_l)
thresholded_s_l = tomato.threshold_segmentation(s_l, setting.threshold_s_l)
new_bin_img = tomato.largest_connected_component(thresholded_s_l)
filled_img, defect = self.fill_holes(new_bin_img)
# 绘制西红柿边缘并获取缺陷信息
edge, mask = tomato.draw_tomato_edge(img, new_bin_img)
org_defect = tomato.bitwise_and_rgb_with_binary(edge, new_bin_img)
fore = tomato.bitwise_and_rgb_with_binary(img, mask)
fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=threshold_fore_g_r_t)
fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=setting.threshold_fore_g_r_t)
res = cv2.bitwise_or(new_bin_img, fore_g_r_t)
nogreen = tomato.bitwise_and_rgb_with_binary(edge, res)
# 统计白色像素点个数
@ -554,8 +560,8 @@ class Data_processing:
# cv2.imwrite('filled_img.jpg',filled_img)
# 将处理后的图像转换为 RGB 格式
rp = cv2.cvtColor(nogreen, cv2.COLOR_BGR2RGB)
#直径单位为cm所以需要除以10
diameter = (long_axis + short_axis) /425 * 63 / 2 / 10
#直径单位为cm
diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2
# print(f'直径:{diameter}')
# 如果直径小于3判断为空果拖异常图则将所有值重置为0
if diameter < 2.5:
@ -563,16 +569,17 @@ class Data_processing:
green_percentage = 0
number_defects = 0
total_pixels = 0
rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB)
rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
dtype=np.uint8), cv2.COLOR_BGR2RGB)
return diameter, green_percentage, number_defects, total_pixels, rp
def analyze_passion_fruit(self, img, hue_value=37, hue_delta=10, value_target=25, value_delta=10):
def analyze_passion_fruit(self, img):
if img is None:
logging.error("Error: 无图像数据.")
return None
# 创建PassionFruit类的实例
pf = Passion_fruit(hue_value=hue_value, hue_delta=hue_delta, value_target=value_target, value_delta=value_delta)
pf = Passion_fruit()
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
@ -596,15 +603,16 @@ class Data_processing:
edge = pf.draw_contours_on_image(img, contour_mask)
org_defect = pf.bitwise_and_rgb_with_binary(edge, max_mask)
rp = cv2.cvtColor(org_defect, cv2.COLOR_BGR2RGB)
#直径单位为cm所以需要除以10
diameter = (long_axis + short_axis) /425 * 63 / 2 / 10
#直径单位为cm
diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2
# print(f'直径:{diameter}')
if diameter < 2.5:
diameter = 0
green_percentage = 0
weight = 0
number_defects = 0
total_pixels = 0
rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB)
rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
dtype=np.uint8), cv2.COLOR_BGR2RGB)
return diameter, weight, number_defects, total_pixels, rp
def process_data(seif, cmd: str, images: list, spec: any, pipe: Pipe, detector: Spec_predict) -> bool:

View File

@ -4,9 +4,50 @@
# @File : config.py
# @Software: PyCharm
from root_dir import ROOT_DIR
class Config:
#文件相关参数
#预热参数
n_rows, n_cols, n_channels = 256, 256, 3
n_spec_rows, n_spec_cols, n_spec_bands = 25, 30, 13
n_rgb_rows, n_rgb_cols, n_rgb_bands = 613, 800, 3
tomato_img_dir = ROOT_DIR / 'models' / 'TO.bmp'
passion_fruit_img_dir = ROOT_DIR / 'models' / 'PF.bmp'
#模型路径
#糖度模型
brix_model_path = ROOT_DIR / 'models' / 'passion_fruit.joblib'
#图像分类模型
imgclassifier_model_path = ROOT_DIR / 'models' / 'imgclassifier.joblib'
imgclassifier_class_indices_path = ROOT_DIR / 'models' / 'class_indices.json'
#classifer.py参数
#tomato
find_reflection_threshold = 190
extract_g_r_factor = 1.5
#passion_fruit
hue_value = 37
hue_delta = 10
value_target = 25
value_delta = 10
#spec_predict
#筛选谱段并未使用在qt取数据时已经筛选
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
#data_processing
#根据标定数据计算的参数,实际长度/像素长度单位cm
pixel_length_ratio = 6.3/425
#绿叶面积阈值,高于此阈值认为连通域是绿叶
area_threshold = 20000
#百香果密度g/cm^3
density = 0.652228972
#百香果面积比例每个像素代表的实际面积cm^2
area_ratio = 0.00021973702422145334
#def analyze_tomato
#s_l通道阈值
threshold_s_l = 180
threshold_fore_g_r_t = 20

View File

@ -14,11 +14,11 @@ import logging
from utils import Pipe
import numpy as np
import time
from config import Config
def main(is_debug=False):
setting = Config()
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'), encoding='utf-8')
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
console_handler = logging.StreamHandler(sys.stdout)
@ -27,15 +27,18 @@ def main(is_debug=False):
handlers=[file_handler, console_handler],
level=logging.DEBUG)
#模型加载
detector = Spec_predict(ROOT_DIR/'models'/'passion_fruit_2.joblib')
# classifier = ImageClassifier(ROOT_DIR/'models'/'resnet34_0619.pth', ROOT_DIR/'models'/'class_indices.json')
detector = Spec_predict()
detector.load(path=setting.brix_model_path)
# classifier = ImageClassifier(model_path=setting.imgclassifier_model_path,
# class_indices_path=setting.imgclassifier_class_indices_path)
dp = Data_processing()
print('系统初始化中...')
#模型预热
_ = detector.predict(np.ones((30, 30, 224), dtype=np.uint16))
# _ = classifier.predict(np.ones((224, 224, 3), dtype=np.uint8))
# _, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp'))
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp'))
#与qt_test测试时需要注释掉预热模型接收尺寸为253013qt_test发送的数据为3030224需要对数据进行切片classifer.py第352行
_ = detector.predict(np.ones((setting.n_spec_rows, setting.n_spec_cols, setting.n_spec_bands), dtype=np.uint16))
# _ = classifier.predict(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8))
# _, _, _, _, _ =dp.analyze_tomato(cv2.imread(str(setting.tomato_img_dir)))
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(str(setting.passion_fruit_img_dir))
print('系统初始化完成')
rgb_receive_name = r'\\.\pipe\rgb_receive'

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@ -0,0 +1,4 @@
{
"0": "exist",
"1": "no_exist"
}

Binary file not shown.

View File

@ -1,4 +1,3 @@
import pathlib
file_path = pathlib.Path(__file__)

View File

@ -12,6 +12,7 @@ import win32pipe
import time
import logging
import numpy as np
from config import Config as setting
from PIL import Image
import io
@ -204,13 +205,15 @@ class Pipe:
gp = int(green_percentage * 100).to_bytes(1, byteorder='big')
weight = 0
weight = weight.to_bytes(1, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes
send_message = (length + cmd_re + brix + gp + diameter + weight +
defect_num + total_defect_area + height + width + img_bytes)
elif cmd == 'PF':
brix = int(brix * 1000).to_bytes(2, byteorder='big')
gp = 0
gp = gp.to_bytes(1, byteorder='big')
weight = weight.to_bytes(1, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes
send_message = (length + cmd_re + brix + gp + diameter + weight +
defect_num + total_defect_area + height + width + img_bytes)
elif cmd == 'KO':
brix = 0
brix = brix.to_bytes(2, byteorder='big')
@ -222,13 +225,15 @@ class Pipe:
defect_num = defect_num.to_bytes(2, byteorder='big')
total_defect_area = 0
total_defect_area = total_defect_area.to_bytes(4, byteorder='big')
height = 100
height = setting.n_rgb_rows
height = height.to_bytes(2, byteorder='big')
width = 100
width = setting.n_rgb_cols
width = width.to_bytes(2, byteorder='big')
img_bytes = np.zeros((100, 100, 3), dtype=np.uint8).tobytes()
img_bytes = np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
dtype=np.uint8).tobytes()
length = (18).to_bytes(4, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes
send_message = (length + cmd_re + brix + gp + diameter + weight +
defect_num + total_defect_area + height + width + img_bytes)
try:
win32file.WriteFile(self.rgb_send, send_message)
# time.sleep(0.01)

View File

@ -4,25 +4,40 @@ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from spec_read import all_spectral_data
import joblib
# def prepare_data(data):
# """Reshape data and select specified spectral bands."""
# reshaped_data = data.reshape(100, -1)
# selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
# return reshaped_data[:, selected_bands]
def prepare_data(data):
"""Reshape data and select specified spectral bands."""
reshaped_data = data.reshape(100, -1)
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
return reshaped_data[:, selected_bands]
# 筛选特定的波段
data_selected = data[:, :25, :, selected_bands]
print(f'筛选后的数据尺寸:{data_selected.shape}')
# 将筛选后的数据重塑为二维数组,每行代表一个样本
reshaped_data = data_selected.reshape(-1, 25 * 30 * len(selected_bands))
return reshaped_data
def split_data(X, y, test_size=0.20, random_state=12):
"""Split data into training and test sets."""
return train_test_split(X, y, test_size=test_size, random_state=random_state)
def evaluate_model(model, X_test, y_test):
"""Evaluate the model and return MSE and predictions."""
"""Evaluate the model and return multiple metrics and predictions."""
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
return mse, y_pred
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
return mse, mae, r2, y_pred
def print_predictions(y_test, y_pred, model_name):
"""Print actual and predicted values."""
@ -44,6 +59,7 @@ def main():
])
X = prepare_data(all_spectral_data)
print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}')
X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity)
models = {
@ -55,13 +71,17 @@ def main():
for model_name, model in models.items():
model.fit(X_train, y_train)
if model_name == "RandomForest":
joblib.dump(model, r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib')
joblib.dump(model,
r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit.joblib')
mse, y_pred = evaluate_model(model, X_test, y_test)
mse, mae, r2, y_pred = evaluate_model(model, X_test, y_test)
print(f"Model: {model_name}")
print(f"Mean Squared Error on the test set: {mse}")
print(f"MSE on the test set: {mse}")
print(f"MAE on the test set: {mae}")
print(f"R² score on the test set: {r2}")
print_predictions(y_test, y_pred, model_name)
print("\n" + "-" * 50 + "\n")
if __name__ == "__main__":
main()

View File

@ -60,7 +60,7 @@ def predict(model, data):
def main():
# 加载模型
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib')
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
# 读取数据
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'