mirror of
https://github.com/NanjingForestryUniversity/supermachine--tomato-passion_fruit.git
synced 2025-11-08 22:34:00 +00:00
106 lines
3.3 KiB
Python
106 lines
3.3 KiB
Python
import cv2
|
|
import numpy as np
|
|
from sklearn import svm
|
|
from sklearn.preprocessing import StandardScaler
|
|
import time
|
|
import os
|
|
import joblib
|
|
|
|
|
|
def load_model(model_path):
|
|
# 加载模型和标准化器
|
|
model, scaler = joblib.load(model_path)
|
|
return model, scaler
|
|
|
|
def predict_image_array(image_array, model_path):
|
|
# 加载模型和标准化器
|
|
model, scaler = load_model(model_path)
|
|
|
|
# 将图像转换为像素数组
|
|
test_pixels = image_array.reshape(-1, 3)
|
|
|
|
# 标准化
|
|
test_pixels_scaled = scaler.transform(test_pixels)
|
|
|
|
# 预测
|
|
predictions = model.predict(test_pixels_scaled)
|
|
|
|
# 转换预测结果为图像
|
|
mask_predicted = predictions.reshape(image_array.shape[0], image_array.shape[1])
|
|
|
|
return mask_predicted
|
|
def prepare_data(image_dir, mask_dir):
|
|
# 初始化像素和标签列表
|
|
all_pixels = []
|
|
all_labels = []
|
|
|
|
# 获取图像和掩码文件名列表
|
|
image_files = sorted(os.listdir(image_dir))
|
|
mask_files = sorted(os.listdir(mask_dir))
|
|
|
|
# 遍历所有图像和掩码文件
|
|
for image_file, mask_file in zip(image_files, mask_files):
|
|
# 读取原始图像和掩码图像
|
|
image = cv2.imread(os.path.join(image_dir, image_file))
|
|
mask = cv2.imread(os.path.join(mask_dir, mask_file), cv2.IMREAD_GRAYSCALE)
|
|
|
|
# 提取像素
|
|
pixels = image.reshape(-1, 3) # 将图像转换为(n_pixels, 3)
|
|
labels = (mask.reshape(-1) > 128).astype(int) # 标记为0或1
|
|
|
|
# 添加到列表
|
|
all_pixels.append(pixels)
|
|
all_labels.append(labels)
|
|
|
|
# 将列表转换为NumPy数组
|
|
all_pixels = np.concatenate(all_pixels, axis=0)
|
|
all_labels = np.concatenate(all_labels, axis=0)
|
|
|
|
return all_pixels, all_labels
|
|
|
|
# 加载数据
|
|
train_pixels, train_labels = prepare_data('/Users/xs/PycharmProjects/super-tomato/datasets_green/train-2/img',
|
|
'/Users/xs/PycharmProjects/super-tomato/datasets_green/train-2/label')
|
|
|
|
# 数据标准化
|
|
scaler = StandardScaler()
|
|
train_pixels_scaled = scaler.fit_transform(train_pixels)
|
|
|
|
# 创建SVM模型
|
|
# model = svm.SVC(kernel='linear', C=1.0)
|
|
# model.fit(train_pixels_scaled, train_labels)
|
|
# # 在训练模型后保存模型
|
|
# joblib.dump((model, scaler), '/Users/xs/PycharmProjects/super-tomato/svm_green.joblib') # 替换为你的模型文件路径
|
|
|
|
print('模型训练完成!')
|
|
|
|
def predict_image(image_path, model, scaler):
|
|
# 读取图像
|
|
image = cv2.imread(image_path)
|
|
test_pixels = image.reshape(-1, 3)
|
|
|
|
# 标准化
|
|
test_pixels_scaled = scaler.transform(test_pixels)
|
|
|
|
# 预测
|
|
predictions = model.predict(test_pixels_scaled)
|
|
|
|
# 转换预测结果为图像
|
|
mask_predicted = predictions.reshape(image.shape[0], image.shape[1])
|
|
|
|
return mask_predicted
|
|
|
|
|
|
# 对一个新的图像进行预测
|
|
time1 = time.time()
|
|
model, scaler = load_model('/Users/xs/PycharmProjects/super-tomato/svm_green.joblib')
|
|
|
|
predicted_mask = predict_image('/Users/xs/PycharmProjects/super-tomato/defect_big.bmp', model, scaler)
|
|
cv2.imwrite('/Users/xs/PycharmProjects/super-tomato/defect_mask.bmp', (predicted_mask * 255).astype('uint8'))
|
|
cv2.imshow('Predicted Mask', (predicted_mask * 255).astype('uint8'))
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows()
|
|
|
|
time2 = time.time()
|
|
print(f'预测时间: {time2 - time1:.2f}秒')
|