mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
模型部分clean重构
This commit is contained in:
parent
75ab881bc5
commit
4b56560cc8
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
20
config.py
20
config.py
@ -4,22 +4,26 @@ import numpy as np
|
||||
|
||||
class Config:
|
||||
# 文件相关参数
|
||||
nRows, nCols, nBands, threshold, = 256, 1024, 22, 3
|
||||
nrgbRows, nrgbCols, nrgbBands, rgb_threshold = 1024, 4096, 3, 2
|
||||
nRows, nCols, nBands, spec_size_threshold, = 256, 1024, 22, 3
|
||||
nRgbRows, nRgbCols, nRgbBands, rgb_size_threshold = 1024, 4096, 3, 4
|
||||
|
||||
# 需要设置的谱段等参数
|
||||
selected_bands = [127, 201, 202, 294]
|
||||
bands = [127, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
|
||||
294]
|
||||
bands = [127, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
|
||||
211, 212, 213, 214, 215, 216, 217, 218, 219, 294]
|
||||
is_yellow_min = np.array([0.10167048, 0.1644719, 0.1598884, 0.31534621])
|
||||
is_yellow_max = np.array([0.212984, 0.25896924, 0.26509268, 0.51943593])
|
||||
is_black_threshold = np.asarray([0.1369, 0.1472, 0.1439, 0.1814])
|
||||
black_yellow_bands = [0, 2, 3, 21]
|
||||
green_bands = [i for i in range(1, 21)]
|
||||
|
||||
# 机器学习模型相关参数
|
||||
# 光谱模型参数
|
||||
blk_size = 4
|
||||
pixel_model_path = r"./models/dt.p"
|
||||
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
|
||||
rgb_model_path = r"./models/dt_2022-07-20_14-40.model"
|
||||
rgb_tobacco_model_path = r"models/beijing_dt_2022-07-21_16-44.model"
|
||||
rgb_background_model_path = r"models/tobacco_dt_2022-07-21_16-30.model"
|
||||
|
||||
# rgb模型参数
|
||||
rgb_tobacco_model_path = r"models/tobacco_dt_2022-07-21_16-30.model"
|
||||
rgb_background_model_path = r"models/beijing_dt_2022-07-21_16-44.model"
|
||||
threshold_low, threshold_high = 5, 255
|
||||
threshold_s = 175
|
||||
|
||||
109
main.py
109
main.py
@ -1,20 +1,20 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
|
||||
from config import Config
|
||||
from models import ManualTree, AnonymousColorDetector
|
||||
from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
|
||||
import cv2
|
||||
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
threshold = Config.threshold
|
||||
rgb_threshold = Config.rgb_threshold
|
||||
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
||||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
||||
|
||||
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||
pixel_model_path=Config.pixel_model_path)
|
||||
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||
background_model_path=Config.rgb_background_model_path)
|
||||
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||
total_rgb = Config.nrgbRows * Config.nrgbCols * Config.nrgbBands * 1 # int型变量
|
||||
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
@ -25,6 +25,67 @@ def main():
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
||||
data = os.read(fd_img, total_len)
|
||||
# 读取(开启一个管道)
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
print("[INFO] Get threshold: ", threshold)
|
||||
continue
|
||||
else:
|
||||
data_total = data
|
||||
rgb_data = os.read(fd_rgb, total_rgb)
|
||||
if len(rgb_data) < 3:
|
||||
rgb_threshold = int(float(rgb_data))
|
||||
print(rgb_threshold)
|
||||
continue
|
||||
else:
|
||||
rgb_data_total = rgb_data
|
||||
os.close(fd_img)
|
||||
os.close(fd_rgb)
|
||||
# 识别
|
||||
t1 = time.time()
|
||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0,
|
||||
2,
|
||||
1)
|
||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||
# 光谱识别
|
||||
mask = spec_detector.predict(img_data)
|
||||
# rgb识别
|
||||
mask_rgb = rgb_detector.predict(rgb_data)
|
||||
# 结果合并
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t2 = time.time()
|
||||
print(f'rgb len = {len(rgb_data)}')
|
||||
|
||||
# 写出
|
||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||
os.write(fd_mask, mask_result.tobytes())
|
||||
os.close(fd_mask)
|
||||
t3 = time.time()
|
||||
print(f'total time is:{t3 - t1}')
|
||||
|
||||
|
||||
def save_main():
|
||||
threshold = Config.spec_size_threshold
|
||||
rgb_threshold = Config.rgb_size_threshold
|
||||
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
||||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
||||
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
if not os.access(rgb_fifo_path, os.F_OK):
|
||||
os.mkfifo(rgb_fifo_path, 0o777)
|
||||
img_list = []
|
||||
idx = 0
|
||||
while idx <= 30:
|
||||
idx += 1
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
||||
data = os.read(fd_img, total_len)
|
||||
|
||||
# 读取(开启一个管道)
|
||||
if len(data) < 3:
|
||||
@ -40,36 +101,40 @@ def main():
|
||||
continue
|
||||
else:
|
||||
rgb_data_total = rgb_data
|
||||
|
||||
os.close(fd_img)
|
||||
os.close(fd_rgb)
|
||||
|
||||
# 识别
|
||||
t1 = time.time()
|
||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands,
|
||||
-1)).transpose(0, 2, 1)
|
||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8)
|
||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \
|
||||
transpose(0, 2, 1)
|
||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||
img_list.append((rgb_data.copy(), img_data.copy()))
|
||||
|
||||
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = manual_tree.blk_predict(data=img_data)
|
||||
rgb_data = tobacco_detector.pretreatment(rgb_data)
|
||||
rgb_predict_result = 1 - (
|
||||
background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data)))
|
||||
# print(rgb_data.shape)
|
||||
rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low,
|
||||
threshold_high=Config.threshold_high) |
|
||||
tobacco_detector.swell(tobacco_detector.predict(rgb_data,
|
||||
threshold_low=Config.threshold_low,
|
||||
threshold_high=Config.threshold_high)))
|
||||
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
||||
.sum(axis=1)
|
||||
mask_rgb[mask_rgb <= rgb_threshold] = 0
|
||||
mask_rgb[mask_rgb > rgb_threshold] = 1
|
||||
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
||||
.sum(axis=1)
|
||||
mask[mask <= threshold] = 0
|
||||
mask[mask > threshold] = 1
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
# mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
mask_result = mask_rgb
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t2 = time.time()
|
||||
# print(f'shibie time is:{t2 - t1}')
|
||||
print(f'rgb len = {len(rgb_data)}')
|
||||
|
||||
# 写出
|
||||
@ -78,6 +143,14 @@ def main():
|
||||
os.close(fd_mask)
|
||||
t3 = time.time()
|
||||
print(f'total time is:{t3 - t1}')
|
||||
i = 0
|
||||
print("Stop Serving")
|
||||
for img in img_list:
|
||||
print(f"writing img {i}...")
|
||||
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
|
||||
np.save(f'./{i}.npy', img[1])
|
||||
i += 1
|
||||
print("save success")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
88
main_test.py
88
main_test.py
@ -10,20 +10,27 @@ import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from models import Detector, AnonymousColorDetector
|
||||
from utils import read_labeled_img
|
||||
from config import Config
|
||||
from models import Detector, AnonymousColorDetector, ManualTree
|
||||
from utils import read_labeled_img, size_threshold
|
||||
|
||||
|
||||
def virtual_main(detector: AnonymousColorDetector, test_img=None, test_img_dir=None, test_model=False):
|
||||
def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False):
|
||||
"""
|
||||
虚拟读图测试程序
|
||||
|
||||
:param detector: 杂质探测器,需要继承Detector类
|
||||
:param test_img: 测试图像,rgb格式的图片或者路径
|
||||
:param test_img_dir: 测试图像文件夹
|
||||
:param test_model: 是否进行模型约束性测试
|
||||
:param test_spectra: 是否测试光谱
|
||||
:param test_rgb: 是否测试rgb
|
||||
:return:
|
||||
"""
|
||||
if (test_img is not None) or (test_img_dir is not None):
|
||||
threshold = Config.spec_size_threshold
|
||||
rgb_threshold = Config.rgb_size_threshold
|
||||
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
||||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
||||
if test_img is not None:
|
||||
if isinstance(test_img, str):
|
||||
img = cv2.imread(test_img)[:, :, ::-1]
|
||||
@ -31,42 +38,45 @@ def virtual_main(detector: AnonymousColorDetector, test_img=None, test_img_dir=N
|
||||
img = test_img
|
||||
else:
|
||||
raise TypeError("test img should be np.ndarray or str")
|
||||
t1 = time.time()
|
||||
img = cv2.resize(img, (1024, 256))
|
||||
t2 = time.time()
|
||||
result = 1 - detector.predict(img)
|
||||
t3 = time.time()
|
||||
fig, axs = plt.subplots(3, 1)
|
||||
axs[0].imshow(img)
|
||||
axs[1].imshow(result)
|
||||
mask_color = np.zeros_like(img)
|
||||
mask_color[result > 0] = (0, 0, 255)
|
||||
result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0)
|
||||
axs[2].imshow(result_show)
|
||||
axs[0].set_title(
|
||||
f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms')
|
||||
plt.show()
|
||||
if test_img_dir is not None:
|
||||
image_names = os.listdir(test_img_dir)
|
||||
image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')]
|
||||
for image_name in image_names:
|
||||
img = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
|
||||
if test_model:
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||
ground_truth = dataset['yangeng']
|
||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
||||
rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
|
||||
# 识别
|
||||
t1 = time.time()
|
||||
if test_spectra:
|
||||
# spectra part
|
||||
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = manual_tree.blk_predict(data=img_data)
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
mask_spec = size_threshold(mask, Config.blk_size, threshold)
|
||||
if test_rgb:
|
||||
# rgb part
|
||||
rgb_data = tobacco_detector.pretreatment(rgb_data)
|
||||
background = background_detector.predict(rgb_data)
|
||||
tobacco = tobacco_detector.predict(rgb_data)
|
||||
tobacco_d = tobacco_detector.swell(tobacco)
|
||||
rgb_predict_result = 1 - (background | tobacco_d)
|
||||
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold)
|
||||
fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True)
|
||||
axs[0].imshow(rgb_data)
|
||||
axs[0].set_title("rgb raw data")
|
||||
axs[1].imshow(background)
|
||||
axs[1].set_title("background")
|
||||
axs[2].imshow(tobacco)
|
||||
axs[2].set_title("tobacco")
|
||||
axs[3].imshow(rgb_predict_result)
|
||||
axs[3].set_title("1 - (background + dilate(tobacco))")
|
||||
axs[4].imshow(mask_rgb)
|
||||
axs[4].set_title("final mask")
|
||||
plt.show()
|
||||
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
# mask_result = rgb_predict_result
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t2 = time.time()
|
||||
print(f'rgb len = {len(rgb_data)}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model')
|
||||
virtual_main(model,
|
||||
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
|
||||
test_model=True)
|
||||
virtual_main(model,
|
||||
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
|
||||
test_model=True)
|
||||
virtual_main(model,
|
||||
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
|
||||
test_model=True)
|
||||
pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True)
|
||||
|
||||
157
models.py
157
models.py
@ -9,17 +9,19 @@ import pickle
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import tqdm
|
||||
from scipy.ndimage import binary_dilation
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from config import Config
|
||||
from utils import lab_scatter, read_labeled_img
|
||||
from tqdm import tqdm
|
||||
from utils import lab_scatter, read_labeled_img, size_threshold
|
||||
|
||||
from elm import ELM
|
||||
deploy = False
|
||||
if not deploy:
|
||||
print("Training env")
|
||||
from tqdm import tqdm
|
||||
from elm import ELM
|
||||
|
||||
|
||||
class Detector(object):
|
||||
@ -87,7 +89,7 @@ class AnonymousColorDetector(Detector):
|
||||
y_predict = self.model.predict(x_val)
|
||||
print(classification_report(y_true=y_val, y_pred=y_predict))
|
||||
|
||||
def predict(self, x, threshold_low=10, threshold_high=170):
|
||||
def predict(self, x, threshold_low=5, threshold_high=255):
|
||||
"""
|
||||
输入rgb彩色图像
|
||||
|
||||
@ -98,9 +100,11 @@ class AnonymousColorDetector(Detector):
|
||||
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
|
||||
x = x.reshape(w * h, -1)
|
||||
mask = (threshold_low < x[:, 0]) & (x[:, 0] < threshold_high)
|
||||
mask_result = self.model.predict(x[mask])
|
||||
result = np.ones((w * h,))
|
||||
result[mask] = mask_result
|
||||
result = np.ones((w * h,), dtype=np.uint8)
|
||||
|
||||
if np.any(mask):
|
||||
mask_result = self.model.predict(x[mask])
|
||||
result[mask] = mask_result
|
||||
return result.reshape(h, w)
|
||||
|
||||
@staticmethod
|
||||
@ -133,6 +137,12 @@ class AnonymousColorDetector(Detector):
|
||||
bar.close()
|
||||
return negative_samples
|
||||
|
||||
def pretreatment(self, x):
|
||||
return cv2.resize(x, (1024, 256))
|
||||
|
||||
def swell(self, x):
|
||||
return cv2.dilate(x, kernel=np.ones((3, 3), np.uint8))
|
||||
|
||||
def save(self):
|
||||
path = datetime.datetime.now().strftime(f"{self.model_type}_%Y-%m-%d_%H-%M.model")
|
||||
with open(path, 'wb') as f:
|
||||
@ -292,6 +302,137 @@ class BlkModel:
|
||||
self.rfc = pickle.load(f)
|
||||
|
||||
|
||||
class RgbDetector(Detector):
|
||||
def __init__(self, tobacco_model_path, background_model_path):
|
||||
self.background_detector = None
|
||||
self.tobacco_detector = None
|
||||
self.load(tobacco_model_path, background_model_path)
|
||||
|
||||
def predict(self, rgb_data):
|
||||
rgb_data = self.tobacco_detector.pretreatment(rgb_data) # resize to the required size
|
||||
background = self.background_detector.predict(rgb_data)
|
||||
tobacco = self.tobacco_detector.predict(rgb_data)
|
||||
tobacco_d = self.tobacco_detector.swell(tobacco) # dilate the tobacco to remove the tobacco edge error
|
||||
high_s = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2HSV)[..., 1] > Config.threshold_s
|
||||
non_tobacco_or_background = 1 - (background | tobacco_d) # 既非烟梗也非背景的区域
|
||||
rgb_predict_result = high_s | non_tobacco_or_background # 高饱和度区域或者是双非区域都是杂质
|
||||
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold) # 杂质大小限制,超过大小的才打
|
||||
return mask_rgb
|
||||
|
||||
def load(self, tobacco_model_path, background_model_path):
|
||||
self.tobacco_detector = AnonymousColorDetector(tobacco_model_path)
|
||||
self.background_detector = AnonymousColorDetector(background_model_path)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def fit(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class SpecDetector(Detector):
|
||||
# 初始化机器学习像素模型、深度学习像素模型、分块模型
|
||||
def __init__(self, blk_model_path, pixel_model_path):
|
||||
self.blk_model = None
|
||||
self.pixel_model_ml = None
|
||||
self.load(blk_model_path, pixel_model_path)
|
||||
|
||||
def load(self, blk_model_path, pixel_model_path):
|
||||
self.pixel_model_ml = PixelModelML(pixel_model_path)
|
||||
self.blk_model = BlkModel(blk_model_path)
|
||||
|
||||
def predict(self, img_data):
|
||||
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = self.blk_predict(data=img_data)
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold)
|
||||
return mask
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def fit(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# 区分烟梗和非黄色且非背景的杂质
|
||||
@staticmethod
|
||||
def is_yellow(features):
|
||||
features = features.reshape((Config.nRows * Config.nCols), len(Config.selected_bands))
|
||||
sum_x = features.sum(axis=1)[..., np.newaxis]
|
||||
rate = features / sum_x
|
||||
mask = ((rate < Config.is_yellow_max) & (rate > Config.is_yellow_min))
|
||||
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
|
||||
return mask
|
||||
|
||||
# 区分背景和黄色杂质
|
||||
@staticmethod
|
||||
def is_black(feature, threshold):
|
||||
feature = feature.reshape((Config.nRows * Config.nCols), feature.shape[2])
|
||||
mask = (feature <= threshold)
|
||||
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
|
||||
return mask
|
||||
|
||||
# 预测出烟梗的mask
|
||||
def predict_tobacco(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
预测出烟梗的mask
|
||||
:param x: 图像数据,形状是 nRows x nCols x nBands
|
||||
:return: bool类型的mask,是否为烟梗, True为烟梗
|
||||
"""
|
||||
black_res = self.is_black(x[..., Config.black_yellow_bands], Config.is_black_threshold)
|
||||
yellow_res = self.is_yellow(x[..., Config.black_yellow_bands])
|
||||
yellow_things = (~black_res) & yellow_res
|
||||
x_yellow = x[yellow_things, ...]
|
||||
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands])
|
||||
yellow_things[yellow_things] = tobacco
|
||||
return yellow_things
|
||||
|
||||
# 预测出杂质的机器学习像素模型
|
||||
def pixel_predict_ml_dilation(self, data, iteration) -> np.ndarray:
|
||||
"""
|
||||
预测出杂质的位置mask
|
||||
:param data: 图像数据,形状是 nRows x nCols x nBands
|
||||
:param iteration: 膨胀的次数
|
||||
:return: bool类型的mask,是否为杂质, True为杂质
|
||||
"""
|
||||
black_res = self.is_black(data[..., Config.black_yellow_bands], Config.is_black_threshold)
|
||||
yellow_res = self.is_yellow(data[..., Config.black_yellow_bands])
|
||||
# non_yellow_things为异色杂质
|
||||
non_yellow_things = (~black_res) & (~yellow_res)
|
||||
# yellow_things为黄色物体(烟梗+杂质)
|
||||
yellow_things = (~black_res) & yellow_res
|
||||
# x_yellow为挑出的黄色物体
|
||||
x_yellow = data[yellow_things, ...]
|
||||
if x_yellow.shape[0] == 0:
|
||||
return non_yellow_things
|
||||
else:
|
||||
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5
|
||||
|
||||
non_yellow_things[yellow_things] = ~tobacco
|
||||
# 杂质mask中将背景赋值为0,将杂质赋值为1
|
||||
non_yellow_things = non_yellow_things + 0
|
||||
|
||||
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
|
||||
yellow_things[yellow_things] = tobacco
|
||||
yellow_things = yellow_things + 0
|
||||
yellow_things = binary_dilation(yellow_things, iterations=iteration)
|
||||
yellow_things = yellow_things + 0
|
||||
yellow_things[yellow_things == 1] = 2
|
||||
|
||||
# 将杂质mask和烟梗mask相加,得到的mask中含有0(背景),1(杂质),2(烟梗),3(膨胀后的烟梗与杂质相加的部分)
|
||||
mask = non_yellow_things + yellow_things
|
||||
mask[mask == 0] = False
|
||||
mask[mask == 1] = True
|
||||
mask[mask == 2] = False
|
||||
mask[mask == 3] = False
|
||||
return mask
|
||||
|
||||
# 预测出杂质的分块模型
|
||||
def blk_predict(self, data):
|
||||
blk_result_array = self.blk_model.predict(data)
|
||||
return blk_result_array
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
|
||||
8
utils.py
8
utils.py
@ -96,6 +96,14 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
||||
plt.show()
|
||||
|
||||
|
||||
def size_threshold(img, blk_size, threshold):
|
||||
mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \
|
||||
reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1)
|
||||
mask[mask <= threshold] = 0
|
||||
mask[mask > threshold] = 1
|
||||
return mask
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
|
||||
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user