add dual_main.py and dual_main_test.py

This commit is contained in:
karllzy 2022-06-23 12:33:27 +08:00
parent 52c4feee09
commit 9e1dad8637
6 changed files with 677 additions and 39 deletions

File diff suppressed because one or more lines are too long

110
dual_main.py Executable file
View File

@ -0,0 +1,110 @@
import os
import numpy as np
from models import SpecDetector, PixelWisedDetector
from root_dir import ROOT_DIR
from multiprocessing import Process, Queue
nrows, ncols, nbands = 256, 1024, 4
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
cmd_fifo_path = '/tmp/tobacco_cmd.fifo'
pxl_model_path = "rf_1x1_c4_1_sen1_4.model"
blk_model_path = "rf_8x8_c4_185_sen32_4.model"
def main(pxl_model_path=pxl_model_path, blk_model_path=blk_model_path):
# 启动两个模型线程
blk_cmd_queue, pxl_cmd_queue = Queue(maxsize=100), Queue(maxsize=100)
blk_img_queue, pxl_img_queue = Queue(maxsize=100), Queue(maxsize=100)
blk_msk_queue, pxl_msk_queue = Queue(maxsize=100), Queue(maxsize=100)
blk_process = Process(target=block_model, args=(blk_cmd_queue, blk_img_queue, blk_msk_queue, blk_model_path, ))
pxl_process = Process(target=pixel_model, args=(pxl_cmd_queue, pxl_img_queue, pxl_msk_queue, pxl_model_path, ))
blk_process.start()
pxl_process.start()
total_len = nrows * ncols * nbands * 4
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
data = b''
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
while len(data) < total_len:
data += os.read(fd_img, total_len)
if len(data) > total_len:
data_total = data[:total_len]
data = data[total_len:]
else:
data_total = data
data = b''
os.close(fd_img)
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
print(f"get img shape {img.shape}")
pxl_img_queue.put(img)
blk_img_queue.put(img)
pxl_msk = pxl_msk_queue.get()
blk_msk = blk_msk_queue.get()
mask = pxl_msk & blk_msk
print(f"predict success get mask shape: {mask.shape}")
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask.tobytes())
os.close(fd_mask)
def block_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, blk_model_path=blk_model_path):
blk_model = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path), blk_sz=8, channel_num=4)
_ = blk_model.predict(np.ones((nrows, ncols, nbands)))
rigor_rate = 70
while True:
# deal with the cmd if cmd_queue is not empty
if not cmd_queue.empty():
cmd = cmd_queue.get()
if isinstance(cmd, int):
rigor_rate = cmd
elif isinstance(cmd, str):
if cmd == 'stop':
break
else:
try:
blk_model_path = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path),
blk_sz=8, channel_num=4)
except Exception as e:
print(f"Load Model Failed! {e}")
# deal with the img if img_queue is not empty
if not img_queue.empty():
img = img_queue.get()
mask = blk_model.predict(img, rigor_rate)
mask_queue.put(mask)
def pixel_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, pixel_model_path=pxl_model_path):
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path), blk_sz=1, channel_num=4)
_ = pixel_model.predict(np.ones((nrows, ncols, nbands)))
rigor_rate = 70
while True:
# deal with the cmd if cmd_queue is not empty
if not cmd_queue.empty():
cmd = cmd_queue.get()
if isinstance(cmd, int):
rigor_rate = cmd
elif isinstance(cmd, str):
if cmd == 'stop':
break
else:
try:
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path),
blk_sz=1, channel_num=4)
except Exception as e:
print(f"Load Model Failed! {e}")
# deal with the img if img_queue is not empty
if not img_queue.empty():
img = img_queue.get()
mask = pixel_model.predict(img, rigor_rate)
mask_queue.put(mask)
if __name__ == '__main__':
main()

33
main.py
View File

@ -1,31 +1,44 @@
import os import os
import cv2
import numpy as np import numpy as np
from models import SpecDetector from models import SpecDetector
from root_dir import ROOT_DIR from root_dir import ROOT_DIR
nrows, ncols, nbands = 600, 1024, 4 nrows, ncols, nbands = 256, 1024, 4
img_fifo_path = "/tmp/dkimg.fifo" img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo" mask_fifo_path = "/tmp/dkmask.fifo"
selected_model = "rf_8x8_c4_400_13.model" selected_model = "rf_8x8_c4_185_sen32_4.model"
def main(): def main():
model_path = os.path.join(ROOT_DIR, "models", selected_model) model_path = os.path.join(ROOT_DIR, "models", selected_model)
detector = SpecDetector(model_path, blk_sz=8, channel_num=4) detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
_ = detector.predict(np.ones((600, 1024, 4))) _ = detector.predict(np.ones((256, 1024, 4)))
total_len = nrows * ncols * nbands * 4 total_len = nrows * ncols * nbands * 4
if not os.access(img_fifo_path, os.F_OK): if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777) os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK): if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777) os.mkfifo(mask_fifo_path, 0o777)
data = b''
fd_img = os.open(img_fifo_path, os.O_RDONLY)
print("connect to fifo")
while True: while True:
data = os.read(fd_img, total_len) # 读取
print("get img") fd_img = os.open(img_fifo_path, os.O_RDONLY)
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1) while len(data) < total_len:
data += os.read(fd_img, total_len)
if len(data) > total_len:
data_total = data[:total_len]
data = data[total_len:]
else:
data_total = data
data = b''
os.close(fd_img)
# 识别
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
mask = detector.predict(img) mask = detector.predict(img)
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY) fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask.tobytes()) os.write(fd_mask, mask.tobytes())
os.close(fd_mask) os.close(fd_mask)

View File

@ -2,13 +2,15 @@ import os
import pickle import pickle
import time import time
import cv2
import numpy as np import numpy as np
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
nrows, ncols, nbands = 256, 1024, 4
def feature(x): def feature(x):
x = x.reshape((x.shape[0], -1)) x = x.reshape((x.shape[0], -1))
@ -42,6 +44,31 @@ def train_rf_and_report(train_x, train_y, test_x, test_y,
return rfc return rfc
def train_t_and_report(train_x, train_y, test_x, test_y, save_path=None):
rfc = DecisionTreeClassifier(random_state=42, class_weight={0: 10, 1: 10})
rfc = rfc.fit(train_x, train_y)
t1 = time.time()
y_pred = rfc.predict(test_x)
y_pred_binary = np.ones_like(y_pred)
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
y_pred_binary[(y_pred > 1)] = 2
test_y_binary = np.ones_like(test_y)
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
test_y_binary[(test_y > 1)] = 2
print("预测时间:", time.time() - t1)
print("训练集acc" + str(accuracy_score(train_y, rfc.predict(train_x))))
print("测试集acc" + str(accuracy_score(test_y, rfc.predict(test_x))))
print('-'*50)
print('测试集报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
print('混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
print('二分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
print('二混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
if save_path is not None:
with open(save_path, 'wb') as f:
pickle.dump(rfc, f)
return rfc
def evaluation_and_report(model, test_x, test_y): def evaluation_and_report(model, test_x, test_y):
t1 = time.time() t1 = time.time()
y_pred = model.predict(test_x) y_pred = model.predict(test_x)
@ -96,21 +123,21 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
""" """
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别. Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels) ;param data: image data, shape (num_rows x ncols x num_channels)
;param blk_sz: block size ;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物 ;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz) ;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
""" """
x_list = [] x_list = []
for i in range(0, 600 // blk_sz): for i in range(0, nrows // blk_sz):
for j in range(0, 1024 // blk_sz): for j in range(0, ncols // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...] block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data) x_list.append(block_data)
return x_list return x_list
class SpecDetector(object): class SpecDetector(object):
def __init__(self, model_path, blk_sz=8, channel_num=4): def __init__(self, model_path, blk_sz=8, channel_num=nbands):
self.blk_sz, self.channel_num = blk_sz, channel_num self.blk_sz, self.channel_num = blk_sz, channel_num
if os.path.exists(model_path): if os.path.exists(model_path):
with open(model_path, "rb") as model_file: with open(model_path, "rb") as model_file:
@ -118,20 +145,22 @@ class SpecDetector(object):
else: else:
raise FileNotFoundError("Model File not found") raise FileNotFoundError("Model File not found")
def predict(self, data): def predict(self, data, rigor_rate=70):
blocks = split_x(data, blk_sz=self.blk_sz) blocks = split_x(data, blk_sz=self.blk_sz)
blocks = np.array(blocks) blocks = np.array(blocks)
features = feature(np.array(blocks)) features = feature(np.array(blocks))
y_pred = self.clf.predict(features) print("Spec Detector", rigor_rate)
y_pred_binary = np.ones_like(y_pred) y_pred = self.clf.predict_proba(features)
y_pred, y_prob = np.argmax(y_pred, axis=1), np.max(y_pred, axis=1)
y_pred_binary = np.zeros_like(y_pred)
# classes merge # classes merge
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0 y_pred_binary[((y_pred == 2) | (y_pred > 3)) & (y_prob > (100 - rigor_rate) / 100.0)] = 1
# transform to mask # transform to mask
mask = self.mask_transform(y_pred_binary, (1024, 600)) mask = self.mask_transform(y_pred_binary, (ncols, nrows))
return mask return mask
def mask_transform(self, result, dst_size): def mask_transform(self, result, dst_size):
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz mask_size = nrows // self.blk_sz, ncols // self.blk_sz
mask = np.zeros(mask_size, dtype=np.uint8) mask = np.zeros(mask_size, dtype=np.uint8)
for idx, r in enumerate(result): for idx, r in enumerate(result):
row, col = idx // mask_size[1], idx % mask_size[1] row, col = idx // mask_size[1], idx % mask_size[1]
@ -140,8 +169,34 @@ class SpecDetector(object):
return mask return mask
class PixelWisedDetector(object):
def __init__(self, model_path, blk_sz=1, channel_num=nbands):
self.blk_sz, self.channel_num = blk_sz, channel_num
if os.path.exists(model_path):
with open(model_path, "rb") as model_file:
self.clf = pickle.load(model_file)
else:
raise FileNotFoundError("Model File not found")
def predict(self, data, rigor_rate=70):
features = data.reshape((-1, self.channel_num))
y_pred = self.clf.predict(features, rigor_rate)
y_pred_binary = np.ones_like(y_pred, dtype=np.uint8)
print("pixel detector", rigor_rate)
# classes merge
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
# transform to mask
mask = self.mask_transform(y_pred_binary)
return mask
def mask_transform(self, result):
mask_size = (nrows, ncols)
mask = result.reshape(mask_size)
return mask
class PcaSpecDetector(object): class PcaSpecDetector(object):
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=4): def __init__(self, model_path, pca_path, blk_sz=8, channel_num=nbands):
self.blk_sz, self.channel_num = blk_sz, channel_num self.blk_sz, self.channel_num = blk_sz, channel_num
if os.path.exists(model_path): if os.path.exists(model_path):
with open(model_path, "rb") as model_file: with open(model_path, "rb") as model_file:
@ -163,11 +218,11 @@ class PcaSpecDetector(object):
# classes merge # classes merge
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0 y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
# transform to mask # transform to mask
mask = self.mask_transform(y_pred_binary, (1024, 600)) mask = self.mask_transform(y_pred_binary, (ncols, nrows))
return mask return mask
def mask_transform(self, result, dst_size): def mask_transform(self, result, dst_size):
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz mask_size = nrows // self.blk_sz, ncols // self.blk_sz
mask = np.zeros(mask_size, dtype=np.uint8) mask = np.zeros(mask_size, dtype=np.uint8)
for idx, r in enumerate(result): for idx, r in enumerate(result):
row, col = idx // mask_size[1], idx % mask_size[1] row, col = idx // mask_size[1], idx % mask_size[1]

61
test_files/dual_main_test.py Executable file
View File

@ -0,0 +1,61 @@
import glob
import os
import unittest
import cv2
import numpy as np
from utils import read_raw_file
nrows, ncols = 256, 1024
class DualMainTestCase(unittest.TestCase):
def test_dual_main(self):
test_img_dirs = '/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/*.raw'
selected_bands = None
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
total_len = nrows * ncols
spectral_files = glob.glob(test_img_dirs)
print("reading raw files ...")
raw_files = [read_raw_file(file, selected_bands=selected_bands) for file in spectral_files]
print("reading file success!")
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
data = b''
for raw_file in raw_files:
if raw_file.shape[0] > nrows:
raw_file = raw_file[:nrows, ...]
# 写出
print(f"send {raw_file.shape}")
fd_img = os.open(img_fifo_path, os.O_WRONLY)
os.write(fd_img, raw_file.tobytes())
os.close(fd_img)
# 等待
fd_mask = os.open(mask_fifo_path, os.O_RDONLY)
while len(data) < total_len:
data += os.read(fd_mask, total_len)
if len(data) > total_len:
data_total = data[:total_len]
data = data[total_len:]
else:
data_total = data
data = b''
os.close(fd_mask)
mask = np.frombuffer(data_total, dtype=np.uint8).reshape((-1, ncols))
# 显示
rgb_img = np.asarray(raw_file[..., [0, 2, 3]] * 255, dtype=np.uint8)
mask_color = np.zeros_like(rgb_img)
mask_color[mask > 0] = (0, 0, 255)
combine = cv2.addWeighted(rgb_img, 1, mask_color, 0.5, 0)
cv2.imshow("img", combine)
cv2.waitKey(0)
if __name__ == '__main__':
unittest.main()

View File

@ -6,9 +6,12 @@ import os
import time import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tqdm
from models import SpecDetector from models import SpecDetector
nrows, ncols = 256, 1024
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int: def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
""" """
@ -35,6 +38,8 @@ def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
:param sensitivity: 敏感度 :param sensitivity: 敏感度
:return: :return:
""" """
if (pixel_blk.shape[0] ==1) and (pixel_blk.shape[1] == 1):
return pixel_blk[0][0]
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1} defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls) color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
for cls in defect_dict.keys()} for cls in defect_dict.keys()}
@ -55,7 +60,7 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
""" """
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别. Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels) ;param data: image data, shape (num_rows x ncols x num_channels)
;param labeled_img: RGB labeled img with respect to the image! ;param labeled_img: RGB labeled img with respect to the image!
make sure that the defect is (255, 0, 0) and background is (255, 255, 255) make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
;param blk_sz: block size ;param blk_sz: block size
@ -71,8 +76,8 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
truth_map = np.all(labeled_img == color, axis=2) truth_map = np.all(labeled_img == color, axis=2)
class_img[truth_map] = class_idx class_img[truth_map] = class_idx
x_list, y_list = [], [] x_list, y_list = [], []
for i in range(0, 600 // blk_sz): for i in range(0, nrows // blk_sz):
for j in range(0, 1024 // blk_sz): for j in range(0, ncols // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...] block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...] block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
block_label = determine_class(block_label, sensitivity=sensitivity) block_label = determine_class(block_label, sensitivity=sensitivity)
@ -90,14 +95,14 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
""" """
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别. Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels) ;param data: image data, shape (num_rows x ncols x num_channels)
;param blk_sz: block size ;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物 ;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz) ;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
""" """
x_list = [] x_list = []
for i in range(0, 600 // blk_sz): for i in range(0, nrows // blk_sz):
for j in range(0, 1024 // blk_sz): for j in range(0, ncols // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...] block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data) x_list.append(block_data)
return x_list return x_list
@ -105,7 +110,6 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
def visualization_evaluation(detector, data_path, selected_bands=None): def visualization_evaluation(detector, data_path, selected_bands=None):
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
nrows, ncols = 600, 1024
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw")) image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
for idx, image_path in enumerate(image_paths): for idx, image_path in enumerate(image_paths):
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
@ -132,9 +136,9 @@ def visualization_evaluation(detector, data_path, selected_bands=None):
def visualization_y(y_list, k_size): def visualization_y(y_list, k_size):
mask = np.zeros((600 // k_size, 1024 // k_size), dtype=np.uint8) mask = np.zeros((nrows // k_size, ncols // k_size), dtype=np.uint8)
for idx, r in enumerate(y_list): for idx, r in enumerate(y_list):
row, col = idx // (1024 // k_size), idx % (1024 // k_size) row, col = idx // (ncols // k_size), idx % (ncols // k_size)
mask[row, col] = r mask[row, col] = r
fig, axs = plt.subplots() fig, axs = plt.subplots()
axs.imshow(mask) axs.imshow(mask)
@ -142,16 +146,23 @@ def visualization_y(y_list, k_size):
def read_raw_file(file_name, selected_bands=None): def read_raw_file(file_name, selected_bands=None):
print(f"reading file {file_name}")
with open(file_name, "rb") as f: with open(file_name, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, 1024)).transpose(0, 2, 1) data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, ncols)).transpose(0, 2, 1)
if selected_bands is not None: if selected_bands is not None:
data = data[..., selected_bands] data = data[..., selected_bands]
return data return data
def write_raw_file(file_name, data: np.ndarray):
data = data.transpose(0, 2, 1).reshape((nrows, -1, ncols))
with open(file_name, 'wb') as f:
f.write(data.tobytes())
def read_black_and_white_file(file_name): def read_black_and_white_file(file_name):
with open(file_name, "rb") as f: with open(file_name, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, 1024)).transpose(0, 2, 1) data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, ncols)).transpose(0, 2, 1)
return data return data
@ -166,8 +177,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands)) model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
y_label = model.predict(data) y_label = model.predict(data)
x_list, y_list = [], [] x_list, y_list = [], []
for i in range(0, 600 // blk_sz): for i in range(0, nrows // blk_sz):
for j in range(0, 1024 // blk_sz): for j in range(0, ncols // blk_sz):
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \ if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
> 0: > 0:
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...] block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
@ -179,8 +190,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None, def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
target_class_left=None, ): target_class_left=None, ):
y_label = np.zeros((data.shape[0], data.shape[1])) y_label = np.zeros((data.shape[0], data.shape[1]))
for i in range(0, 600): for i in range(0, nrows):
for j in range(0, 1024): for j in range(0, ncols):
if np.sum(np.sum(data[i, j])) >= light_threshold: if np.sum(np.sum(data[i, j])) >= light_threshold:
if j > split_line: if j > split_line:
y_label[i, j] = target_class_right y_label[i, j] = target_class_right
@ -192,3 +203,21 @@ def generate_impurity_label(data, light_threshold, color_dict, split_line=0, tar
axs[1].matshow(data[..., 0]) axs[1].matshow(data[..., 0])
plt.show() plt.show()
return pic return pic
def file_transform(input_dir, output_dir, selected_bands=None):
files = os.listdir(input_dir)
filtered_files = [file for file in files if file.endswith('.raw')]
os.makedirs(output_dir, mode=0o777, exist_ok=True)
for file_path in filtered_files:
input_path = os.path.join(input_dir, file_path)
output_path = os.path.join(output_dir, file_path)
data = read_raw_file(input_path, selected_bands=selected_bands)
write_raw_file(output_path, data)
if __name__ == '__main__':
selected_bands = [127, 201, 202, 294]
input_dir, output_dir = r"/Volumes/LENOVO_USB_HDD/zhouchao/616/",\
r"/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/"
file_transform(input_dir=input_dir, output_dir=output_dir, selected_bands=selected_bands)