mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
模型整合
This commit is contained in:
parent
ebff13d9fc
commit
874fe2c69c
@ -17,7 +17,7 @@
|
||||
3. 使用[02_classification.ipynb](./02_classification.ipynb)进行训练
|
||||
4. 使用[03_data_update.ipynb](02_classification.ipynb)进行数据的更新与添加
|
||||
5. 使用`main_test.py`文件进行读图测试
|
||||
6. **部署**,复制`utils.py`、`models.py`、`main.py`、`models`、`config.py`这5个文件或文件夹,运行main.py来提供预测服务。
|
||||
6. **部署**,复制`utils.py`、`models.py`、`main.py`、`models/`、`config.py`这5个文件或文件夹,运行main.py来提供预测服务。
|
||||
|
||||
|
||||
## 训练的原理
|
||||
|
||||
25
config.py
Normal file
25
config.py
Normal file
@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Config:
|
||||
# 文件相关参数
|
||||
nRows, nCols, nBands, threshold, = 256, 1024, 22, 3
|
||||
nrgbRows, nrgbCols, nrgbBands, rgb_threshold = 1024, 4096, 3, 2
|
||||
# 需要设置的谱段等参数
|
||||
selected_bands = [127, 201, 202, 294]
|
||||
bands = [127, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
|
||||
294]
|
||||
is_yellow_min = np.array([0.10167048, 0.1644719, 0.1598884, 0.31534621])
|
||||
is_yellow_max = np.array([0.212984, 0.25896924, 0.26509268, 0.51943593])
|
||||
is_black_threshold = np.asarray([0.1369, 0.1472, 0.1439, 0.1814])
|
||||
black_yellow_bands = [0, 2, 3, 21]
|
||||
green_bands = [i for i in range(1, 21)]
|
||||
|
||||
# 机器学习模型相关参数
|
||||
blk_size = 4
|
||||
pixel_model_path = r"./models/dt.p"
|
||||
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
|
||||
rgb_model_path = r"./models/dt_2022-07-20_14-40.model"
|
||||
rgb_tobacco_model_path = r"models/beijing_dt_2022-07-21_16-44.model"
|
||||
rgb_background_model_path = r"models/tobacco_dt_2022-07-21_16-30.model"
|
||||
92
main.py
Normal file
92
main.py
Normal file
@ -0,0 +1,92 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from config import Config
|
||||
from models import ManualTree, AnonymousColorDetector
|
||||
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
threshold = Config.threshold
|
||||
rgb_threshold = Config.rgb_threshold
|
||||
manualTree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
||||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
||||
|
||||
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||
total_rgb = Config.nrgbRows * Config.nrgbCols * Config.nrgbBands * 1 # int型变量
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
if not os.access(rgb_fifo_path, os.F_OK):
|
||||
os.mkfifo(rgb_fifo_path, 0o777)
|
||||
while True:
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
||||
data = os.read(fd_img, total_len)
|
||||
|
||||
# 读取(开启一个管道)
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
print(threshold)
|
||||
continue
|
||||
else:
|
||||
data_total = data
|
||||
rgb_data = os.read(fd_rgb, total_rgb)
|
||||
if len(rgb_data) < 3:
|
||||
rgb_threshold = int(float(rgb_data))
|
||||
print(rgb_threshold)
|
||||
continue
|
||||
else:
|
||||
rgb_data_total = rgb_data
|
||||
|
||||
os.close(fd_img)
|
||||
os.close(fd_rgb)
|
||||
|
||||
# 识别
|
||||
t1 = time.time()
|
||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0,
|
||||
2,
|
||||
1)
|
||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8)
|
||||
pixel_predict_result = manualTree.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = manualTree.blk_predict(data=img_data)
|
||||
rgb_data = tobacco_detector.pretreatment(rgb_data)
|
||||
rgb_predict_result = 1 - (
|
||||
background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data)))
|
||||
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
||||
.sum(axis=1)
|
||||
mask_rgb[mask_rgb <= rgb_threshold] = 0
|
||||
mask_rgb[mask_rgb > rgb_threshold] = 1
|
||||
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
||||
.sum(axis=1)
|
||||
# print(threshold)
|
||||
mask[mask <= threshold] = 0
|
||||
mask[mask > threshold] = 1
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t2 = time.time()
|
||||
# print(f'shibie time is:{t2 - t1}')
|
||||
print(f'rgb len = {len(rgb_data)}')
|
||||
|
||||
# 写出
|
||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||
os.write(fd_mask, mask_result.tobytes())
|
||||
os.close(fd_mask)
|
||||
t3 = time.time()
|
||||
print(f'total time is:{t3 - t1}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 相关参数
|
||||
img_fifo_path = "/tmp/dkimg.fifo"
|
||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
||||
|
||||
# 主函数
|
||||
main()
|
||||
152
models.py
152
models.py
@ -10,9 +10,12 @@ import cv2
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import tqdm
|
||||
from scipy.ndimage import binary_dilation
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from config import Config
|
||||
from utils import lab_scatter, read_labeled_img
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -20,9 +23,6 @@ from elm import ELM
|
||||
|
||||
|
||||
class Detector(object):
|
||||
def __int__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -38,6 +38,7 @@ class Detector(object):
|
||||
|
||||
class AnonymousColorDetector(Detector):
|
||||
def __init__(self, file_path: str = None):
|
||||
super(AnonymousColorDetector, self).__init__()
|
||||
self.model = None
|
||||
self.model_type = 'None'
|
||||
if file_path is not None:
|
||||
@ -154,6 +155,143 @@ class AnonymousColorDetector(Detector):
|
||||
lab_scatter(draw_dataset, is_3d=True, is_ps_color_space=False, **kwargs)
|
||||
|
||||
|
||||
class ManualTree:
|
||||
# 初始化机器学习像素模型、深度学习像素模型、分块模型
|
||||
def __init__(self, blk_model_path, pixel_model_path):
|
||||
self.pixel_model_ml = PixelModelML(pixel_model_path)
|
||||
self.blk_model = BlkModel(blk_model_path)
|
||||
|
||||
# 区分烟梗和非黄色且非背景的杂质
|
||||
@staticmethod
|
||||
def is_yellow(features):
|
||||
features = features.reshape((Config.nRows * Config.nCols), len(Config.selected_bands))
|
||||
sum_x = features.sum(axis=1)[..., np.newaxis]
|
||||
rate = features / sum_x
|
||||
mask = ((rate < Config.is_yellow_max) & (rate > Config.is_yellow_min))
|
||||
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
|
||||
return mask
|
||||
|
||||
# 区分背景和黄色杂质
|
||||
@staticmethod
|
||||
def is_black(feature, threshold):
|
||||
feature = feature.reshape((Config.nRows * Config.nCols), feature.shape[2])
|
||||
mask = (feature <= threshold)
|
||||
mask = np.all(mask, axis=1).reshape(Config.nRows, Config.nCols)
|
||||
return mask
|
||||
|
||||
# 预测出烟梗的mask
|
||||
def predict_tobacco(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
预测出烟梗的mask
|
||||
:param x: 图像数据,形状是 nRows x nCols x nBands
|
||||
:return: bool类型的mask,是否为烟梗, True为烟梗
|
||||
"""
|
||||
black_res = self.is_black(x[..., Config.black_yellow_bands], Config.is_black_threshold)
|
||||
yellow_res = self.is_yellow(x[..., Config.black_yellow_bands])
|
||||
yellow_things = (~black_res) & yellow_res
|
||||
x_yellow = x[yellow_things, ...]
|
||||
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands])
|
||||
yellow_things[yellow_things] = tobacco
|
||||
return yellow_things
|
||||
|
||||
# 预测出杂质的机器学习像素模型
|
||||
def pixel_predict_ml_dilation(self, data, iteration) -> np.ndarray:
|
||||
"""
|
||||
预测出杂质的位置mask
|
||||
:param data: 图像数据,形状是 nRows x nCols x nBands
|
||||
:param iteration: 膨胀的次数
|
||||
:return: bool类型的mask,是否为杂质, True为杂质
|
||||
"""
|
||||
black_res = self.is_black(data[..., Config.black_yellow_bands], Config.is_black_threshold)
|
||||
yellow_res = self.is_yellow(data[..., Config.black_yellow_bands])
|
||||
# non_yellow_things为异色杂质
|
||||
non_yellow_things = (~black_res) & (~yellow_res)
|
||||
# yellow_things为黄色物体(烟梗+杂质)
|
||||
yellow_things = (~black_res) & yellow_res
|
||||
# x_yellow为挑出的黄色物体
|
||||
x_yellow = data[yellow_things, ...]
|
||||
if x_yellow.shape[0] == 0:
|
||||
return non_yellow_things
|
||||
else:
|
||||
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5
|
||||
|
||||
non_yellow_things[yellow_things] = ~tobacco
|
||||
# 杂质mask中将背景赋值为0,将杂质赋值为1
|
||||
non_yellow_things = non_yellow_things + 0
|
||||
|
||||
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
|
||||
yellow_things[yellow_things] = tobacco
|
||||
yellow_things = yellow_things + 0
|
||||
yellow_things = binary_dilation(yellow_things, iterations=iteration)
|
||||
yellow_things = yellow_things + 0
|
||||
yellow_things[yellow_things == 1] = 2
|
||||
|
||||
# 将杂质mask和烟梗mask相加,得到的mask中含有0(背景),1(杂质),2(烟梗),3(膨胀后的烟梗与杂质相加的部分)
|
||||
mask = non_yellow_things + yellow_things
|
||||
mask[mask == 0] = False
|
||||
mask[mask == 1] = True
|
||||
mask[mask == 2] = False
|
||||
mask[mask == 3] = False
|
||||
return mask
|
||||
|
||||
# 预测出杂质的分块模型
|
||||
def blk_predict(self, data):
|
||||
blk_result_array = self.blk_model.predict(data)
|
||||
return blk_result_array
|
||||
|
||||
|
||||
# 机器学习像素模型类
|
||||
class PixelModelML:
|
||||
def __init__(self, pixel_model_path):
|
||||
with open(pixel_model_path, "rb") as f:
|
||||
self.dt = pickle.load(f)
|
||||
|
||||
def predict(self, feature):
|
||||
pixel_result_array = self.dt.predict(feature)
|
||||
return pixel_result_array
|
||||
|
||||
|
||||
# 分块模型类
|
||||
class BlkModel:
|
||||
def __init__(self, blk_model_path):
|
||||
self.rfc = None
|
||||
self.load(blk_model_path)
|
||||
|
||||
@staticmethod
|
||||
def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||
"""
|
||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||
|
||||
;param data: image data, shape (num_rows x ncols x num_channels)
|
||||
;param blk_sz: block size
|
||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||
"""
|
||||
x_list = []
|
||||
for i in range(0, 256 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
x_list.append(block_data)
|
||||
return x_list
|
||||
|
||||
def predict(self, data):
|
||||
data_blk = data
|
||||
data_blk = np.array(self.split_x(data_blk, blk_sz=Config.blk_size))
|
||||
data_blk = data_blk.reshape((data_blk.shape[0]), -1)
|
||||
y_pred = self.rfc.predict(data_blk)
|
||||
y_pred[y_pred < 2] = 0
|
||||
y_pred[y_pred > 1] = 1
|
||||
blk_result_array = y_pred.reshape(256 // Config.blk_size, 1024 // Config.blk_size).repeat(Config.blk_size,
|
||||
axis=0).repeat(
|
||||
Config.blk_size,
|
||||
axis=1)
|
||||
return blk_result_array
|
||||
|
||||
def load(self, model_path: str):
|
||||
with open(model_path, "rb") as f:
|
||||
self.rfc = pickle.load(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
@ -161,10 +299,10 @@ if __name__ == '__main__':
|
||||
ground_truth = dataset['yangeng']
|
||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
||||
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
||||
data = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
||||
x, y = data['x'], data['y']
|
||||
detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
||||
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
||||
x, y = temp['x'], temp['y']
|
||||
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
||||
lab_scatter(dataset, class_max_num=5000, is_3d=True, is_ps_color_space=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user