mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
268 lines
10 KiB
Python
Executable File
268 lines
10 KiB
Python
Executable File
# -*- codeing = utf-8 -*-
|
||
# Time : 2022/7/18 9:46
|
||
# @Auther : zhouchao
|
||
# @File: utils.py
|
||
# @Software:PyCharm
|
||
import glob
|
||
import logging
|
||
import os
|
||
from queue import Queue
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from numpy.random import default_rng
|
||
from matplotlib import pyplot as plt
|
||
import re
|
||
|
||
|
||
def natural_sort(l):
|
||
"""
|
||
自然排序
|
||
:param l: 待排序
|
||
:return:
|
||
"""
|
||
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
||
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
||
return sorted(l, key=alphanum_key)
|
||
|
||
|
||
class MergeDict(dict):
|
||
def __init__(self):
|
||
super(MergeDict, self).__init__()
|
||
|
||
def merge(self, merged: dict):
|
||
for k, v in merged.items():
|
||
if k not in self.keys():
|
||
self.update({k: v})
|
||
else:
|
||
original = self.__getitem__(k)
|
||
new_value = np.concatenate([original, v], axis=0)
|
||
self.update({k: new_value})
|
||
return self
|
||
|
||
|
||
class ImgQueue(Queue):
|
||
"""
|
||
A custom queue subclass that provides a :meth:`clear` method.
|
||
"""
|
||
|
||
def clear(self):
|
||
"""
|
||
Clears all items from the queue.
|
||
"""
|
||
|
||
with self.mutex:
|
||
unfinished = self.unfinished_tasks - len(self.queue)
|
||
if unfinished <= 0:
|
||
if unfinished < 0:
|
||
raise ValueError('task_done() called too many times')
|
||
self.all_tasks_done.notify_all()
|
||
self.unfinished_tasks = unfinished
|
||
self.queue.clear()
|
||
self.not_full.notify_all()
|
||
|
||
def safe_put(self, item):
|
||
if self.full():
|
||
_ = self.get()
|
||
return False
|
||
self.put(item)
|
||
return True
|
||
|
||
|
||
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
|
||
"""
|
||
根据dataset_dir下的文件创建数据集
|
||
|
||
:param dataset_dir: 文件夹名称,文件夹内必须包含'label'和'label'两个文件夹,并分别存放同名的图像与标签
|
||
:param color_dict: 进行标签图像的颜色查找
|
||
:param ext: 图片后缀名,默认为.bmp
|
||
:param is_ps_color_space: 是否使用ps的标准lab色彩空间,默认True
|
||
:return: 字典形式的数据集{label: vector(n x 3)},vector为lab色彩空间
|
||
"""
|
||
img_names = [img_name for img_name in os.listdir(os.path.join(dataset_dir, 'label'))
|
||
if img_name.endswith(ext)]
|
||
total_dataset = MergeDict()
|
||
for img_name in img_names:
|
||
img_path, label_path = [os.path.join(dataset_dir, folder, img_name) for folder in ['img', 'label']]
|
||
# 读取图片和色彩空间转换
|
||
img = cv2.imread(img_path)
|
||
label_img = cv2.imread(label_path)
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
||
# 从opencv的色彩空间到Photoshop的色彩空间
|
||
if is_ps_color_space:
|
||
alpha, beta = np.array([100 / 255, 1, 1]), np.array([0, -128, -128])
|
||
img = img * alpha + beta
|
||
img = np.asarray(np.round(img, 0), dtype=int)
|
||
dataset = {label: img[np.all(label_img == color, axis=2)] for color, label in color_dict.items()}
|
||
total_dataset.merge(dataset)
|
||
return total_dataset
|
||
|
||
|
||
def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_space=True, **kwargs):
|
||
"""
|
||
在lab色彩空间内绘制3维数据分布情况
|
||
|
||
:param dataset: 字典形式的数据集{label: vector(n x 3)},vector为lab色彩空间
|
||
:param class_max_num: 每个类别最多画的样本数量,默认不限制
|
||
:param is_3d: 进行lab三维绘制或者a,b两通道绘制
|
||
:param is_ps_color_space: 是否使用ps的标准lab色彩空间,默认True
|
||
:return: None
|
||
"""
|
||
# 观察色彩分布情况
|
||
fig = plt.figure()
|
||
if is_3d:
|
||
ax = fig.add_subplot(projection='3d')
|
||
else:
|
||
ax = fig.add_subplot()
|
||
for label, data in dataset.items():
|
||
if class_max_num is not None:
|
||
assert isinstance(class_max_num, int)
|
||
if data.shape[0] > class_max_num:
|
||
sample_idx = np.arange(data.shape[0])
|
||
sample_idx = np.random.choice(sample_idx, class_max_num)
|
||
data = data[sample_idx, :]
|
||
l, a, b = [data[:, i] for i in range(3)]
|
||
if is_3d:
|
||
ax.scatter(a, b, l, label=label, alpha=0.1)
|
||
else:
|
||
ax.scatter(a, b, label=label, alpha=0.1)
|
||
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
||
[255, 0, 255, 0, 255, 0]
|
||
ax.set_xlim(x_min, x_max)
|
||
ax.set_ylim(y_min, y_max)
|
||
ax.set_xlabel('a*')
|
||
ax.set_ylabel('b*')
|
||
if is_3d:
|
||
ax.set_zlim(z_min, z_max)
|
||
ax.set_zlabel('L')
|
||
plt.legend()
|
||
plt.show()
|
||
|
||
|
||
def size_threshold(img, blk_size, threshold, last_end: np.ndarray = None) -> np.ndarray:
|
||
mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \
|
||
reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1)
|
||
mask[mask <= threshold] = 0
|
||
mask[mask > threshold] = 1
|
||
if last_end is not None:
|
||
half_blk_size = blk_size // 2
|
||
assert (last_end.shape[0] == half_blk_size) and (last_end.shape[1] == img.shape[1])
|
||
mask_up = np.concatenate((last_end, img[:-half_blk_size, :]), axis=0)
|
||
mask_up_right = np.concatenate((mask_up[:, half_blk_size:],
|
||
np.zeros((img.shape[0], half_blk_size), dtype=np.uint8)), axis=1)
|
||
mask_up = size_threshold(mask_up, blk_size, threshold)
|
||
mask_up_right = size_threshold(mask_up_right, blk_size, threshold)
|
||
mask[:-1, :] = mask_up[1:, :]
|
||
mask[:-1, 1:] = mask_up_right[1:, :-1]
|
||
return mask
|
||
|
||
|
||
def valve_merge(img: np.ndarray, merge_size: int = 2) -> np.ndarray:
|
||
assert img.shape[1] % merge_size == 0 # 列数必须能够被整除
|
||
img_shape = (img.shape[1], img.shape[0])
|
||
img = img.reshape((img.shape[0], img.shape[1] // merge_size, merge_size))
|
||
img = np.sum(img, axis=2)
|
||
img[img > 0] = 1
|
||
img = cv2.resize(img.astype(np.uint8), dsize=img_shape)
|
||
return img
|
||
|
||
|
||
def valve_expend(img: np.ndarray) -> np.ndarray:
|
||
kernel = np.ones((1, 3), np.uint8)
|
||
img = cv2.dilate(img, kernel, iterations=1)
|
||
return img
|
||
|
||
|
||
def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
|
||
"""
|
||
用于限制阀门同时开启个数的函数
|
||
:param mask: 阀门开启mask,0,1格式,每个1对应一个阀门
|
||
:param max_valve_num: 最大阀门数量
|
||
:return:
|
||
"""
|
||
assert (max_valve_num >= 5) and (max_valve_num < 50)
|
||
row_valve_count = np.sum(mask, axis=1)
|
||
if np.any(row_valve_count > max_valve_num):
|
||
over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel()
|
||
logging.warning(f'发现单行喷阀数量{len(over_rows_idx)}超过限制,已限制到最大许可值{max_valve_num}')
|
||
over_rows = mask[over_rows_idx, :]
|
||
|
||
# a simple function to get lucky valves when too many valves appear in the same line
|
||
def process_row(each_row):
|
||
valve_idx = np.argwhere(each_row > 0).ravel()
|
||
lucky_valve_idx = default_rng().choice(valve_idx, max_valve_num)
|
||
new_row = np.zeros_like(each_row)
|
||
new_row[lucky_valve_idx] = 1
|
||
return new_row
|
||
limited_rows = np.apply_along_axis(process_row, 1, over_rows)
|
||
mask[over_rows_idx] = limited_rows
|
||
return mask
|
||
|
||
|
||
def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
|
||
"""
|
||
Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII...
|
||
|
||
:param file_name: file name of ENVI ascii file
|
||
:param hdr_file_name: hdr file name for a "BANDS" vector in the output
|
||
:param save_xy: save the x, y position on the first two cols of the result vector
|
||
:return: dict {class_name: vector, ...}
|
||
"""
|
||
number_line_start_with = "; Number of ROIs: "
|
||
roi_name_start_with, roi_npts_start_with = "; ROI name: ", "; ROI npts: "
|
||
data_start_with = "; ID"
|
||
class_num, class_names, class_nums, vectors = 0, [], [], []
|
||
with open(file_name, 'r') as f:
|
||
for line_text in f:
|
||
if line_text.startswith(number_line_start_with):
|
||
class_num = int(line_text[len(number_line_start_with):])
|
||
elif line_text.startswith(roi_name_start_with):
|
||
class_names.append(line_text[len(roi_name_start_with):-1])
|
||
elif line_text.startswith(roi_npts_start_with):
|
||
class_nums.append(int(line_text[len(roi_name_start_with):-1]))
|
||
elif line_text.startswith(data_start_with):
|
||
col_list = list(filter(None, line_text[1:].split(" ")))
|
||
assert (len(class_names) == class_num) and (len(class_names) == len(class_nums))
|
||
break
|
||
elif line_text.startswith(";"):
|
||
continue
|
||
for vector_rows in class_nums:
|
||
vector_str = ''
|
||
for i in range(vector_rows):
|
||
vector_str += f.readline()
|
||
vector = np.fromstring(vector_str, dtype=np.float, sep=" ").reshape(-1, len(col_list))
|
||
assert vector.shape[0] == vector_rows
|
||
vector = vector[:, 3:] if not save_xy else vector[:, 1:]
|
||
vectors.append(vector)
|
||
f.readline() # suppose to read a blank line
|
||
if hdr_file_name is not None:
|
||
bands = []
|
||
with open(hdr_file_name, 'r') as f:
|
||
start_bands = False
|
||
for line_text in f:
|
||
if start_bands:
|
||
if line_text.endswith(",\n"):
|
||
bands.append(float(line_text[:-2]))
|
||
else:
|
||
bands.append(float(line_text))
|
||
break
|
||
elif line_text.startswith("wavelength ="):
|
||
start_bands = True
|
||
bands = np.array(bands, dtype=np.float)
|
||
vectors.append(bands)
|
||
class_names.append("BANDS")
|
||
return dict(zip(class_names, vectors))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
|
||
# (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
||
# dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False)
|
||
# lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False)
|
||
a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8)
|
||
# a.repeat(3, axis=0)
|
||
# b = valve_merge(a, 2)
|
||
# print(b)
|
||
c = valve_expend(a)
|
||
print(c)
|