mirror of
https://github.com/NanjingForestryUniversity/supermachine--tomato-passion_fruit.git
synced 2025-11-09 14:54:07 +00:00
feat:新增resnnet分类模型,判断是否空果;
fix:修复了一些小问题,确保在空果情况下不会程序报错卡死
This commit is contained in:
parent
5d6d53c95d
commit
e7bbaf4b01
1
.gitignore
vendored
1
.gitignore
vendored
@ -89,3 +89,4 @@ fabric.properties
|
|||||||
!/20240410RGBtest1/super-tomato/defect_big.bmp
|
!/20240410RGBtest1/super-tomato/defect_big.bmp
|
||||||
!/20240410RGBtest1/super-tomato/defect_mask.bmp
|
!/20240410RGBtest1/super-tomato/defect_mask.bmp
|
||||||
!/20240410RGBtest1/super-tomato/prediction.png
|
!/20240410RGBtest1/super-tomato/prediction.png
|
||||||
|
/20240529RGBtest3/data/
|
||||||
|
|||||||
@ -7,12 +7,18 @@
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import utils
|
import utils
|
||||||
from root_dir import ROOT_DIR
|
|
||||||
from sklearn.ensemble import RandomForestRegressor
|
from sklearn.ensemble import RandomForestRegressor
|
||||||
import joblib
|
import joblib
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class Tomato:
|
class Tomato:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -242,13 +248,11 @@ class Passion_fruit:
|
|||||||
lower_hue = np.array([self.hue_value - self.hue_delta, 0, 0])
|
lower_hue = np.array([self.hue_value - self.hue_delta, 0, 0])
|
||||||
upper_hue = np.array([self.hue_value + self.hue_delta, 255, 255])
|
upper_hue = np.array([self.hue_value + self.hue_delta, 255, 255])
|
||||||
hue_mask = cv2.inRange(hsv_image, lower_hue, upper_hue)
|
hue_mask = cv2.inRange(hsv_image, lower_hue, upper_hue)
|
||||||
|
|
||||||
# 创建V通道排除中心值的掩码
|
# 创建V通道排除中心值的掩码
|
||||||
lower_value_1 = np.array([0, 0, 0])
|
lower_value_1 = np.array([0, 0, 0])
|
||||||
upper_value_1 = np.array([180, 255, self.value_target - self.value_delta])
|
upper_value_1 = np.array([180, 255, self.value_target - self.value_delta])
|
||||||
lower_value_2 = np.array([0, 0, self.value_target + self.value_delta])
|
lower_value_2 = np.array([0, 0, self.value_target + self.value_delta])
|
||||||
upper_value_2 = np.array([180, 255, 255])
|
upper_value_2 = np.array([180, 255, 255])
|
||||||
|
|
||||||
value_mask_1 = cv2.inRange(hsv_image, lower_value_1, upper_value_1)
|
value_mask_1 = cv2.inRange(hsv_image, lower_value_1, upper_value_1)
|
||||||
value_mask_1 = cv2.bitwise_not(value_mask_1)
|
value_mask_1 = cv2.bitwise_not(value_mask_1)
|
||||||
value_mask_2 = cv2.inRange(hsv_image, lower_value_2, upper_value_2)
|
value_mask_2 = cv2.inRange(hsv_image, lower_value_2, upper_value_2)
|
||||||
@ -278,13 +282,10 @@ class Passion_fruit:
|
|||||||
"""
|
"""
|
||||||
# 确保mask_image是二值图像
|
# 确保mask_image是二值图像
|
||||||
_, binary_mask = cv2.threshold(mask_image, 127, 255, cv2.THRESH_BINARY)
|
_, binary_mask = cv2.threshold(mask_image, 127, 255, cv2.THRESH_BINARY)
|
||||||
|
|
||||||
# 查找mask图像中的轮廓
|
# 查找mask图像中的轮廓
|
||||||
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
# 在原图上绘制轮廓
|
# 在原图上绘制轮廓
|
||||||
cv2.drawContours(original_image, contours, -1, (0, 255, 0), 2)
|
cv2.drawContours(original_image, contours, -1, (0, 255, 0), 2)
|
||||||
|
|
||||||
return original_image
|
return original_image
|
||||||
|
|
||||||
def bitwise_and_rgb_with_binary(self, rgb_img, bin_img):
|
def bitwise_and_rgb_with_binary(self, rgb_img, bin_img):
|
||||||
@ -294,8 +295,26 @@ class Passion_fruit:
|
|||||||
:param bin_img: 二值图像
|
:param bin_img: 二值图像
|
||||||
:return: 按位与后的结果图像
|
:return: 按位与后的结果图像
|
||||||
'''
|
'''
|
||||||
|
# 检查 RGB 图像是否为空或全黑
|
||||||
|
if rgb_img is None or rgb_img.size == 0 or np.all(rgb_img == 0):
|
||||||
|
logging.error("RGB 图像为空或全黑,返回一个全黑RGB图像。")
|
||||||
|
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
|
||||||
|
# 检查二值图像是否为空或全黑
|
||||||
|
if bin_img is None or bin_img.size == 0 or np.all(bin_img == 0):
|
||||||
|
logging.error("二值图像为空或全黑,返回一个全黑RGB图像。")
|
||||||
|
return np.zeros((100, 100, 3), dtype=np.uint8) if rgb_img is None else np.zeros_like(rgb_img)
|
||||||
|
# 转换二值图像为三通道
|
||||||
|
try:
|
||||||
bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR)
|
bin_img_3channel = cv2.cvtColor(bin_img, cv2.COLOR_GRAY2BGR)
|
||||||
|
except cv2.error as e:
|
||||||
|
logging.error(f"转换二值图像时发生错误: {e}")
|
||||||
|
return np.zeros_like(rgb_img)
|
||||||
|
# 进行按位与操作
|
||||||
|
try:
|
||||||
result = cv2.bitwise_and(rgb_img, bin_img_3channel)
|
result = cv2.bitwise_and(rgb_img, bin_img_3channel)
|
||||||
|
except cv2.error as e:
|
||||||
|
logging.error(f"执行按位与操作时发生错误: {e}")
|
||||||
|
return np.zeros_like(rgb_img)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
class Spec_predict(object):
|
class Spec_predict(object):
|
||||||
@ -331,65 +350,17 @@ class Spec_predict(object):
|
|||||||
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
|
selected_bands = [8, 9, 10, 48, 49, 50, 77, 80, 103, 108, 115, 143, 145]
|
||||||
data_x = data_x[:, selected_bands]
|
data_x = data_x[:, selected_bands]
|
||||||
data_y = self.model.predict(data_x)
|
data_y = self.model.predict(data_x)
|
||||||
return data_y
|
return data_y[0]
|
||||||
|
|
||||||
|
|
||||||
# def get_tomato_dimensions(edge_img):
|
|
||||||
# """
|
|
||||||
# 根据边缘二值化轮廓图,计算果子的长径、短径和长短径比值。
|
|
||||||
# 使用最小外接矩形和最小外接圆两种方法。
|
|
||||||
#
|
|
||||||
# 参数:
|
|
||||||
# edge_img (numpy.ndarray): 边缘二值化轮廓图,背景为黑色,番茄区域为白色。
|
|
||||||
#
|
|
||||||
# 返回:
|
|
||||||
# tuple: (长径, 短径, 长短径比值)
|
|
||||||
# """
|
|
||||||
# if edge_img is None or edge_img.any() == 0:
|
|
||||||
# return (0, 0)
|
|
||||||
# # 最小外接矩形
|
|
||||||
# rect = cv2.minAreaRect(cv2.findContours(edge_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0])
|
|
||||||
# major_axis, minor_axis = rect[1]
|
|
||||||
# # aspect_ratio = max(major_axis, minor_axis) / min(major_axis, minor_axis)
|
|
||||||
#
|
|
||||||
# # # 最小外接圆
|
|
||||||
# # (x, y), radius = cv2.minEnclosingCircle(
|
|
||||||
# # cv2.findContours(edge_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0])
|
|
||||||
# # diameter = 2 * radius
|
|
||||||
# # aspect_ratio_circle = 1.0
|
|
||||||
#
|
|
||||||
# return (max(major_axis, minor_axis), min(major_axis, minor_axis))
|
|
||||||
|
|
||||||
# def get_defect_info(defect_img):
|
|
||||||
# """
|
|
||||||
# 根据区域缺陷二值化轮廓图,计算缺陷区域的个数和总面积。
|
|
||||||
#
|
|
||||||
# 参数:
|
|
||||||
# defect_img (numpy.ndarray): 番茄区域缺陷二值化轮廓图,背景为黑色,番茄区域为白色,缺陷区域为黑色连通域。
|
|
||||||
#
|
|
||||||
# 返回:
|
|
||||||
# tuple: (缺陷区域个数, 缺陷区域像素面积,缺陷像素总面积)
|
|
||||||
# """
|
|
||||||
# # 检查输入是否为空
|
|
||||||
# if defect_img is None or defect_img.any() == 0:
|
|
||||||
# return (0, 0)
|
|
||||||
#
|
|
||||||
# nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(defect_img, connectivity=4)
|
|
||||||
# max_area = max(stats[i, cv2.CC_STAT_AREA] for i in range(1, nb_components))
|
|
||||||
# areas = []
|
|
||||||
# for i in range(1, nb_components):
|
|
||||||
# area = stats[i, cv2.CC_STAT_AREA]
|
|
||||||
# if area != max_area:
|
|
||||||
# areas.append(area)
|
|
||||||
# number_defects = len(areas)
|
|
||||||
# total_pixels = sum(areas)
|
|
||||||
# return number_defects, total_pixels
|
|
||||||
|
|
||||||
class Data_processing:
|
class Data_processing:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def contour_process(self, image_array):
|
def contour_process(self, image_array):
|
||||||
|
# 检查图像是否为空或全黑
|
||||||
|
if image_array is None or image_array.size == 0 or np.all(image_array == 0):
|
||||||
|
# print("输入的图像为空或全黑,返回一个全黑图像。")
|
||||||
|
return np.zeros_like(image_array) if image_array is not None else np.zeros((100, 100), dtype=np.uint8)
|
||||||
# 应用中值滤波
|
# 应用中值滤波
|
||||||
image_filtered = cv2.medianBlur(image_array, 5)
|
image_filtered = cv2.medianBlur(image_array, 5)
|
||||||
|
|
||||||
@ -495,6 +466,7 @@ class Data_processing:
|
|||||||
# 设置 S-L 通道阈值并处理图像
|
# 设置 S-L 通道阈值并处理图像
|
||||||
threshold_s_l = 180
|
threshold_s_l = 180
|
||||||
threshold_fore_g_r_t = 20
|
threshold_fore_g_r_t = 20
|
||||||
|
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
|
||||||
s_l = tomato.extract_s_l(img)
|
s_l = tomato.extract_s_l(img)
|
||||||
thresholded_s_l = tomato.threshold_segmentation(s_l, threshold_s_l)
|
thresholded_s_l = tomato.threshold_segmentation(s_l, threshold_s_l)
|
||||||
new_bin_img = tomato.largest_connected_component(thresholded_s_l)
|
new_bin_img = tomato.largest_connected_component(thresholded_s_l)
|
||||||
@ -526,6 +498,7 @@ class Data_processing:
|
|||||||
# 创建PassionFruit类的实例
|
# 创建PassionFruit类的实例
|
||||||
pf = Passion_fruit(hue_value=hue_value, hue_delta=hue_delta, value_target=value_target, value_delta=value_delta)
|
pf = Passion_fruit(hue_value=hue_value, hue_delta=hue_delta, value_target=value_target, value_delta=value_delta)
|
||||||
|
|
||||||
|
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
|
||||||
hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||||
combined_mask = pf.create_mask(hsv_image)
|
combined_mask = pf.create_mask(hsv_image)
|
||||||
combined_mask = pf.apply_morphology(combined_mask)
|
combined_mask = pf.apply_morphology(combined_mask)
|
||||||
@ -540,3 +513,221 @@ class Data_processing:
|
|||||||
diameter = (long_axis + short_axis) / 2
|
diameter = (long_axis + short_axis) / 2
|
||||||
|
|
||||||
return diameter, weigth, number_defects, total_pixels, rp
|
return diameter, weigth, number_defects, total_pixels, rp
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
'''
|
||||||
|
BasicBlock for ResNet18 and ResNet34
|
||||||
|
|
||||||
|
'''
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
|
||||||
|
kernel_size=3, stride=stride, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channel)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
|
||||||
|
kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channel)
|
||||||
|
self.downsample = downsample
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
"""
|
||||||
|
注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
|
||||||
|
但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
|
||||||
|
这么做的好处是能够在top1上提升大概0.5%的准确率。
|
||||||
|
可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
|
||||||
|
"""
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, in_channel, out_channel, stride=1, downsample=None,
|
||||||
|
groups=1, width_per_group=64):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
|
||||||
|
width = int(out_channel * (width_per_group / 64.)) * groups
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
|
||||||
|
kernel_size=1, stride=1, bias=False) # squeeze channels
|
||||||
|
self.bn1 = nn.BatchNorm2d(width)
|
||||||
|
# -----------------------------------------
|
||||||
|
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
|
||||||
|
kernel_size=3, stride=stride, bias=False, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm2d(width)
|
||||||
|
# -----------------------------------------
|
||||||
|
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
|
||||||
|
kernel_size=1, stride=1, bias=False) # unsqueeze channels
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
'''
|
||||||
|
ResNet18 and ResNet34
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
block,
|
||||||
|
blocks_num,
|
||||||
|
num_classes=1000,
|
||||||
|
include_top=True,
|
||||||
|
groups=1,
|
||||||
|
width_per_group=64):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
self.include_top = include_top
|
||||||
|
self.in_channel = 64
|
||||||
|
|
||||||
|
self.groups = groups
|
||||||
|
self.width_per_group = width_per_group
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
|
||||||
|
padding=3, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
self.layer1 = self._make_layer(block, 64, blocks_num[0])
|
||||||
|
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
|
||||||
|
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
|
||||||
|
if self.include_top:
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
|
||||||
|
def _make_layer(self, block, channel, block_num, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.in_channel != channel * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(channel * block.expansion))
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.in_channel,
|
||||||
|
channel,
|
||||||
|
downsample=downsample,
|
||||||
|
stride=stride,
|
||||||
|
groups=self.groups,
|
||||||
|
width_per_group=self.width_per_group))
|
||||||
|
self.in_channel = channel * block.expansion
|
||||||
|
|
||||||
|
for _ in range(1, block_num):
|
||||||
|
layers.append(block(self.in_channel,
|
||||||
|
channel,
|
||||||
|
groups=self.groups,
|
||||||
|
width_per_group=self.width_per_group))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
|
if self.include_top:
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
x = self.fc(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def resnet18(num_classes=1000, include_top=True):
|
||||||
|
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
|
||||||
|
|
||||||
|
def resnetzy(num_classes=1000, include_top=True):
|
||||||
|
return ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifier:
|
||||||
|
'''
|
||||||
|
图像分类器,用于加载预训练的 ResNet 模型并进行图像分类。
|
||||||
|
'''
|
||||||
|
def __init__(self, model_path, class_indices_path, device=None):
|
||||||
|
if device is None:
|
||||||
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# 加载类别索引
|
||||||
|
assert os.path.exists(class_indices_path), f"File: '{class_indices_path}' does not exist."
|
||||||
|
with open(class_indices_path, "r") as json_file:
|
||||||
|
self.class_indict = json.load(json_file)
|
||||||
|
|
||||||
|
# 创建模型并加载权重
|
||||||
|
self.model = resnetzy(num_classes=len(self.class_indict)).to(self.device)
|
||||||
|
assert os. path.exists(model_path), f"File: '{model_path}' does not exist."
|
||||||
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# 设置图像转换
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize(256),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
def predict(self, image_np):
|
||||||
|
'''
|
||||||
|
对图像进行分类预测。
|
||||||
|
:param image_np:
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
# 将numpy数组转换为图像
|
||||||
|
image = Image.fromarray(image_np.astype('uint8'), 'RGB')
|
||||||
|
image = self.transform(image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(image).cpu()
|
||||||
|
predict = torch.softmax(output, dim=1)
|
||||||
|
predict_cla = torch.argmax(predict, dim=1).numpy()
|
||||||
|
|
||||||
|
# return self.class_indict[str(predict_cla[0])]
|
||||||
|
return predict_cla[0]
|
||||||
@ -10,7 +10,7 @@ import os
|
|||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from root_dir import ROOT_DIR
|
from root_dir import ROOT_DIR
|
||||||
from classifer import Spec_predict, Data_processing
|
from classifer import Spec_predict, Data_processing, ImageClassifier
|
||||||
import logging
|
import logging
|
||||||
from utils import Pipe
|
from utils import Pipe
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -76,7 +76,7 @@ def process_data(cmd: str, images: list, spec: any, dp: Data_processing, pipe: P
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def main(is_debug=False):
|
def main(is_debug=False):
|
||||||
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'))
|
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'tomato.log'), encoding='utf-8')
|
||||||
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
|
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
|
console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
|
||||||
@ -84,12 +84,14 @@ def main(is_debug=False):
|
|||||||
handlers=[file_handler, console_handler],
|
handlers=[file_handler, console_handler],
|
||||||
level=logging.DEBUG)
|
level=logging.DEBUG)
|
||||||
detector = Spec_predict(ROOT_DIR/'models'/'passion_fruit_2.joblib')
|
detector = Spec_predict(ROOT_DIR/'models'/'passion_fruit_2.joblib')
|
||||||
|
classifier = ImageClassifier(ROOT_DIR/'models'/'resnet18_0616.pth', ROOT_DIR/'models'/'class_indices.json')
|
||||||
dp = Data_processing()
|
dp = Data_processing()
|
||||||
|
|
||||||
_ = detector.predict(np.ones((30, 30, 224), dtype=np.uint16))
|
_ = detector.predict(np.ones((30, 30, 224), dtype=np.uint16))
|
||||||
_, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp'))
|
_ = classifier.predict(np.ones((224, 224, 3), dtype=np.uint8))
|
||||||
_, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp'))
|
# _, _, _, _, _ =dp.analyze_tomato(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\tomato_img\bad\71.bmp'))
|
||||||
print('初始化完成')
|
# _, _, _, _, _ = dp.analyze_passion_fruit(cv2.imread(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\data\passion_fruit_img\38.bmp'))
|
||||||
|
print('系统初始化完成')
|
||||||
|
|
||||||
rgb_receive_name = r'\\.\pipe\rgb_receive'
|
rgb_receive_name = r'\\.\pipe\rgb_receive'
|
||||||
rgb_send_name = r'\\.\pipe\rgb_send'
|
rgb_send_name = r'\\.\pipe\rgb_send'
|
||||||
@ -106,7 +108,6 @@ def main(is_debug=False):
|
|||||||
if cmd == 'YR':
|
if cmd == 'YR':
|
||||||
break # 当接收到的不是预热命令时,结束预热循环
|
break # 当接收到的不是预热命令时,结束预热循环
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
images = []
|
images = []
|
||||||
cmd = None
|
cmd = None
|
||||||
@ -115,19 +116,28 @@ def main(is_debug=False):
|
|||||||
data = pipe.receive_rgb_data(rgb_receive)
|
data = pipe.receive_rgb_data(rgb_receive)
|
||||||
end_time10 = time.time()
|
end_time10 = time.time()
|
||||||
print(f'接收一份数据时间:{end_time10 - start_time1}秒')
|
print(f'接收一份数据时间:{end_time10 - start_time1}秒')
|
||||||
|
|
||||||
start_time11 = time.time()
|
start_time11 = time.time()
|
||||||
cmd, img = pipe.parse_img(data)
|
cmd, img = pipe.parse_img(data)
|
||||||
end_time1 = time.time()
|
end_time1 = time.time()
|
||||||
print(f'处理一份数据时间:{end_time1 - start_time11}秒')
|
print(f'处理一份数据时间:{end_time1 - start_time11}秒')
|
||||||
print(f'接收1张图时间:{end_time1 - start_time1}秒')
|
print(f'接收一张图时间:{end_time1 - start_time1}秒')
|
||||||
# print(cmd, img.shape)
|
|
||||||
# #打印img的数据类型
|
# 使用分类器进行预测
|
||||||
# print(img.dtype)
|
prediction = classifier.predict(img)
|
||||||
|
print(f'预测结果:{prediction}')
|
||||||
|
if prediction == 1:
|
||||||
images.append(img)
|
images.append(img)
|
||||||
# print(len(images))
|
else:
|
||||||
if cmd not in ['TO', 'PF', 'YR']:
|
response = pipe.send_data(cmd='KO', brix=0, diameter=0, green_percentage=0, weigth=0, defect_num=0,
|
||||||
|
total_defect_area=0, rp=None)
|
||||||
|
print("图像中无果,跳过此图像")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cmd not in ['TO', 'PF', 'YR', 'KO']:
|
||||||
logging.error(f'错误指令,指令为{cmd}')
|
logging.error(f'错误指令,指令为{cmd}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
spec = None
|
spec = None
|
||||||
if cmd == 'PF':
|
if cmd == 'PF':
|
||||||
start_time2 = time.time()
|
start_time2 = time.time()
|
||||||
@ -135,17 +145,21 @@ def main(is_debug=False):
|
|||||||
_, spec = pipe.parse_spec(spec_data)
|
_, spec = pipe.parse_spec(spec_data)
|
||||||
end_time2 = time.time()
|
end_time2 = time.time()
|
||||||
print(f'接收光谱数据时间:{end_time2 - start_time2}秒')
|
print(f'接收光谱数据时间:{end_time2 - start_time2}秒')
|
||||||
# print(spec.shape)
|
|
||||||
start_time3 = time.time()
|
start_time3 = time.time()
|
||||||
|
if images: # 确保images不为空
|
||||||
response = process_data(cmd, images, spec, dp, pipe, detector)
|
response = process_data(cmd, images, spec, dp, pipe, detector)
|
||||||
end_time3 = time.time()
|
end_time3 = time.time()
|
||||||
print(f'处理时间:{end_time3 - start_time3}秒')
|
print(f'处理时间:{end_time3 - start_time3}秒')
|
||||||
end_time = time.time()
|
|
||||||
print(f'全流程时间:{end_time - start_time}秒')
|
|
||||||
if response:
|
if response:
|
||||||
logging.info(f'处理成功,响应为: {response}')
|
logging.info(f'处理成功,响应为: {response}')
|
||||||
else:
|
else:
|
||||||
logging.error('处理失败')
|
logging.error('处理失败')
|
||||||
|
else:
|
||||||
|
print("没有有效的图像进行处理")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f'全流程时间:{end_time - start_time}秒')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@ -6,11 +6,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import win32file
|
import win32file
|
||||||
import win32pipe
|
import win32pipe
|
||||||
import time
|
import time
|
||||||
@ -164,26 +160,31 @@ class Pipe:
|
|||||||
return cmd, spec
|
return cmd, spec
|
||||||
|
|
||||||
def send_data(self,cmd:str, brix, green_percentage, weigth, diameter, defect_num, total_defect_area, rp):
|
def send_data(self,cmd:str, brix, green_percentage, weigth, diameter, defect_num, total_defect_area, rp):
|
||||||
# start_time = time.time()
|
'''
|
||||||
#
|
发送数据
|
||||||
# rp1 = Image.fromarray(rp.astype(np.uint8))
|
:param cmd:
|
||||||
# # cv2.imwrite('rp1.bmp', rp1)
|
:param brix:
|
||||||
#
|
:param green_percentage:
|
||||||
# # 将 Image 对象保存到 BytesIO 流中
|
:param weigth:
|
||||||
# img_bytes = io.BytesIO()
|
:param diameter:
|
||||||
# rp1.save(img_bytes, format='BMP')
|
:param defect_num:
|
||||||
# img_bytes = img_bytes.getvalue()
|
:param total_defect_area:
|
||||||
|
:param rp:
|
||||||
# width = rp.shape[0]
|
:return:
|
||||||
# height = rp.shape[1]
|
'''
|
||||||
# print(width, height)
|
|
||||||
# img_bytes = rp.tobytes()
|
|
||||||
# length = len(img_bytes) + 18
|
|
||||||
# print(length)
|
|
||||||
# length = length.to_bytes(4, byteorder='big')
|
|
||||||
# width = width.to_bytes(2, byteorder='big')
|
|
||||||
# height = height.to_bytes(2, byteorder='big')
|
|
||||||
cmd = cmd.strip().upper()
|
cmd = cmd.strip().upper()
|
||||||
|
if cmd == 'KO':
|
||||||
|
cmd_ko = cmd.encode('ascii')
|
||||||
|
length = (2).to_bytes(4, byteorder='big') # 因为只有KO两个字节,所以长度是2
|
||||||
|
send_message = length + cmd_ko
|
||||||
|
try:
|
||||||
|
win32file.WriteFile(self.rgb_send, send_message)
|
||||||
|
print('KO消息发送成功')
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'发送KO指令失败,错误类型:{e}')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
cmd_type = 'RE'
|
cmd_type = 'RE'
|
||||||
cmd_re = cmd_type.upper().encode('ascii')
|
cmd_re = cmd_type.upper().encode('ascii')
|
||||||
img = np.asarray(rp, dtype=np.uint8) # 将图像转换为 NumPy 数组
|
img = np.asarray(rp, dtype=np.uint8) # 将图像转换为 NumPy 数组
|
||||||
@ -205,19 +206,19 @@ class Pipe:
|
|||||||
weigth = weigth.to_bytes(1, byteorder='big')
|
weigth = weigth.to_bytes(1, byteorder='big')
|
||||||
send_message = length + cmd_re + brix + gp + diameter + weigth + defect_num + total_defect_area + height + width + img_bytes
|
send_message = length + cmd_re + brix + gp + diameter + weigth + defect_num + total_defect_area + height + width + img_bytes
|
||||||
elif cmd == 'PF':
|
elif cmd == 'PF':
|
||||||
brix = int(brix.item() * 1000).to_bytes(2, byteorder='big')
|
brix = int(brix * 1000).to_bytes(2, byteorder='big')
|
||||||
gp = 0
|
gp = 0
|
||||||
gp = gp.to_bytes(1, byteorder='big')
|
gp = gp.to_bytes(1, byteorder='big')
|
||||||
weigth = weigth.to_bytes(1, byteorder='big')
|
weigth = weigth.to_bytes(1, byteorder='big')
|
||||||
send_message = length + cmd_re + brix + gp + diameter + weigth + defect_num + total_defect_area + height + width + img_bytes
|
send_message = length + cmd_re + brix + gp + diameter + weigth + defect_num + total_defect_area + height + width + img_bytes
|
||||||
try:
|
try:
|
||||||
win32file.WriteFile(self.rgb_send, send_message)
|
win32file.WriteFile(self.rgb_send, send_message)
|
||||||
time.sleep(0.01)
|
# time.sleep(0.01)
|
||||||
print('发送成功')
|
print('发送成功')
|
||||||
print(len(send_message), len(img_bytes))
|
# print(len(send_message), len(img_bytes))
|
||||||
# print(len(send_message))
|
# print(len(send_message))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f'发送完成指令失败,错误类型:{e}')
|
logging.error(f'发送指令失败,错误类型:{e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# end_time = time.time()
|
# end_time = time.time()
|
||||||
|
|||||||
@ -60,7 +60,7 @@ def predict(model, data):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 加载模型
|
# 加载模型
|
||||||
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit.joblib')
|
model = load_model(r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\models\passion_fruit_2.joblib')
|
||||||
|
|
||||||
# 读取数据
|
# 读取数据
|
||||||
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
|
directory = r'D:\project\supermachine--tomato-passion_fruit\20240529RGBtest3\xs\光谱数据3030'
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user