在main_test内添加了模拟运行程序

This commit is contained in:
FEIJINTI 2022-07-28 17:24:48 +08:00
parent cdbe341591
commit 081b4e050f
4 changed files with 93 additions and 91 deletions

31
main.py
View File

@ -74,37 +74,6 @@ def main():
print(f'total time is:{t3 - t1}')
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
if os.path.isdir(buffer_path):
buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
else:
buffer_names = [buffer_path, ]
for buffer_name in buffer_names:
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
data = f.read()
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
.transpose(0, 2, 1)
if selected_bands is not None:
img = img[..., selected_bands]
if img.shape[0] == 1:
img = img[0, ...]
if not no_mask:
mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
with open(os.path.join(buffer_path, mask_name), 'rb') as f:
data = f.read()
mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
else:
mask_name = "no mask"
mask = np.zeros_like(img)
# mask = cv2.resize(mask, (1024, 256))
fig, axs = plt.subplots(2, 1)
axs[0].matshow(img)
axs[0].set_title(buffer_name)
axs[1].imshow(mask)
axs[1].set_title(mask_name)
plt.show()
if __name__ == '__main__':
# 相关参数
img_fifo_path = "/tmp/dkimg.fifo"

View File

@ -10,73 +10,99 @@ import cv2
import matplotlib.pyplot as plt
import numpy as np
import transmit
from config import Config
from models import Detector, AnonymousColorDetector, ManualTree
from utils import read_labeled_img, size_threshold
from models import Detector, AnonymousColorDetector, ManualTree, SpecDetector, RgbDetector
from utils import read_labeled_img, size_threshold, natural_sort
def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False):
"""
虚拟读图测试程序
class TestMain:
def __init__(self):
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
pixel_model_path=Config.pixel_model_path)
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
:param test_img: 测试图像rgb格式的图片或者路径
:param test_img_dir: 测试图像文件夹
:param test_spectra: 是否测试光谱
:param test_rgb: 是否测试rgb
:return:
"""
if (test_img is not None) or (test_img_dir is not None):
threshold = Config.spec_size_threshold
rgb_threshold = Config.rgb_size_threshold
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
if test_img is not None:
if isinstance(test_img, str):
img = cv2.imread(test_img)[:, :, ::-1]
elif isinstance(test_img, np.ndarray):
img = test_img
def pony_run(self, test_path, test_spectra=False, test_rgb=False,
convert=False):
"""
虚拟读图测试程序
:param test_path: 测试文件夹或者图片
:param test_spectra: 是否测试光谱
:param test_rgb: 是否测试rgb
:param convert: 是否进行格式转化
:return:
"""
if os.path.isdir(test_path):
rgb_file_names, spec_file_names = [[file_name for file_name in os.listdir(test_path) if
file_name.startswith(file_type)] for file_type in ['rgb', 'spec']]
rgb_file_names, spec_file_names = natural_sort(rgb_file_names), natural_sort(spec_file_names)
else:
raise TypeError("test img should be np.ndarray or str")
if test_img_dir is not None:
image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')]
for image_name in image_names:
rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
# 识别
t1 = time.time()
if test_spectra:
# spectra part
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manual_tree.blk_predict(data=img_data)
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask_spec = size_threshold(mask, Config.blk_size, threshold)
with open(test_path, 'rb') as f:
data = f.read()
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
_ = self.test_spec(spec_img=spec_img)
elif test_rgb:
with open(test_path, 'rb') as f:
data = f.read()
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
_ = self.test_rgb(rgb_img)
return
for rgb_file_name, spec_file_name in zip(rgb_file_names, spec_file_names):
if test_spectra:
with open(os.path.join(test_path, spec_file_name), 'rb') as f:
data = f.read()
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
spec_mask = self.test_spec(spec_img, img_name=spec_file_name)
if test_rgb:
# rgb part
rgb_data = tobacco_detector.pretreatment(rgb_data)
background = background_detector.predict(rgb_data)
tobacco = tobacco_detector.predict(rgb_data)
tobacco_d = tobacco_detector.swell(tobacco)
rgb_predict_result = 1 - (background | tobacco_d)
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold)
fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True)
axs[0].imshow(rgb_data)
axs[0].set_title("rgb raw data")
axs[1].imshow(background)
axs[1].set_title("background")
axs[2].imshow(tobacco)
axs[2].set_title("tobacco")
axs[3].imshow(rgb_predict_result)
axs[3].set_title("1 - (background + dilate(tobacco))")
axs[4].imshow(mask_rgb)
axs[4].set_title("final mask")
plt.show()
with open(os.path.join(test_path, rgb_file_name), 'rb') as f:
data = f.read()
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
rgb_mask = self.test_rgb(rgb_img, img_name=rgb_file_name)
if test_rgb and test_spectra:
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
spec_img=spec_img[..., [21, 3, 0]], spec_mask=spec_mask,
file_name=rgb_file_name)
mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = rgb_predict_result
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
print(f'rgb len = {len(rgb_data)}')
def test_rgb(self, rgb_img, img_name):
rgb_mask = self._rgb_detector.predict(rgb_img)
fig, axs = plt.subplots(2, 1)
axs[0].imshow(rgb_img)
axs[0].set_title(f"rgb img {img_name}")
axs[1].imshow(rgb_mask)
axs[1].set_title('rgb mask')
plt.show()
return rgb_mask
def test_spec(self, spec_img, img_name):
spec_mask = self._spec_detector.predict(spec_img)
fig, axs = plt.subplots(2, 1)
axs[0].imshow(spec_img[..., [21, 3, 0]])
axs[0].set_title(f"spec img {img_name}")
axs[1].imshow(spec_mask)
axs[1].set_title('spec mask')
plt.show()
return spec_mask
@staticmethod
def merge(rgb_img, rgb_mask, spec_img, spec_mask, file_name):
mask_result = (spec_mask | rgb_mask).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
fig, axs = plt.subplots(3, 2)
axs[0, 0].set_title(file_name)
axs[0, 0].imshow(rgb_img)
axs[1, 0].imshow(spec_img)
axs[2, 0].imshow(mask_result)
axs[0, 1].imshow(rgb_mask)
axs[1, 1].imshow(spec_mask)
axs[2, 1].imshow(mask_result)
plt.show()
return mask_result
if __name__ == '__main__':
pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True)
testor = TestMain()
testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\728-1-3',
test_rgb=True, test_spectra=True)

View File

@ -9,7 +9,7 @@ from models import SpecDetector, RgbDetector
import typing
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.DEBUG)
level=logging.INFO)
class Transmitter(object):

View File

@ -10,6 +10,13 @@ from queue import Queue
import cv2
import numpy as np
from matplotlib import pyplot as plt
import re
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
class MergeDict(dict):