mirror of
https://github.com/NanjingForestryUniversity/tobacoo-industry.git
synced 2025-11-08 22:33:52 +00:00
add dual_main.py and dual_main_test.py
This commit is contained in:
parent
52c4feee09
commit
9e1dad8637
370
07_pixelwised_detection.ipynb
Normal file
370
07_pixelwised_detection.ipynb
Normal file
File diff suppressed because one or more lines are too long
110
dual_main.py
Executable file
110
dual_main.py
Executable file
@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from models import SpecDetector, PixelWisedDetector
|
||||||
|
from root_dir import ROOT_DIR
|
||||||
|
from multiprocessing import Process, Queue
|
||||||
|
|
||||||
|
nrows, ncols, nbands = 256, 1024, 4
|
||||||
|
img_fifo_path = "/tmp/dkimg.fifo"
|
||||||
|
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||||
|
cmd_fifo_path = '/tmp/tobacco_cmd.fifo'
|
||||||
|
|
||||||
|
pxl_model_path = "rf_1x1_c4_1_sen1_4.model"
|
||||||
|
blk_model_path = "rf_8x8_c4_185_sen32_4.model"
|
||||||
|
|
||||||
|
|
||||||
|
def main(pxl_model_path=pxl_model_path, blk_model_path=blk_model_path):
|
||||||
|
# 启动两个模型线程
|
||||||
|
blk_cmd_queue, pxl_cmd_queue = Queue(maxsize=100), Queue(maxsize=100)
|
||||||
|
blk_img_queue, pxl_img_queue = Queue(maxsize=100), Queue(maxsize=100)
|
||||||
|
blk_msk_queue, pxl_msk_queue = Queue(maxsize=100), Queue(maxsize=100)
|
||||||
|
blk_process = Process(target=block_model, args=(blk_cmd_queue, blk_img_queue, blk_msk_queue, blk_model_path, ))
|
||||||
|
pxl_process = Process(target=pixel_model, args=(pxl_cmd_queue, pxl_img_queue, pxl_msk_queue, pxl_model_path, ))
|
||||||
|
blk_process.start()
|
||||||
|
pxl_process.start()
|
||||||
|
total_len = nrows * ncols * nbands * 4
|
||||||
|
if not os.access(img_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(img_fifo_path, 0o777)
|
||||||
|
if not os.access(mask_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(mask_fifo_path, 0o777)
|
||||||
|
data = b''
|
||||||
|
while True:
|
||||||
|
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||||
|
while len(data) < total_len:
|
||||||
|
data += os.read(fd_img, total_len)
|
||||||
|
if len(data) > total_len:
|
||||||
|
data_total = data[:total_len]
|
||||||
|
data = data[total_len:]
|
||||||
|
else:
|
||||||
|
data_total = data
|
||||||
|
data = b''
|
||||||
|
os.close(fd_img)
|
||||||
|
|
||||||
|
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
||||||
|
print(f"get img shape {img.shape}")
|
||||||
|
pxl_img_queue.put(img)
|
||||||
|
blk_img_queue.put(img)
|
||||||
|
pxl_msk = pxl_msk_queue.get()
|
||||||
|
blk_msk = blk_msk_queue.get()
|
||||||
|
mask = pxl_msk & blk_msk
|
||||||
|
print(f"predict success get mask shape: {mask.shape}")
|
||||||
|
# 写出
|
||||||
|
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||||
|
os.write(fd_mask, mask.tobytes())
|
||||||
|
os.close(fd_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def block_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, blk_model_path=blk_model_path):
|
||||||
|
blk_model = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path), blk_sz=8, channel_num=4)
|
||||||
|
_ = blk_model.predict(np.ones((nrows, ncols, nbands)))
|
||||||
|
rigor_rate = 70
|
||||||
|
while True:
|
||||||
|
# deal with the cmd if cmd_queue is not empty
|
||||||
|
if not cmd_queue.empty():
|
||||||
|
cmd = cmd_queue.get()
|
||||||
|
if isinstance(cmd, int):
|
||||||
|
rigor_rate = cmd
|
||||||
|
elif isinstance(cmd, str):
|
||||||
|
if cmd == 'stop':
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
blk_model_path = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path),
|
||||||
|
blk_sz=8, channel_num=4)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Load Model Failed! {e}")
|
||||||
|
# deal with the img if img_queue is not empty
|
||||||
|
if not img_queue.empty():
|
||||||
|
img = img_queue.get()
|
||||||
|
mask = blk_model.predict(img, rigor_rate)
|
||||||
|
mask_queue.put(mask)
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, pixel_model_path=pxl_model_path):
|
||||||
|
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path), blk_sz=1, channel_num=4)
|
||||||
|
_ = pixel_model.predict(np.ones((nrows, ncols, nbands)))
|
||||||
|
rigor_rate = 70
|
||||||
|
while True:
|
||||||
|
# deal with the cmd if cmd_queue is not empty
|
||||||
|
if not cmd_queue.empty():
|
||||||
|
cmd = cmd_queue.get()
|
||||||
|
if isinstance(cmd, int):
|
||||||
|
rigor_rate = cmd
|
||||||
|
elif isinstance(cmd, str):
|
||||||
|
if cmd == 'stop':
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path),
|
||||||
|
blk_sz=1, channel_num=4)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Load Model Failed! {e}")
|
||||||
|
# deal with the img if img_queue is not empty
|
||||||
|
if not img_queue.empty():
|
||||||
|
img = img_queue.get()
|
||||||
|
mask = pixel_model.predict(img, rigor_rate)
|
||||||
|
mask_queue.put(mask)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
33
main.py
33
main.py
@ -1,31 +1,44 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from models import SpecDetector
|
from models import SpecDetector
|
||||||
from root_dir import ROOT_DIR
|
from root_dir import ROOT_DIR
|
||||||
|
|
||||||
nrows, ncols, nbands = 600, 1024, 4
|
nrows, ncols, nbands = 256, 1024, 4
|
||||||
img_fifo_path = "/tmp/dkimg.fifo"
|
img_fifo_path = "/tmp/dkimg.fifo"
|
||||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||||
selected_model = "rf_8x8_c4_400_13.model"
|
selected_model = "rf_8x8_c4_185_sen32_4.model"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_path = os.path.join(ROOT_DIR, "models", selected_model)
|
model_path = os.path.join(ROOT_DIR, "models", selected_model)
|
||||||
detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
|
detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
|
||||||
_ = detector.predict(np.ones((600, 1024, 4)))
|
_ = detector.predict(np.ones((256, 1024, 4)))
|
||||||
total_len = nrows * ncols * nbands * 4
|
total_len = nrows * ncols * nbands * 4
|
||||||
|
|
||||||
if not os.access(img_fifo_path, os.F_OK):
|
if not os.access(img_fifo_path, os.F_OK):
|
||||||
os.mkfifo(img_fifo_path, 0o777)
|
os.mkfifo(img_fifo_path, 0o777)
|
||||||
if not os.access(mask_fifo_path, os.F_OK):
|
if not os.access(mask_fifo_path, os.F_OK):
|
||||||
os.mkfifo(mask_fifo_path, 0o777)
|
os.mkfifo(mask_fifo_path, 0o777)
|
||||||
|
data = b''
|
||||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
|
||||||
print("connect to fifo")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
data = os.read(fd_img, total_len)
|
# 读取
|
||||||
print("get img")
|
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||||
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
while len(data) < total_len:
|
||||||
|
data += os.read(fd_img, total_len)
|
||||||
|
if len(data) > total_len:
|
||||||
|
data_total = data[:total_len]
|
||||||
|
data = data[total_len:]
|
||||||
|
else:
|
||||||
|
data_total = data
|
||||||
|
data = b''
|
||||||
|
|
||||||
|
os.close(fd_img)
|
||||||
|
# 识别
|
||||||
|
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
||||||
mask = detector.predict(img)
|
mask = detector.predict(img)
|
||||||
|
# 写出
|
||||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||||
os.write(fd_mask, mask.tobytes())
|
os.write(fd_mask, mask.tobytes())
|
||||||
os.close(fd_mask)
|
os.close(fd_mask)
|
||||||
|
|||||||
83
models.py
83
models.py
@ -2,13 +2,15 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
||||||
|
|
||||||
|
nrows, ncols, nbands = 256, 1024, 4
|
||||||
|
|
||||||
|
|
||||||
def feature(x):
|
def feature(x):
|
||||||
x = x.reshape((x.shape[0], -1))
|
x = x.reshape((x.shape[0], -1))
|
||||||
@ -42,6 +44,31 @@ def train_rf_and_report(train_x, train_y, test_x, test_y,
|
|||||||
return rfc
|
return rfc
|
||||||
|
|
||||||
|
|
||||||
|
def train_t_and_report(train_x, train_y, test_x, test_y, save_path=None):
|
||||||
|
rfc = DecisionTreeClassifier(random_state=42, class_weight={0: 10, 1: 10})
|
||||||
|
rfc = rfc.fit(train_x, train_y)
|
||||||
|
t1 = time.time()
|
||||||
|
y_pred = rfc.predict(test_x)
|
||||||
|
y_pred_binary = np.ones_like(y_pred)
|
||||||
|
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||||
|
y_pred_binary[(y_pred > 1)] = 2
|
||||||
|
test_y_binary = np.ones_like(test_y)
|
||||||
|
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||||
|
test_y_binary[(test_y > 1)] = 2
|
||||||
|
print("预测时间:", time.time() - t1)
|
||||||
|
print("训练集acc:" + str(accuracy_score(train_y, rfc.predict(train_x))))
|
||||||
|
print("测试集acc:" + str(accuracy_score(test_y, rfc.predict(test_x))))
|
||||||
|
print('-'*50)
|
||||||
|
print('测试集报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||||
|
print('混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
print('二分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||||
|
print('二混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
if save_path is not None:
|
||||||
|
with open(save_path, 'wb') as f:
|
||||||
|
pickle.dump(rfc, f)
|
||||||
|
return rfc
|
||||||
|
|
||||||
|
|
||||||
def evaluation_and_report(model, test_x, test_y):
|
def evaluation_and_report(model, test_x, test_y):
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
y_pred = model.predict(test_x)
|
y_pred = model.predict(test_x)
|
||||||
@ -96,21 +123,21 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
|
|||||||
"""
|
"""
|
||||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||||
|
|
||||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
;param data: image data, shape (num_rows x ncols x num_channels)
|
||||||
;param blk_sz: block size
|
;param blk_sz: block size
|
||||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||||
"""
|
"""
|
||||||
x_list = []
|
x_list = []
|
||||||
for i in range(0, 600 // blk_sz):
|
for i in range(0, nrows // blk_sz):
|
||||||
for j in range(0, 1024 // blk_sz):
|
for j in range(0, ncols // blk_sz):
|
||||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
x_list.append(block_data)
|
x_list.append(block_data)
|
||||||
return x_list
|
return x_list
|
||||||
|
|
||||||
|
|
||||||
class SpecDetector(object):
|
class SpecDetector(object):
|
||||||
def __init__(self, model_path, blk_sz=8, channel_num=4):
|
def __init__(self, model_path, blk_sz=8, channel_num=nbands):
|
||||||
self.blk_sz, self.channel_num = blk_sz, channel_num
|
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
with open(model_path, "rb") as model_file:
|
with open(model_path, "rb") as model_file:
|
||||||
@ -118,20 +145,22 @@ class SpecDetector(object):
|
|||||||
else:
|
else:
|
||||||
raise FileNotFoundError("Model File not found")
|
raise FileNotFoundError("Model File not found")
|
||||||
|
|
||||||
def predict(self, data):
|
def predict(self, data, rigor_rate=70):
|
||||||
blocks = split_x(data, blk_sz=self.blk_sz)
|
blocks = split_x(data, blk_sz=self.blk_sz)
|
||||||
blocks = np.array(blocks)
|
blocks = np.array(blocks)
|
||||||
features = feature(np.array(blocks))
|
features = feature(np.array(blocks))
|
||||||
y_pred = self.clf.predict(features)
|
print("Spec Detector", rigor_rate)
|
||||||
y_pred_binary = np.ones_like(y_pred)
|
y_pred = self.clf.predict_proba(features)
|
||||||
|
y_pred, y_prob = np.argmax(y_pred, axis=1), np.max(y_pred, axis=1)
|
||||||
|
y_pred_binary = np.zeros_like(y_pred)
|
||||||
# classes merge
|
# classes merge
|
||||||
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
y_pred_binary[((y_pred == 2) | (y_pred > 3)) & (y_prob > (100 - rigor_rate) / 100.0)] = 1
|
||||||
# transform to mask
|
# transform to mask
|
||||||
mask = self.mask_transform(y_pred_binary, (1024, 600))
|
mask = self.mask_transform(y_pred_binary, (ncols, nrows))
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def mask_transform(self, result, dst_size):
|
def mask_transform(self, result, dst_size):
|
||||||
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz
|
mask_size = nrows // self.blk_sz, ncols // self.blk_sz
|
||||||
mask = np.zeros(mask_size, dtype=np.uint8)
|
mask = np.zeros(mask_size, dtype=np.uint8)
|
||||||
for idx, r in enumerate(result):
|
for idx, r in enumerate(result):
|
||||||
row, col = idx // mask_size[1], idx % mask_size[1]
|
row, col = idx // mask_size[1], idx % mask_size[1]
|
||||||
@ -140,8 +169,34 @@ class SpecDetector(object):
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class PixelWisedDetector(object):
|
||||||
|
def __init__(self, model_path, blk_sz=1, channel_num=nbands):
|
||||||
|
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
with open(model_path, "rb") as model_file:
|
||||||
|
self.clf = pickle.load(model_file)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError("Model File not found")
|
||||||
|
|
||||||
|
def predict(self, data, rigor_rate=70):
|
||||||
|
features = data.reshape((-1, self.channel_num))
|
||||||
|
y_pred = self.clf.predict(features, rigor_rate)
|
||||||
|
y_pred_binary = np.ones_like(y_pred, dtype=np.uint8)
|
||||||
|
print("pixel detector", rigor_rate)
|
||||||
|
# classes merge
|
||||||
|
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||||
|
# transform to mask
|
||||||
|
mask = self.mask_transform(y_pred_binary)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def mask_transform(self, result):
|
||||||
|
mask_size = (nrows, ncols)
|
||||||
|
mask = result.reshape(mask_size)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class PcaSpecDetector(object):
|
class PcaSpecDetector(object):
|
||||||
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=4):
|
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=nbands):
|
||||||
self.blk_sz, self.channel_num = blk_sz, channel_num
|
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
with open(model_path, "rb") as model_file:
|
with open(model_path, "rb") as model_file:
|
||||||
@ -163,11 +218,11 @@ class PcaSpecDetector(object):
|
|||||||
# classes merge
|
# classes merge
|
||||||
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||||
# transform to mask
|
# transform to mask
|
||||||
mask = self.mask_transform(y_pred_binary, (1024, 600))
|
mask = self.mask_transform(y_pred_binary, (ncols, nrows))
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def mask_transform(self, result, dst_size):
|
def mask_transform(self, result, dst_size):
|
||||||
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz
|
mask_size = nrows // self.blk_sz, ncols // self.blk_sz
|
||||||
mask = np.zeros(mask_size, dtype=np.uint8)
|
mask = np.zeros(mask_size, dtype=np.uint8)
|
||||||
for idx, r in enumerate(result):
|
for idx, r in enumerate(result):
|
||||||
row, col = idx // mask_size[1], idx % mask_size[1]
|
row, col = idx // mask_size[1], idx % mask_size[1]
|
||||||
|
|||||||
61
test_files/dual_main_test.py
Executable file
61
test_files/dual_main_test.py
Executable file
@ -0,0 +1,61 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from utils import read_raw_file
|
||||||
|
|
||||||
|
nrows, ncols = 256, 1024
|
||||||
|
|
||||||
|
|
||||||
|
class DualMainTestCase(unittest.TestCase):
|
||||||
|
def test_dual_main(self):
|
||||||
|
test_img_dirs = '/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/*.raw'
|
||||||
|
selected_bands = None
|
||||||
|
img_fifo_path = "/tmp/dkimg.fifo"
|
||||||
|
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||||
|
|
||||||
|
total_len = nrows * ncols
|
||||||
|
spectral_files = glob.glob(test_img_dirs)
|
||||||
|
print("reading raw files ...")
|
||||||
|
raw_files = [read_raw_file(file, selected_bands=selected_bands) for file in spectral_files]
|
||||||
|
print("reading file success!")
|
||||||
|
if not os.access(img_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(img_fifo_path, 0o777)
|
||||||
|
if not os.access(mask_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(mask_fifo_path, 0o777)
|
||||||
|
data = b''
|
||||||
|
for raw_file in raw_files:
|
||||||
|
if raw_file.shape[0] > nrows:
|
||||||
|
raw_file = raw_file[:nrows, ...]
|
||||||
|
# 写出
|
||||||
|
print(f"send {raw_file.shape}")
|
||||||
|
fd_img = os.open(img_fifo_path, os.O_WRONLY)
|
||||||
|
os.write(fd_img, raw_file.tobytes())
|
||||||
|
os.close(fd_img)
|
||||||
|
# 等待
|
||||||
|
fd_mask = os.open(mask_fifo_path, os.O_RDONLY)
|
||||||
|
while len(data) < total_len:
|
||||||
|
data += os.read(fd_mask, total_len)
|
||||||
|
if len(data) > total_len:
|
||||||
|
data_total = data[:total_len]
|
||||||
|
data = data[total_len:]
|
||||||
|
else:
|
||||||
|
data_total = data
|
||||||
|
data = b''
|
||||||
|
os.close(fd_mask)
|
||||||
|
mask = np.frombuffer(data_total, dtype=np.uint8).reshape((-1, ncols))
|
||||||
|
|
||||||
|
# 显示
|
||||||
|
rgb_img = np.asarray(raw_file[..., [0, 2, 3]] * 255, dtype=np.uint8)
|
||||||
|
mask_color = np.zeros_like(rgb_img)
|
||||||
|
mask_color[mask > 0] = (0, 0, 255)
|
||||||
|
combine = cv2.addWeighted(rgb_img, 1, mask_color, 0.5, 0)
|
||||||
|
cv2.imshow("img", combine)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
59
utils.py
59
utils.py
@ -6,9 +6,12 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from models import SpecDetector
|
from models import SpecDetector
|
||||||
|
|
||||||
|
nrows, ncols = 256, 1024
|
||||||
|
|
||||||
|
|
||||||
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
|
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
|
||||||
"""
|
"""
|
||||||
@ -35,6 +38,8 @@ def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
|
|||||||
:param sensitivity: 敏感度
|
:param sensitivity: 敏感度
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if (pixel_blk.shape[0] ==1) and (pixel_blk.shape[1] == 1):
|
||||||
|
return pixel_blk[0][0]
|
||||||
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
|
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
|
||||||
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
|
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
|
||||||
for cls in defect_dict.keys()}
|
for cls in defect_dict.keys()}
|
||||||
@ -55,7 +60,7 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
|
|||||||
"""
|
"""
|
||||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||||
|
|
||||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
;param data: image data, shape (num_rows x ncols x num_channels)
|
||||||
;param labeled_img: RGB labeled img with respect to the image!
|
;param labeled_img: RGB labeled img with respect to the image!
|
||||||
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
|
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
|
||||||
;param blk_sz: block size
|
;param blk_sz: block size
|
||||||
@ -71,8 +76,8 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
|
|||||||
truth_map = np.all(labeled_img == color, axis=2)
|
truth_map = np.all(labeled_img == color, axis=2)
|
||||||
class_img[truth_map] = class_idx
|
class_img[truth_map] = class_idx
|
||||||
x_list, y_list = [], []
|
x_list, y_list = [], []
|
||||||
for i in range(0, 600 // blk_sz):
|
for i in range(0, nrows // blk_sz):
|
||||||
for j in range(0, 1024 // blk_sz):
|
for j in range(0, ncols // blk_sz):
|
||||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
block_label = determine_class(block_label, sensitivity=sensitivity)
|
block_label = determine_class(block_label, sensitivity=sensitivity)
|
||||||
@ -90,14 +95,14 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
|
|||||||
"""
|
"""
|
||||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||||
|
|
||||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
;param data: image data, shape (num_rows x ncols x num_channels)
|
||||||
;param blk_sz: block size
|
;param blk_sz: block size
|
||||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||||
"""
|
"""
|
||||||
x_list = []
|
x_list = []
|
||||||
for i in range(0, 600 // blk_sz):
|
for i in range(0, nrows // blk_sz):
|
||||||
for j in range(0, 1024 // blk_sz):
|
for j in range(0, ncols // blk_sz):
|
||||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
x_list.append(block_data)
|
x_list.append(block_data)
|
||||||
return x_list
|
return x_list
|
||||||
@ -105,7 +110,6 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
|
|||||||
|
|
||||||
def visualization_evaluation(detector, data_path, selected_bands=None):
|
def visualization_evaluation(detector, data_path, selected_bands=None):
|
||||||
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
|
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
|
||||||
nrows, ncols = 600, 1024
|
|
||||||
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
|
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
|
||||||
for idx, image_path in enumerate(image_paths):
|
for idx, image_path in enumerate(image_paths):
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
@ -132,9 +136,9 @@ def visualization_evaluation(detector, data_path, selected_bands=None):
|
|||||||
|
|
||||||
|
|
||||||
def visualization_y(y_list, k_size):
|
def visualization_y(y_list, k_size):
|
||||||
mask = np.zeros((600 // k_size, 1024 // k_size), dtype=np.uint8)
|
mask = np.zeros((nrows // k_size, ncols // k_size), dtype=np.uint8)
|
||||||
for idx, r in enumerate(y_list):
|
for idx, r in enumerate(y_list):
|
||||||
row, col = idx // (1024 // k_size), idx % (1024 // k_size)
|
row, col = idx // (ncols // k_size), idx % (ncols // k_size)
|
||||||
mask[row, col] = r
|
mask[row, col] = r
|
||||||
fig, axs = plt.subplots()
|
fig, axs = plt.subplots()
|
||||||
axs.imshow(mask)
|
axs.imshow(mask)
|
||||||
@ -142,16 +146,23 @@ def visualization_y(y_list, k_size):
|
|||||||
|
|
||||||
|
|
||||||
def read_raw_file(file_name, selected_bands=None):
|
def read_raw_file(file_name, selected_bands=None):
|
||||||
|
print(f"reading file {file_name}")
|
||||||
with open(file_name, "rb") as f:
|
with open(file_name, "rb") as f:
|
||||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, 1024)).transpose(0, 2, 1)
|
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, ncols)).transpose(0, 2, 1)
|
||||||
if selected_bands is not None:
|
if selected_bands is not None:
|
||||||
data = data[..., selected_bands]
|
data = data[..., selected_bands]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def write_raw_file(file_name, data: np.ndarray):
|
||||||
|
data = data.transpose(0, 2, 1).reshape((nrows, -1, ncols))
|
||||||
|
with open(file_name, 'wb') as f:
|
||||||
|
f.write(data.tobytes())
|
||||||
|
|
||||||
|
|
||||||
def read_black_and_white_file(file_name):
|
def read_black_and_white_file(file_name):
|
||||||
with open(file_name, "rb") as f:
|
with open(file_name, "rb") as f:
|
||||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, 1024)).transpose(0, 2, 1)
|
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, ncols)).transpose(0, 2, 1)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@ -166,8 +177,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
|
|||||||
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
|
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
|
||||||
y_label = model.predict(data)
|
y_label = model.predict(data)
|
||||||
x_list, y_list = [], []
|
x_list, y_list = [], []
|
||||||
for i in range(0, 600 // blk_sz):
|
for i in range(0, nrows // blk_sz):
|
||||||
for j in range(0, 1024 // blk_sz):
|
for j in range(0, ncols // blk_sz):
|
||||||
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
|
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
|
||||||
> 0:
|
> 0:
|
||||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
@ -179,8 +190,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
|
|||||||
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
|
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
|
||||||
target_class_left=None, ):
|
target_class_left=None, ):
|
||||||
y_label = np.zeros((data.shape[0], data.shape[1]))
|
y_label = np.zeros((data.shape[0], data.shape[1]))
|
||||||
for i in range(0, 600):
|
for i in range(0, nrows):
|
||||||
for j in range(0, 1024):
|
for j in range(0, ncols):
|
||||||
if np.sum(np.sum(data[i, j])) >= light_threshold:
|
if np.sum(np.sum(data[i, j])) >= light_threshold:
|
||||||
if j > split_line:
|
if j > split_line:
|
||||||
y_label[i, j] = target_class_right
|
y_label[i, j] = target_class_right
|
||||||
@ -192,3 +203,21 @@ def generate_impurity_label(data, light_threshold, color_dict, split_line=0, tar
|
|||||||
axs[1].matshow(data[..., 0])
|
axs[1].matshow(data[..., 0])
|
||||||
plt.show()
|
plt.show()
|
||||||
return pic
|
return pic
|
||||||
|
|
||||||
|
|
||||||
|
def file_transform(input_dir, output_dir, selected_bands=None):
|
||||||
|
files = os.listdir(input_dir)
|
||||||
|
filtered_files = [file for file in files if file.endswith('.raw')]
|
||||||
|
os.makedirs(output_dir, mode=0o777, exist_ok=True)
|
||||||
|
for file_path in filtered_files:
|
||||||
|
input_path = os.path.join(input_dir, file_path)
|
||||||
|
output_path = os.path.join(output_dir, file_path)
|
||||||
|
data = read_raw_file(input_path, selected_bands=selected_bands)
|
||||||
|
write_raw_file(output_path, data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
selected_bands = [127, 201, 202, 294]
|
||||||
|
input_dir, output_dir = r"/Volumes/LENOVO_USB_HDD/zhouchao/616/",\
|
||||||
|
r"/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/"
|
||||||
|
file_transform(input_dir=input_dir, output_dir=output_dir, selected_bands=selected_bands)
|
||||||
Loading…
Reference in New Issue
Block a user