add dual_main.py and dual_main_test.py

This commit is contained in:
karllzy 2022-06-23 12:33:27 +08:00
parent 52c4feee09
commit 9e1dad8637
6 changed files with 677 additions and 39 deletions

File diff suppressed because one or more lines are too long

110
dual_main.py Executable file
View File

@ -0,0 +1,110 @@
import os
import numpy as np
from models import SpecDetector, PixelWisedDetector
from root_dir import ROOT_DIR
from multiprocessing import Process, Queue
nrows, ncols, nbands = 256, 1024, 4
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
cmd_fifo_path = '/tmp/tobacco_cmd.fifo'
pxl_model_path = "rf_1x1_c4_1_sen1_4.model"
blk_model_path = "rf_8x8_c4_185_sen32_4.model"
def main(pxl_model_path=pxl_model_path, blk_model_path=blk_model_path):
# 启动两个模型线程
blk_cmd_queue, pxl_cmd_queue = Queue(maxsize=100), Queue(maxsize=100)
blk_img_queue, pxl_img_queue = Queue(maxsize=100), Queue(maxsize=100)
blk_msk_queue, pxl_msk_queue = Queue(maxsize=100), Queue(maxsize=100)
blk_process = Process(target=block_model, args=(blk_cmd_queue, blk_img_queue, blk_msk_queue, blk_model_path, ))
pxl_process = Process(target=pixel_model, args=(pxl_cmd_queue, pxl_img_queue, pxl_msk_queue, pxl_model_path, ))
blk_process.start()
pxl_process.start()
total_len = nrows * ncols * nbands * 4
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
data = b''
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
while len(data) < total_len:
data += os.read(fd_img, total_len)
if len(data) > total_len:
data_total = data[:total_len]
data = data[total_len:]
else:
data_total = data
data = b''
os.close(fd_img)
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
print(f"get img shape {img.shape}")
pxl_img_queue.put(img)
blk_img_queue.put(img)
pxl_msk = pxl_msk_queue.get()
blk_msk = blk_msk_queue.get()
mask = pxl_msk & blk_msk
print(f"predict success get mask shape: {mask.shape}")
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask.tobytes())
os.close(fd_mask)
def block_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, blk_model_path=blk_model_path):
blk_model = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path), blk_sz=8, channel_num=4)
_ = blk_model.predict(np.ones((nrows, ncols, nbands)))
rigor_rate = 70
while True:
# deal with the cmd if cmd_queue is not empty
if not cmd_queue.empty():
cmd = cmd_queue.get()
if isinstance(cmd, int):
rigor_rate = cmd
elif isinstance(cmd, str):
if cmd == 'stop':
break
else:
try:
blk_model_path = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path),
blk_sz=8, channel_num=4)
except Exception as e:
print(f"Load Model Failed! {e}")
# deal with the img if img_queue is not empty
if not img_queue.empty():
img = img_queue.get()
mask = blk_model.predict(img, rigor_rate)
mask_queue.put(mask)
def pixel_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, pixel_model_path=pxl_model_path):
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path), blk_sz=1, channel_num=4)
_ = pixel_model.predict(np.ones((nrows, ncols, nbands)))
rigor_rate = 70
while True:
# deal with the cmd if cmd_queue is not empty
if not cmd_queue.empty():
cmd = cmd_queue.get()
if isinstance(cmd, int):
rigor_rate = cmd
elif isinstance(cmd, str):
if cmd == 'stop':
break
else:
try:
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path),
blk_sz=1, channel_num=4)
except Exception as e:
print(f"Load Model Failed! {e}")
# deal with the img if img_queue is not empty
if not img_queue.empty():
img = img_queue.get()
mask = pixel_model.predict(img, rigor_rate)
mask_queue.put(mask)
if __name__ == '__main__':
main()

33
main.py
View File

@ -1,31 +1,44 @@
import os
import cv2
import numpy as np
from models import SpecDetector
from root_dir import ROOT_DIR
nrows, ncols, nbands = 600, 1024, 4
nrows, ncols, nbands = 256, 1024, 4
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
selected_model = "rf_8x8_c4_400_13.model"
selected_model = "rf_8x8_c4_185_sen32_4.model"
def main():
model_path = os.path.join(ROOT_DIR, "models", selected_model)
detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
_ = detector.predict(np.ones((600, 1024, 4)))
_ = detector.predict(np.ones((256, 1024, 4)))
total_len = nrows * ncols * nbands * 4
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
fd_img = os.open(img_fifo_path, os.O_RDONLY)
print("connect to fifo")
data = b''
while True:
data = os.read(fd_img, total_len)
print("get img")
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
# 读取
fd_img = os.open(img_fifo_path, os.O_RDONLY)
while len(data) < total_len:
data += os.read(fd_img, total_len)
if len(data) > total_len:
data_total = data[:total_len]
data = data[total_len:]
else:
data_total = data
data = b''
os.close(fd_img)
# 识别
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
mask = detector.predict(img)
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask.tobytes())
os.close(fd_mask)

View File

@ -2,13 +2,15 @@ import os
import pickle
import time
import cv2
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
nrows, ncols, nbands = 256, 1024, 4
def feature(x):
x = x.reshape((x.shape[0], -1))
@ -42,6 +44,31 @@ def train_rf_and_report(train_x, train_y, test_x, test_y,
return rfc
def train_t_and_report(train_x, train_y, test_x, test_y, save_path=None):
rfc = DecisionTreeClassifier(random_state=42, class_weight={0: 10, 1: 10})
rfc = rfc.fit(train_x, train_y)
t1 = time.time()
y_pred = rfc.predict(test_x)
y_pred_binary = np.ones_like(y_pred)
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
y_pred_binary[(y_pred > 1)] = 2
test_y_binary = np.ones_like(test_y)
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
test_y_binary[(test_y > 1)] = 2
print("预测时间:", time.time() - t1)
print("训练集acc" + str(accuracy_score(train_y, rfc.predict(train_x))))
print("测试集acc" + str(accuracy_score(test_y, rfc.predict(test_x))))
print('-'*50)
print('测试集报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
print('混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
print('二分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
print('二混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
if save_path is not None:
with open(save_path, 'wb') as f:
pickle.dump(rfc, f)
return rfc
def evaluation_and_report(model, test_x, test_y):
t1 = time.time()
y_pred = model.predict(test_x)
@ -96,21 +123,21 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
"""
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels)
;param data: image data, shape (num_rows x ncols x num_channels)
;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
"""
x_list = []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
for i in range(0, nrows // blk_sz):
for j in range(0, ncols // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data)
return x_list
class SpecDetector(object):
def __init__(self, model_path, blk_sz=8, channel_num=4):
def __init__(self, model_path, blk_sz=8, channel_num=nbands):
self.blk_sz, self.channel_num = blk_sz, channel_num
if os.path.exists(model_path):
with open(model_path, "rb") as model_file:
@ -118,20 +145,22 @@ class SpecDetector(object):
else:
raise FileNotFoundError("Model File not found")
def predict(self, data):
def predict(self, data, rigor_rate=70):
blocks = split_x(data, blk_sz=self.blk_sz)
blocks = np.array(blocks)
features = feature(np.array(blocks))
y_pred = self.clf.predict(features)
y_pred_binary = np.ones_like(y_pred)
print("Spec Detector", rigor_rate)
y_pred = self.clf.predict_proba(features)
y_pred, y_prob = np.argmax(y_pred, axis=1), np.max(y_pred, axis=1)
y_pred_binary = np.zeros_like(y_pred)
# classes merge
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
y_pred_binary[((y_pred == 2) | (y_pred > 3)) & (y_prob > (100 - rigor_rate) / 100.0)] = 1
# transform to mask
mask = self.mask_transform(y_pred_binary, (1024, 600))
mask = self.mask_transform(y_pred_binary, (ncols, nrows))
return mask
def mask_transform(self, result, dst_size):
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz
mask_size = nrows // self.blk_sz, ncols // self.blk_sz
mask = np.zeros(mask_size, dtype=np.uint8)
for idx, r in enumerate(result):
row, col = idx // mask_size[1], idx % mask_size[1]
@ -140,8 +169,34 @@ class SpecDetector(object):
return mask
class PixelWisedDetector(object):
def __init__(self, model_path, blk_sz=1, channel_num=nbands):
self.blk_sz, self.channel_num = blk_sz, channel_num
if os.path.exists(model_path):
with open(model_path, "rb") as model_file:
self.clf = pickle.load(model_file)
else:
raise FileNotFoundError("Model File not found")
def predict(self, data, rigor_rate=70):
features = data.reshape((-1, self.channel_num))
y_pred = self.clf.predict(features, rigor_rate)
y_pred_binary = np.ones_like(y_pred, dtype=np.uint8)
print("pixel detector", rigor_rate)
# classes merge
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
# transform to mask
mask = self.mask_transform(y_pred_binary)
return mask
def mask_transform(self, result):
mask_size = (nrows, ncols)
mask = result.reshape(mask_size)
return mask
class PcaSpecDetector(object):
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=4):
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=nbands):
self.blk_sz, self.channel_num = blk_sz, channel_num
if os.path.exists(model_path):
with open(model_path, "rb") as model_file:
@ -163,11 +218,11 @@ class PcaSpecDetector(object):
# classes merge
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
# transform to mask
mask = self.mask_transform(y_pred_binary, (1024, 600))
mask = self.mask_transform(y_pred_binary, (ncols, nrows))
return mask
def mask_transform(self, result, dst_size):
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz
mask_size = nrows // self.blk_sz, ncols // self.blk_sz
mask = np.zeros(mask_size, dtype=np.uint8)
for idx, r in enumerate(result):
row, col = idx // mask_size[1], idx % mask_size[1]

61
test_files/dual_main_test.py Executable file
View File

@ -0,0 +1,61 @@
import glob
import os
import unittest
import cv2
import numpy as np
from utils import read_raw_file
nrows, ncols = 256, 1024
class DualMainTestCase(unittest.TestCase):
def test_dual_main(self):
test_img_dirs = '/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/*.raw'
selected_bands = None
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
total_len = nrows * ncols
spectral_files = glob.glob(test_img_dirs)
print("reading raw files ...")
raw_files = [read_raw_file(file, selected_bands=selected_bands) for file in spectral_files]
print("reading file success!")
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
data = b''
for raw_file in raw_files:
if raw_file.shape[0] > nrows:
raw_file = raw_file[:nrows, ...]
# 写出
print(f"send {raw_file.shape}")
fd_img = os.open(img_fifo_path, os.O_WRONLY)
os.write(fd_img, raw_file.tobytes())
os.close(fd_img)
# 等待
fd_mask = os.open(mask_fifo_path, os.O_RDONLY)
while len(data) < total_len:
data += os.read(fd_mask, total_len)
if len(data) > total_len:
data_total = data[:total_len]
data = data[total_len:]
else:
data_total = data
data = b''
os.close(fd_mask)
mask = np.frombuffer(data_total, dtype=np.uint8).reshape((-1, ncols))
# 显示
rgb_img = np.asarray(raw_file[..., [0, 2, 3]] * 255, dtype=np.uint8)
mask_color = np.zeros_like(rgb_img)
mask_color[mask > 0] = (0, 0, 255)
combine = cv2.addWeighted(rgb_img, 1, mask_color, 0.5, 0)
cv2.imshow("img", combine)
cv2.waitKey(0)
if __name__ == '__main__':
unittest.main()

View File

@ -6,9 +6,12 @@ import os
import time
import matplotlib.pyplot as plt
import tqdm
from models import SpecDetector
nrows, ncols = 256, 1024
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
"""
@ -35,6 +38,8 @@ def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
:param sensitivity: 敏感度
:return:
"""
if (pixel_blk.shape[0] ==1) and (pixel_blk.shape[1] == 1):
return pixel_blk[0][0]
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
for cls in defect_dict.keys()}
@ -55,7 +60,7 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
"""
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels)
;param data: image data, shape (num_rows x ncols x num_channels)
;param labeled_img: RGB labeled img with respect to the image!
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
;param blk_sz: block size
@ -71,8 +76,8 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
truth_map = np.all(labeled_img == color, axis=2)
class_img[truth_map] = class_idx
x_list, y_list = [], []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
for i in range(0, nrows // blk_sz):
for j in range(0, ncols // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
block_label = determine_class(block_label, sensitivity=sensitivity)
@ -90,14 +95,14 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
"""
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels)
;param data: image data, shape (num_rows x ncols x num_channels)
;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
"""
x_list = []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
for i in range(0, nrows // blk_sz):
for j in range(0, ncols // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data)
return x_list
@ -105,7 +110,6 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
def visualization_evaluation(detector, data_path, selected_bands=None):
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
nrows, ncols = 600, 1024
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
for idx, image_path in enumerate(image_paths):
with open(image_path, 'rb') as f:
@ -132,9 +136,9 @@ def visualization_evaluation(detector, data_path, selected_bands=None):
def visualization_y(y_list, k_size):
mask = np.zeros((600 // k_size, 1024 // k_size), dtype=np.uint8)
mask = np.zeros((nrows // k_size, ncols // k_size), dtype=np.uint8)
for idx, r in enumerate(y_list):
row, col = idx // (1024 // k_size), idx % (1024 // k_size)
row, col = idx // (ncols // k_size), idx % (ncols // k_size)
mask[row, col] = r
fig, axs = plt.subplots()
axs.imshow(mask)
@ -142,16 +146,23 @@ def visualization_y(y_list, k_size):
def read_raw_file(file_name, selected_bands=None):
print(f"reading file {file_name}")
with open(file_name, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, 1024)).transpose(0, 2, 1)
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, ncols)).transpose(0, 2, 1)
if selected_bands is not None:
data = data[..., selected_bands]
return data
def write_raw_file(file_name, data: np.ndarray):
data = data.transpose(0, 2, 1).reshape((nrows, -1, ncols))
with open(file_name, 'wb') as f:
f.write(data.tobytes())
def read_black_and_white_file(file_name):
with open(file_name, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, 1024)).transpose(0, 2, 1)
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, ncols)).transpose(0, 2, 1)
return data
@ -166,8 +177,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
y_label = model.predict(data)
x_list, y_list = [], []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
for i in range(0, nrows // blk_sz):
for j in range(0, ncols // blk_sz):
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
> 0:
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
@ -179,8 +190,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
target_class_left=None, ):
y_label = np.zeros((data.shape[0], data.shape[1]))
for i in range(0, 600):
for j in range(0, 1024):
for i in range(0, nrows):
for j in range(0, ncols):
if np.sum(np.sum(data[i, j])) >= light_threshold:
if j > split_line:
y_label[i, j] = target_class_right
@ -192,3 +203,21 @@ def generate_impurity_label(data, light_threshold, color_dict, split_line=0, tar
axs[1].matshow(data[..., 0])
plt.show()
return pic
def file_transform(input_dir, output_dir, selected_bands=None):
files = os.listdir(input_dir)
filtered_files = [file for file in files if file.endswith('.raw')]
os.makedirs(output_dir, mode=0o777, exist_ok=True)
for file_path in filtered_files:
input_path = os.path.join(input_dir, file_path)
output_path = os.path.join(output_dir, file_path)
data = read_raw_file(input_path, selected_bands=selected_bands)
write_raw_file(output_path, data)
if __name__ == '__main__':
selected_bands = [127, 201, 202, 294]
input_dir, output_dir = r"/Volumes/LENOVO_USB_HDD/zhouchao/616/",\
r"/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/"
file_transform(input_dir=input_dir, output_dir=output_dir, selected_bands=selected_bands)