import cv2 import numpy as np from genetic_selection import GeneticSelectionCV from sklearn.metrics import confusion_matrix, classification_report from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier import logging import os import shutil import time import socket def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None): """ Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII... :param file_name: file name of ENVI ascii file :param hdr_file_name: hdr file name for a "BANDS" vector in the output :param save_xy: save the x, y position on the first two cols of the result vector :return: dict {class_name: vector, ...} """ number_line_start_with = "; Number of ROIs: " roi_name_start_with, roi_npts_start_with = "; ROI name: ", "; ROI npts:" data_start_with, data_start_with2, data_start_with3 = "; ID", "; ID", "; ID" class_num, class_names, class_nums, vectors = 0, [], [], [] with open(file_name, 'r') as f: for line_text in f: if line_text.startswith(number_line_start_with): class_num = int(line_text[len(number_line_start_with):]) elif line_text.startswith(roi_name_start_with): class_names.append(line_text[len(roi_name_start_with):-1]) elif line_text.startswith(roi_npts_start_with): class_nums.append(int(line_text[len(roi_name_start_with):-1])) elif line_text.startswith(data_start_with) or line_text.startswith(data_start_with2) or line_text.startswith(data_start_with3): col_list = list(filter(None, line_text[1:].split(" "))) assert (len(class_names) == class_num) and (len(class_names) == len(class_nums)) break elif line_text.startswith(";"): continue for vector_rows in class_nums: vector_str = '' for i in range(vector_rows): vector_str += f.readline() vector = np.fromstring(vector_str, dtype=float, sep=" ").reshape(-1, len(col_list)) assert vector.shape[0] == vector_rows vector = vector[:, 3:] if not save_xy else vector[:, 1:] vectors.append(vector) f.readline() # suppose to read a blank line if hdr_file_name is not None: import re with open(hdr_file_name, 'r') as f: hdr_info = f.read() bands = re.findall(r"wavelength = {[^{}]+}", hdr_info, flags=re.IGNORECASE | re.MULTILINE) bands_num = re.findall(r"bands\s*=\s*(\d+)", hdr_info, flags=re.I) if (len(bands) == 0) or len(bands_num) == 0: Warning("The given hdr file is invalid, can't find bands = ? or wavelength = {?}.") else: bands = re.findall(r'{[^{}]+}', bands[0], flags=re.MULTILINE)[0][3:-2] bands = bands.split(',\n') bands = np.asarray(bands, dtype=float) bands_num = int(bands_num[0]) if bands_num == bands.shape[0]: bands = np.array(bands, dtype=float) vectors.append(bands) class_names.append("BANDS") else: Warning("The given hdr file is invalid, bands num is not equal to wavelength.") return dict(zip(class_names, vectors)) def ga_feature_extraction(data_x, data_y): ''' 使用遗传算法进行特征提取 :param data_x: 特征 :param data_y: 类别 ''' Xtrain, Xtest, Ytrain, Ytest = train_test_split(data_x, data_y, test_size=0.3) clf = DecisionTreeClassifier(random_state=3) selector = GeneticSelectionCV(clf, cv=30, verbose=1, scoring="accuracy", max_features=10, n_population=500, crossover_proba=0.6, mutation_proba=0.3, n_generations=300, crossover_independent_proba=0.6, mutation_independent_proba=0.1, tournament_size=10, n_gen_no_change=10, caching=True, n_jobs=-1) selector = selector.fit(Xtrain, Ytrain) Xtrain_ga, Xtest_ga = Xtrain[:, selector.support_], Xtest[:, selector.support_] clf = clf.fit(Xtrain_ga, Ytrain) print(np.where(selector.support_ == True)) y_pred = clf.predict(Xtest_ga) print(classification_report(Ytest, y_pred)) print(confusion_matrix(Ytest, y_pred)) def read_raw(file_name, shape=None, setect_bands=None, cut_shape=None): ''' 读取raw文件 :param file_name: 文件名 :param setect_bands: 选择的波段 :return: 波段数据 ''' if shape is None: shape = (692, 272, 384) with open(file_name, 'rb') as f: data = np.frombuffer(f.read(), dtype=np.float32).reshape(shape).transpose(0, 2, 1) if setect_bands is not None: data = data[:, :, setect_bands] if cut_shape is not None: data = data[: cut_shape[0], : cut_shape[1], :] return data def save_raw(file_name, data): ''' 保存raw文件 :param file_name: 文件名 :param data: 数据 ''' data = data.transpose(0, 2, 1) # 将data转换为一维数组 data = data.reshape(-1) with open(file_name, 'wb') as f: f.write(data.astype(np.float32).tobytes()) def read_rgb(file_name): ''' 读取rgb文件 :param file_name: 文件名 :return: rgb数据 ''' data = cv2.imread(file_name) data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) #给一个颜色对应的字典,用于将rgb转换为类别,白色对应0,黄色对应1,青色对应2,红色对应3,绿色对应4,蓝色对应5 color_dict = {(255, 255, 255): 0, (255, 255, 0): 1, (0, 255, 255): 2, (255, 0, 0): 3, (0, 255, 0): 4, (0, 0, 255): 5} # 保存图片的形状,用于将一维数组转换为三维数组 shape = data.shape # 将rgb转换为类别 data = data.reshape(-1, 3).tolist() # 将rgb转换为类别 mapped_data = [] for i, color in enumerate(data): mapped_value = color_dict.get(tuple(color)) if mapped_value is None: print("No mapping found for color", color, "at index", i) else: mapped_data.append(mapped_value) # 将一维数组转换为三维数组 data = np.array(mapped_data).reshape(shape[0], shape[1]) return data def read_data(raw_path, rgb_path, shape=None, setect_bands=None, blk_size=4, cut_shape=None, dp=False): ''' 读取数据 :param raw_path: raw文件路径 :param rgb_path: rgb文件路径 :param setect_bands: 选择的波段 :return: 波段数据,rgb数据 ''' if shape is None: shape = (692, 272, 384) with open(raw_path, 'rb') as f: raw = np.frombuffer(f.read(), dtype=np.float32).reshape(shape).transpose(0, 2, 1) if setect_bands is not None: raw = raw[:, :, setect_bands] color_dict = {(255, 255, 255): 0, (255, 255, 0): 1, (0, 255, 255): 2, (255, 0, 0): 3, (0, 255, 0): 4, (0, 0, 255): 5} rgb = cv2.imread(rgb_path) rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) if cut_shape is not None: raw = raw[ :cut_shape[0], :cut_shape[1], :] rgb = rgb[ :cut_shape[0], :cut_shape[1], :] data_x = [] data_y = [] for i in range(0, rgb.shape[0], blk_size): for j in range(0, rgb.shape[1], blk_size): x = raw[i:i + blk_size, j:j + blk_size, :] y = rgb[i:i + blk_size, j:j + blk_size] # # 取y的第三行第三列的像素值,判断该像素值是否在color_dict中,如果在则将x和y添加到data_x和data_y中 # y = tuple(y[2, 2, :]) # if y in color_dict.keys(): # data_x.append(x) # data_y.append(color_dict[y]) # 取y的中心点像素值,判断该像素值是否在color_dict中,如果在则将x和y添加到data_x和data_y中 y = tuple(y[blk_size//2, blk_size//2, :]) if y in color_dict.keys(): data_x.append(x) data_y.append(color_dict[y]) data_x = np.array(data_x) data_y = np.array(data_y).astype(np.uint8) return data_x, data_y def try_connect(connect_ip: str, port_number: int, is_repeat: bool = False, max_reconnect_times: int = 50) -> ( bool, socket.socket): """ 尝试连接. :param is_repeat: 是否是重新连接 :param max_reconnect_times:最大重连次数 :return: (连接状态True为成功, Socket / None) """ reconnect_time = 0 while reconnect_time < max_reconnect_times: logging.warning(f'尝试{"重新" if is_repeat else ""}发起第{reconnect_time + 1}次连接...') try: connected_sock = PreSocket(socket.AF_INET, socket.SOCK_STREAM) connected_sock.connect((connect_ip, port_number)) except Exception as e: reconnect_time += 1 logging.error(f'第{reconnect_time}次连接失败... 5秒后重新连接...\n {e}') time.sleep(5) continue logging.warning(f'{"重新" if is_repeat else ""}连接成功') return True, connected_sock return False, None class PreSocket(socket.socket): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pre_pack = b'' self.settimeout(5) def receive(self, *args, **kwargs): if self.pre_pack == b'': return self.recv(*args, **kwargs) else: data_len = args[0] required, left = self.pre_pack[:data_len], self.pre_pack[data_len:] self.pre_pack = left return required def set_prepack(self, pre_pack: bytes): temp = self.pre_pack self.pre_pack = temp + pre_pack class DualSock(PreSocket): def __init__(self, connect_ip='127.0.0.1', recv_port: int = 21122, send_port: int = 21123): super().__init__() received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port) send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port) self.status = received_status and send_status def send(self, *args, **kwargs) -> int: return self.send_sock.send(*args, **kwargs) def receive(self, *args, **kwargs) -> bytes: return self.received_sock.receive(*args, **kwargs) def set_prepack(self, pre_pack: bytes): self.received_sock.set_prepack(pre_pack) def reconnect(self, connect_ip='127.0.0.1', recv_port:int = 21122, send_port: int = 21123): received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port) send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port) return received_status and send_status def receive_sock(recv_sock: PreSocket, pre_pack: bytes = b'', time_out: float = -1.0, time_out_single=5e20) -> ( bytes, bytes): """ 从指定的socket中读取数据.自动阻塞,如果返回的数据为空则说明连接出现问题,需要重新连接。 :param recv_sock: 指定sock :param pre_pack: 上一包的粘包内容 :param time_out: 每隔time_out至少要发来一次指令,否则认为出现问题进行重连,小于0则为一直等 :param time_out_single: 单次指令超时时间,单位是秒 :return: data, next_pack """ recv_sock.set_prepack(pre_pack) # 开头校验 time_start_recv = time.time() while True: if time_out > 0: if (time.time() - time_start_recv) > time_out: logging.error(f'指令接收超时') return b'', b'' try: temp = recv_sock.receive(1) except ConnectionError as e: logging.error(f'连接出错, 错误代码:\n{e}') return b'', b'' except TimeoutError as e: # logging.error(f'超时了,错误代码: \n{e}') logging.info('运行中,等待指令..') continue except socket.timeout as e: logging.info('运行中,等待指令..') continue except Exception as e: logging.error(f'遇见未知错误,错误代码: \n{e}') return b'', b'' if temp == b'\xaa': break # 接收开头后,开始进行时间记录 time_start_recv = time.time() # 获取报文长度 temp = b'' while len(temp) < 4: if (time.time() - time_start_recv) > time_out_single: logging.error(f'单次指令接收超时') return b'', b'' try: temp += recv_sock.receive(1) except Exception as e: logging.error(f'接收报文的长度不正确, 错误代码: \n{e}') return b'', b'' try: data_len = int.from_bytes(temp, byteorder='big') except Exception as e: logging.error(f'转换失败,错误代码 \n{e}') return b'', b'' # 读取报文内容 temp = b'' while len(temp) < data_len: if (time.time() - time_start_recv) > time_out_single: logging.error(f'单次指令接收超时') return b'', b'' try: temp += recv_sock.receive(data_len) except Exception as e: logging.error(f'接收报文内容失败, 错误代码: \n{e}') return b'', b'' data, next_pack = temp[:data_len], temp[data_len:] recv_sock.set_prepack(next_pack) next_pack = b'' # 进行数据校验 temp = b'' while len(temp) < 3: if (time.time() - time_start_recv) > time_out_single: logging.error(f'单次指令接收超时') return b'', b'' try: temp += recv_sock.receive(1) except Exception as e: logging.error(f'接收报文校验失败, 错误代码: \n{e}') return b'', b'' if temp == b'\xff\xff\xbb': return data, next_pack else: logging.error(f"接收了一个完美的只错了校验位的报文") return b'', b'' def parse_protocol(data: bytes) -> (str, any): ''' 指令转换 :param data: 接收到的报文 :return: 指令类型,指令内容 ''' try: assert len(data) > 4 except AssertionError: logging.error('指令转换失败,长度不足5') return '', None cmd, data = data[:4], data[4:] cmd = cmd.decode('ascii').strip().upper() if cmd == 'IM': n_rows, n_cols, img = data[:2], data[2:4], data[4:] try: n_rows, n_cols = [int.from_bytes(x, byteorder='big') for x in [n_rows, n_cols]] except Exception as e: logging.error(f'长宽转换失败, 错误代码{e}, 报文大小: n_rows:{n_rows}, n_cols: {n_cols}') return '', None try: assert n_rows * n_cols * 3 == len(img) except AssertionError: logging.error('图像指令IM转换失败,数据长度错误') return '', None img = np.frombuffer(img, dtype=np.uint8).reshape((n_rows, n_cols, -1)) return cmd, img def ack_sock(send_sock: socket.socket, cmd_type: str) -> bool: ''' 发送应答 :param cmd_type:指令类型 :param send_sock:指定sock :return:是否发送成功 ''' msg = b'\xaa\x00\x00\x00\x05' + (' A' + cmd_type).upper().encode('ascii') + b'\xff\xff\xff\xbb' try: send_sock.send(msg) except Exception as e: logging.error(f'发送应答失败,错误类型:{e}') return False return True def done_sock(send_sock: socket.socket, cmd_type: str, result = '') -> bool: ''' 发送任务完成指令 :param send_sock: 指定sock :param cmd_type: 指令类型 :param result: 数据 :return: 是否发送成功 ''' cmd = cmd_type.strip().upper() if cmd_type == 'IM': result = result.encode() # 指令4位 length = len(result) + 4 length = length.to_bytes(4, byteorder='big') # msg = b'\xaa' + length + (' D' + cmd).upper().encode('ascii') + result + b'\xff\xff\xbb' msg = result try: send_sock.send(msg) except Exception as e: logging.error(f'发送完成指令失败,错误类型:{e}') return False return True def simple_sock(send_sock: socket.socket, cmd_type: str, result) -> bool: ''' 发送任务完成指令 :param cmd_type:指令类型 :param send_sock:指定sock :param result:数据 :return:是否发送成功 ''' cmd_type = cmd_type.strip().upper() if cmd_type == 'IM': if result == 0: msg = b'S' elif result == 1: msg = b'Z' elif result == 2: msg = b'Q' elif cmd_type == 'TR': msg = b'A' elif cmd_type == 'MD': msg = b'D' elif cmd_type == 'KM': msg = b'K' result = result.encode('ascii') result = b',' + result length = len(result) msg = msg + length.to_bytes(4, 'big') + result try: send_sock.send(msg) except Exception as e: logging.error(f'发送完成指令失败,错误类型:{e}') return False return True def mkdir_if_not_exist(dir_name, is_delete=False): """ 创建文件夹 :param dir_name: 文件夹 :param is_delete: 是否删除 :return: 是否成功 """ try: if is_delete: if os.path.exists(dir_name): shutil.rmtree(dir_name) print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name) if not os.path.exists(dir_name): os.makedirs(dir_name) print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name) return True except Exception as e: print('[Exception] %s' % e) return False def create_file(file_name): """ 创建文件 :param file_name: 文件名 :return: None """ if os.path.exists(file_name): print("文件存在:%s" % file_name) return False # os.remove(file_name) # 删除已有文件 if not os.path.exists(file_name): print("文件不存在,创建文件:%s" % file_name) open(file_name, 'a').close() return True class Logger(object): def __init__(self, is_to_file=False, path=None): self.is_to_file = is_to_file if path is None: path = "Astragalins.log" self.path = path create_file(path) def log(self, content): if self.is_to_file: with open(self.path, "a") as f: print(time.strftime("[%Y-%m-%d_%H-%M-%S]:"), file=f) print(content, file=f) else: print(content)