mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
在main_test内添加了模拟运行程序
This commit is contained in:
parent
cdbe341591
commit
081b4e050f
31
main.py
31
main.py
@ -74,37 +74,6 @@ def main():
|
||||
print(f'total time is:{t3 - t1}')
|
||||
|
||||
|
||||
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
|
||||
if os.path.isdir(buffer_path):
|
||||
buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
|
||||
else:
|
||||
buffer_names = [buffer_path, ]
|
||||
for buffer_name in buffer_names:
|
||||
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
|
||||
data = f.read()
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
|
||||
.transpose(0, 2, 1)
|
||||
if selected_bands is not None:
|
||||
img = img[..., selected_bands]
|
||||
if img.shape[0] == 1:
|
||||
img = img[0, ...]
|
||||
if not no_mask:
|
||||
mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
|
||||
with open(os.path.join(buffer_path, mask_name), 'rb') as f:
|
||||
data = f.read()
|
||||
mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
|
||||
else:
|
||||
mask_name = "no mask"
|
||||
mask = np.zeros_like(img)
|
||||
# mask = cv2.resize(mask, (1024, 256))
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].matshow(img)
|
||||
axs[0].set_title(buffer_name)
|
||||
axs[1].imshow(mask)
|
||||
axs[1].set_title(mask_name)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 相关参数
|
||||
img_fifo_path = "/tmp/dkimg.fifo"
|
||||
|
||||
144
main_test.py
144
main_test.py
@ -10,73 +10,99 @@ import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import transmit
|
||||
from config import Config
|
||||
from models import Detector, AnonymousColorDetector, ManualTree
|
||||
from utils import read_labeled_img, size_threshold
|
||||
from models import Detector, AnonymousColorDetector, ManualTree, SpecDetector, RgbDetector
|
||||
from utils import read_labeled_img, size_threshold, natural_sort
|
||||
|
||||
|
||||
def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False):
|
||||
"""
|
||||
虚拟读图测试程序
|
||||
class TestMain:
|
||||
def __init__(self):
|
||||
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||
pixel_model_path=Config.pixel_model_path)
|
||||
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||
background_model_path=Config.rgb_background_model_path)
|
||||
|
||||
:param test_img: 测试图像,rgb格式的图片或者路径
|
||||
:param test_img_dir: 测试图像文件夹
|
||||
:param test_spectra: 是否测试光谱
|
||||
:param test_rgb: 是否测试rgb
|
||||
:return:
|
||||
"""
|
||||
if (test_img is not None) or (test_img_dir is not None):
|
||||
threshold = Config.spec_size_threshold
|
||||
rgb_threshold = Config.rgb_size_threshold
|
||||
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
||||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
||||
if test_img is not None:
|
||||
if isinstance(test_img, str):
|
||||
img = cv2.imread(test_img)[:, :, ::-1]
|
||||
elif isinstance(test_img, np.ndarray):
|
||||
img = test_img
|
||||
def pony_run(self, test_path, test_spectra=False, test_rgb=False,
|
||||
convert=False):
|
||||
"""
|
||||
虚拟读图测试程序
|
||||
|
||||
:param test_path: 测试文件夹或者图片
|
||||
:param test_spectra: 是否测试光谱
|
||||
:param test_rgb: 是否测试rgb
|
||||
:param convert: 是否进行格式转化
|
||||
:return:
|
||||
"""
|
||||
if os.path.isdir(test_path):
|
||||
rgb_file_names, spec_file_names = [[file_name for file_name in os.listdir(test_path) if
|
||||
file_name.startswith(file_type)] for file_type in ['rgb', 'spec']]
|
||||
rgb_file_names, spec_file_names = natural_sort(rgb_file_names), natural_sort(spec_file_names)
|
||||
else:
|
||||
raise TypeError("test img should be np.ndarray or str")
|
||||
if test_img_dir is not None:
|
||||
image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')]
|
||||
for image_name in image_names:
|
||||
rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
|
||||
# 识别
|
||||
t1 = time.time()
|
||||
if test_spectra:
|
||||
# spectra part
|
||||
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = manual_tree.blk_predict(data=img_data)
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
mask_spec = size_threshold(mask, Config.blk_size, threshold)
|
||||
with open(test_path, 'rb') as f:
|
||||
data = f.read()
|
||||
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
|
||||
_ = self.test_spec(spec_img=spec_img)
|
||||
elif test_rgb:
|
||||
with open(test_path, 'rb') as f:
|
||||
data = f.read()
|
||||
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
|
||||
_ = self.test_rgb(rgb_img)
|
||||
return
|
||||
for rgb_file_name, spec_file_name in zip(rgb_file_names, spec_file_names):
|
||||
if test_spectra:
|
||||
with open(os.path.join(test_path, spec_file_name), 'rb') as f:
|
||||
data = f.read()
|
||||
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
|
||||
spec_mask = self.test_spec(spec_img, img_name=spec_file_name)
|
||||
if test_rgb:
|
||||
# rgb part
|
||||
rgb_data = tobacco_detector.pretreatment(rgb_data)
|
||||
background = background_detector.predict(rgb_data)
|
||||
tobacco = tobacco_detector.predict(rgb_data)
|
||||
tobacco_d = tobacco_detector.swell(tobacco)
|
||||
rgb_predict_result = 1 - (background | tobacco_d)
|
||||
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold)
|
||||
fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True)
|
||||
axs[0].imshow(rgb_data)
|
||||
axs[0].set_title("rgb raw data")
|
||||
axs[1].imshow(background)
|
||||
axs[1].set_title("background")
|
||||
axs[2].imshow(tobacco)
|
||||
axs[2].set_title("tobacco")
|
||||
axs[3].imshow(rgb_predict_result)
|
||||
axs[3].set_title("1 - (background + dilate(tobacco))")
|
||||
axs[4].imshow(mask_rgb)
|
||||
axs[4].set_title("final mask")
|
||||
plt.show()
|
||||
with open(os.path.join(test_path, rgb_file_name), 'rb') as f:
|
||||
data = f.read()
|
||||
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
|
||||
rgb_mask = self.test_rgb(rgb_img, img_name=rgb_file_name)
|
||||
if test_rgb and test_spectra:
|
||||
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
|
||||
spec_img=spec_img[..., [21, 3, 0]], spec_mask=spec_mask,
|
||||
file_name=rgb_file_name)
|
||||
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
# mask_result = rgb_predict_result
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t2 = time.time()
|
||||
print(f'rgb len = {len(rgb_data)}')
|
||||
def test_rgb(self, rgb_img, img_name):
|
||||
rgb_mask = self._rgb_detector.predict(rgb_img)
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].imshow(rgb_img)
|
||||
axs[0].set_title(f"rgb img {img_name}")
|
||||
axs[1].imshow(rgb_mask)
|
||||
axs[1].set_title('rgb mask')
|
||||
plt.show()
|
||||
return rgb_mask
|
||||
|
||||
def test_spec(self, spec_img, img_name):
|
||||
spec_mask = self._spec_detector.predict(spec_img)
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].imshow(spec_img[..., [21, 3, 0]])
|
||||
axs[0].set_title(f"spec img {img_name}")
|
||||
axs[1].imshow(spec_mask)
|
||||
axs[1].set_title('spec mask')
|
||||
plt.show()
|
||||
return spec_mask
|
||||
|
||||
@staticmethod
|
||||
def merge(rgb_img, rgb_mask, spec_img, spec_mask, file_name):
|
||||
mask_result = (spec_mask | rgb_mask).astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
fig, axs = plt.subplots(3, 2)
|
||||
axs[0, 0].set_title(file_name)
|
||||
axs[0, 0].imshow(rgb_img)
|
||||
axs[1, 0].imshow(spec_img)
|
||||
axs[2, 0].imshow(mask_result)
|
||||
axs[0, 1].imshow(rgb_mask)
|
||||
axs[1, 1].imshow(spec_mask)
|
||||
axs[2, 1].imshow(mask_result)
|
||||
plt.show()
|
||||
return mask_result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True)
|
||||
testor = TestMain()
|
||||
testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\728-1-3',
|
||||
test_rgb=True, test_spectra=True)
|
||||
|
||||
@ -9,7 +9,7 @@ from models import SpecDetector, RgbDetector
|
||||
import typing
|
||||
import logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||
level=logging.DEBUG)
|
||||
level=logging.INFO)
|
||||
|
||||
|
||||
class Transmitter(object):
|
||||
|
||||
7
utils.py
7
utils.py
@ -10,6 +10,13 @@ from queue import Queue
|
||||
import cv2
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
import re
|
||||
|
||||
|
||||
def natural_sort(l):
|
||||
convert = lambda text: int(text) if text.isdigit() else text.lower()
|
||||
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
return sorted(l, key=alphanum_key)
|
||||
|
||||
|
||||
class MergeDict(dict):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user