diff --git a/main.py b/main.py index 9f4a8f6..2f88c89 100755 --- a/main.py +++ b/main.py @@ -74,37 +74,6 @@ def main(): print(f'total time is:{t3 - t1}') -def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None): - if os.path.isdir(buffer_path): - buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')] - else: - buffer_names = [buffer_path, ] - for buffer_name in buffer_names: - with open(os.path.join(buffer_path, buffer_name), 'rb') as f: - data = f.read() - img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \ - .transpose(0, 2, 1) - if selected_bands is not None: - img = img[..., selected_bands] - if img.shape[0] == 1: - img = img[0, ...] - if not no_mask: - mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '') - with open(os.path.join(buffer_path, mask_name), 'rb') as f: - data = f.read() - mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1)) - else: - mask_name = "no mask" - mask = np.zeros_like(img) - # mask = cv2.resize(mask, (1024, 256)) - fig, axs = plt.subplots(2, 1) - axs[0].matshow(img) - axs[0].set_title(buffer_name) - axs[1].imshow(mask) - axs[1].set_title(mask_name) - plt.show() - - if __name__ == '__main__': # 相关参数 img_fifo_path = "/tmp/dkimg.fifo" diff --git a/main_test.py b/main_test.py index 2a329b0..80b667b 100644 --- a/main_test.py +++ b/main_test.py @@ -10,73 +10,99 @@ import cv2 import matplotlib.pyplot as plt import numpy as np +import transmit from config import Config -from models import Detector, AnonymousColorDetector, ManualTree -from utils import read_labeled_img, size_threshold +from models import Detector, AnonymousColorDetector, ManualTree, SpecDetector, RgbDetector +from utils import read_labeled_img, size_threshold, natural_sort -def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False): - """ - 虚拟读图测试程序 +class TestMain: + def __init__(self): + self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, + pixel_model_path=Config.pixel_model_path) + self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, + background_model_path=Config.rgb_background_model_path) - :param test_img: 测试图像,rgb格式的图片或者路径 - :param test_img_dir: 测试图像文件夹 - :param test_spectra: 是否测试光谱 - :param test_rgb: 是否测试rgb - :return: - """ - if (test_img is not None) or (test_img_dir is not None): - threshold = Config.spec_size_threshold - rgb_threshold = Config.rgb_size_threshold - manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) - tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) - background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path) - if test_img is not None: - if isinstance(test_img, str): - img = cv2.imread(test_img)[:, :, ::-1] - elif isinstance(test_img, np.ndarray): - img = test_img + def pony_run(self, test_path, test_spectra=False, test_rgb=False, + convert=False): + """ + 虚拟读图测试程序 + + :param test_path: 测试文件夹或者图片 + :param test_spectra: 是否测试光谱 + :param test_rgb: 是否测试rgb + :param convert: 是否进行格式转化 + :return: + """ + if os.path.isdir(test_path): + rgb_file_names, spec_file_names = [[file_name for file_name in os.listdir(test_path) if + file_name.startswith(file_type)] for file_type in ['rgb', 'spec']] + rgb_file_names, spec_file_names = natural_sort(rgb_file_names), natural_sort(spec_file_names) else: - raise TypeError("test img should be np.ndarray or str") - if test_img_dir is not None: - image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')] - for image_name in image_names: - rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1] - # 识别 - t1 = time.time() if test_spectra: - # spectra part - pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1) - blk_predict_result = manual_tree.blk_predict(data=img_data) - mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) - mask_spec = size_threshold(mask, Config.blk_size, threshold) + with open(test_path, 'rb') as f: + data = f.read() + spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data) + _ = self.test_spec(spec_img=spec_img) + elif test_rgb: + with open(test_path, 'rb') as f: + data = f.read() + rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data) + _ = self.test_rgb(rgb_img) + return + for rgb_file_name, spec_file_name in zip(rgb_file_names, spec_file_names): + if test_spectra: + with open(os.path.join(test_path, spec_file_name), 'rb') as f: + data = f.read() + spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data) + spec_mask = self.test_spec(spec_img, img_name=spec_file_name) if test_rgb: - # rgb part - rgb_data = tobacco_detector.pretreatment(rgb_data) - background = background_detector.predict(rgb_data) - tobacco = tobacco_detector.predict(rgb_data) - tobacco_d = tobacco_detector.swell(tobacco) - rgb_predict_result = 1 - (background | tobacco_d) - mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold) - fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True) - axs[0].imshow(rgb_data) - axs[0].set_title("rgb raw data") - axs[1].imshow(background) - axs[1].set_title("background") - axs[2].imshow(tobacco) - axs[2].set_title("tobacco") - axs[3].imshow(rgb_predict_result) - axs[3].set_title("1 - (background + dilate(tobacco))") - axs[4].imshow(mask_rgb) - axs[4].set_title("final mask") - plt.show() + with open(os.path.join(test_path, rgb_file_name), 'rb') as f: + data = f.read() + rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data) + rgb_mask = self.test_rgb(rgb_img, img_name=rgb_file_name) + if test_rgb and test_spectra: + self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask, + spec_img=spec_img[..., [21, 3, 0]], spec_mask=spec_mask, + file_name=rgb_file_name) - mask_result = (mask | mask_rgb).astype(np.uint8) - # mask_result = rgb_predict_result - mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) - t2 = time.time() - print(f'rgb len = {len(rgb_data)}') + def test_rgb(self, rgb_img, img_name): + rgb_mask = self._rgb_detector.predict(rgb_img) + fig, axs = plt.subplots(2, 1) + axs[0].imshow(rgb_img) + axs[0].set_title(f"rgb img {img_name}") + axs[1].imshow(rgb_mask) + axs[1].set_title('rgb mask') + plt.show() + return rgb_mask + + def test_spec(self, spec_img, img_name): + spec_mask = self._spec_detector.predict(spec_img) + fig, axs = plt.subplots(2, 1) + axs[0].imshow(spec_img[..., [21, 3, 0]]) + axs[0].set_title(f"spec img {img_name}") + axs[1].imshow(spec_mask) + axs[1].set_title('spec mask') + plt.show() + return spec_mask + + @staticmethod + def merge(rgb_img, rgb_mask, spec_img, spec_mask, file_name): + mask_result = (spec_mask | rgb_mask).astype(np.uint8) + mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + fig, axs = plt.subplots(3, 2) + axs[0, 0].set_title(file_name) + axs[0, 0].imshow(rgb_img) + axs[1, 0].imshow(spec_img) + axs[2, 0].imshow(mask_result) + axs[0, 1].imshow(rgb_mask) + axs[1, 1].imshow(spec_mask) + axs[2, 1].imshow(mask_result) + plt.show() + return mask_result if __name__ == '__main__': - pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True) + testor = TestMain() + testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\728-1-3', + test_rgb=True, test_spectra=True) diff --git a/transmit.py b/transmit.py index 2632a12..099befd 100644 --- a/transmit.py +++ b/transmit.py @@ -9,7 +9,7 @@ from models import SpecDetector, RgbDetector import typing import logging logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', - level=logging.DEBUG) + level=logging.INFO) class Transmitter(object): diff --git a/utils.py b/utils.py index bf466f8..9359bc4 100755 --- a/utils.py +++ b/utils.py @@ -10,6 +10,13 @@ from queue import Queue import cv2 import numpy as np from matplotlib import pyplot as plt +import re + + +def natural_sort(l): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] + return sorted(l, key=alphanum_key) class MergeDict(dict):