修改了float32类型的图片乱识别的bug

This commit is contained in:
FEIJINTI 2022-10-12 17:04:59 +08:00
parent e8af097ec7
commit f1cf707f6e
3 changed files with 11 additions and 9 deletions

View File

@ -13,6 +13,7 @@ from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from scipy.stats import binom from scipy.stats import binom
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import time import time
@ -22,7 +23,7 @@ sys.path.append(os.getcwd())
from root_dir import ROOT_DIR from root_dir import ROOT_DIR
import utils import utils
FEATURE_INDEX = [0, 1, 2, 3, 4, 5] FEATURE_INDEX = [1, 2]
class WoodClass(object): class WoodClass(object):

View File

@ -48,7 +48,7 @@ def main(is_debug=False):
console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
logging.basicConfig(format='%(asctime)s %(filename)s[line:%(lineno)d] - %(levelname)s - %(message)s', logging.basicConfig(format='%(asctime)s %(filename)s[line:%(lineno)d] - %(levelname)s - %(message)s',
handlers=[file_handler, console_handler], level=logging.DEBUG) handlers=[file_handler, console_handler], level=logging.DEBUG)
dual_sock = DualSock(connect_ip='192.168.2.221') dual_sock = DualSock(connect_ip='127.0.0.1')
while not dual_sock.status: while not dual_sock.status:
dual_sock.reconnect() dual_sock.reconnect()
@ -56,6 +56,7 @@ def main(is_debug=False):
# model_path = os.path.join(ROOT_DIR, r"models\model_2022-09-28_13-15.p") # model_path = os.path.join(ROOT_DIR, r"models\model_2022-09-28_13-15.p")
detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False) detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False)
detector.load(path=model_path) detector.load(path=model_path)
_ = detector.predict(np.random.randint(1, 254, (1200, 4096, 3), dtype=np.uint8))
while True: while True:
pack, next_pack = receive_sock(dual_sock) pack, next_pack = receive_sock(dual_sock)
if pack == b"": if pack == b"":

View File

@ -201,7 +201,7 @@ bytes, bytes):
try: try:
temp += recv_sock.receive(data_len) temp += recv_sock.receive(data_len)
except Exception as e: except Exception as e:
logging.error(f'接收报文内容失败, 错误代码: \n{e}\n报文内容\n{temp}') logging.error(f'接收报文内容失败, 错误代码: \n{e}')
return b'', b'' return b'', b''
data, next_pack = temp[:data_len], temp[data_len:] data, next_pack = temp[:data_len], temp[data_len:]
recv_sock.set_prepack(next_pack) recv_sock.set_prepack(next_pack)
@ -216,7 +216,7 @@ bytes, bytes):
try: try:
temp += recv_sock.receive(1) temp += recv_sock.receive(1)
except Exception as e: except Exception as e:
logging.error(f'接收报文校验失败, 错误代码: \n{e}, 报文如下: \n{temp}') logging.error(f'接收报文校验失败, 错误代码: \n{e}')
return b'', b'' return b'', b''
if temp == b'\xff\xff\xbb': if temp == b'\xff\xff\xbb':
return data, next_pack return data, next_pack
@ -244,15 +244,15 @@ def parse_protocol(data: bytes) -> (str, any):
try: try:
n_rows, n_cols = [int.from_bytes(x, byteorder='big') for x in [n_rows, n_cols]] n_rows, n_cols = [int.from_bytes(x, byteorder='big') for x in [n_rows, n_cols]]
except Exception as e: except Exception as e:
logging.error(f'长宽转换失败, 错误代码{e}, 报文内容: n_rows:{n_rows}, n_cols: {n_cols}') logging.error(f'长宽转换失败, 错误代码{e}, 报文大小: n_rows:{n_rows}, n_cols: {n_cols}')
return '', None return '', None
try: try:
assert n_rows * n_cols * 12 == len(img) assert n_rows * n_cols * 3 == len(img)
# 因为是float32类型 所以长度要乘12 如果是uint8则乘3 # 因为是float32类型 所以长度要乘12 如果是uint8则乘3
except AssertionError: except AssertionError:
logging.error('图像指令IM转换失败数据长度错误') logging.error('图像指令IM转换失败数据长度错误')
return '', None return '', None
img = np.frombuffer(img, dtype=np.float32).reshape((n_rows, n_cols, -1)) img = np.frombuffer(img, dtype=np.uint8).reshape((n_rows, n_cols, -1))
return cmd, img return cmd, img
elif cmd == 'TR': elif cmd == 'TR':
data = data.decode('ascii') data = data.decode('ascii')
@ -313,11 +313,11 @@ def simple_sock(send_sock: socket.socket, cmd_type: str, result: int = '') -> bo
cmd_type = cmd_type.strip().upper() cmd_type = cmd_type.strip().upper()
if cmd_type == 'IM': if cmd_type == 'IM':
if result == 0: if result == 0:
msg = b'Q' msg = b'S'
elif result == 1: elif result == 1:
msg = b'Z' msg = b'Z'
elif result == 2: elif result == 2:
msg = b'S' msg = b'Q'
elif cmd_type == 'TR': elif cmd_type == 'TR':
msg = b'A' msg = b'A'
elif cmd_type == 'MD': elif cmd_type == 'MD':