mirror of
https://github.com/NanjingForestryUniversity/supermachine--tomato-passion_fruit.git
synced 2025-11-09 14:54:07 +00:00
refactor:修改代码结构,使用config文件进行参数设置;添加必要的注释;
feat:增加糖度模型训练&预测代码;重新训练了新的糖度模型;增加图像处理预热原文件,置于models文件夹下; fix:修改胜哥糖度模型训练代码的数据切片部分逻辑
This commit is contained in:
parent
6a304ceb7c
commit
158cf4d9a9
111
20240529RGBtest3/brix_model_train&test/model_train.py
Normal file
111
20240529RGBtest3/brix_model_train&test/model_train.py
Normal file
@ -0,0 +1,111 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.model_selection import train_test_split, GridSearchCV
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
from spec_read import all_spectral_data
|
||||
import joblib
|
||||
|
||||
|
||||
def prepare_data(data):
|
||||
"""Reshape data and select specified spectral bands."""
|
||||
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
|
||||
# 筛选特定的波段
|
||||
data_selected = data[:, :, :, selected_bands]
|
||||
# 将筛选后的数据重塑为二维数组,每行代表一个样本
|
||||
reshaped_data = data_selected.reshape(-1, 30 * 30 * len(selected_bands))
|
||||
return reshaped_data
|
||||
|
||||
|
||||
def split_data(X, y, test_size=0.20, random_state=12):
|
||||
"""Split data into training and test sets."""
|
||||
return train_test_split(X, y, test_size=test_size, random_state=random_state)
|
||||
|
||||
|
||||
def evaluate_model(model, X_test, y_test):
|
||||
"""Evaluate the model and return multiple metrics and predictions."""
|
||||
y_pred = model.predict(X_test)
|
||||
mse = mean_squared_error(y_test, y_pred)
|
||||
mae = mean_absolute_error(y_test, y_pred)
|
||||
r2 = r2_score(y_test, y_pred)
|
||||
return mse, mae, r2, y_pred
|
||||
|
||||
|
||||
def print_predictions(y_test, y_pred, model_name):
|
||||
"""Print actual and predicted values."""
|
||||
print(f"Test Set Predictions for {model_name}:")
|
||||
for i, (real, pred) in enumerate(zip(y_test, y_pred)):
|
||||
print(f"Sample {i + 1}: True Value = {real:.2f}, Predicted Value = {pred:.2f}")
|
||||
|
||||
def main():
|
||||
sweetness_acidity = np.array([
|
||||
16.2, 16.1, 17, 16.9, 16.8, 17.8, 18.1, 17.2, 17, 17.2, 17.1, 17.2,
|
||||
17.2, 17.2, 18.1, 17, 17.6, 17.4, 17.1, 17.1, 16.9, 17.6, 17.3, 16.3,
|
||||
16.5, 18.7, 17.6, 16.2, 16.8, 17.2, 16.8, 17.3, 16, 16.6, 16.7, 16.7,
|
||||
17.3, 16.3, 16.8, 17.4, 17.3, 16.3, 16.1, 17.2, 18.6, 16.8, 16.1, 17.2,
|
||||
18.3, 16.5, 16.6, 17, 17, 17.8, 16.4, 18, 17.7, 17, 18.3, 16.8, 17.5,
|
||||
17.7, 18.5, 18, 17.7, 17, 18.3, 18.1, 17.4, 17.7, 17.8, 16.3, 17.1, 16.8,
|
||||
17.2, 17.5, 16.6, 17.7, 17.1, 17.7, 19.4, 20.3, 17.3, 15.8, 18, 17.7,
|
||||
17.2, 15.2, 18, 18.4, 18.3, 15.7, 17.2, 18.6, 15.6, 17, 16.9, 17.4, 17.8,
|
||||
16.5
|
||||
])
|
||||
|
||||
X = prepare_data(all_spectral_data)
|
||||
print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}')
|
||||
X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity)
|
||||
|
||||
models_params = {
|
||||
"RandomForest": {
|
||||
'model': RandomForestRegressor(),
|
||||
'params': {
|
||||
'n_estimators': [100, 200, 300],
|
||||
'max_depth': [None, 10, 20],
|
||||
'min_samples_split': [2, 5],
|
||||
'min_samples_leaf': [1, 2],
|
||||
'random_state': [42]
|
||||
}
|
||||
},
|
||||
"GradientBoosting": {
|
||||
'model': GradientBoostingRegressor(),
|
||||
'params': {
|
||||
'n_estimators': [100, 200, 300],
|
||||
'learning_rate': [0.01, 0.1, 0.2],
|
||||
'max_depth': [3, 5, 7],
|
||||
'min_samples_split': [2, 5],
|
||||
'min_samples_leaf': [1, 2],
|
||||
'random_state': [42]
|
||||
}
|
||||
},
|
||||
"SVR": {
|
||||
'model': SVR(),
|
||||
'params': {
|
||||
'C': [0.1, 1, 10, 100],
|
||||
'gamma': ['scale', 'auto', 0.01, 0.1],
|
||||
'epsilon': [0.01, 0.1, 0.5]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
best_models = {}
|
||||
|
||||
for model_name, mp in models_params.items():
|
||||
grid_search = GridSearchCV(mp['model'], mp['params'], cv=5, scoring='r2', verbose=2)
|
||||
grid_search.fit(X_train, y_train)
|
||||
best_models[model_name] = grid_search.best_estimator_
|
||||
mse, mae, r2, y_pred = evaluate_model(grid_search.best_estimator_, X_test, y_test)
|
||||
print(f"Best {model_name} parameters: {grid_search.best_params_}")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"MSE on the test set: {mse}")
|
||||
print(f"MAE on the test set: {mae}")
|
||||
print(f"R² score on the test set: {r2}")
|
||||
print_predictions(y_test, y_pred, model_name)
|
||||
print("\n" + "-" * 50 + "\n")
|
||||
|
||||
# Optionally save the best model for each type
|
||||
joblib.dump(grid_search.best_estimator_, f'{model_name}_best_model.joblib')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
87
20240529RGBtest3/brix_model_train&test/predict.py
Normal file
87
20240529RGBtest3/brix_model_train&test/predict.py
Normal file
@ -0,0 +1,87 @@
|
||||
import joblib
|
||||
import numpy as np
|
||||
import os
|
||||
from model_train import prepare_data
|
||||
|
||||
def read_spectral_data(hdr_path, raw_path):
|
||||
# Read HDR file for image dimensions information
|
||||
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
|
||||
lines = hdr_file.readlines()
|
||||
height = width = bands = 0
|
||||
for line in lines:
|
||||
if line.startswith('lines'):
|
||||
height = int(line.split()[-1])
|
||||
elif line.startswith('samples'):
|
||||
width = int(line.split()[-1])
|
||||
elif line.startswith('bands'):
|
||||
bands = int(line.split()[-1])
|
||||
|
||||
# Read spectral data from RAW file
|
||||
raw_image = np.fromfile(raw_path, dtype='uint16')
|
||||
# Initialize the image with the actual read dimensions
|
||||
formatImage = np.zeros((height, width, bands))
|
||||
|
||||
for row in range(height):
|
||||
for dim in range(bands):
|
||||
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
|
||||
|
||||
# Ensure the image is 30x30x224 by cropping or padding
|
||||
target_height, target_width, target_bands = 30, 30, 224
|
||||
# Crop or pad height
|
||||
if height > target_height:
|
||||
formatImage = formatImage[:target_height, :, :]
|
||||
elif height < target_height:
|
||||
pad_height = target_height - height
|
||||
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
|
||||
|
||||
# Crop or pad width
|
||||
if width > target_width:
|
||||
formatImage = formatImage[:, :target_width, :]
|
||||
elif width < target_width:
|
||||
pad_width = target_width - width
|
||||
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
|
||||
|
||||
# Crop or pad bands if necessary (usually bands should not change)
|
||||
if bands > target_bands:
|
||||
formatImage = formatImage[:, :, :target_bands]
|
||||
elif bands < target_bands:
|
||||
pad_bands = target_bands - bands
|
||||
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
|
||||
|
||||
return formatImage
|
||||
|
||||
def load_model(model_path):
|
||||
"""加载模型"""
|
||||
return joblib.load(model_path)
|
||||
|
||||
def predict(model, data):
|
||||
"""预测数据"""
|
||||
return model.predict(data)
|
||||
|
||||
def main():
|
||||
# 加载模型
|
||||
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
|
||||
|
||||
# 读取数据
|
||||
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
|
||||
all_spectral_data = []
|
||||
for i in range(1, 101):
|
||||
hdr_path = os.path.join(directory, f'{i}.HDR')
|
||||
raw_path = os.path.join(directory, f'{i}')
|
||||
spectral_data = read_spectral_data(hdr_path, raw_path)
|
||||
all_spectral_data.append(spectral_data)
|
||||
all_spectral_data = np.stack(all_spectral_data)
|
||||
print(all_spectral_data.shape)
|
||||
|
||||
# 预处理数据
|
||||
data_prepared = prepare_data(all_spectral_data)
|
||||
print(data_prepared.shape)
|
||||
|
||||
# 预测数据
|
||||
predictions = predict(model, data_prepared)
|
||||
|
||||
# 打印预测结果
|
||||
print(predictions)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
70
20240529RGBtest3/brix_model_train&test/spec_read.py
Normal file
70
20240529RGBtest3/brix_model_train&test/spec_read.py
Normal file
@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
def read_spectral_data(hdr_path, raw_path):
|
||||
# Read HDR file for image dimensions information
|
||||
with open(hdr_path, 'r', encoding='latin1') as hdr_file:
|
||||
lines = hdr_file.readlines()
|
||||
height = width = bands = 0
|
||||
for line in lines:
|
||||
if line.startswith('lines'):
|
||||
height = int(line.split()[-1])
|
||||
elif line.startswith('samples'):
|
||||
width = int(line.split()[-1])
|
||||
elif line.startswith('bands'):
|
||||
bands = int(line.split()[-1])
|
||||
|
||||
# Read spectral data from RAW file
|
||||
raw_image = np.fromfile(raw_path, dtype='uint16')
|
||||
# Initialize the image with the actual read dimensions
|
||||
formatImage = np.zeros((height, width, bands))
|
||||
|
||||
for row in range(height):
|
||||
for dim in range(bands):
|
||||
formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width]
|
||||
|
||||
# Ensure the image is 30x30x224 by cropping or padding
|
||||
target_height, target_width, target_bands = 30, 30, 224
|
||||
# Crop or pad height
|
||||
if height > target_height:
|
||||
formatImage = formatImage[:target_height, :, :]
|
||||
elif height < target_height:
|
||||
pad_height = target_height - height
|
||||
formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0)
|
||||
|
||||
# Crop or pad width
|
||||
if width > target_width:
|
||||
formatImage = formatImage[:, :target_width, :]
|
||||
elif width < target_width:
|
||||
pad_width = target_width - width
|
||||
formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0)
|
||||
|
||||
# Crop or pad bands if necessary (usually bands should not change)
|
||||
if bands > target_bands:
|
||||
formatImage = formatImage[:, :, :target_bands]
|
||||
elif bands < target_bands:
|
||||
pad_bands = target_bands - bands
|
||||
formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0)
|
||||
|
||||
return formatImage
|
||||
|
||||
|
||||
# Specify the directory containing the HDR and RAW files
|
||||
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
|
||||
|
||||
# Initialize a list to hold all the spectral data arrays
|
||||
all_spectral_data = []
|
||||
|
||||
# Loop through each data set (assuming there are 40 datasets)
|
||||
for i in range(1, 101):
|
||||
hdr_path = os.path.join(directory, f'{i}.HDR')
|
||||
raw_path = os.path.join(directory, f'{i}')
|
||||
|
||||
# Read data
|
||||
spectral_data = read_spectral_data(hdr_path, raw_path)
|
||||
all_spectral_data.append(spectral_data)
|
||||
|
||||
# Stack all data into a single numpy array
|
||||
all_spectral_data = np.stack(all_spectral_data)
|
||||
print(all_spectral_data.shape) # This should print (40, 30, 30, 224)
|
||||
@ -14,6 +14,7 @@ import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from utils import Pipe
|
||||
from config import Config as setting
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
#图像分类网络所需库,实际并未使用分类网络
|
||||
# import torch
|
||||
@ -23,8 +24,10 @@ from sklearn.ensemble import RandomForestRegressor
|
||||
|
||||
#番茄RGB处理模型
|
||||
class Tomato:
|
||||
def __init__(self):
|
||||
def __init__(self, find_reflection_threshold=setting.find_reflection_threshold, extract_g_r_factor=setting.extract_g_r_factor):
|
||||
''' 初始化 Tomato 类。'''
|
||||
self.find_reflection_threshold = find_reflection_threshold
|
||||
self.extract_g_r_factor = extract_g_r_factor
|
||||
pass
|
||||
|
||||
def extract_s_l(self, image):
|
||||
@ -40,14 +43,14 @@ class Tomato:
|
||||
result = cv2.add(s_channel, l_channel)
|
||||
return result
|
||||
|
||||
def find_reflection(self, image, threshold=190):
|
||||
def find_reflection(self, image):
|
||||
'''
|
||||
通过阈值处理识别图像中的反射区域。
|
||||
:param image: 输入的单通道图像
|
||||
:param threshold: 用于二值化的阈值
|
||||
:return: 二值化后的图像,高于阈值的部分为白色,其余为黑色
|
||||
'''
|
||||
_, reflection = cv2.threshold(image, threshold, 255, cv2.THRESH_BINARY)
|
||||
_, reflection = cv2.threshold(image, self.find_reflection_threshold, 255, cv2.THRESH_BINARY)
|
||||
return reflection
|
||||
|
||||
def otsu_threshold(self, image):
|
||||
@ -67,7 +70,7 @@ class Tomato:
|
||||
'''
|
||||
g_channel = image[:, :, 1]
|
||||
r_channel = image[:, :, 2]
|
||||
result = cv2.subtract(cv2.multiply(g_channel, 1.5), r_channel)
|
||||
result = cv2.subtract(cv2.multiply(g_channel, self.extract_g_r_factor), r_channel)
|
||||
return result
|
||||
|
||||
def extract_r_b(self, image):
|
||||
@ -226,7 +229,8 @@ class Tomato:
|
||||
|
||||
#百香果RGB处理模型
|
||||
class Passion_fruit:
|
||||
def __init__(self, hue_value=37, hue_delta=10, value_target=25, value_delta=10):
|
||||
def __init__(self, hue_value=setting.hue_value, hue_delta=setting.hue_delta,
|
||||
value_target=setting.value_target, value_delta=setting.value_delta):
|
||||
# 初始化常用参数
|
||||
self.hue_value = hue_value
|
||||
self.hue_delta = hue_delta
|
||||
@ -259,7 +263,8 @@ class Passion_fruit:
|
||||
def find_largest_component(self, mask):
|
||||
if mask is None or mask.size == 0 or np.all(mask == 0):
|
||||
logging.warning("RGB 图像为空或全黑,返回一个全黑RGB图像。")
|
||||
return np.zeros((100, 100, 3), dtype=np.uint8) if mask is None else np.zeros_like(mask)
|
||||
return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
|
||||
if mask is None else np.zeros_like(mask)
|
||||
# 寻找最大连通组件
|
||||
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S)
|
||||
if num_labels < 2:
|
||||
@ -291,11 +296,13 @@ class Passion_fruit:
|
||||
# 检查 RGB 图像是否为空或全黑
|
||||
if rgb_img is None or rgb_img.size == 0 or np.all(rgb_img == 0):
|
||||
logging.warning("RGB 图像为空或全黑,返回一个全黑RGB图像。")
|
||||
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
|
||||
return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
|
||||
if rgb_img is None else np.zeros_like(rgb_img)
|
||||
# 检查二值图像是否为空或全黑
|
||||
if bin_img is None or bin_img.size == 0 or np.all(bin_img == 0):
|
||||
logging.warning("二值图像为空或全黑,返回一个全黑RGB图像。")
|
||||
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
|
||||
return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \
|
||||
if bin_img is None else np.zeros_like(bin_img)
|
||||
# 转换二值图像为三通道
|
||||
try:
|
||||
bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR)
|
||||
@ -336,13 +343,15 @@ class Spec_predict(object):
|
||||
|
||||
def predict(self, data_x):
|
||||
'''
|
||||
对数据进行预测
|
||||
:param data_x: 波段选择后的数据
|
||||
:return: 预测结果二值化后的数据,0为背景,1为黄芪,2为杂质2,3为杂质1,4为甘草片,5为红芪
|
||||
预测数据
|
||||
:param data_x: 重塑为二维数组的数据
|
||||
:return: 预测结果——糖度
|
||||
'''
|
||||
data_x = data_x.reshape(1, -1)
|
||||
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
|
||||
data_x = data_x[:, selected_bands]
|
||||
# 对数据进行切片,筛选谱段
|
||||
#qt_test进行测试时如果读取的是(30,30,224)需要解开注释进行数据切片,筛选谱段
|
||||
# data_x = data_x[ :25, :, setting.selected_bands ]
|
||||
# 将筛选后的数据重塑为二维数组,每行代表一个样本
|
||||
data_x = data_x.reshape(-1, setting.n_spec_rows * setting.n_spec_cols * setting.n_spec_bands)
|
||||
data_y = self.model.predict(data_x)
|
||||
return data_y[0]
|
||||
|
||||
@ -508,8 +517,8 @@ class Data_processing:
|
||||
返回:
|
||||
float: 估算的西红柿体积
|
||||
"""
|
||||
a = ((long_axis / 425) * 6.3) / 2
|
||||
b = ((short_axis / 425) * 6.3) / 2
|
||||
a = (long_axis * setting.pixel_length_ratio) / 2
|
||||
b = (short_axis * setting.pixel_length_ratio) / 2
|
||||
volume = 4 / 3 * np.pi * a * b * b
|
||||
weight = round(volume * self.density)
|
||||
#重量单位为g
|
||||
@ -524,19 +533,16 @@ class Data_processing:
|
||||
tuple: (长径, 短径, 缺陷区域个数, 缺陷区域总像素, 处理后的图像)
|
||||
"""
|
||||
tomato = Tomato() # 创建 Tomato 类的实例
|
||||
# 设置 S-L 通道阈值并处理图像
|
||||
threshold_s_l = 180
|
||||
threshold_fore_g_r_t = 20
|
||||
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
|
||||
s_l = tomato.extract_s_l(img)
|
||||
thresholded_s_l = tomato.threshold_segmentation(s_l, threshold_s_l)
|
||||
thresholded_s_l = tomato.threshold_segmentation(s_l, setting.threshold_s_l)
|
||||
new_bin_img = tomato.largest_connected_component(thresholded_s_l)
|
||||
filled_img, defect = self.fill_holes(new_bin_img)
|
||||
# 绘制西红柿边缘并获取缺陷信息
|
||||
edge, mask = tomato.draw_tomato_edge(img, new_bin_img)
|
||||
org_defect = tomato.bitwise_and_rgb_with_binary(edge, new_bin_img)
|
||||
fore = tomato.bitwise_and_rgb_with_binary(img, mask)
|
||||
fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=threshold_fore_g_r_t)
|
||||
fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=setting.threshold_fore_g_r_t)
|
||||
res = cv2.bitwise_or(new_bin_img, fore_g_r_t)
|
||||
nogreen = tomato.bitwise_and_rgb_with_binary(edge, res)
|
||||
# 统计白色像素点个数
|
||||
@ -554,8 +560,8 @@ class Data_processing:
|
||||
# cv2.imwrite('filled_img.jpg',filled_img)
|
||||
# 将处理后的图像转换为 RGB 格式
|
||||
rp = cv2.cvtColor(nogreen, cv2.COLOR_BGR2RGB)
|
||||
#直径单位为cm,所以需要除以10
|
||||
diameter = (long_axis + short_axis) /425 * 63 / 2 / 10
|
||||
#直径单位为cm
|
||||
diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2
|
||||
# print(f'直径:{diameter}')
|
||||
# 如果直径小于3,判断为空果拖异常图,则将所有值重置为0
|
||||
if diameter < 2.5:
|
||||
@ -563,16 +569,17 @@ class Data_processing:
|
||||
green_percentage = 0
|
||||
number_defects = 0
|
||||
total_pixels = 0
|
||||
rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB)
|
||||
rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
|
||||
dtype=np.uint8), cv2.COLOR_BGR2RGB)
|
||||
return diameter, green_percentage, number_defects, total_pixels, rp
|
||||
|
||||
def analyze_passion_fruit(self, img, hue_value=37, hue_delta=10, value_target=25, value_delta=10):
|
||||
def analyze_passion_fruit(self, img):
|
||||
if img is None:
|
||||
logging.error("Error: 无图像数据.")
|
||||
return None
|
||||
|
||||
# 创建PassionFruit类的实例
|
||||
pf = Passion_fruit(hue_value=hue_value, hue_delta=hue_delta, value_target=value_target, value_delta=value_delta)
|
||||
pf = Passion_fruit()
|
||||
|
||||
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
|
||||
hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
@ -596,15 +603,16 @@ class Data_processing:
|
||||
edge = pf.draw_contours_on_image(img, contour_mask)
|
||||
org_defect = pf.bitwise_and_rgb_with_binary(edge, max_mask)
|
||||
rp = cv2.cvtColor(org_defect, cv2.COLOR_BGR2RGB)
|
||||
#直径单位为cm,所以需要除以10
|
||||
diameter = (long_axis + short_axis) /425 * 63 / 2 / 10
|
||||
#直径单位为cm
|
||||
diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2
|
||||
# print(f'直径:{diameter}')
|
||||
if diameter < 2.5:
|
||||
diameter = 0
|
||||
green_percentage = 0
|
||||
weight = 0
|
||||
number_defects = 0
|
||||
total_pixels = 0
|
||||
rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB)
|
||||
rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
|
||||
dtype=np.uint8), cv2.COLOR_BGR2RGB)
|
||||
return diameter, weight, number_defects, total_pixels, rp
|
||||
|
||||
def process_data(seif, cmd: str, images: list, spec: any, pipe: Pipe, detector: Spec_predict) -> bool:
|
||||
|
||||
@ -4,9 +4,50 @@
|
||||
# @File : config.py
|
||||
# @Software: PyCharm
|
||||
|
||||
|
||||
from root_dir import ROOT_DIR
|
||||
|
||||
class Config:
|
||||
#文件相关参数
|
||||
#预热参数
|
||||
n_rows, n_cols, n_channels = 256, 256, 3
|
||||
n_spec_rows, n_spec_cols, n_spec_bands = 25, 30, 13
|
||||
n_rgb_rows, n_rgb_cols, n_rgb_bands = 613, 800, 3
|
||||
tomato_img_dir = ROOT_DIR / 'models' / 'TO.bmp'
|
||||
passion_fruit_img_dir = ROOT_DIR / 'models' / 'PF.bmp'
|
||||
#模型路径
|
||||
#糖度模型
|
||||
brix_model_path = ROOT_DIR / 'models' / 'passion_fruit.joblib'
|
||||
#图像分类模型
|
||||
imgclassifier_model_path = ROOT_DIR / 'models' / 'imgclassifier.joblib'
|
||||
imgclassifier_class_indices_path = ROOT_DIR / 'models' / 'class_indices.json'
|
||||
|
||||
|
||||
#classifer.py参数
|
||||
#tomato
|
||||
find_reflection_threshold = 190
|
||||
extract_g_r_factor = 1.5
|
||||
|
||||
#passion_fruit
|
||||
hue_value = 37
|
||||
hue_delta = 10
|
||||
value_target = 25
|
||||
value_delta = 10
|
||||
|
||||
#spec_predict
|
||||
#筛选谱段并未使用,在qt取数据时已经筛选
|
||||
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
|
||||
|
||||
#data_processing
|
||||
#根据标定数据计算的参数,实际长度/像素长度,单位cm
|
||||
pixel_length_ratio = 6.3/425
|
||||
#绿叶面积阈值,高于此阈值认为连通域是绿叶
|
||||
area_threshold = 20000
|
||||
#百香果密度(g/cm^3)
|
||||
density = 0.652228972
|
||||
#百香果面积比例,每个像素代表的实际面积(cm^2)
|
||||
area_ratio = 0.00021973702422145334
|
||||
|
||||
#def analyze_tomato
|
||||
#s_l通道阈值
|
||||
threshold_s_l = 180
|
||||
threshold_fore_g_r_t = 20
|
||||
|
||||
|
||||
@ -14,11 +14,11 @@ import logging
|
||||
from utils import Pipe
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
from config import Config
|
||||
|
||||
|
||||
def main(is_debug=False):
|
||||
setting = Config()
|
||||
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'), encoding='utf-8')
|
||||
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
@ -27,15 +27,18 @@ def main(is_debug=False):
|
||||
handlers=[file_handler, console_handler],
|
||||
level=logging.DEBUG)
|
||||
#模型加载
|
||||
detector = Spec_predict(ROOT_DIR/'models'/'passion_fruit_2.joblib')
|
||||
# classifier = ImageClassifier(ROOT_DIR/'models'/'resnet34_0619.pth', ROOT_DIR/'models'/'class_indices.json')
|
||||
detector = Spec_predict()
|
||||
detector.load(path=setting.brix_model_path)
|
||||
# classifier = ImageClassifier(model_path=setting.imgclassifier_model_path,
|
||||
# class_indices_path=setting.imgclassifier_class_indices_path)
|
||||
dp = Data_processing()
|
||||
print('系统初始化中...')
|
||||
#模型预热
|
||||
_ = detector.predict(np.ones((30, 30, 224), dtype=np.uint16))
|
||||
# _ = classifier.predict(np.ones((224, 224, 3), dtype=np.uint8))
|
||||
# _, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp'))
|
||||
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp'))
|
||||
#与qt_test测试时需要注释掉预热,模型接收尺寸为(25,30,13),qt_test发送的数据为(30,30,224),需要对数据进行切片(classifer.py第352行)
|
||||
_ = detector.predict(np.ones((setting.n_spec_rows, setting.n_spec_cols, setting.n_spec_bands), dtype=np.uint16))
|
||||
# _ = classifier.predict(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8))
|
||||
# _, _, _, _, _ =dp.analyze_tomato(cv2.imread(str(setting.tomato_img_dir)))
|
||||
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(str(setting.passion_fruit_img_dir))
|
||||
print('系统初始化完成')
|
||||
|
||||
rgb_receive_name = r'\\.\pipe\rgb_receive'
|
||||
|
||||
BIN
20240529RGBtest3/models/PF.bmp
Normal file
BIN
20240529RGBtest3/models/PF.bmp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.5 MiB |
BIN
20240529RGBtest3/models/TO.bmp
Normal file
BIN
20240529RGBtest3/models/TO.bmp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.5 MiB |
4
20240529RGBtest3/models/class_indices.json
Normal file
4
20240529RGBtest3/models/class_indices.json
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"0": "exist",
|
||||
"1": "no_exist"
|
||||
}
|
||||
BIN
20240529RGBtest3/models/passion_fruit.joblib
Normal file
BIN
20240529RGBtest3/models/passion_fruit.joblib
Normal file
Binary file not shown.
Binary file not shown.
@ -1,4 +1,3 @@
|
||||
|
||||
import pathlib
|
||||
|
||||
file_path = pathlib.Path(__file__)
|
||||
|
||||
@ -12,6 +12,7 @@ import win32pipe
|
||||
import time
|
||||
import logging
|
||||
import numpy as np
|
||||
from config import Config as setting
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
@ -204,13 +205,15 @@ class Pipe:
|
||||
gp = int(green_percentage * 100).to_bytes(1, byteorder='big')
|
||||
weight = 0
|
||||
weight = weight.to_bytes(1, byteorder='big')
|
||||
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes
|
||||
send_message = (length + cmd_re + brix + gp + diameter + weight +
|
||||
defect_num + total_defect_area + height + width + img_bytes)
|
||||
elif cmd == 'PF':
|
||||
brix = int(brix * 1000).to_bytes(2, byteorder='big')
|
||||
gp = 0
|
||||
gp = gp.to_bytes(1, byteorder='big')
|
||||
weight = weight.to_bytes(1, byteorder='big')
|
||||
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes
|
||||
send_message = (length + cmd_re + brix + gp + diameter + weight +
|
||||
defect_num + total_defect_area + height + width + img_bytes)
|
||||
elif cmd == 'KO':
|
||||
brix = 0
|
||||
brix = brix.to_bytes(2, byteorder='big')
|
||||
@ -222,13 +225,15 @@ class Pipe:
|
||||
defect_num = defect_num.to_bytes(2, byteorder='big')
|
||||
total_defect_area = 0
|
||||
total_defect_area = total_defect_area.to_bytes(4, byteorder='big')
|
||||
height = 100
|
||||
height = setting.n_rgb_rows
|
||||
height = height.to_bytes(2, byteorder='big')
|
||||
width = 100
|
||||
width = setting.n_rgb_cols
|
||||
width = width.to_bytes(2, byteorder='big')
|
||||
img_bytes = np.zeros((100, 100, 3), dtype=np.uint8).tobytes()
|
||||
img_bytes = np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands),
|
||||
dtype=np.uint8).tobytes()
|
||||
length = (18).to_bytes(4, byteorder='big')
|
||||
send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes
|
||||
send_message = (length + cmd_re + brix + gp + diameter + weight +
|
||||
defect_num + total_defect_area + height + width + img_bytes)
|
||||
try:
|
||||
win32file.WriteFile(self.rgb_send, send_message)
|
||||
# time.sleep(0.01)
|
||||
|
||||
@ -4,25 +4,40 @@ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
from spec_read import all_spectral_data
|
||||
import joblib
|
||||
|
||||
# def prepare_data(data):
|
||||
# """Reshape data and select specified spectral bands."""
|
||||
# reshaped_data = data.reshape(100, -1)
|
||||
# selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
|
||||
# return reshaped_data[:, selected_bands]
|
||||
|
||||
def prepare_data(data):
|
||||
"""Reshape data and select specified spectral bands."""
|
||||
reshaped_data = data.reshape(100, -1)
|
||||
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
|
||||
return reshaped_data[:, selected_bands]
|
||||
# 筛选特定的波段
|
||||
data_selected = data[:, :25, :, selected_bands]
|
||||
print(f'筛选后的数据尺寸:{data_selected.shape}')
|
||||
# 将筛选后的数据重塑为二维数组,每行代表一个样本
|
||||
reshaped_data = data_selected.reshape(-1, 25 * 30 * len(selected_bands))
|
||||
return reshaped_data
|
||||
|
||||
|
||||
def split_data(X, y, test_size=0.20, random_state=12):
|
||||
"""Split data into training and test sets."""
|
||||
return train_test_split(X, y, test_size=test_size, random_state=random_state)
|
||||
|
||||
|
||||
def evaluate_model(model, X_test, y_test):
|
||||
"""Evaluate the model and return MSE and predictions."""
|
||||
"""Evaluate the model and return multiple metrics and predictions."""
|
||||
y_pred = model.predict(X_test)
|
||||
mse = mean_squared_error(y_test, y_pred)
|
||||
return mse, y_pred
|
||||
mae = mean_absolute_error(y_test, y_pred)
|
||||
r2 = r2_score(y_test, y_pred)
|
||||
return mse, mae, r2, y_pred
|
||||
|
||||
|
||||
def print_predictions(y_test, y_pred, model_name):
|
||||
"""Print actual and predicted values."""
|
||||
@ -44,6 +59,7 @@ def main():
|
||||
])
|
||||
|
||||
X = prepare_data(all_spectral_data)
|
||||
print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}')
|
||||
X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity)
|
||||
|
||||
models = {
|
||||
@ -55,13 +71,17 @@ def main():
|
||||
for model_name, model in models.items():
|
||||
model.fit(X_train, y_train)
|
||||
if model_name == "RandomForest":
|
||||
joblib.dump(model, r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib')
|
||||
joblib.dump(model,
|
||||
r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit.joblib')
|
||||
|
||||
mse, y_pred = evaluate_model(model, X_test, y_test)
|
||||
mse, mae, r2, y_pred = evaluate_model(model, X_test, y_test)
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Mean Squared Error on the test set: {mse}")
|
||||
print(f"MSE on the test set: {mse}")
|
||||
print(f"MAE on the test set: {mae}")
|
||||
print(f"R² score on the test set: {r2}")
|
||||
print_predictions(y_test, y_pred, model_name)
|
||||
print("\n" + "-" * 50 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -60,7 +60,7 @@ def predict(model, data):
|
||||
|
||||
def main():
|
||||
# 加载模型
|
||||
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib')
|
||||
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib')
|
||||
|
||||
# 读取数据
|
||||
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user