mirror of
https://github.com/NanjingForestryUniversity/tobacoo-industry.git
synced 2025-11-08 22:33:52 +00:00
add dual_main.py and dual_main_test.py
This commit is contained in:
parent
52c4feee09
commit
9e1dad8637
370
07_pixelwised_detection.ipynb
Normal file
370
07_pixelwised_detection.ipynb
Normal file
File diff suppressed because one or more lines are too long
110
dual_main.py
Executable file
110
dual_main.py
Executable file
@ -0,0 +1,110 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from models import SpecDetector, PixelWisedDetector
|
||||
from root_dir import ROOT_DIR
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
nrows, ncols, nbands = 256, 1024, 4
|
||||
img_fifo_path = "/tmp/dkimg.fifo"
|
||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||
cmd_fifo_path = '/tmp/tobacco_cmd.fifo'
|
||||
|
||||
pxl_model_path = "rf_1x1_c4_1_sen1_4.model"
|
||||
blk_model_path = "rf_8x8_c4_185_sen32_4.model"
|
||||
|
||||
|
||||
def main(pxl_model_path=pxl_model_path, blk_model_path=blk_model_path):
|
||||
# 启动两个模型线程
|
||||
blk_cmd_queue, pxl_cmd_queue = Queue(maxsize=100), Queue(maxsize=100)
|
||||
blk_img_queue, pxl_img_queue = Queue(maxsize=100), Queue(maxsize=100)
|
||||
blk_msk_queue, pxl_msk_queue = Queue(maxsize=100), Queue(maxsize=100)
|
||||
blk_process = Process(target=block_model, args=(blk_cmd_queue, blk_img_queue, blk_msk_queue, blk_model_path, ))
|
||||
pxl_process = Process(target=pixel_model, args=(pxl_cmd_queue, pxl_img_queue, pxl_msk_queue, pxl_model_path, ))
|
||||
blk_process.start()
|
||||
pxl_process.start()
|
||||
total_len = nrows * ncols * nbands * 4
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
data = b''
|
||||
while True:
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
while len(data) < total_len:
|
||||
data += os.read(fd_img, total_len)
|
||||
if len(data) > total_len:
|
||||
data_total = data[:total_len]
|
||||
data = data[total_len:]
|
||||
else:
|
||||
data_total = data
|
||||
data = b''
|
||||
os.close(fd_img)
|
||||
|
||||
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
||||
print(f"get img shape {img.shape}")
|
||||
pxl_img_queue.put(img)
|
||||
blk_img_queue.put(img)
|
||||
pxl_msk = pxl_msk_queue.get()
|
||||
blk_msk = blk_msk_queue.get()
|
||||
mask = pxl_msk & blk_msk
|
||||
print(f"predict success get mask shape: {mask.shape}")
|
||||
# 写出
|
||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||
os.write(fd_mask, mask.tobytes())
|
||||
os.close(fd_mask)
|
||||
|
||||
|
||||
def block_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, blk_model_path=blk_model_path):
|
||||
blk_model = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path), blk_sz=8, channel_num=4)
|
||||
_ = blk_model.predict(np.ones((nrows, ncols, nbands)))
|
||||
rigor_rate = 70
|
||||
while True:
|
||||
# deal with the cmd if cmd_queue is not empty
|
||||
if not cmd_queue.empty():
|
||||
cmd = cmd_queue.get()
|
||||
if isinstance(cmd, int):
|
||||
rigor_rate = cmd
|
||||
elif isinstance(cmd, str):
|
||||
if cmd == 'stop':
|
||||
break
|
||||
else:
|
||||
try:
|
||||
blk_model_path = SpecDetector(os.path.join(ROOT_DIR, "models", blk_model_path),
|
||||
blk_sz=8, channel_num=4)
|
||||
except Exception as e:
|
||||
print(f"Load Model Failed! {e}")
|
||||
# deal with the img if img_queue is not empty
|
||||
if not img_queue.empty():
|
||||
img = img_queue.get()
|
||||
mask = blk_model.predict(img, rigor_rate)
|
||||
mask_queue.put(mask)
|
||||
|
||||
|
||||
def pixel_model(cmd_queue: Queue, img_queue: Queue, mask_queue: Queue, pixel_model_path=pxl_model_path):
|
||||
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path), blk_sz=1, channel_num=4)
|
||||
_ = pixel_model.predict(np.ones((nrows, ncols, nbands)))
|
||||
rigor_rate = 70
|
||||
while True:
|
||||
# deal with the cmd if cmd_queue is not empty
|
||||
if not cmd_queue.empty():
|
||||
cmd = cmd_queue.get()
|
||||
if isinstance(cmd, int):
|
||||
rigor_rate = cmd
|
||||
elif isinstance(cmd, str):
|
||||
if cmd == 'stop':
|
||||
break
|
||||
else:
|
||||
try:
|
||||
pixel_model = PixelWisedDetector(os.path.join(ROOT_DIR, "models", pixel_model_path),
|
||||
blk_sz=1, channel_num=4)
|
||||
except Exception as e:
|
||||
print(f"Load Model Failed! {e}")
|
||||
# deal with the img if img_queue is not empty
|
||||
if not img_queue.empty():
|
||||
img = img_queue.get()
|
||||
mask = pixel_model.predict(img, rigor_rate)
|
||||
mask_queue.put(mask)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
33
main.py
33
main.py
@ -1,31 +1,44 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from models import SpecDetector
|
||||
from root_dir import ROOT_DIR
|
||||
|
||||
nrows, ncols, nbands = 600, 1024, 4
|
||||
nrows, ncols, nbands = 256, 1024, 4
|
||||
img_fifo_path = "/tmp/dkimg.fifo"
|
||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||
selected_model = "rf_8x8_c4_400_13.model"
|
||||
selected_model = "rf_8x8_c4_185_sen32_4.model"
|
||||
|
||||
|
||||
def main():
|
||||
model_path = os.path.join(ROOT_DIR, "models", selected_model)
|
||||
detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
|
||||
_ = detector.predict(np.ones((600, 1024, 4)))
|
||||
_ = detector.predict(np.ones((256, 1024, 4)))
|
||||
total_len = nrows * ncols * nbands * 4
|
||||
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
print("connect to fifo")
|
||||
|
||||
data = b''
|
||||
while True:
|
||||
data = os.read(fd_img, total_len)
|
||||
print("get img")
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
||||
# 读取
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
while len(data) < total_len:
|
||||
data += os.read(fd_img, total_len)
|
||||
if len(data) > total_len:
|
||||
data_total = data[:total_len]
|
||||
data = data[total_len:]
|
||||
else:
|
||||
data_total = data
|
||||
data = b''
|
||||
|
||||
os.close(fd_img)
|
||||
# 识别
|
||||
img = np.frombuffer(data_total, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
||||
mask = detector.predict(img)
|
||||
# 写出
|
||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||
os.write(fd_mask, mask.tobytes())
|
||||
os.close(fd_mask)
|
||||
|
||||
83
models.py
83
models.py
@ -2,13 +2,15 @@ import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
||||
|
||||
nrows, ncols, nbands = 256, 1024, 4
|
||||
|
||||
|
||||
def feature(x):
|
||||
x = x.reshape((x.shape[0], -1))
|
||||
@ -42,6 +44,31 @@ def train_rf_and_report(train_x, train_y, test_x, test_y,
|
||||
return rfc
|
||||
|
||||
|
||||
def train_t_and_report(train_x, train_y, test_x, test_y, save_path=None):
|
||||
rfc = DecisionTreeClassifier(random_state=42, class_weight={0: 10, 1: 10})
|
||||
rfc = rfc.fit(train_x, train_y)
|
||||
t1 = time.time()
|
||||
y_pred = rfc.predict(test_x)
|
||||
y_pred_binary = np.ones_like(y_pred)
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||
y_pred_binary[(y_pred > 1)] = 2
|
||||
test_y_binary = np.ones_like(test_y)
|
||||
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||
test_y_binary[(test_y > 1)] = 2
|
||||
print("预测时间:", time.time() - t1)
|
||||
print("训练集acc:" + str(accuracy_score(train_y, rfc.predict(train_x))))
|
||||
print("测试集acc:" + str(accuracy_score(test_y, rfc.predict(test_x))))
|
||||
print('-'*50)
|
||||
print('测试集报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||
print('混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||
print('二分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||
print('二混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||
if save_path is not None:
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump(rfc, f)
|
||||
return rfc
|
||||
|
||||
|
||||
def evaluation_and_report(model, test_x, test_y):
|
||||
t1 = time.time()
|
||||
y_pred = model.predict(test_x)
|
||||
@ -96,21 +123,21 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||
"""
|
||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||
|
||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||
;param data: image data, shape (num_rows x ncols x num_channels)
|
||||
;param blk_sz: block size
|
||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||
"""
|
||||
x_list = []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
for i in range(0, nrows // blk_sz):
|
||||
for j in range(0, ncols // blk_sz):
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
x_list.append(block_data)
|
||||
return x_list
|
||||
|
||||
|
||||
class SpecDetector(object):
|
||||
def __init__(self, model_path, blk_sz=8, channel_num=4):
|
||||
def __init__(self, model_path, blk_sz=8, channel_num=nbands):
|
||||
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||
if os.path.exists(model_path):
|
||||
with open(model_path, "rb") as model_file:
|
||||
@ -118,20 +145,22 @@ class SpecDetector(object):
|
||||
else:
|
||||
raise FileNotFoundError("Model File not found")
|
||||
|
||||
def predict(self, data):
|
||||
def predict(self, data, rigor_rate=70):
|
||||
blocks = split_x(data, blk_sz=self.blk_sz)
|
||||
blocks = np.array(blocks)
|
||||
features = feature(np.array(blocks))
|
||||
y_pred = self.clf.predict(features)
|
||||
y_pred_binary = np.ones_like(y_pred)
|
||||
print("Spec Detector", rigor_rate)
|
||||
y_pred = self.clf.predict_proba(features)
|
||||
y_pred, y_prob = np.argmax(y_pred, axis=1), np.max(y_pred, axis=1)
|
||||
y_pred_binary = np.zeros_like(y_pred)
|
||||
# classes merge
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||
y_pred_binary[((y_pred == 2) | (y_pred > 3)) & (y_prob > (100 - rigor_rate) / 100.0)] = 1
|
||||
# transform to mask
|
||||
mask = self.mask_transform(y_pred_binary, (1024, 600))
|
||||
mask = self.mask_transform(y_pred_binary, (ncols, nrows))
|
||||
return mask
|
||||
|
||||
def mask_transform(self, result, dst_size):
|
||||
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz
|
||||
mask_size = nrows // self.blk_sz, ncols // self.blk_sz
|
||||
mask = np.zeros(mask_size, dtype=np.uint8)
|
||||
for idx, r in enumerate(result):
|
||||
row, col = idx // mask_size[1], idx % mask_size[1]
|
||||
@ -140,8 +169,34 @@ class SpecDetector(object):
|
||||
return mask
|
||||
|
||||
|
||||
class PixelWisedDetector(object):
|
||||
def __init__(self, model_path, blk_sz=1, channel_num=nbands):
|
||||
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||
if os.path.exists(model_path):
|
||||
with open(model_path, "rb") as model_file:
|
||||
self.clf = pickle.load(model_file)
|
||||
else:
|
||||
raise FileNotFoundError("Model File not found")
|
||||
|
||||
def predict(self, data, rigor_rate=70):
|
||||
features = data.reshape((-1, self.channel_num))
|
||||
y_pred = self.clf.predict(features, rigor_rate)
|
||||
y_pred_binary = np.ones_like(y_pred, dtype=np.uint8)
|
||||
print("pixel detector", rigor_rate)
|
||||
# classes merge
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||
# transform to mask
|
||||
mask = self.mask_transform(y_pred_binary)
|
||||
return mask
|
||||
|
||||
def mask_transform(self, result):
|
||||
mask_size = (nrows, ncols)
|
||||
mask = result.reshape(mask_size)
|
||||
return mask
|
||||
|
||||
|
||||
class PcaSpecDetector(object):
|
||||
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=4):
|
||||
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=nbands):
|
||||
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||
if os.path.exists(model_path):
|
||||
with open(model_path, "rb") as model_file:
|
||||
@ -163,11 +218,11 @@ class PcaSpecDetector(object):
|
||||
# classes merge
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||
# transform to mask
|
||||
mask = self.mask_transform(y_pred_binary, (1024, 600))
|
||||
mask = self.mask_transform(y_pred_binary, (ncols, nrows))
|
||||
return mask
|
||||
|
||||
def mask_transform(self, result, dst_size):
|
||||
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz
|
||||
mask_size = nrows // self.blk_sz, ncols // self.blk_sz
|
||||
mask = np.zeros(mask_size, dtype=np.uint8)
|
||||
for idx, r in enumerate(result):
|
||||
row, col = idx // mask_size[1], idx % mask_size[1]
|
||||
|
||||
61
test_files/dual_main_test.py
Executable file
61
test_files/dual_main_test.py
Executable file
@ -0,0 +1,61 @@
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from utils import read_raw_file
|
||||
|
||||
nrows, ncols = 256, 1024
|
||||
|
||||
|
||||
class DualMainTestCase(unittest.TestCase):
|
||||
def test_dual_main(self):
|
||||
test_img_dirs = '/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/*.raw'
|
||||
selected_bands = None
|
||||
img_fifo_path = "/tmp/dkimg.fifo"
|
||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||
|
||||
total_len = nrows * ncols
|
||||
spectral_files = glob.glob(test_img_dirs)
|
||||
print("reading raw files ...")
|
||||
raw_files = [read_raw_file(file, selected_bands=selected_bands) for file in spectral_files]
|
||||
print("reading file success!")
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
data = b''
|
||||
for raw_file in raw_files:
|
||||
if raw_file.shape[0] > nrows:
|
||||
raw_file = raw_file[:nrows, ...]
|
||||
# 写出
|
||||
print(f"send {raw_file.shape}")
|
||||
fd_img = os.open(img_fifo_path, os.O_WRONLY)
|
||||
os.write(fd_img, raw_file.tobytes())
|
||||
os.close(fd_img)
|
||||
# 等待
|
||||
fd_mask = os.open(mask_fifo_path, os.O_RDONLY)
|
||||
while len(data) < total_len:
|
||||
data += os.read(fd_mask, total_len)
|
||||
if len(data) > total_len:
|
||||
data_total = data[:total_len]
|
||||
data = data[total_len:]
|
||||
else:
|
||||
data_total = data
|
||||
data = b''
|
||||
os.close(fd_mask)
|
||||
mask = np.frombuffer(data_total, dtype=np.uint8).reshape((-1, ncols))
|
||||
|
||||
# 显示
|
||||
rgb_img = np.asarray(raw_file[..., [0, 2, 3]] * 255, dtype=np.uint8)
|
||||
mask_color = np.zeros_like(rgb_img)
|
||||
mask_color[mask > 0] = (0, 0, 255)
|
||||
combine = cv2.addWeighted(rgb_img, 1, mask_color, 0.5, 0)
|
||||
cv2.imshow("img", combine)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
59
utils.py
59
utils.py
@ -6,9 +6,12 @@ import os
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import tqdm
|
||||
|
||||
from models import SpecDetector
|
||||
|
||||
nrows, ncols = 256, 1024
|
||||
|
||||
|
||||
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
|
||||
"""
|
||||
@ -35,6 +38,8 @@ def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
|
||||
:param sensitivity: 敏感度
|
||||
:return:
|
||||
"""
|
||||
if (pixel_blk.shape[0] ==1) and (pixel_blk.shape[1] == 1):
|
||||
return pixel_blk[0][0]
|
||||
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
|
||||
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
|
||||
for cls in defect_dict.keys()}
|
||||
@ -55,7 +60,7 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
|
||||
"""
|
||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||
|
||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||
;param data: image data, shape (num_rows x ncols x num_channels)
|
||||
;param labeled_img: RGB labeled img with respect to the image!
|
||||
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
|
||||
;param blk_sz: block size
|
||||
@ -71,8 +76,8 @@ def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity
|
||||
truth_map = np.all(labeled_img == color, axis=2)
|
||||
class_img[truth_map] = class_idx
|
||||
x_list, y_list = [], []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
for i in range(0, nrows // blk_sz):
|
||||
for j in range(0, ncols // blk_sz):
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
block_label = determine_class(block_label, sensitivity=sensitivity)
|
||||
@ -90,14 +95,14 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||
"""
|
||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||
|
||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||
;param data: image data, shape (num_rows x ncols x num_channels)
|
||||
;param blk_sz: block size
|
||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||
"""
|
||||
x_list = []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
for i in range(0, nrows // blk_sz):
|
||||
for j in range(0, ncols // blk_sz):
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
x_list.append(block_data)
|
||||
return x_list
|
||||
@ -105,7 +110,6 @@ def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||
|
||||
def visualization_evaluation(detector, data_path, selected_bands=None):
|
||||
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
|
||||
nrows, ncols = 600, 1024
|
||||
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
|
||||
for idx, image_path in enumerate(image_paths):
|
||||
with open(image_path, 'rb') as f:
|
||||
@ -132,9 +136,9 @@ def visualization_evaluation(detector, data_path, selected_bands=None):
|
||||
|
||||
|
||||
def visualization_y(y_list, k_size):
|
||||
mask = np.zeros((600 // k_size, 1024 // k_size), dtype=np.uint8)
|
||||
mask = np.zeros((nrows // k_size, ncols // k_size), dtype=np.uint8)
|
||||
for idx, r in enumerate(y_list):
|
||||
row, col = idx // (1024 // k_size), idx % (1024 // k_size)
|
||||
row, col = idx // (ncols // k_size), idx % (ncols // k_size)
|
||||
mask[row, col] = r
|
||||
fig, axs = plt.subplots()
|
||||
axs.imshow(mask)
|
||||
@ -142,16 +146,23 @@ def visualization_y(y_list, k_size):
|
||||
|
||||
|
||||
def read_raw_file(file_name, selected_bands=None):
|
||||
print(f"reading file {file_name}")
|
||||
with open(file_name, "rb") as f:
|
||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, 1024)).transpose(0, 2, 1)
|
||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, ncols)).transpose(0, 2, 1)
|
||||
if selected_bands is not None:
|
||||
data = data[..., selected_bands]
|
||||
return data
|
||||
|
||||
|
||||
def write_raw_file(file_name, data: np.ndarray):
|
||||
data = data.transpose(0, 2, 1).reshape((nrows, -1, ncols))
|
||||
with open(file_name, 'wb') as f:
|
||||
f.write(data.tobytes())
|
||||
|
||||
|
||||
def read_black_and_white_file(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, 1024)).transpose(0, 2, 1)
|
||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, ncols)).transpose(0, 2, 1)
|
||||
return data
|
||||
|
||||
|
||||
@ -166,8 +177,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
|
||||
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
|
||||
y_label = model.predict(data)
|
||||
x_list, y_list = [], []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
for i in range(0, nrows // blk_sz):
|
||||
for j in range(0, ncols // blk_sz):
|
||||
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
|
||||
> 0:
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
@ -179,8 +190,8 @@ def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
|
||||
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
|
||||
target_class_left=None, ):
|
||||
y_label = np.zeros((data.shape[0], data.shape[1]))
|
||||
for i in range(0, 600):
|
||||
for j in range(0, 1024):
|
||||
for i in range(0, nrows):
|
||||
for j in range(0, ncols):
|
||||
if np.sum(np.sum(data[i, j])) >= light_threshold:
|
||||
if j > split_line:
|
||||
y_label[i, j] = target_class_right
|
||||
@ -192,3 +203,21 @@ def generate_impurity_label(data, light_threshold, color_dict, split_line=0, tar
|
||||
axs[1].matshow(data[..., 0])
|
||||
plt.show()
|
||||
return pic
|
||||
|
||||
|
||||
def file_transform(input_dir, output_dir, selected_bands=None):
|
||||
files = os.listdir(input_dir)
|
||||
filtered_files = [file for file in files if file.endswith('.raw')]
|
||||
os.makedirs(output_dir, mode=0o777, exist_ok=True)
|
||||
for file_path in filtered_files:
|
||||
input_path = os.path.join(input_dir, file_path)
|
||||
output_path = os.path.join(output_dir, file_path)
|
||||
data = read_raw_file(input_path, selected_bands=selected_bands)
|
||||
write_raw_file(output_path, data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
selected_bands = [127, 201, 202, 294]
|
||||
input_dir, output_dir = r"/Volumes/LENOVO_USB_HDD/zhouchao/616/",\
|
||||
r"/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/"
|
||||
file_transform(input_dir=input_dir, output_dir=output_dir, selected_bands=selected_bands)
|
||||
Loading…
Reference in New Issue
Block a user