模型部分clean重构

This commit is contained in:
FEIJINTI 2022-07-26 17:23:25 +08:00
parent 75ab881bc5
commit 4b56560cc8
7 changed files with 500 additions and 131 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -4,22 +4,26 @@ import numpy as np
class Config: class Config:
# 文件相关参数 # 文件相关参数
nRows, nCols, nBands, threshold, = 256, 1024, 22, 3 nRows, nCols, nBands, spec_size_threshold, = 256, 1024, 22, 3
nrgbRows, nrgbCols, nrgbBands, rgb_threshold = 1024, 4096, 3, 2 nRgbRows, nRgbCols, nRgbBands, rgb_size_threshold = 1024, 4096, 3, 4
# 需要设置的谱段等参数 # 需要设置的谱段等参数
selected_bands = [127, 201, 202, 294] selected_bands = [127, 201, 202, 294]
bands = [127, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, bands = [127, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
294] 211, 212, 213, 214, 215, 216, 217, 218, 219, 294]
is_yellow_min = np.array([0.10167048, 0.1644719, 0.1598884, 0.31534621]) is_yellow_min = np.array([0.10167048, 0.1644719, 0.1598884, 0.31534621])
is_yellow_max = np.array([0.212984, 0.25896924, 0.26509268, 0.51943593]) is_yellow_max = np.array([0.212984, 0.25896924, 0.26509268, 0.51943593])
is_black_threshold = np.asarray([0.1369, 0.1472, 0.1439, 0.1814]) is_black_threshold = np.asarray([0.1369, 0.1472, 0.1439, 0.1814])
black_yellow_bands = [0, 2, 3, 21] black_yellow_bands = [0, 2, 3, 21]
green_bands = [i for i in range(1, 21)] green_bands = [i for i in range(1, 21)]
# 机器学习模型相关参数 # 光谱模型参数
blk_size = 4 blk_size = 4
pixel_model_path = r"./models/dt.p" pixel_model_path = r"./models/dt.p"
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model" blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
rgb_model_path = r"./models/dt_2022-07-20_14-40.model"
rgb_tobacco_model_path = r"models/beijing_dt_2022-07-21_16-44.model" # rgb模型参数
rgb_background_model_path = r"models/tobacco_dt_2022-07-21_16-30.model" rgb_tobacco_model_path = r"models/tobacco_dt_2022-07-21_16-30.model"
rgb_background_model_path = r"models/beijing_dt_2022-07-21_16-44.model"
threshold_low, threshold_high = 5, 255
threshold_s = 175

109
main.py
View File

@ -1,20 +1,20 @@
import os import os
import time import time
import numpy as np import numpy as np
import scipy.io
from config import Config from config import Config
from models import ManualTree, AnonymousColorDetector from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
import cv2
# 主函数
def main(): def main():
threshold = Config.threshold spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
rgb_threshold = Config.rgb_threshold pixel_model_path=Config.pixel_model_path)
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) background_model_path=Config.rgb_background_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
total_rgb = Config.nrgbRows * Config.nrgbCols * Config.nrgbBands * 1 # int型变量 total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
if not os.access(img_fifo_path, os.F_OK): if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777) os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK): if not os.access(mask_fifo_path, os.F_OK):
@ -25,6 +25,67 @@ def main():
fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
data = os.read(fd_img, total_len) data = os.read(fd_img, total_len)
# 读取(开启一个管道)
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get threshold: ", threshold)
continue
else:
data_total = data
rgb_data = os.read(fd_rgb, total_rgb)
if len(rgb_data) < 3:
rgb_threshold = int(float(rgb_data))
print(rgb_threshold)
continue
else:
rgb_data_total = rgb_data
os.close(fd_img)
os.close(fd_rgb)
# 识别
t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0,
2,
1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
# 光谱识别
mask = spec_detector.predict(img_data)
# rgb识别
mask_rgb = rgb_detector.predict(rgb_data)
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
print(f'total time is:{t3 - t1}')
def save_main():
threshold = Config.spec_size_threshold
rgb_threshold = Config.rgb_size_threshold
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_fifo_path, os.F_OK):
os.mkfifo(rgb_fifo_path, 0o777)
img_list = []
idx = 0
while idx <= 30:
idx += 1
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
data = os.read(fd_img, total_len)
# 读取(开启一个管道) # 读取(开启一个管道)
if len(data) < 3: if len(data) < 3:
@ -40,36 +101,40 @@ def main():
continue continue
else: else:
rgb_data_total = rgb_data rgb_data_total = rgb_data
os.close(fd_img) os.close(fd_img)
os.close(fd_rgb) os.close(fd_rgb)
# 识别 # 识别
t1 = time.time() t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \
-1)).transpose(0, 2, 1) transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
img_list.append((rgb_data.copy(), img_data.copy()))
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1) pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manual_tree.blk_predict(data=img_data) blk_predict_result = manual_tree.blk_predict(data=img_data)
rgb_data = tobacco_detector.pretreatment(rgb_data) rgb_data = tobacco_detector.pretreatment(rgb_data)
rgb_predict_result = 1 - ( # print(rgb_data.shape)
background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data))) rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low,
threshold_high=Config.threshold_high) |
tobacco_detector.swell(tobacco_detector.predict(rgb_data,
threshold_low=Config.threshold_low,
threshold_high=Config.threshold_high)))
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1) .sum(axis=1)
mask_rgb[mask_rgb <= rgb_threshold] = 0 mask_rgb[mask_rgb <= rgb_threshold] = 0
mask_rgb[mask_rgb > rgb_threshold] = 1 mask_rgb[mask_rgb > rgb_threshold] = 1
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1) .sum(axis=1)
mask[mask <= threshold] = 0 mask[mask <= threshold] = 0
mask[mask > threshold] = 1 mask[mask > threshold] = 1
mask_result = (mask | mask_rgb).astype(np.uint8) # mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_rgb
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time() t2 = time.time()
# print(f'shibie time is:{t2 - t1}')
print(f'rgb len = {len(rgb_data)}') print(f'rgb len = {len(rgb_data)}')
# 写出 # 写出
@ -78,6 +143,14 @@ def main():
os.close(fd_mask) os.close(fd_mask)
t3 = time.time() t3 = time.time()
print(f'total time is:{t3 - t1}') print(f'total time is:{t3 - t1}')
i = 0
print("Stop Serving")
for img in img_list:
print(f"writing img {i}...")
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
np.save(f'./{i}.npy', img[1])
i += 1
print("save success")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -10,20 +10,27 @@ import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from models import Detector, AnonymousColorDetector from config import Config
from utils import read_labeled_img from models import Detector, AnonymousColorDetector, ManualTree
from utils import read_labeled_img, size_threshold
def virtual_main(detector: AnonymousColorDetector, test_img=None, test_img_dir=None, test_model=False): def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False):
""" """
虚拟读图测试程序 虚拟读图测试程序
:param detector: 杂质探测器需要继承Detector类
:param test_img: 测试图像rgb格式的图片或者路径 :param test_img: 测试图像rgb格式的图片或者路径
:param test_img_dir: 测试图像文件夹 :param test_img_dir: 测试图像文件夹
:param test_model: 是否进行模型约束性测试 :param test_spectra: 是否测试光谱
:param test_rgb: 是否测试rgb
:return: :return:
""" """
if (test_img is not None) or (test_img_dir is not None):
threshold = Config.spec_size_threshold
rgb_threshold = Config.rgb_size_threshold
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
if test_img is not None: if test_img is not None:
if isinstance(test_img, str): if isinstance(test_img, str):
img = cv2.imread(test_img)[:, :, ::-1] img = cv2.imread(test_img)[:, :, ::-1]
@ -31,42 +38,45 @@ def virtual_main(detector: AnonymousColorDetector, test_img=None, test_img_dir=N
img = test_img img = test_img
else: else:
raise TypeError("test img should be np.ndarray or str") raise TypeError("test img should be np.ndarray or str")
t1 = time.time()
img = cv2.resize(img, (1024, 256))
t2 = time.time()
result = 1 - detector.predict(img)
t3 = time.time()
fig, axs = plt.subplots(3, 1)
axs[0].imshow(img)
axs[1].imshow(result)
mask_color = np.zeros_like(img)
mask_color[result > 0] = (0, 0, 255)
result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0)
axs[2].imshow(result_show)
axs[0].set_title(
f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms')
plt.show()
if test_img_dir is not None: if test_img_dir is not None:
image_names = os.listdir(test_img_dir) image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')]
for image_name in image_names: for image_name in image_names:
img = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1] rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
if test_model: # 识别
data_dir = "data/dataset" t1 = time.time()
color_dict = {(0, 0, 255): "yangeng"} if test_spectra:
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False) # spectra part
ground_truth = dataset['yangeng'] pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
world_boundary = np.array([0, 0, 0, 255, 255, 255]) blk_predict_result = manual_tree.blk_predict(data=img_data)
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth) mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask_spec = size_threshold(mask, Config.blk_size, threshold)
if test_rgb:
# rgb part
rgb_data = tobacco_detector.pretreatment(rgb_data)
background = background_detector.predict(rgb_data)
tobacco = tobacco_detector.predict(rgb_data)
tobacco_d = tobacco_detector.swell(tobacco)
rgb_predict_result = 1 - (background | tobacco_d)
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold)
fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True)
axs[0].imshow(rgb_data)
axs[0].set_title("rgb raw data")
axs[1].imshow(background)
axs[1].set_title("background")
axs[2].imshow(tobacco)
axs[2].set_title("tobacco")
axs[3].imshow(rgb_predict_result)
axs[3].set_title("1 - (background + dilate(tobacco))")
axs[4].imshow(mask_rgb)
axs[4].set_title("final mask")
plt.show()
mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = rgb_predict_result
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
print(f'rgb len = {len(rgb_data)}')
if __name__ == '__main__': if __name__ == '__main__':
model = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model') pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True)
virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True)
virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True)
virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True)

157
models.py
View File

@ -9,17 +9,19 @@ import pickle
import cv2 import cv2
import numpy as np import numpy as np
import scipy.io import scipy.io
import tqdm
from scipy.ndimage import binary_dilation from scipy.ndimage import binary_dilation
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from config import Config from config import Config
from utils import lab_scatter, read_labeled_img from utils import lab_scatter, read_labeled_img, size_threshold
from tqdm import tqdm
from elm import ELM deploy = False
if not deploy:
print("Training env")
from tqdm import tqdm
from elm import ELM
class Detector(object): class Detector(object):
@ -87,7 +89,7 @@ class AnonymousColorDetector(Detector):
y_predict = self.model.predict(x_val) y_predict = self.model.predict(x_val)
print(classification_report(y_true=y_val, y_pred=y_predict)) print(classification_report(y_true=y_val, y_pred=y_predict))
def predict(self, x, threshold_low=10, threshold_high=170): def predict(self, x, threshold_low=5, threshold_high=255):
""" """
输入rgb彩色图像 输入rgb彩色图像
@ -98,9 +100,11 @@ class AnonymousColorDetector(Detector):
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB) x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
x = x.reshape(w * h, -1) x = x.reshape(w * h, -1)
mask = (threshold_low < x[:, 0]) & (x[:, 0] < threshold_high) mask = (threshold_low < x[:, 0]) & (x[:, 0] < threshold_high)
mask_result = self.model.predict(x[mask]) result = np.ones((w * h,), dtype=np.uint8)
result = np.ones((w * h,))
result[mask] = mask_result if np.any(mask):
mask_result = self.model.predict(x[mask])
result[mask] = mask_result
return result.reshape(h, w) return result.reshape(h, w)
@staticmethod @staticmethod
@ -133,6 +137,12 @@ class AnonymousColorDetector(Detector):
bar.close() bar.close()
return negative_samples return negative_samples
def pretreatment(self, x):
return cv2.resize(x, (1024, 256))
def swell(self, x):
return cv2.dilate(x, kernel=np.ones((3, 3), np.uint8))
def save(self): def save(self):
path = datetime.datetime.now().strftime(f"{self.model_type}_%Y-%m-%d_%H-%M.model") path = datetime.datetime.now().strftime(f"{self.model_type}_%Y-%m-%d_%H-%M.model")
with open(path, 'wb') as f: with open(path, 'wb') as f:
@ -292,6 +302,137 @@ class BlkModel:
self.rfc = pickle.load(f) self.rfc = pickle.load(f)
class RgbDetector(Detector):
def __init__(self, tobacco_model_path, background_model_path):
self.background_detector = None
self.tobacco_detector = None
self.load(tobacco_model_path, background_model_path)
def predict(self, rgb_data):
rgb_data = self.tobacco_detector.pretreatment(rgb_data) # resize to the required size
background = self.background_detector.predict(rgb_data)
tobacco = self.tobacco_detector.predict(rgb_data)
tobacco_d = self.tobacco_detector.swell(tobacco) # dilate the tobacco to remove the tobacco edge error
high_s = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2HSV)[..., 1] > Config.threshold_s
non_tobacco_or_background = 1 - (background | tobacco_d) # 既非烟梗也非背景的区域
rgb_predict_result = high_s | non_tobacco_or_background # 高饱和度区域或者是双非区域都是杂质
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold) # 杂质大小限制,超过大小的才打
return mask_rgb
def load(self, tobacco_model_path, background_model_path):
self.tobacco_detector = AnonymousColorDetector(tobacco_model_path)
self.background_detector = AnonymousColorDetector(background_model_path)
def save(self, *args, **kwargs):
pass
def fit(self, *args, **kwargs):
pass
class SpecDetector(Detector):
# 初始化机器学习像素模型、深度学习像素模型、分块模型
def __init__(self, blk_model_path, pixel_model_path):
self.blk_model = None
self.pixel_model_ml = None
self.load(blk_model_path, pixel_model_path)
def load(self, blk_model_path, pixel_model_path):
self.pixel_model_ml = PixelModelML(pixel_model_path)
self.blk_model = BlkModel(blk_model_path)
def predict(self, img_data):
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = self.blk_predict(data=img_data)
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold)
return mask
def save(self, *args, **kwargs):
pass
def fit(self, *args, **kwargs):
pass
# 区分烟梗和非黄色且非背景的杂质
@staticmethod
def is_yellow(features):
features = features.reshape((Config.nRows * Config.nCols), len(Config.selected_bands))
sum_x = features.sum(axis=1)[..., np.newaxis]
rate = features / sum_x
mask = ((rate < Config.is_yellow_max) & (rate > Config.is_yellow_min))
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
return mask
# 区分背景和黄色杂质
@staticmethod
def is_black(feature, threshold):
feature = feature.reshape((Config.nRows * Config.nCols), feature.shape[2])
mask = (feature <= threshold)
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
return mask
# 预测出烟梗的mask
def predict_tobacco(self, x: np.ndarray) -> np.ndarray:
"""
预测出烟梗的mask
:param x: 图像数据形状是 nRows x nCols x nBands
:return: bool类型的mask是否为烟梗, True为烟梗
"""
black_res = self.is_black(x[..., Config.black_yellow_bands], Config.is_black_threshold)
yellow_res = self.is_yellow(x[..., Config.black_yellow_bands])
yellow_things = (~black_res) & yellow_res
x_yellow = x[yellow_things, ...]
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands])
yellow_things[yellow_things] = tobacco
return yellow_things
# 预测出杂质的机器学习像素模型
def pixel_predict_ml_dilation(self, data, iteration) -> np.ndarray:
"""
预测出杂质的位置mask
:param data: 图像数据形状是 nRows x nCols x nBands
:param iteration: 膨胀的次数
:return: bool类型的mask是否为杂质, True为杂质
"""
black_res = self.is_black(data[..., Config.black_yellow_bands], Config.is_black_threshold)
yellow_res = self.is_yellow(data[..., Config.black_yellow_bands])
# non_yellow_things为异色杂质
non_yellow_things = (~black_res) & (~yellow_res)
# yellow_things为黄色物体(烟梗+杂质)
yellow_things = (~black_res) & yellow_res
# x_yellow为挑出的黄色物体
x_yellow = data[yellow_things, ...]
if x_yellow.shape[0] == 0:
return non_yellow_things
else:
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5
non_yellow_things[yellow_things] = ~tobacco
# 杂质mask中将背景赋值为0,将杂质赋值为1
non_yellow_things = non_yellow_things + 0
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
yellow_things[yellow_things] = tobacco
yellow_things = yellow_things + 0
yellow_things = binary_dilation(yellow_things, iterations=iteration)
yellow_things = yellow_things + 0
yellow_things[yellow_things == 1] = 2
# 将杂质mask和烟梗mask相加,得到的mask中含有0(背景),1(杂质),2(烟梗),3(膨胀后的烟梗与杂质相加的部分)
mask = non_yellow_things + yellow_things
mask[mask == 0] = False
mask[mask == 1] = True
mask[mask == 2] = False
mask[mask == 3] = False
return mask
# 预测出杂质的分块模型
def blk_predict(self, data):
blk_result_array = self.blk_model.predict(data)
return blk_result_array
if __name__ == '__main__': if __name__ == '__main__':
data_dir = "data/dataset" data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"} color_dict = {(0, 0, 255): "yangeng"}

View File

@ -96,6 +96,14 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
plt.show() plt.show()
def size_threshold(img, blk_size, threshold):
mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \
reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
return mask
if __name__ == '__main__': if __name__ == '__main__':
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian", color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"} (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}