在main_test内添加了模拟运行程序

This commit is contained in:
FEIJINTI 2022-07-28 17:24:48 +08:00
parent cdbe341591
commit 081b4e050f
4 changed files with 93 additions and 91 deletions

31
main.py
View File

@ -74,37 +74,6 @@ def main():
print(f'total time is:{t3 - t1}') print(f'total time is:{t3 - t1}')
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
if os.path.isdir(buffer_path):
buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
else:
buffer_names = [buffer_path, ]
for buffer_name in buffer_names:
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
data = f.read()
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
.transpose(0, 2, 1)
if selected_bands is not None:
img = img[..., selected_bands]
if img.shape[0] == 1:
img = img[0, ...]
if not no_mask:
mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
with open(os.path.join(buffer_path, mask_name), 'rb') as f:
data = f.read()
mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
else:
mask_name = "no mask"
mask = np.zeros_like(img)
# mask = cv2.resize(mask, (1024, 256))
fig, axs = plt.subplots(2, 1)
axs[0].matshow(img)
axs[0].set_title(buffer_name)
axs[1].imshow(mask)
axs[1].set_title(mask_name)
plt.show()
if __name__ == '__main__': if __name__ == '__main__':
# 相关参数 # 相关参数
img_fifo_path = "/tmp/dkimg.fifo" img_fifo_path = "/tmp/dkimg.fifo"

View File

@ -10,73 +10,99 @@ import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import transmit
from config import Config from config import Config
from models import Detector, AnonymousColorDetector, ManualTree from models import Detector, AnonymousColorDetector, ManualTree, SpecDetector, RgbDetector
from utils import read_labeled_img, size_threshold from utils import read_labeled_img, size_threshold, natural_sort
def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False): class TestMain:
def __init__(self):
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
pixel_model_path=Config.pixel_model_path)
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
def pony_run(self, test_path, test_spectra=False, test_rgb=False,
convert=False):
""" """
虚拟读图测试程序 虚拟读图测试程序
:param test_img: 测试图像rgb格式的图片或者路径 :param test_path: 测试文件夹或者图片
:param test_img_dir: 测试图像文件夹
:param test_spectra: 是否测试光谱 :param test_spectra: 是否测试光谱
:param test_rgb: 是否测试rgb :param test_rgb: 是否测试rgb
:param convert: 是否进行格式转化
:return: :return:
""" """
if (test_img is not None) or (test_img_dir is not None): if os.path.isdir(test_path):
threshold = Config.spec_size_threshold rgb_file_names, spec_file_names = [[file_name for file_name in os.listdir(test_path) if
rgb_threshold = Config.rgb_size_threshold file_name.startswith(file_type)] for file_type in ['rgb', 'spec']]
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_file_names, spec_file_names = natural_sort(rgb_file_names), natural_sort(spec_file_names)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
if test_img is not None:
if isinstance(test_img, str):
img = cv2.imread(test_img)[:, :, ::-1]
elif isinstance(test_img, np.ndarray):
img = test_img
else: else:
raise TypeError("test img should be np.ndarray or str")
if test_img_dir is not None:
image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')]
for image_name in image_names:
rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
# 识别
t1 = time.time()
if test_spectra: if test_spectra:
# spectra part with open(test_path, 'rb') as f:
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1) data = f.read()
blk_predict_result = manual_tree.blk_predict(data=img_data) spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) _ = self.test_spec(spec_img=spec_img)
mask_spec = size_threshold(mask, Config.blk_size, threshold) elif test_rgb:
with open(test_path, 'rb') as f:
data = f.read()
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
_ = self.test_rgb(rgb_img)
return
for rgb_file_name, spec_file_name in zip(rgb_file_names, spec_file_names):
if test_spectra:
with open(os.path.join(test_path, spec_file_name), 'rb') as f:
data = f.read()
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
spec_mask = self.test_spec(spec_img, img_name=spec_file_name)
if test_rgb: if test_rgb:
# rgb part with open(os.path.join(test_path, rgb_file_name), 'rb') as f:
rgb_data = tobacco_detector.pretreatment(rgb_data) data = f.read()
background = background_detector.predict(rgb_data) rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
tobacco = tobacco_detector.predict(rgb_data) rgb_mask = self.test_rgb(rgb_img, img_name=rgb_file_name)
tobacco_d = tobacco_detector.swell(tobacco) if test_rgb and test_spectra:
rgb_predict_result = 1 - (background | tobacco_d) self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold) spec_img=spec_img[..., [21, 3, 0]], spec_mask=spec_mask,
fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True) file_name=rgb_file_name)
axs[0].imshow(rgb_data)
axs[0].set_title("rgb raw data")
axs[1].imshow(background)
axs[1].set_title("background")
axs[2].imshow(tobacco)
axs[2].set_title("tobacco")
axs[3].imshow(rgb_predict_result)
axs[3].set_title("1 - (background + dilate(tobacco))")
axs[4].imshow(mask_rgb)
axs[4].set_title("final mask")
plt.show()
mask_result = (mask | mask_rgb).astype(np.uint8) def test_rgb(self, rgb_img, img_name):
# mask_result = rgb_predict_result rgb_mask = self._rgb_detector.predict(rgb_img)
fig, axs = plt.subplots(2, 1)
axs[0].imshow(rgb_img)
axs[0].set_title(f"rgb img {img_name}")
axs[1].imshow(rgb_mask)
axs[1].set_title('rgb mask')
plt.show()
return rgb_mask
def test_spec(self, spec_img, img_name):
spec_mask = self._spec_detector.predict(spec_img)
fig, axs = plt.subplots(2, 1)
axs[0].imshow(spec_img[..., [21, 3, 0]])
axs[0].set_title(f"spec img {img_name}")
axs[1].imshow(spec_mask)
axs[1].set_title('spec mask')
plt.show()
return spec_mask
@staticmethod
def merge(rgb_img, rgb_mask, spec_img, spec_mask, file_name):
mask_result = (spec_mask | rgb_mask).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time() fig, axs = plt.subplots(3, 2)
print(f'rgb len = {len(rgb_data)}') axs[0, 0].set_title(file_name)
axs[0, 0].imshow(rgb_img)
axs[1, 0].imshow(spec_img)
axs[2, 0].imshow(mask_result)
axs[0, 1].imshow(rgb_mask)
axs[1, 1].imshow(spec_mask)
axs[2, 1].imshow(mask_result)
plt.show()
return mask_result
if __name__ == '__main__': if __name__ == '__main__':
pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True) testor = TestMain()
testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\728-1-3',
test_rgb=True, test_spectra=True)

View File

@ -9,7 +9,7 @@ from models import SpecDetector, RgbDetector
import typing import typing
import logging import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.DEBUG) level=logging.INFO)
class Transmitter(object): class Transmitter(object):

View File

@ -10,6 +10,13 @@ from queue import Queue
import cv2 import cv2
import numpy as np import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import re
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
class MergeDict(dict): class MergeDict(dict):