Polarization_Camera/ThresholdSimulation.py
2024-12-19 14:18:12 +08:00

289 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import os
import cv2
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
class ImageLabel(QtWidgets.QLabel):
def __init__(self, parent=None):
super().__init__(parent)
self.setAlignment(QtCore.Qt.AlignCenter)
self.setBackgroundRole(QtGui.QPalette.Base)
self.setSizePolicy(QtWidgets.QSizePolicy.Ignored, QtWidgets.QSizePolicy.Ignored)
self.setScaledContents(True)
class ImageViewer(QtWidgets.QWidget):
def __init__(self, folder_path):
super().__init__()
self.folder_path = folder_path
self.image_files = [f for f in os.listdir(folder_path) if f.lower().endswith('.bmp')]
self.image_files.sort()
self.current_index = 0
self.threshold_g = 50 # 初始 T_g 值,根据需要调整
self.threshold_diff = 10 # 初始 T 值,根据需要调整
self.zoom_factor = 1.0 # 初始缩放因子
# 设置UI
self.initUI()
if self.image_files:
self.loadImage(self.current_index)
def initUI(self):
# 主布局
main_layout = QtWidgets.QHBoxLayout(self)
# 侧边栏
self.sidebar = QtWidgets.QListWidget()
self.sidebar.addItems(self.image_files)
self.sidebar.currentRowChanged.connect(self.sidebarSelectionChanged)
main_layout.addWidget(self.sidebar, 1)
# 右侧区域
right_layout = QtWidgets.QVBoxLayout()
# 图片显示区域使用 QScrollArea 以支持滚动
self.original_label = ImageLabel("Original Image")
self.masked_label = ImageLabel("Masked Image")
self.original_scroll = QtWidgets.QScrollArea()
self.original_scroll.setWidgetResizable(True)
self.original_scroll.setWidget(self.original_label)
self.masked_scroll = QtWidgets.QScrollArea()
self.masked_scroll.setWidgetResizable(True)
self.masked_scroll.setWidget(self.masked_label)
# 使用垂直布局上下显示两张图片
images_layout = QtWidgets.QVBoxLayout()
images_layout.addWidget(self.original_scroll)
images_layout.addWidget(self.masked_scroll)
right_layout.addLayout(images_layout, 8)
# 阈值调整区域
threshold_layout = QtWidgets.QGridLayout()
# T_g 调整
self.slider_g = QtWidgets.QSlider(QtCore.Qt.Horizontal)
self.slider_g.setMinimum(0)
self.slider_g.setMaximum(255)
self.slider_g.setValue(self.threshold_g)
self.slider_g.valueChanged.connect(self.sliderGChanged)
self.input_g = QtWidgets.QLineEdit(str(self.threshold_g))
self.input_g.setFixedWidth(50)
self.input_g.returnPressed.connect(self.inputGChanged)
threshold_layout.addWidget(QtWidgets.QLabel("T_g (G > T_g):"), 0, 0)
threshold_layout.addWidget(self.slider_g, 0, 1)
threshold_layout.addWidget(self.input_g, 0, 2)
# T 调整
self.slider_diff = QtWidgets.QSlider(QtCore.Qt.Horizontal)
self.slider_diff.setMinimum(-255)
self.slider_diff.setMaximum(255)
self.slider_diff.setValue(self.threshold_diff)
self.slider_diff.valueChanged.connect(self.sliderDiffChanged)
self.input_diff = QtWidgets.QLineEdit(str(self.threshold_diff))
self.input_diff.setFixedWidth(50)
self.input_diff.returnPressed.connect(self.inputDiffChanged)
threshold_layout.addWidget(QtWidgets.QLabel("T (G - R > T):"), 1, 0)
threshold_layout.addWidget(self.slider_diff, 1, 1)
threshold_layout.addWidget(self.input_diff, 1, 2)
right_layout.addLayout(threshold_layout, 2)
# 按钮区域(包括导航和缩放按钮)
button_layout = QtWidgets.QHBoxLayout()
self.prev_button = QtWidgets.QPushButton("Previous (P)")
self.next_button = QtWidgets.QPushButton("Next (N)")
self.zoom_in_button = QtWidgets.QPushButton("Zoom In (+)")
self.zoom_out_button = QtWidgets.QPushButton("Zoom Out (-)")
self.prev_button.clicked.connect(self.showPreviousImage)
self.next_button.clicked.connect(self.showNextImage)
self.zoom_in_button.clicked.connect(self.zoomIn)
self.zoom_out_button.clicked.connect(self.zoomOut)
button_layout.addWidget(self.prev_button)
button_layout.addWidget(self.next_button)
button_layout.addStretch()
button_layout.addWidget(self.zoom_in_button)
button_layout.addWidget(self.zoom_out_button)
right_layout.addLayout(button_layout, 1)
main_layout.addLayout(right_layout, 4)
self.setLayout(main_layout)
self.setWindowTitle("BMP Image Viewer with Mask and Zoom")
self.resize(1600, 900)
def loadImage(self, index):
if index < 0 or index >= len(self.image_files):
return
image_path = os.path.join(self.folder_path, self.image_files[index])
image = cv2.imread(image_path)
if image is None:
QtWidgets.QMessageBox.warning(self, "Error", f"无法加载图片: {image_path}")
return
# 调整大小以适应窗口
image = self.resizeImage(image, max_width=800, max_height=800)
self.original_image = image.copy()
self.displayImage(self.original_label, self.original_image)
# 生成掩膜并显示
mask = self.generateMask(image, self.threshold_g, self.threshold_diff)
masked_image = self.applyMask(image, mask)
self.masked_image = masked_image
self.displayImage(self.masked_label, self.masked_image)
# 重置缩放因子
self.zoom_factor = 1.0
# 更新侧边栏选择
self.sidebar.blockSignals(True)
self.sidebar.setCurrentRow(index)
self.sidebar.blockSignals(False)
def resizeImage(self, image, max_width=800, max_height=800):
height, width = image.shape[:2]
scaling_factor = min(max_width / width, max_height / height, 1)
new_size = (int(width * scaling_factor), int(height * scaling_factor))
resized_image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
return resized_image
def displayImage(self, label, image):
# 转换颜色格式从 BGR 到 RGB
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
height, width, channel = rgb_image.shape
bytes_per_line = 3 * width
q_image = QtGui.QImage(rgb_image.data, width, height, bytes_per_line, QtGui.QImage.Format_RGB888)
pixmap = QtGui.QPixmap.fromImage(q_image)
scaled_pixmap = pixmap.scaled(pixmap.size() * self.zoom_factor, QtCore.Qt.KeepAspectRatio,
QtCore.Qt.SmoothTransformation)
label.setPixmap(scaled_pixmap)
label.adjustSize()
def generateMask(self, image, threshold_g, threshold_diff):
# 分离 G 和 R 通道
B, G, R = cv2.split(image)
# 生成掩膜G > T_g 且 (G - R) > T
mask = (G > threshold_g) & ((G.astype(int) - R.astype(int)) > threshold_diff)
return mask
def applyMask(self, image, mask):
# 创建一个红色的掩膜
overlay = image.copy()
overlay[mask] = [0, 0, 255] # BGR格式红色
# 透明度混合
alpha = 0.5
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
return image
def showNextImage(self):
if self.current_index < len(self.image_files) - 1:
self.current_index += 1
self.loadImage(self.current_index)
def showPreviousImage(self):
if self.current_index > 0:
self.current_index -= 1
self.loadImage(self.current_index)
def sliderGChanged(self, value):
self.threshold_g = value
self.input_g.setText(str(self.threshold_g))
self.updateMask()
def sliderDiffChanged(self, value):
self.threshold_diff = value
self.input_diff.setText(str(self.threshold_diff))
self.updateMask()
def inputGChanged(self):
try:
value = int(self.input_g.text())
if 0 <= value <= 255:
self.threshold_g = value
self.slider_g.setValue(value)
self.updateMask()
else:
raise ValueError
except ValueError:
QtWidgets.QMessageBox.warning(self, "Invalid Input", "T_g 必须是0到255之间的整数。")
self.input_g.setText(str(self.threshold_g))
def inputDiffChanged(self):
try:
value = int(self.input_diff.text())
if -255 <= value <= 255:
self.threshold_diff = value
self.slider_diff.setValue(value)
self.updateMask()
else:
raise ValueError
except ValueError:
QtWidgets.QMessageBox.warning(self, "Invalid Input", "T 必须是-255到255之间的整数。")
self.input_diff.setText(str(self.threshold_diff))
def updateMask(self):
if hasattr(self, 'original_image'):
mask = self.generateMask(self.original_image, self.threshold_g, self.threshold_diff)
masked_image = self.applyMask(self.original_image.copy(), mask)
self.masked_image = masked_image
self.displayImage(self.masked_label, self.masked_image)
def sidebarSelectionChanged(self, index):
if 0 <= index < len(self.image_files):
self.current_index = index
self.loadImage(self.current_index)
def keyPressEvent(self, event):
if event.key() == QtCore.Qt.Key_N:
self.showNextImage()
elif event.key() == QtCore.Qt.Key_P:
self.showPreviousImage()
elif event.key() == QtCore.Qt.Key_Plus or event.key() == QtCore.Qt.Key_Equal:
self.zoomIn()
elif event.key() == QtCore.Qt.Key_Minus:
self.zoomOut()
else:
super().keyPressEvent(event)
def zoomIn(self):
self.zoom_factor *= 1.25
self.updateDisplayImages()
def zoomOut(self):
self.zoom_factor /= 1.25
self.updateDisplayImages()
def updateDisplayImages(self):
if hasattr(self, 'original_image'):
self.displayImage(self.original_label, self.original_image)
if hasattr(self, 'masked_image'):
self.displayImage(self.masked_label, self.masked_image)
def main():
app = QtWidgets.QApplication(sys.argv)
# 这里你可以修改为你想要读取的文件夹路径
folder_path = QtWidgets.QFileDialog.getExistingDirectory(None, "选择包含 BMP 图片的文件夹")
if not folder_path:
sys.exit()
viewer = ImageViewer(folder_path)
viewer.show()
sys.exit(app.exec_())
if __name__ == '__main__':
main()