模型整合

This commit is contained in:
FEIJINTI 2022-07-25 09:49:17 +08:00
parent ebff13d9fc
commit 874fe2c69c
4 changed files with 263 additions and 8 deletions

View File

@ -17,7 +17,7 @@
3. 使用[02_classification.ipynb](./02_classification.ipynb)进行训练
4. 使用[03_data_update.ipynb](02_classification.ipynb)进行数据的更新与添加
5. 使用`main_test.py`文件进行读图测试
6. **部署**,复制`utils.py`、`models.py`、`main.py`、`models`、`config.py`这5个文件或文件夹运行main.py来提供预测服务。
6. **部署**,复制`utils.py`、`models.py`、`main.py`、`models/`、`config.py`这5个文件或文件夹运行main.py来提供预测服务。
## 训练的原理

25
config.py Normal file
View File

@ -0,0 +1,25 @@
import torch
import numpy as np
class Config:
# 文件相关参数
nRows, nCols, nBands, threshold, = 256, 1024, 22, 3
nrgbRows, nrgbCols, nrgbBands, rgb_threshold = 1024, 4096, 3, 2
# 需要设置的谱段等参数
selected_bands = [127, 201, 202, 294]
bands = [127, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
294]
is_yellow_min = np.array([0.10167048, 0.1644719, 0.1598884, 0.31534621])
is_yellow_max = np.array([0.212984, 0.25896924, 0.26509268, 0.51943593])
is_black_threshold = np.asarray([0.1369, 0.1472, 0.1439, 0.1814])
black_yellow_bands = [0, 2, 3, 21]
green_bands = [i for i in range(1, 21)]
# 机器学习模型相关参数
blk_size = 4
pixel_model_path = r"./models/dt.p"
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
rgb_model_path = r"./models/dt_2022-07-20_14-40.model"
rgb_tobacco_model_path = r"models/beijing_dt_2022-07-21_16-44.model"
rgb_background_model_path = r"models/tobacco_dt_2022-07-21_16-30.model"

92
main.py Normal file
View File

@ -0,0 +1,92 @@
import os
import time
import numpy as np
from config import Config
from models import ManualTree, AnonymousColorDetector
# 主函数
def main():
threshold = Config.threshold
rgb_threshold = Config.rgb_threshold
manualTree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
total_rgb = Config.nrgbRows * Config.nrgbCols * Config.nrgbBands * 1 # int型变量
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_fifo_path, os.F_OK):
os.mkfifo(rgb_fifo_path, 0o777)
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
data = os.read(fd_img, total_len)
# 读取(开启一个管道)
if len(data) < 3:
threshold = int(float(data))
print(threshold)
continue
else:
data_total = data
rgb_data = os.read(fd_rgb, total_rgb)
if len(rgb_data) < 3:
rgb_threshold = int(float(rgb_data))
print(rgb_threshold)
continue
else:
rgb_data_total = rgb_data
os.close(fd_img)
os.close(fd_rgb)
# 识别
t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0,
2,
1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8)
pixel_predict_result = manualTree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manualTree.blk_predict(data=img_data)
rgb_data = tobacco_detector.pretreatment(rgb_data)
rgb_predict_result = 1 - (
background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data)))
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1)
mask_rgb[mask_rgb <= rgb_threshold] = 0
mask_rgb[mask_rgb > rgb_threshold] = 1
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1)
# print(threshold)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
# print(f'shibie time is:{t2 - t1}')
print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
print(f'total time is:{t3 - t1}')
if __name__ == '__main__':
# 相关参数
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
rgb_fifo_path = "/tmp/dkrgb.fifo"
# 主函数
main()

152
models.py
View File

@ -10,9 +10,12 @@ import cv2
import numpy as np
import scipy.io
import tqdm
from scipy.ndimage import binary_dilation
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from config import Config
from utils import lab_scatter, read_labeled_img
from tqdm import tqdm
@ -20,9 +23,6 @@ from elm import ELM
class Detector(object):
def __int__(self, *args, **kwargs):
raise NotImplementedError
def predict(self, *args, **kwargs):
raise NotImplementedError
@ -38,6 +38,7 @@ class Detector(object):
class AnonymousColorDetector(Detector):
def __init__(self, file_path: str = None):
super(AnonymousColorDetector, self).__init__()
self.model = None
self.model_type = 'None'
if file_path is not None:
@ -154,6 +155,143 @@ class AnonymousColorDetector(Detector):
lab_scatter(draw_dataset, is_3d=True, is_ps_color_space=False, **kwargs)
class ManualTree:
# 初始化机器学习像素模型、深度学习像素模型、分块模型
def __init__(self, blk_model_path, pixel_model_path):
self.pixel_model_ml = PixelModelML(pixel_model_path)
self.blk_model = BlkModel(blk_model_path)
# 区分烟梗和非黄色且非背景的杂质
@staticmethod
def is_yellow(features):
features = features.reshape((Config.nRows * Config.nCols), len(Config.selected_bands))
sum_x = features.sum(axis=1)[..., np.newaxis]
rate = features / sum_x
mask = ((rate < Config.is_yellow_max) & (rate > Config.is_yellow_min))
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
return mask
# 区分背景和黄色杂质
@staticmethod
def is_black(feature, threshold):
feature = feature.reshape((Config.nRows * Config.nCols), feature.shape[2])
mask = (feature <= threshold)
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
return mask
# 预测出烟梗的mask
def predict_tobacco(self, x: np.ndarray) -> np.ndarray:
"""
预测出烟梗的mask
:param x: 图像数据形状是 nRows x nCols x nBands
:return: bool类型的mask是否为烟梗, True为烟梗
"""
black_res = self.is_black(x[..., Config.black_yellow_bands], Config.is_black_threshold)
yellow_res = self.is_yellow(x[..., Config.black_yellow_bands])
yellow_things = (~black_res) & yellow_res
x_yellow = x[yellow_things, ...]
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands])
yellow_things[yellow_things] = tobacco
return yellow_things
# 预测出杂质的机器学习像素模型
def pixel_predict_ml_dilation(self, data, iteration) -> np.ndarray:
"""
预测出杂质的位置mask
:param data: 图像数据形状是 nRows x nCols x nBands
:param iteration: 膨胀的次数
:return: bool类型的mask是否为杂质, True为杂质
"""
black_res = self.is_black(data[..., Config.black_yellow_bands], Config.is_black_threshold)
yellow_res = self.is_yellow(data[..., Config.black_yellow_bands])
# non_yellow_things为异色杂质
non_yellow_things = (~black_res) & (~yellow_res)
# yellow_things为黄色物体(烟梗+杂质)
yellow_things = (~black_res) & yellow_res
# x_yellow为挑出的黄色物体
x_yellow = data[yellow_things, ...]
if x_yellow.shape[0] == 0:
return non_yellow_things
else:
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5
non_yellow_things[yellow_things] = ~tobacco
# 杂质mask中将背景赋值为0,将杂质赋值为1
non_yellow_things = non_yellow_things + 0
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
yellow_things[yellow_things] = tobacco
yellow_things = yellow_things + 0
yellow_things = binary_dilation(yellow_things, iterations=iteration)
yellow_things = yellow_things + 0
yellow_things[yellow_things == 1] = 2
# 将杂质mask和烟梗mask相加,得到的mask中含有0(背景),1(杂质),2(烟梗),3(膨胀后的烟梗与杂质相加的部分)
mask = non_yellow_things + yellow_things
mask[mask == 0] = False
mask[mask == 1] = True
mask[mask == 2] = False
mask[mask == 3] = False
return mask
# 预测出杂质的分块模型
def blk_predict(self, data):
blk_result_array = self.blk_model.predict(data)
return blk_result_array
# 机器学习像素模型类
class PixelModelML:
def __init__(self, pixel_model_path):
with open(pixel_model_path, "rb") as f:
self.dt = pickle.load(f)
def predict(self, feature):
pixel_result_array = self.dt.predict(feature)
return pixel_result_array
# 分块模型类
class BlkModel:
def __init__(self, blk_model_path):
self.rfc = None
self.load(blk_model_path)
@staticmethod
def split_x(data: np.ndarray, blk_sz: int) -> list:
"""
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x ncols x num_channels)
;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
"""
x_list = []
for i in range(0, 256 // blk_sz):
for j in range(0, 1024 // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data)
return x_list
def predict(self, data):
data_blk = data
data_blk = np.array(self.split_x(data_blk, blk_sz=Config.blk_size))
data_blk = data_blk.reshape((data_blk.shape[0]), -1)
y_pred = self.rfc.predict(data_blk)
y_pred[y_pred < 2] = 0
y_pred[y_pred > 1] = 1
blk_result_array = y_pred.reshape(256 // Config.blk_size, 1024 // Config.blk_size).repeat(Config.blk_size,
axis=0).repeat(
Config.blk_size,
axis=1)
return blk_result_array
def load(self, model_path: str):
with open(model_path, "rb") as f:
self.rfc = pickle.load(f)
if __name__ == '__main__':
data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"}
@ -161,10 +299,10 @@ if __name__ == '__main__':
ground_truth = dataset['yangeng']
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
world_boundary = np.array([0, 0, 0, 255, 255, 255])
boundary = np.array([0, 0, 0, 255, 255, 255])
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
data = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
x, y = data['x'], data['y']
detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
x, y = temp['x'], temp['y']
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
lab_scatter(dataset, class_max_num=5000, is_3d=True, is_ps_color_space=False)