diff --git a/main.py b/main.py index e368ea7..ee51cf4 100755 --- a/main.py +++ b/main.py @@ -72,20 +72,31 @@ def main(): print(f'total time is:{t3 - t1}\n') -def read_c_captures(buffer_path): - buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')] +def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None): + if os.path.isdir(buffer_path): + buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')] + else: + buffer_names = [buffer_path, ] for buffer_name in buffer_names: with open(os.path.join(buffer_path, buffer_name), 'rb') as f: data = f.read() - img = np.frombuffer(data, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ + img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \ .transpose(0, 2, 1) - mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '') - with open(os.path.join(buffer_path, mask_name), 'rb') as f: - data = f.read() - mask = np.frombuffer(data, dtype=np.uint8).reshape((256, 1024, -1)) + if selected_bands is not None: + img = img[..., selected_bands] + if img.shape[0] == 1: + img = img[0, ...] + if not no_mask: + mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '') + with open(os.path.join(buffer_path, mask_name), 'rb') as f: + data = f.read() + mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1)) + else: + mask_name = "no mask" + mask = np.zeros_like(img) # mask = cv2.resize(mask, (1024, 256)) fig, axs = plt.subplots(2, 1) - axs[0].imshow(img[..., [21, 3, 0]]) + axs[0].matshow(img) axs[0].set_title(buffer_name) axs[1].imshow(mask) axs[1].set_title(mask_name) @@ -100,4 +111,5 @@ if __name__ == '__main__': # 主函数 main() - # read_c_captures('/home/lzy/2022.7.20/tobacco_v1_0/') + # read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024, + # selected_bands=[380, 300, 200])