diff --git a/20240529RGBtest3/brix_model_train&test/model_train.py b/20240529RGBtest3/brix_model_train&test/model_train.py new file mode 100644 index 0000000..553f6c2 --- /dev/null +++ b/20240529RGBtest3/brix_model_train&test/model_train.py @@ -0,0 +1,111 @@ +import numpy as np +import matplotlib.pyplot as plt +from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor +from sklearn.svm import SVR +from sklearn.neighbors import KNeighborsRegressor +from sklearn.model_selection import train_test_split, GridSearchCV +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +from spec_read import all_spectral_data +import joblib + + +def prepare_data(data): + """Reshape data and select specified spectral bands.""" + selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145] + # 筛选特定的波段 + data_selected = data[:, :, :, selected_bands] + # 将筛选后的数据重塑为二维数组,每行代表一个样本 + reshaped_data = data_selected.reshape(-1, 30 * 30 * len(selected_bands)) + return reshaped_data + + +def split_data(X, y, test_size=0.20, random_state=12): + """Split data into training and test sets.""" + return train_test_split(X, y, test_size=test_size, random_state=random_state) + + +def evaluate_model(model, X_test, y_test): + """Evaluate the model and return multiple metrics and predictions.""" + y_pred = model.predict(X_test) + mse = mean_squared_error(y_test, y_pred) + mae = mean_absolute_error(y_test, y_pred) + r2 = r2_score(y_test, y_pred) + return mse, mae, r2, y_pred + + +def print_predictions(y_test, y_pred, model_name): + """Print actual and predicted values.""" + print(f"Test Set Predictions for {model_name}:") + for i, (real, pred) in enumerate(zip(y_test, y_pred)): + print(f"Sample {i + 1}: True Value = {real:.2f}, Predicted Value = {pred:.2f}") + +def main(): + sweetness_acidity = np.array([ + 16.2, 16.1, 17, 16.9, 16.8, 17.8, 18.1, 17.2, 17, 17.2, 17.1, 17.2, + 17.2, 17.2, 18.1, 17, 17.6, 17.4, 17.1, 17.1, 16.9, 17.6, 17.3, 16.3, + 16.5, 18.7, 17.6, 16.2, 16.8, 17.2, 16.8, 17.3, 16, 16.6, 16.7, 16.7, + 17.3, 16.3, 16.8, 17.4, 17.3, 16.3, 16.1, 17.2, 18.6, 16.8, 16.1, 17.2, + 18.3, 16.5, 16.6, 17, 17, 17.8, 16.4, 18, 17.7, 17, 18.3, 16.8, 17.5, + 17.7, 18.5, 18, 17.7, 17, 18.3, 18.1, 17.4, 17.7, 17.8, 16.3, 17.1, 16.8, + 17.2, 17.5, 16.6, 17.7, 17.1, 17.7, 19.4, 20.3, 17.3, 15.8, 18, 17.7, + 17.2, 15.2, 18, 18.4, 18.3, 15.7, 17.2, 18.6, 15.6, 17, 16.9, 17.4, 17.8, + 16.5 + ]) + + X = prepare_data(all_spectral_data) + print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}') + X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity) + + models_params = { + "RandomForest": { + 'model': RandomForestRegressor(), + 'params': { + 'n_estimators': [100, 200, 300], + 'max_depth': [None, 10, 20], + 'min_samples_split': [2, 5], + 'min_samples_leaf': [1, 2], + 'random_state': [42] + } + }, + "GradientBoosting": { + 'model': GradientBoostingRegressor(), + 'params': { + 'n_estimators': [100, 200, 300], + 'learning_rate': [0.01, 0.1, 0.2], + 'max_depth': [3, 5, 7], + 'min_samples_split': [2, 5], + 'min_samples_leaf': [1, 2], + 'random_state': [42] + } + }, + "SVR": { + 'model': SVR(), + 'params': { + 'C': [0.1, 1, 10, 100], + 'gamma': ['scale', 'auto', 0.01, 0.1], + 'epsilon': [0.01, 0.1, 0.5] + } + } + } + + best_models = {} + + for model_name, mp in models_params.items(): + grid_search = GridSearchCV(mp['model'], mp['params'], cv=5, scoring='r2', verbose=2) + grid_search.fit(X_train, y_train) + best_models[model_name] = grid_search.best_estimator_ + mse, mae, r2, y_pred = evaluate_model(grid_search.best_estimator_, X_test, y_test) + print(f"Best {model_name} parameters: {grid_search.best_params_}") + print(f"Model: {model_name}") + print(f"MSE on the test set: {mse}") + print(f"MAE on the test set: {mae}") + print(f"R² score on the test set: {r2}") + print_predictions(y_test, y_pred, model_name) + print("\n" + "-" * 50 + "\n") + + # Optionally save the best model for each type + joblib.dump(grid_search.best_estimator_, f'{model_name}_best_model.joblib') + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/20240529RGBtest3/brix_model_train&test/predict.py b/20240529RGBtest3/brix_model_train&test/predict.py new file mode 100644 index 0000000..604759e --- /dev/null +++ b/20240529RGBtest3/brix_model_train&test/predict.py @@ -0,0 +1,87 @@ +import joblib +import numpy as np +import os +from model_train import prepare_data + +def read_spectral_data(hdr_path, raw_path): + # Read HDR file for image dimensions information + with open(hdr_path, 'r', encoding='latin1') as hdr_file: + lines = hdr_file.readlines() + height = width = bands = 0 + for line in lines: + if line.startswith('lines'): + height = int(line.split()[-1]) + elif line.startswith('samples'): + width = int(line.split()[-1]) + elif line.startswith('bands'): + bands = int(line.split()[-1]) + + # Read spectral data from RAW file + raw_image = np.fromfile(raw_path, dtype='uint16') + # Initialize the image with the actual read dimensions + formatImage = np.zeros((height, width, bands)) + + for row in range(height): + for dim in range(bands): + formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width] + + # Ensure the image is 30x30x224 by cropping or padding + target_height, target_width, target_bands = 30, 30, 224 + # Crop or pad height + if height > target_height: + formatImage = formatImage[:target_height, :, :] + elif height < target_height: + pad_height = target_height - height + formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0) + + # Crop or pad width + if width > target_width: + formatImage = formatImage[:, :target_width, :] + elif width < target_width: + pad_width = target_width - width + formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0) + + # Crop or pad bands if necessary (usually bands should not change) + if bands > target_bands: + formatImage = formatImage[:, :, :target_bands] + elif bands < target_bands: + pad_bands = target_bands - bands + formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0) + + return formatImage + +def load_model(model_path): + """加载模型""" + return joblib.load(model_path) + +def predict(model, data): + """预测数据""" + return model.predict(data) + +def main(): + # 加载模型 + model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib') + + # 读取数据 + directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030' + all_spectral_data = [] + for i in range(1, 101): + hdr_path = os.path.join(directory, f'{i}.HDR') + raw_path = os.path.join(directory, f'{i}') + spectral_data = read_spectral_data(hdr_path, raw_path) + all_spectral_data.append(spectral_data) + all_spectral_data = np.stack(all_spectral_data) + print(all_spectral_data.shape) + + # 预处理数据 + data_prepared = prepare_data(all_spectral_data) + print(data_prepared.shape) + + # 预测数据 + predictions = predict(model, data_prepared) + + # 打印预测结果 + print(predictions) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/20240529RGBtest3/brix_model_train&test/spec_read.py b/20240529RGBtest3/brix_model_train&test/spec_read.py new file mode 100644 index 0000000..4a764ce --- /dev/null +++ b/20240529RGBtest3/brix_model_train&test/spec_read.py @@ -0,0 +1,70 @@ +import numpy as np +import os + + +def read_spectral_data(hdr_path, raw_path): + # Read HDR file for image dimensions information + with open(hdr_path, 'r', encoding='latin1') as hdr_file: + lines = hdr_file.readlines() + height = width = bands = 0 + for line in lines: + if line.startswith('lines'): + height = int(line.split()[-1]) + elif line.startswith('samples'): + width = int(line.split()[-1]) + elif line.startswith('bands'): + bands = int(line.split()[-1]) + + # Read spectral data from RAW file + raw_image = np.fromfile(raw_path, dtype='uint16') + # Initialize the image with the actual read dimensions + formatImage = np.zeros((height, width, bands)) + + for row in range(height): + for dim in range(bands): + formatImage[row, :, dim] = raw_image[(dim + row * bands) * width:(dim + 1 + row * bands) * width] + + # Ensure the image is 30x30x224 by cropping or padding + target_height, target_width, target_bands = 30, 30, 224 + # Crop or pad height + if height > target_height: + formatImage = formatImage[:target_height, :, :] + elif height < target_height: + pad_height = target_height - height + formatImage = np.pad(formatImage, ((0, pad_height), (0, 0), (0, 0)), mode='constant', constant_values=0) + + # Crop or pad width + if width > target_width: + formatImage = formatImage[:, :target_width, :] + elif width < target_width: + pad_width = target_width - width + formatImage = np.pad(formatImage, ((0, 0), (0, pad_width), (0, 0)), mode='constant', constant_values=0) + + # Crop or pad bands if necessary (usually bands should not change) + if bands > target_bands: + formatImage = formatImage[:, :, :target_bands] + elif bands < target_bands: + pad_bands = target_bands - bands + formatImage = np.pad(formatImage, ((0, 0), (0, 0), (0, pad_bands)), mode='constant', constant_values=0) + + return formatImage + + +# Specify the directory containing the HDR and RAW files +directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030' + +# Initialize a list to hold all the spectral data arrays +all_spectral_data = [] + +# Loop through each data set (assuming there are 40 datasets) +for i in range(1, 101): + hdr_path = os.path.join(directory, f'{i}.HDR') + raw_path = os.path.join(directory, f'{i}') + + # Read data + spectral_data = read_spectral_data(hdr_path, raw_path) + all_spectral_data.append(spectral_data) + +# Stack all data into a single numpy array +all_spectral_data = np.stack(all_spectral_data) +print(all_spectral_data.shape) # This should print (40, 30, 30, 224) \ No newline at end of file diff --git a/20240529RGBtest3/classifer.py b/20240529RGBtest3/classifer.py index 28ac842..9f05ff6 100644 --- a/20240529RGBtest3/classifer.py +++ b/20240529RGBtest3/classifer.py @@ -14,6 +14,7 @@ import random import numpy as np from PIL import Image from utils import Pipe +from config import Config as setting from sklearn.ensemble import RandomForestRegressor #图像分类网络所需库,实际并未使用分类网络 # import torch @@ -23,8 +24,10 @@ from sklearn.ensemble import RandomForestRegressor #番茄RGB处理模型 class Tomato: - def __init__(self): + def __init__(self, find_reflection_threshold=setting.find_reflection_threshold, extract_g_r_factor=setting.extract_g_r_factor): ''' 初始化 Tomato 类。''' + self.find_reflection_threshold = find_reflection_threshold + self.extract_g_r_factor = extract_g_r_factor pass def extract_s_l(self, image): @@ -40,14 +43,14 @@ class Tomato: result = cv2.add(s_channel, l_channel) return result - def find_reflection(self, image, threshold=190): + def find_reflection(self, image): ''' 通过阈值处理识别图像中的反射区域。 :param image: 输入的单通道图像 :param threshold: 用于二值化的阈值 :return: 二值化后的图像,高于阈值的部分为白色,其余为黑色 ''' - _, reflection = cv2.threshold(image, threshold, 255, cv2.THRESH_BINARY) + _, reflection = cv2.threshold(image, self.find_reflection_threshold, 255, cv2.THRESH_BINARY) return reflection def otsu_threshold(self, image): @@ -67,7 +70,7 @@ class Tomato: ''' g_channel = image[:, :, 1] r_channel = image[:, :, 2] - result = cv2.subtract(cv2.multiply(g_channel, 1.5), r_channel) + result = cv2.subtract(cv2.multiply(g_channel, self.extract_g_r_factor), r_channel) return result def extract_r_b(self, image): @@ -226,7 +229,8 @@ class Tomato: #百香果RGB处理模型 class Passion_fruit: - def __init__(self, hue_value=37, hue_delta=10, value_target=25, value_delta=10): + def __init__(self, hue_value=setting.hue_value, hue_delta=setting.hue_delta, + value_target=setting.value_target, value_delta=setting.value_delta): # 初始化常用参数 self.hue_value = hue_value self.hue_delta = hue_delta @@ -259,7 +263,8 @@ class Passion_fruit: def find_largest_component(self, mask): if mask is None or mask.size == 0 or np.all(mask == 0): logging.warning("RGB 图像为空或全黑,返回一个全黑RGB图像。") - return np.zeros((100, 100, 3), dtype=np.uint8) if mask is None else np.zeros_like(mask) + return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \ + if mask is None else np.zeros_like(mask) # 寻找最大连通组件 num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S) if num_labels < 2: @@ -291,11 +296,13 @@ class Passion_fruit: # 检查 RGB 图像是否为空或全黑 if rgb_img is None or rgb_img.size == 0 or np.all(rgb_img == 0): logging.warning("RGB 图像为空或全黑,返回一个全黑RGB图像。") - return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img) + return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \ + if rgb_img is None else np.zeros_like(rgb_img) # 检查二值图像是否为空或全黑 if bin_img is None or bin_img.size == 0 or np.all(bin_img == 0): logging.warning("二值图像为空或全黑,返回一个全黑RGB图像。") - return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img) + return np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8) \ + if bin_img is None else np.zeros_like(bin_img) # 转换二值图像为三通道 try: bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR) @@ -336,13 +343,15 @@ class Spec_predict(object): def predict(self, data_x): ''' - 对数据进行预测 - :param data_x: 波段选择后的数据 - :return: 预测结果二值化后的数据,0为背景,1为黄芪,2为杂质2,3为杂质1,4为甘草片,5为红芪 + 预测数据 + :param data_x: 重塑为二维数组的数据 + :return: 预测结果——糖度 ''' - data_x = data_x.reshape(1, -1) - selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145] - data_x = data_x[:, selected_bands] + # 对数据进行切片,筛选谱段 + #qt_test进行测试时如果读取的是(30,30,224)需要解开注释进行数据切片,筛选谱段 + # data_x = data_x[ :25, :, setting.selected_bands ] + # 将筛选后的数据重塑为二维数组,每行代表一个样本 + data_x = data_x.reshape(-1, setting.n_spec_rows * setting.n_spec_cols * setting.n_spec_bands) data_y = self.model.predict(data_x) return data_y[0] @@ -508,8 +517,8 @@ class Data_processing: 返回: float: 估算的西红柿体积 """ - a = ((long_axis / 425) * 6.3) / 2 - b = ((short_axis / 425) * 6.3) / 2 + a = (long_axis * setting.pixel_length_ratio) / 2 + b = (short_axis * setting.pixel_length_ratio) / 2 volume = 4 / 3 * np.pi * a * b * b weight = round(volume * self.density) #重量单位为g @@ -524,19 +533,16 @@ class Data_processing: tuple: (长径, 短径, 缺陷区域个数, 缺陷区域总像素, 处理后的图像) """ tomato = Tomato() # 创建 Tomato 类的实例 - # 设置 S-L 通道阈值并处理图像 - threshold_s_l = 180 - threshold_fore_g_r_t = 20 img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR) s_l = tomato.extract_s_l(img) - thresholded_s_l = tomato.threshold_segmentation(s_l, threshold_s_l) + thresholded_s_l = tomato.threshold_segmentation(s_l, setting.threshold_s_l) new_bin_img = tomato.largest_connected_component(thresholded_s_l) filled_img, defect = self.fill_holes(new_bin_img) # 绘制西红柿边缘并获取缺陷信息 edge, mask = tomato.draw_tomato_edge(img, new_bin_img) org_defect = tomato.bitwise_and_rgb_with_binary(edge, new_bin_img) fore = tomato.bitwise_and_rgb_with_binary(img, mask) - fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=threshold_fore_g_r_t) + fore_g_r_t = tomato.threshold_segmentation(tomato.extract_g_r(fore), threshold=setting.threshold_fore_g_r_t) res = cv2.bitwise_or(new_bin_img, fore_g_r_t) nogreen = tomato.bitwise_and_rgb_with_binary(edge, res) # 统计白色像素点个数 @@ -554,8 +560,8 @@ class Data_processing: # cv2.imwrite('filled_img.jpg',filled_img) # 将处理后的图像转换为 RGB 格式 rp = cv2.cvtColor(nogreen, cv2.COLOR_BGR2RGB) - #直径单位为cm,所以需要除以10 - diameter = (long_axis + short_axis) /425 * 63 / 2 / 10 + #直径单位为cm + diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2 # print(f'直径:{diameter}') # 如果直径小于3,判断为空果拖异常图,则将所有值重置为0 if diameter < 2.5: @@ -563,16 +569,17 @@ class Data_processing: green_percentage = 0 number_defects = 0 total_pixels = 0 - rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB) + rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), + dtype=np.uint8), cv2.COLOR_BGR2RGB) return diameter, green_percentage, number_defects, total_pixels, rp - def analyze_passion_fruit(self, img, hue_value=37, hue_delta=10, value_target=25, value_delta=10): + def analyze_passion_fruit(self, img): if img is None: logging.error("Error: 无图像数据.") return None # 创建PassionFruit类的实例 - pf = Passion_fruit(hue_value=hue_value, hue_delta=hue_delta, value_target=value_target, value_delta=value_delta) + pf = Passion_fruit() img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR) hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) @@ -596,15 +603,16 @@ class Data_processing: edge = pf.draw_contours_on_image(img, contour_mask) org_defect = pf.bitwise_and_rgb_with_binary(edge, max_mask) rp = cv2.cvtColor(org_defect, cv2.COLOR_BGR2RGB) - #直径单位为cm,所以需要除以10 - diameter = (long_axis + short_axis) /425 * 63 / 2 / 10 + #直径单位为cm + diameter = (long_axis + short_axis) * setting.pixel_length_ratio / 2 # print(f'直径:{diameter}') if diameter < 2.5: diameter = 0 - green_percentage = 0 + weight = 0 number_defects = 0 total_pixels = 0 - rp = cv2.cvtColor(np.ones((613, 800, 3), dtype=np.uint8), cv2.COLOR_BGR2RGB) + rp = cv2.cvtColor(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), + dtype=np.uint8), cv2.COLOR_BGR2RGB) return diameter, weight, number_defects, total_pixels, rp def process_data(seif, cmd: str, images: list, spec: any, pipe: Pipe, detector: Spec_predict) -> bool: diff --git a/20240529RGBtest3/config.py b/20240529RGBtest3/config.py index 9575cd0..c71f3e0 100644 --- a/20240529RGBtest3/config.py +++ b/20240529RGBtest3/config.py @@ -4,9 +4,50 @@ # @File : config.py # @Software: PyCharm - +from root_dir import ROOT_DIR class Config: #文件相关参数 #预热参数 - n_rows, n_cols, n_channels = 256, 256, 3 \ No newline at end of file + n_spec_rows, n_spec_cols, n_spec_bands = 25, 30, 13 + n_rgb_rows, n_rgb_cols, n_rgb_bands = 613, 800, 3 + tomato_img_dir = ROOT_DIR / 'models' / 'TO.bmp' + passion_fruit_img_dir = ROOT_DIR / 'models' / 'PF.bmp' + #模型路径 + #糖度模型 + brix_model_path = ROOT_DIR / 'models' / 'passion_fruit.joblib' + #图像分类模型 + imgclassifier_model_path = ROOT_DIR / 'models' / 'imgclassifier.joblib' + imgclassifier_class_indices_path = ROOT_DIR / 'models' / 'class_indices.json' + + + #classifer.py参数 + #tomato + find_reflection_threshold = 190 + extract_g_r_factor = 1.5 + + #passion_fruit + hue_value = 37 + hue_delta = 10 + value_target = 25 + value_delta = 10 + + #spec_predict + #筛选谱段并未使用,在qt取数据时已经筛选 + selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145] + + #data_processing + #根据标定数据计算的参数,实际长度/像素长度,单位cm + pixel_length_ratio = 6.3/425 + #绿叶面积阈值,高于此阈值认为连通域是绿叶 + area_threshold = 20000 + #百香果密度(g/cm^3) + density = 0.652228972 + #百香果面积比例,每个像素代表的实际面积(cm^2) + area_ratio = 0.00021973702422145334 + + #def analyze_tomato + #s_l通道阈值 + threshold_s_l = 180 + threshold_fore_g_r_t = 20 + diff --git a/20240529RGBtest3/main.py b/20240529RGBtest3/main.py index ca53002..fce8333 100644 --- a/20240529RGBtest3/main.py +++ b/20240529RGBtest3/main.py @@ -14,11 +14,11 @@ import logging from utils import Pipe import numpy as np import time - - +from config import Config def main(is_debug=False): + setting = Config() file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'), encoding='utf-8') file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) console_handler = logging.StreamHandler(sys.stdout) @@ -27,15 +27,18 @@ def main(is_debug=False): handlers=[file_handler, console_handler], level=logging.DEBUG) #模型加载 - detector = Spec_predict(ROOT_DIR/'models'/'passion_fruit_2.joblib') - # classifier = ImageClassifier(ROOT_DIR/'models'/'resnet34_0619.pth', ROOT_DIR/'models'/'class_indices.json') + detector = Spec_predict() + detector.load(path=setting.brix_model_path) + # classifier = ImageClassifier(model_path=setting.imgclassifier_model_path, + # class_indices_path=setting.imgclassifier_class_indices_path) dp = Data_processing() print('系统初始化中...') #模型预热 - _ = detector.predict(np.ones((30, 30, 224), dtype=np.uint16)) - # _ = classifier.predict(np.ones((224, 224, 3), dtype=np.uint8)) - # _, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp')) - # _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp')) + #与qt_test测试时需要注释掉预热,模型接收尺寸为(25,30,13),qt_test发送的数据为(30,30,224),需要对数据进行切片(classifer.py第352行) + _ = detector.predict(np.ones((setting.n_spec_rows, setting.n_spec_cols, setting.n_spec_bands), dtype=np.uint16)) + # _ = classifier.predict(np.ones((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), dtype=np.uint8)) + # _, _, _, _, _ =dp.analyze_tomato(cv2.imread(str(setting.tomato_img_dir))) + # _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(str(setting.passion_fruit_img_dir)) print('系统初始化完成') rgb_receive_name = r'\\.\pipe\rgb_receive' diff --git a/20240529RGBtest3/models/PF.bmp b/20240529RGBtest3/models/PF.bmp new file mode 100644 index 0000000..44fbe47 Binary files /dev/null and b/20240529RGBtest3/models/PF.bmp differ diff --git a/20240529RGBtest3/models/TO.bmp b/20240529RGBtest3/models/TO.bmp new file mode 100644 index 0000000..51d04e6 Binary files /dev/null and b/20240529RGBtest3/models/TO.bmp differ diff --git a/20240529RGBtest3/models/class_indices.json b/20240529RGBtest3/models/class_indices.json new file mode 100644 index 0000000..fd39595 --- /dev/null +++ b/20240529RGBtest3/models/class_indices.json @@ -0,0 +1,4 @@ +{ + "0": "exist", + "1": "no_exist" +} \ No newline at end of file diff --git a/20240529RGBtest3/models/passion_fruit.joblib b/20240529RGBtest3/models/passion_fruit.joblib new file mode 100644 index 0000000..604c96c Binary files /dev/null and b/20240529RGBtest3/models/passion_fruit.joblib differ diff --git a/20240529RGBtest3/models/passion_fruit_2.joblib b/20240529RGBtest3/models/passion_fruit_2.joblib deleted file mode 100644 index 3595d37..0000000 Binary files a/20240529RGBtest3/models/passion_fruit_2.joblib and /dev/null differ diff --git a/20240529RGBtest3/root_dir.py b/20240529RGBtest3/root_dir.py index 754b950..a3d55ab 100644 --- a/20240529RGBtest3/root_dir.py +++ b/20240529RGBtest3/root_dir.py @@ -1,4 +1,3 @@ - import pathlib file_path = pathlib.Path(__file__) diff --git a/20240529RGBtest3/utils.py b/20240529RGBtest3/utils.py index fd7ecdc..0568b10 100644 --- a/20240529RGBtest3/utils.py +++ b/20240529RGBtest3/utils.py @@ -12,6 +12,7 @@ import win32pipe import time import logging import numpy as np +from config import Config as setting from PIL import Image import io @@ -204,13 +205,15 @@ class Pipe: gp = int(green_percentage * 100).to_bytes(1, byteorder='big') weight = 0 weight = weight.to_bytes(1, byteorder='big') - send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes + send_message = (length + cmd_re + brix + gp + diameter + weight + + defect_num + total_defect_area + height + width + img_bytes) elif cmd == 'PF': brix = int(brix * 1000).to_bytes(2, byteorder='big') gp = 0 gp = gp.to_bytes(1, byteorder='big') weight = weight.to_bytes(1, byteorder='big') - send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes + send_message = (length + cmd_re + brix + gp + diameter + weight + + defect_num + total_defect_area + height + width + img_bytes) elif cmd == 'KO': brix = 0 brix = brix.to_bytes(2, byteorder='big') @@ -222,13 +225,15 @@ class Pipe: defect_num = defect_num.to_bytes(2, byteorder='big') total_defect_area = 0 total_defect_area = total_defect_area.to_bytes(4, byteorder='big') - height = 100 + height = setting.n_rgb_rows height = height.to_bytes(2, byteorder='big') - width = 100 + width = setting.n_rgb_cols width = width.to_bytes(2, byteorder='big') - img_bytes = np.zeros((100, 100, 3), dtype=np.uint8).tobytes() + img_bytes = np.zeros((setting.n_rgb_rows, setting.n_rgb_cols, setting.n_rgb_bands), + dtype=np.uint8).tobytes() length = (18).to_bytes(4, byteorder='big') - send_message = length + cmd_re + brix + gp + diameter + weight + defect_num + total_defect_area + height + width + img_bytes + send_message = (length + cmd_re + brix + gp + diameter + weight + + defect_num + total_defect_area + height + width + img_bytes) try: win32file.WriteFile(self.rgb_send, send_message) # time.sleep(0.01) diff --git a/20240529RGBtest3/xs/dimensionality_reduction.py b/20240529RGBtest3/xs/dimensionality_reduction.py index e0d5f30..a5bbc1b 100644 --- a/20240529RGBtest3/xs/dimensionality_reduction.py +++ b/20240529RGBtest3/xs/dimensionality_reduction.py @@ -4,25 +4,40 @@ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor from sklearn.svm import SVR from sklearn.neighbors import KNeighborsRegressor from sklearn.model_selection import train_test_split -from sklearn.metrics import mean_squared_error +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score from spec_read import all_spectral_data import joblib +# def prepare_data(data): +# """Reshape data and select specified spectral bands.""" +# reshaped_data = data.reshape(100, -1) +# selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145] +# return reshaped_data[:, selected_bands] + def prepare_data(data): """Reshape data and select specified spectral bands.""" - reshaped_data = data.reshape(100, -1) selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145] - return reshaped_data[:, selected_bands] + # 筛选特定的波段 + data_selected = data[:, :25, :, selected_bands] + print(f'筛选后的数据尺寸:{data_selected.shape}') + # 将筛选后的数据重塑为二维数组,每行代表一个样本 + reshaped_data = data_selected.reshape(-1, 25 * 30 * len(selected_bands)) + return reshaped_data + def split_data(X, y, test_size=0.20, random_state=12): """Split data into training and test sets.""" return train_test_split(X, y, test_size=test_size, random_state=random_state) + def evaluate_model(model, X_test, y_test): - """Evaluate the model and return MSE and predictions.""" + """Evaluate the model and return multiple metrics and predictions.""" y_pred = model.predict(X_test) mse = mean_squared_error(y_test, y_pred) - return mse, y_pred + mae = mean_absolute_error(y_test, y_pred) + r2 = r2_score(y_test, y_pred) + return mse, mae, r2, y_pred + def print_predictions(y_test, y_pred, model_name): """Print actual and predicted values.""" @@ -44,6 +59,7 @@ def main(): ]) X = prepare_data(all_spectral_data) + print(f'原数据尺寸:{all_spectral_data.shape};训练数据尺寸:{X.shape}') X_train, X_test, y_train, y_test = split_data(X, sweetness_acidity) models = { @@ -55,13 +71,17 @@ def main(): for model_name, model in models.items(): model.fit(X_train, y_train) if model_name == "RandomForest": - joblib.dump(model, r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib') + joblib.dump(model, + r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit.joblib') - mse, y_pred = evaluate_model(model, X_test, y_test) + mse, mae, r2, y_pred = evaluate_model(model, X_test, y_test) print(f"Model: {model_name}") - print(f"Mean Squared Error on the test set: {mse}") + print(f"MSE on the test set: {mse}") + print(f"MAE on the test set: {mae}") + print(f"R² score on the test set: {r2}") print_predictions(y_test, y_pred, model_name) - print("\n" + "-"*50 + "\n") + print("\n" + "-" * 50 + "\n") + if __name__ == "__main__": main() \ No newline at end of file diff --git a/20240529RGBtest3/xs/predict.py b/20240529RGBtest3/xs/predict.py index 7e29ad8..953dda3 100644 --- a/20240529RGBtest3/xs/predict.py +++ b/20240529RGBtest3/xs/predict.py @@ -60,7 +60,7 @@ def predict(model, data): def main(): # 加载模型 - model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib') + model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_3.joblib') # 读取数据 directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'