mirror of
https://github.com/NanjingForestryUniversity/tobacoo-industry.git
synced 2025-11-08 14:23:53 +00:00
223 lines
8.7 KiB
Python
Executable File
223 lines
8.7 KiB
Python
Executable File
import cv2
|
||
import numpy as np
|
||
|
||
import glob
|
||
import os
|
||
import time
|
||
|
||
import matplotlib.pyplot as plt
|
||
import tqdm
|
||
|
||
from models import SpecDetector
|
||
|
||
nrows, ncols = 256, 1024
|
||
|
||
|
||
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
|
||
"""
|
||
将label转为类别
|
||
|
||
:param pixel: 一个 n x n 的像素块
|
||
:param color_dict: 用于转化的字典 {(0, 0, 255): 1, ....} 色彩采用bgr
|
||
:return:类别白噢好
|
||
"""
|
||
# 0 表示的是背景, 1表示的是烟梗,剩下的都是杂质
|
||
if color_dict is None:
|
||
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
|
||
if (pixel[0], pixel[1], pixel[2]) in color_dict.keys():
|
||
return color_dict[(pixel[0], pixel[1], pixel[2])]
|
||
else:
|
||
return -1
|
||
|
||
|
||
def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
|
||
"""
|
||
决定像素块的类别
|
||
|
||
:param pixel_blk: 像素块
|
||
:param sensitivity: 敏感度
|
||
:return:
|
||
"""
|
||
if (pixel_blk.shape[0] ==1) and (pixel_blk.shape[1] == 1):
|
||
return pixel_blk[0][0]
|
||
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
|
||
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
|
||
for cls in defect_dict.keys()}
|
||
grant_cls = {0: 0, 1: 0}
|
||
for cls, num in color_numbers.items():
|
||
grant_cls[defect_dict[cls]] += num
|
||
if grant_cls[1] >= sensitivity:
|
||
color_numbers = {cls: color_numbers[cls] for cls in [2, 3, 4, 5, 6]}
|
||
return max(color_numbers, key=color_numbers.get)
|
||
else:
|
||
if color_numbers[1] >= sensitivity:
|
||
return 1
|
||
return 0
|
||
|
||
|
||
def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity: int = 12,
|
||
color_dict=None, add_background=True) -> tuple:
|
||
"""
|
||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||
|
||
;param data: image data, shape (num_rows x ncols x num_channels)
|
||
;param labeled_img: RGB labeled img with respect to the image!
|
||
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
|
||
;param blk_sz: block size
|
||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||
data y (block_num, ) 1 是杂质, 0是无杂质
|
||
"""
|
||
assert (data.shape[0] == labeled_img.shape[0]) and (data.shape[1] == labeled_img.shape[1])
|
||
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4} \
|
||
if color_dict is None else color_dict
|
||
class_img = np.zeros((labeled_img.shape[0], labeled_img.shape[1]), dtype=int)
|
||
for color, class_idx in color_dict.items():
|
||
truth_map = np.all(labeled_img == color, axis=2)
|
||
class_img[truth_map] = class_idx
|
||
x_list, y_list = [], []
|
||
for i in range(0, nrows // blk_sz):
|
||
for j in range(0, ncols // blk_sz):
|
||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||
block_label = determine_class(block_label, sensitivity=sensitivity)
|
||
if add_background:
|
||
y_list.append(block_label)
|
||
x_list.append(block_data)
|
||
else:
|
||
if block_label != 0:
|
||
y_list.append(block_label)
|
||
x_list.append(block_data)
|
||
return x_list, y_list
|
||
|
||
|
||
def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||
"""
|
||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||
|
||
;param data: image data, shape (num_rows x ncols x num_channels)
|
||
;param blk_sz: block size
|
||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||
"""
|
||
x_list = []
|
||
for i in range(0, nrows // blk_sz):
|
||
for j in range(0, ncols // blk_sz):
|
||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||
x_list.append(block_data)
|
||
return x_list
|
||
|
||
|
||
def visualization_evaluation(detector, data_path, selected_bands=None):
|
||
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
|
||
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
|
||
for idx, image_path in enumerate(image_paths):
|
||
with open(image_path, 'rb') as f:
|
||
data = f.read()
|
||
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)).transpose(0, 2, 1)
|
||
nbands = img.shape[2]
|
||
t1 = time.time()
|
||
mask = detector.predict(img[..., selected_bands] if nbands == 448 else img)
|
||
time_spent = time.time() - t1
|
||
if nbands == 448:
|
||
rgb_img = np.asarray(img[..., [372, 241, 169]] * 255, dtype=np.uint8)
|
||
else:
|
||
rgb_img = np.asarray(img[..., [0, 1, 2]] * 255, dtype=np.uint8)
|
||
mask_color = np.zeros_like(rgb_img)
|
||
mask_color[mask > 0] = (0, 0 , 255)
|
||
combine = cv2.addWeighted(rgb_img, 1, mask_color, 0.5, 0)
|
||
fig, axs = plt.subplots(1, 3)
|
||
axs[0].imshow(rgb_img)
|
||
axs[1].imshow(mask)
|
||
axs[2].imshow(combine)
|
||
fig.suptitle(f"time spent {time_spent * 1000:.2f} ms" + f"\n{image_path}")
|
||
plt.savefig(f"./dataset/{idx}.png", dpi=300)
|
||
plt.show()
|
||
|
||
|
||
def visualization_y(y_list, k_size):
|
||
mask = np.zeros((nrows // k_size, ncols // k_size), dtype=np.uint8)
|
||
for idx, r in enumerate(y_list):
|
||
row, col = idx // (ncols // k_size), idx % (ncols // k_size)
|
||
mask[row, col] = r
|
||
fig, axs = plt.subplots()
|
||
axs.imshow(mask)
|
||
plt.show()
|
||
|
||
|
||
def read_raw_file(file_name, selected_bands=None):
|
||
print(f"reading file {file_name}")
|
||
with open(file_name, "rb") as f:
|
||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, ncols)).transpose(0, 2, 1)
|
||
if selected_bands is not None:
|
||
data = data[..., selected_bands]
|
||
return data
|
||
|
||
|
||
def write_raw_file(file_name, data: np.ndarray):
|
||
data = data.transpose(0, 2, 1).reshape((nrows, -1, ncols))
|
||
with open(file_name, 'wb') as f:
|
||
f.write(data.tobytes())
|
||
|
||
|
||
def read_black_and_white_file(file_name):
|
||
with open(file_name, "rb") as f:
|
||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, ncols)).transpose(0, 2, 1)
|
||
return data
|
||
|
||
|
||
def label2pic(label, color_dict):
|
||
pic = np.zeros((label.shape[0], label.shape[1], 3))
|
||
for color, cls in color_dict.items():
|
||
pic[label == cls] = color
|
||
return pic
|
||
|
||
|
||
def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
|
||
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
|
||
y_label = model.predict(data)
|
||
x_list, y_list = [], []
|
||
for i in range(0, nrows // blk_sz):
|
||
for j in range(0, ncols // blk_sz):
|
||
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
|
||
> 0:
|
||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||
x_list.append(block_data)
|
||
y_list.append(1)
|
||
return x_list, y_list
|
||
|
||
|
||
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
|
||
target_class_left=None, ):
|
||
y_label = np.zeros((data.shape[0], data.shape[1]))
|
||
for i in range(0, nrows):
|
||
for j in range(0, ncols):
|
||
if np.sum(np.sum(data[i, j])) >= light_threshold:
|
||
if j > split_line:
|
||
y_label[i, j] = target_class_right
|
||
else:
|
||
y_label[i, j] = target_class_left
|
||
pic = label2pic(y_label, color_dict=color_dict)
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].matshow(y_label)
|
||
axs[1].matshow(data[..., 0])
|
||
plt.show()
|
||
return pic
|
||
|
||
|
||
def file_transform(input_dir, output_dir, selected_bands=None):
|
||
files = os.listdir(input_dir)
|
||
filtered_files = [file for file in files if file.endswith('.raw')]
|
||
os.makedirs(output_dir, mode=0o777, exist_ok=True)
|
||
for file_path in filtered_files:
|
||
input_path = os.path.join(input_dir, file_path)
|
||
output_path = os.path.join(output_dir, file_path)
|
||
data = read_raw_file(input_path, selected_bands=selected_bands)
|
||
write_raw_file(output_path, data)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
selected_bands = [127, 201, 202, 294]
|
||
input_dir, output_dir = r"/Volumes/LENOVO_USB_HDD/zhouchao/616/",\
|
||
r"/Volumes/LENOVO_USB_HDD/zhouchao/616_cut/"
|
||
file_transform(input_dir=input_dir, output_dir=output_dir, selected_bands=selected_bands) |