From 874fe2c69cf3cf3ffcce29ef9ee87cc596473a9a Mon Sep 17 00:00:00 2001 From: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Mon, 25 Jul 2022 09:49:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=95=B4=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- config.py | 25 +++++++++ main.py | 92 +++++++++++++++++++++++++++++++++ models.py | 152 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 263 insertions(+), 8 deletions(-) create mode 100644 config.py create mode 100644 main.py diff --git a/README.md b/README.md index e305a66..3393cfd 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ 3. 使用[02_classification.ipynb](./02_classification.ipynb)进行训练 4. 使用[03_data_update.ipynb](02_classification.ipynb)进行数据的更新与添加 5. 使用`main_test.py`文件进行读图测试 -6. **部署**,复制`utils.py`、`models.py`、`main.py`、`models`、`config.py`这5个文件或文件夹,运行main.py来提供预测服务。 +6. **部署**,复制`utils.py`、`models.py`、`main.py`、`models/`、`config.py`这5个文件或文件夹,运行main.py来提供预测服务。 ## 训练的原理 diff --git a/config.py b/config.py new file mode 100644 index 0000000..9702068 --- /dev/null +++ b/config.py @@ -0,0 +1,25 @@ +import torch +import numpy as np + + +class Config: + # 文件相关参数 + nRows, nCols, nBands, threshold, = 256, 1024, 22, 3 + nrgbRows, nrgbCols, nrgbBands, rgb_threshold = 1024, 4096, 3, 2 + # 需要设置的谱段等参数 + selected_bands = [127, 201, 202, 294] + bands = [127, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, + 294] + is_yellow_min = np.array([0.10167048, 0.1644719, 0.1598884, 0.31534621]) + is_yellow_max = np.array([0.212984, 0.25896924, 0.26509268, 0.51943593]) + is_black_threshold = np.asarray([0.1369, 0.1472, 0.1439, 0.1814]) + black_yellow_bands = [0, 2, 3, 21] + green_bands = [i for i in range(1, 21)] + + # 机器学习模型相关参数 + blk_size = 4 + pixel_model_path = r"./models/dt.p" + blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model" + rgb_model_path = r"./models/dt_2022-07-20_14-40.model" + rgb_tobacco_model_path = r"models/beijing_dt_2022-07-21_16-44.model" + rgb_background_model_path = r"models/tobacco_dt_2022-07-21_16-30.model" diff --git a/main.py b/main.py new file mode 100644 index 0000000..0a242f5 --- /dev/null +++ b/main.py @@ -0,0 +1,92 @@ +import os +import time +import numpy as np +from config import Config +from models import ManualTree, AnonymousColorDetector + + +# 主函数 +def main(): + threshold = Config.threshold + rgb_threshold = Config.rgb_threshold + manualTree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) + tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) + background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path) + + total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 + total_rgb = Config.nrgbRows * Config.nrgbCols * Config.nrgbBands * 1 # int型变量 + if not os.access(img_fifo_path, os.F_OK): + os.mkfifo(img_fifo_path, 0o777) + if not os.access(mask_fifo_path, os.F_OK): + os.mkfifo(mask_fifo_path, 0o777) + if not os.access(rgb_fifo_path, os.F_OK): + os.mkfifo(rgb_fifo_path, 0o777) + while True: + fd_img = os.open(img_fifo_path, os.O_RDONLY) + fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) + data = os.read(fd_img, total_len) + + # 读取(开启一个管道) + if len(data) < 3: + threshold = int(float(data)) + print(threshold) + continue + else: + data_total = data + rgb_data = os.read(fd_rgb, total_rgb) + if len(rgb_data) < 3: + rgb_threshold = int(float(rgb_data)) + print(rgb_threshold) + continue + else: + rgb_data_total = rgb_data + + os.close(fd_img) + os.close(fd_rgb) + + # 识别 + t1 = time.time() + img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0, + 2, + 1) + rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8) + pixel_predict_result = manualTree.pixel_predict_ml_dilation(data=img_data, iteration=1) + blk_predict_result = manualTree.blk_predict(data=img_data) + rgb_data = tobacco_detector.pretreatment(rgb_data) + rgb_predict_result = 1 - ( + background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data))) + mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ + .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ + .sum(axis=1) + mask_rgb[mask_rgb <= rgb_threshold] = 0 + mask_rgb[mask_rgb > rgb_threshold] = 1 + + mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) + mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ + .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ + .sum(axis=1) + # print(threshold) + mask[mask <= threshold] = 0 + mask[mask > threshold] = 1 + mask_result = (mask | mask_rgb).astype(np.uint8) + mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + t2 = time.time() + # print(f'shibie time is:{t2 - t1}') + print(f'rgb len = {len(rgb_data)}') + + # 写出 + fd_mask = os.open(mask_fifo_path, os.O_WRONLY) + os.write(fd_mask, mask_result.tobytes()) + os.close(fd_mask) + t3 = time.time() + print(f'total time is:{t3 - t1}') + + +if __name__ == '__main__': + # 相关参数 + img_fifo_path = "/tmp/dkimg.fifo" + mask_fifo_path = "/tmp/dkmask.fifo" + rgb_fifo_path = "/tmp/dkrgb.fifo" + + # 主函数 + main() diff --git a/models.py b/models.py index 50b4a7a..bb4eae9 100644 --- a/models.py +++ b/models.py @@ -10,9 +10,12 @@ import cv2 import numpy as np import scipy.io import tqdm +from scipy.ndimage import binary_dilation from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import classification_report from sklearn.model_selection import train_test_split + +from config import Config from utils import lab_scatter, read_labeled_img from tqdm import tqdm @@ -20,9 +23,6 @@ from elm import ELM class Detector(object): - def __int__(self, *args, **kwargs): - raise NotImplementedError - def predict(self, *args, **kwargs): raise NotImplementedError @@ -38,6 +38,7 @@ class Detector(object): class AnonymousColorDetector(Detector): def __init__(self, file_path: str = None): + super(AnonymousColorDetector, self).__init__() self.model = None self.model_type = 'None' if file_path is not None: @@ -154,6 +155,143 @@ class AnonymousColorDetector(Detector): lab_scatter(draw_dataset, is_3d=True, is_ps_color_space=False, **kwargs) +class ManualTree: + # 初始化机器学习像素模型、深度学习像素模型、分块模型 + def __init__(self, blk_model_path, pixel_model_path): + self.pixel_model_ml = PixelModelML(pixel_model_path) + self.blk_model = BlkModel(blk_model_path) + + # 区分烟梗和非黄色且非背景的杂质 + @staticmethod + def is_yellow(features): + features = features.reshape((Config.nRows * Config.nCols), len(Config.selected_bands)) + sum_x = features.sum(axis=1)[..., np.newaxis] + rate = features / sum_x + mask = ((rate < Config.is_yellow_max) & (rate > Config.is_yellow_min)) + mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols) + return mask + + # 区分背景和黄色杂质 + @staticmethod + def is_black(feature, threshold): + feature = feature.reshape((Config.nRows * Config.nCols), feature.shape[2]) + mask = (feature <= threshold) + mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols) + return mask + + # 预测出烟梗的mask + def predict_tobacco(self, x: np.ndarray) -> np.ndarray: + """ + 预测出烟梗的mask + :param x: 图像数据,形状是 nRows x nCols x nBands + :return: bool类型的mask,是否为烟梗, True为烟梗 + """ + black_res = self.is_black(x[..., Config.black_yellow_bands], Config.is_black_threshold) + yellow_res = self.is_yellow(x[..., Config.black_yellow_bands]) + yellow_things = (~black_res) & yellow_res + x_yellow = x[yellow_things, ...] + tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) + yellow_things[yellow_things] = tobacco + return yellow_things + + # 预测出杂质的机器学习像素模型 + def pixel_predict_ml_dilation(self, data, iteration) -> np.ndarray: + """ + 预测出杂质的位置mask + :param data: 图像数据,形状是 nRows x nCols x nBands + :param iteration: 膨胀的次数 + :return: bool类型的mask,是否为杂质, True为杂质 + """ + black_res = self.is_black(data[..., Config.black_yellow_bands], Config.is_black_threshold) + yellow_res = self.is_yellow(data[..., Config.black_yellow_bands]) + # non_yellow_things为异色杂质 + non_yellow_things = (~black_res) & (~yellow_res) + # yellow_things为黄色物体(烟梗+杂质) + yellow_things = (~black_res) & yellow_res + # x_yellow为挑出的黄色物体 + x_yellow = data[yellow_things, ...] + if x_yellow.shape[0] == 0: + return non_yellow_things + else: + tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5 + + non_yellow_things[yellow_things] = ~tobacco + # 杂质mask中将背景赋值为0,将杂质赋值为1 + non_yellow_things = non_yellow_things + 0 + + # 烟梗mask中将背景赋值为0,将烟梗赋值为2 + yellow_things[yellow_things] = tobacco + yellow_things = yellow_things + 0 + yellow_things = binary_dilation(yellow_things, iterations=iteration) + yellow_things = yellow_things + 0 + yellow_things[yellow_things == 1] = 2 + + # 将杂质mask和烟梗mask相加,得到的mask中含有0(背景),1(杂质),2(烟梗),3(膨胀后的烟梗与杂质相加的部分) + mask = non_yellow_things + yellow_things + mask[mask == 0] = False + mask[mask == 1] = True + mask[mask == 2] = False + mask[mask == 3] = False + return mask + + # 预测出杂质的分块模型 + def blk_predict(self, data): + blk_result_array = self.blk_model.predict(data) + return blk_result_array + + +# 机器学习像素模型类 +class PixelModelML: + def __init__(self, pixel_model_path): + with open(pixel_model_path, "rb") as f: + self.dt = pickle.load(f) + + def predict(self, feature): + pixel_result_array = self.dt.predict(feature) + return pixel_result_array + + +# 分块模型类 +class BlkModel: + def __init__(self, blk_model_path): + self.rfc = None + self.load(blk_model_path) + + @staticmethod + def split_x(data: np.ndarray, blk_sz: int) -> list: + """ + Split the data into slices for classification.将数据划分为多个像素块,便于后续识别. + + ;param data: image data, shape (num_rows x ncols x num_channels) + ;param blk_sz: block size + ;param sensitivity: 最少有多少个杂物点能够被认为是杂物 + ;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz) + """ + x_list = [] + for i in range(0, 256 // blk_sz): + for j in range(0, 1024 // blk_sz): + block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...] + x_list.append(block_data) + return x_list + + def predict(self, data): + data_blk = data + data_blk = np.array(self.split_x(data_blk, blk_sz=Config.blk_size)) + data_blk = data_blk.reshape((data_blk.shape[0]), -1) + y_pred = self.rfc.predict(data_blk) + y_pred[y_pred < 2] = 0 + y_pred[y_pred > 1] = 1 + blk_result_array = y_pred.reshape(256 // Config.blk_size, 1024 // Config.blk_size).repeat(Config.blk_size, + axis=0).repeat( + Config.blk_size, + axis=1) + return blk_result_array + + def load(self, model_path: str): + with open(model_path, "rb") as f: + self.rfc = pickle.load(f) + + if __name__ == '__main__': data_dir = "data/dataset" color_dict = {(0, 0, 255): "yangeng"} @@ -161,10 +299,10 @@ if __name__ == '__main__': ground_truth = dataset['yangeng'] detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model') # x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]]) - world_boundary = np.array([0, 0, 0, 255, 255, 255]) + boundary = np.array([0, 0, 0, 255, 255, 255]) # detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000) - detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth) - data = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat') - x, y = data['x'], data['y'] + detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth) + temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat') + x, y = temp['x'], temp['y'] dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]} lab_scatter(dataset, class_max_num=5000, is_3d=True, is_ps_color_space=False)