feat:新增resnnet分类模型,判断是否空果;

fix:修复了一些小问题,确保在空果情况下不会程序报错卡死
This commit is contained in:
TG 2024-06-18 10:47:44 +08:00
parent 5d6d53c95d
commit e7bbaf4b01
7 changed files with 318 additions and 111 deletions

1
.gitignore vendored
View File

@ -89,3 +89,4 @@ fabric.properties
!/20240410RGBtest1/super-tomato/defect_big.bmp
!/20240410RGBtest1/super-tomato/defect_mask.bmp
!/20240410RGBtest1/super-tomato/prediction.png
/20240529RGBtest3/data/

View File

@ -7,12 +7,18 @@
import cv2
import numpy as np
import logging
import os
import utils
from root_dir import ROOT_DIR
from sklearn.ensemble import RandomForestRegressor
import joblib
import torch.nn as nn
import torch
from PIL import Image
from torchvision import transforms
import numpy as np
import json
import logging
class Tomato:
def __init__(self):
@ -242,13 +248,11 @@ class Passion_fruit:
lower_hue = np.array([self.hue_value - self.hue_delta, 0, 0])
upper_hue = np.array([self.hue_value + self.hue_delta, 255, 255])
hue_mask = cv2.inRange(hsv_image, lower_hue, upper_hue)
# 创建V通道排除中心值的掩码
lower_value_1 = np.array([0, 0, 0])
upper_value_1 = np.array([180, 255, self.value_target - self.value_delta])
lower_value_2 = np.array([0, 0, self.value_target + self.value_delta])
upper_value_2 = np.array([180, 255, 255])
value_mask_1 = cv2.inRange(hsv_image, lower_value_1, upper_value_1)
value_mask_1 = cv2.bitwise_not(value_mask_1)
value_mask_2 = cv2.inRange(hsv_image, lower_value_2, upper_value_2)
@ -278,13 +282,10 @@ class Passion_fruit:
"""
# 确保mask_image是二值图像
_, binary_mask = cv2.threshold(mask_image, 127, 255, cv2.THRESH_BINARY)
# 查找mask图像中的轮廓
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 在原图上绘制轮廓
cv2.drawContours(original_image, contours, -1, (0, 255, 0), 2)
return original_image
def bitwise_and_rgb_with_binary(self, rgb_img, bin_img):
@ -294,8 +295,26 @@ class Passion_fruit:
:param bin_img: 二值图像
:return: 按位与后的结果图像
'''
bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR)
result = cv2.bitwise_and(rgb_img, bin_img_3channel)
# 检查 RGB 图像是否为空或全黑
if rgb_img is None or rgb_img.size == 0 or np.all(rgb_img == 0):
logging.error("RGB 图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
# 检查二值图像是否为空或全黑
if bin_img is None or bin_img.size == 0 or np.all(bin_img == 0):
logging.error("二值图像为空或全黑返回一个全黑RGB图像。")
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
# 转换二值图像为三通道
try:
bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR)
except cv2.error as e:
logging.error(f"转换二值图像时发生错误: {e}")
return np.zeros_like(rgb_img)
# 进行按位与操作
try:
result = cv2.bitwise_and(rgb_img, bin_img_3channel)
except cv2.error as e:
logging.error(f"执行按位与操作时发生错误: {e}")
return np.zeros_like(rgb_img)
return result
class Spec_predict(object):
@ -331,65 +350,17 @@ class Spec_predict(object):
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
data_x = data_x[:, selected_bands]
data_y = self.model.predict(data_x)
return data_y
# def get_tomato_dimensions(edge_img):
# """
# 根据边缘二值化轮廓图,计算果子的长径、短径和长短径比值。
# 使用最小外接矩形和最小外接圆两种方法。
#
# 参数:
# edge_img (numpy.ndarray): 边缘二值化轮廓图,背景为黑色,番茄区域为白色。
#
# 返回:
# tuple: (长径, 短径, 长短径比值)
# """
# if edge_img is None or edge_img.any() == 0:
# return (0, 0)
# # 最小外接矩形
# rect = cv2.minAreaRect(cv2.findContours(edge_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0])
# major_axis, minor_axis = rect[1]
# # aspect_ratio = max(major_axis, minor_axis) / min(major_axis, minor_axis)
#
# # # 最小外接圆
# # (x, y), radius = cv2.minEnclosingCircle(
# # cv2.findContours(edge_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0])
# # diameter = 2 * radius
# # aspect_ratio_circle = 1.0
#
# return (max(major_axis, minor_axis), min(major_axis, minor_axis))
# def get_defect_info(defect_img):
# """
# 根据区域缺陷二值化轮廓图,计算缺陷区域的个数和总面积。
#
# 参数:
# defect_img (numpy.ndarray): 番茄区域缺陷二值化轮廓图,背景为黑色,番茄区域为白色,缺陷区域为黑色连通域。
#
# 返回:
# tuple: (缺陷区域个数, 缺陷区域像素面积,缺陷像素总面积)
# """
# # 检查输入是否为空
# if defect_img is None or defect_img.any() == 0:
# return (0, 0)
#
# nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(defect_img, connectivity=4)
# max_area = max(stats[i, cv2.CC_STAT_AREA] for i in range(1, nb_components))
# areas = []
# for i in range(1, nb_components):
# area = stats[i, cv2.CC_STAT_AREA]
# if area != max_area:
# areas.append(area)
# number_defects = len(areas)
# total_pixels = sum(areas)
# return number_defects, total_pixels
return data_y[0]
class Data_processing:
def __init__(self):
pass
def contour_process(self, image_array):
# 检查图像是否为空或全黑
if image_array is None or image_array.size == 0 or np.all(image_array == 0):
# print("输入的图像为空或全黑,返回一个全黑图像。")
return np.zeros_like(image_array) if image_array is not None else np.zeros((100, 100), dtype=np.uint8)
# 应用中值滤波
image_filtered = cv2.medianBlur(image_array, 5)
@ -495,6 +466,7 @@ class Data_processing:
# 设置 S-L 通道阈值并处理图像
threshold_s_l = 180
threshold_fore_g_r_t = 20
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
s_l = tomato.extract_s_l(img)
thresholded_s_l = tomato.threshold_segmentation(s_l, threshold_s_l)
new_bin_img = tomato.largest_connected_component(thresholded_s_l)
@ -526,6 +498,7 @@ class Data_processing:
# 创建PassionFruit类的实例
pf = Passion_fruit(hue_value=hue_value, hue_delta=hue_delta, value_target=value_target, value_delta=value_delta)
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
combined_mask = pf.create_mask(hsv_image)
combined_mask = pf.apply_morphology(combined_mask)
@ -540,3 +513,221 @@ class Data_processing:
diameter = (long_axis + short_axis) / 2
return diameter, weigth, number_defects, total_pixels, rp
class BasicBlock(nn.Module):
'''
BasicBlock for ResNet18 and ResNet34
'''
expansion = 1
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""
注意原论文中在虚线残差结构的主分支上第一个1x1卷积层的步距是2第二个3x3卷积层步距是1
但在pytorch官方实现过程中是第一个1x1卷积层的步距是1第二个3x3卷积层步距是2
这么做的好处是能够在top1上提升大概0.5%的准确率
可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
"""
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None,
groups=1, width_per_group=64):
super(Bottleneck, self).__init__()
width = int(out_channel * (width_per_group / 64.)) * groups
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
kernel_size=1, stride=1, bias=False) # squeeze channels
self.bn1 = nn.BatchNorm2d(width)
# -----------------------------------------
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
kernel_size=3, stride=stride, bias=False, padding=1)
self.bn2 = nn.BatchNorm2d(width)
# -----------------------------------------
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, bias=False) # unsqueeze channels
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
'''
ResNet18 and ResNet34
'''
def __init__(self,
block,
blocks_num,
num_classes=1000,
include_top=True,
groups=1,
width_per_group=64):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.groups = groups
self.width_per_group = width_per_group
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, blocks_num[0])
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
layers.append(block(self.in_channel,
channel,
downsample=downsample,
stride=stride,
groups=self.groups,
width_per_group=self.width_per_group))
self.in_channel = channel * block.expansion
for _ in range(1, block_num):
layers.append(block(self.in_channel,
channel,
groups=self.groups,
width_per_group=self.width_per_group))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet18(num_classes=1000, include_top=True):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
def resnetzy(num_classes=1000, include_top=True):
return ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
class ImageClassifier:
'''
图像分类器用于加载预训练的 ResNet 模型并进行图像分类
'''
def __init__(self, model_path, class_indices_path, device=None):
if device is None:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = device
# 加载类别索引
assert os.path.exists(class_indices_path), f"File: '{class_indices_path}' does not exist."
with open(class_indices_path, "r") as json_file:
self.class_indict = json.load(json_file)
# 创建模型并加载权重
self.model = resnetzy(num_classes=len(self.class_indict)).to(self.device)
assert os. path.exists(model_path), f"File: '{model_path}' does not exist."
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.eval()
# 设置图像转换
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(self, image_np):
'''
对图像进行分类预测
:param image_np:
:return:
'''
# 将numpy数组转换为图像
image = Image.fromarray(image_np.astype('uint8'), 'RGB')
image = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(image).cpu()
predict = torch.softmax(output, dim=1)
predict_cla = torch.argmax(predict, dim=1).numpy()
# return self.class_indict[str(predict_cla[0])]
return predict_cla[0]

View File

@ -10,7 +10,7 @@ import os
import cv2
from root_dir import ROOT_DIR
from classifer import Spec_predict, Data_processing
from classifer import Spec_predict, Data_processing, ImageClassifier
import logging
from utils import Pipe
import numpy as np
@ -76,7 +76,7 @@ def process_data(cmd: str, images: list, spec: any, dp: Data_processing, pipe: P
return response
def main(is_debug=False):
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'))
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'), encoding='utf-8')
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
@ -84,12 +84,14 @@ def main(is_debug=False):
handlers=[file_handler, console_handler],
level=logging.DEBUG)
detector = Spec_predict(ROOT_DIR/'models'/'passion_fruit_2.joblib')
classifier = ImageClassifier(ROOT_DIR/'models'/'resnet18_0616.pth', ROOT_DIR/'models'/'class_indices.json')
dp = Data_processing()
_ = detector.predict(np.ones((30, 30, 224), dtype=np.uint16))
_, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp'))
_, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp'))
print('初始化完成')
_ = classifier.predict(np.ones((224, 224, 3), dtype=np.uint8))
# _, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp'))
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp'))
print('系统初始化完成')
rgb_receive_name = r'\\.\pipe\rgb_receive'
rgb_send_name = r'\\.\pipe\rgb_send'
@ -106,7 +108,6 @@ def main(is_debug=False):
if cmd == 'YR':
break # 当接收到的不是预热命令时,结束预热循环
while True:
start_time = time.time()
images = []
cmd = None
@ -115,19 +116,28 @@ def main(is_debug=False):
data = pipe.receive_rgb_data(rgb_receive)
end_time10 = time.time()
print(f'接收一份数据时间:{end_time10 - start_time1}')
start_time11 = time.time()
cmd, img = pipe.parse_img(data)
end_time1 = time.time()
print(f'处理一份数据时间:{end_time1 - start_time11}')
print(f'接收1张图时间{end_time1 - start_time1}')
# print(cmd, img.shape)
# #打印img的数据类型
# print(img.dtype)
images.append(img)
# print(len(images))
if cmd not in ['TO', 'PF', 'YR']:
print(f'接收一张图时间:{end_time1 - start_time1}')
# 使用分类器进行预测
prediction = classifier.predict(img)
print(f'预测结果:{prediction}')
if prediction == 1:
images.append(img)
else:
response = pipe.send_data(cmd='KO', brix=0, diameter=0, green_percentage=0, weigth=0, defect_num=0,
total_defect_area=0, rp=None)
print("图像中无果,跳过此图像")
continue
if cmd not in ['TO', 'PF', 'YR', 'KO']:
logging.error(f'错误指令,指令为{cmd}')
continue
spec = None
if cmd == 'PF':
start_time2 = time.time()
@ -135,17 +145,21 @@ def main(is_debug=False):
_, spec = pipe.parse_spec(spec_data)
end_time2 = time.time()
print(f'接收光谱数据时间:{end_time2 - start_time2}')
# print(spec.shape)
start_time3 = time.time()
response = process_data(cmd, images, spec, dp, pipe, detector)
end_time3 = time.time()
print(f'处理时间:{end_time3 - start_time3}')
if images: # 确保images不为空
response = process_data(cmd, images, spec, dp, pipe, detector)
end_time3 = time.time()
print(f'处理时间:{end_time3 - start_time3}')
if response:
logging.info(f'处理成功,响应为: {response}')
else:
logging.error('处理失败')
else:
print("没有有效的图像进行处理")
end_time = time.time()
print(f'全流程时间:{end_time - start_time}')
if response:
logging.info(f'处理成功,响应为: {response}')
else:
logging.error('处理失败')
if __name__ == '__main__':

View File

@ -6,11 +6,7 @@
import shutil
import os
import win32file
import win32pipe
import time
@ -164,26 +160,31 @@ class Pipe:
return cmd, spec
def send_data(self,cmd:str, brix, green_percentage, weigth, diameter, defect_num, total_defect_area, rp):
# start_time = time.time()
#
# rp1 = Image.fromarray(rp.astype(np.uint8))
# # cv2.imwrite('rp1.bmp', rp1)
#
# # 将 Image 对象保存到 BytesIO 流中
# img_bytes = io.BytesIO()
# rp1.save(img_bytes, format='BMP')
# img_bytes = img_bytes.getvalue()
# width = rp.shape[0]
# height = rp.shape[1]
# print(width, height)
# img_bytes = rp.tobytes()
# length = len(img_bytes) + 18
# print(length)
# length = length.to_bytes(4, byteorder='big')
# width = width.to_bytes(2, byteorder='big')
# height = height.to_bytes(2, byteorder='big')
'''
发送数据
:param cmd:
:param brix:
:param green_percentage:
:param weigth:
:param diameter:
:param defect_num:
:param total_defect_area:
:param rp:
:return:
'''
cmd = cmd.strip().upper()
if cmd == 'KO':
cmd_ko = cmd.encode('ascii')
length = (2).to_bytes(4, byteorder='big') # 因为只有KO两个字节所以长度是2
send_message = length + cmd_ko
try:
win32file.WriteFile(self.rgb_send, send_message)
print('KO消息发送成功')
except Exception as e:
logging.error(f'发送KO指令失败错误类型{e}')
return False
return True
cmd_type = 'RE'
cmd_re = cmd_type.upper().encode('ascii')
img = np.asarray(rp, dtype=np.uint8) # 将图像转换为 NumPy 数组
@ -205,19 +206,19 @@ class Pipe:
weigth = weigth.to_bytes(1, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weigth + defect_num + total_defect_area + height + width + img_bytes
elif cmd == 'PF':
brix = int(brix.item() * 1000).to_bytes(2, byteorder='big')
brix = int(brix * 1000).to_bytes(2, byteorder='big')
gp = 0
gp = gp.to_bytes(1, byteorder='big')
weigth = weigth.to_bytes(1, byteorder='big')
send_message = length + cmd_re + brix + gp + diameter + weigth + defect_num + total_defect_area + height + width + img_bytes
try:
win32file.WriteFile(self.rgb_send, send_message)
time.sleep(0.01)
# time.sleep(0.01)
print('发送成功')
print(len(send_message), len(img_bytes))
# print(len(send_message), len(img_bytes))
# print(len(send_message))
except Exception as e:
logging.error(f'发送完成指令失败,错误类型:{e}')
logging.error(f'发送指令失败,错误类型:{e}')
return False
# end_time = time.time()

View File

@ -60,7 +60,7 @@ def predict(model, data):
def main():
# 加载模型
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit.joblib')
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib')
# 读取数据
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'