From edd3b0f9c82e859112bd270d9745579aa918d81b Mon Sep 17 00:00:00 2001 From: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Tue, 15 Nov 2022 12:46:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B1=8F=E8=94=BD=E4=BA=86=E5=B7=A6=E4=BE=A78?= =?UTF-8?q?=E4=B8=AA=E5=96=B7=E9=98=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 02_classification.ipynb | 27 ++++++++++----------------- 05_model_training.ipynb | 18 ++++++++++++------ config.py | 12 +++++------- main.py | 1 + utils/__init__.py | 12 ++++++++++-- 5 files changed, 38 insertions(+), 32 deletions(-) diff --git a/02_classification.ipynb b/02_classification.ipynb index d73e588..52dd0ee 100644 --- a/02_classification.ipynb +++ b/02_classification.ipynb @@ -18,7 +18,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "C:\\Users\\FEIJINTI\\miniconda3\\envs\\cv\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, @@ -54,8 +54,8 @@ "dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n", "\n", "# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n", - "color_dict = {(0, 0, 255): \"yangeng\"}\n", - "# color_dict = {(255, 0, 0): 'beijing'}\n", + "# color_dict = {(0, 0, 255): \"yangeng\"}\n", + "color_dict = {(255, 0, 0): 'beijing'}\n", "# color_dict = {(0, 255, 0): \"zibian\"}\n", "label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n", "show_samples = False # 是否展示样本\n", @@ -169,7 +169,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "23961it [00:46, 519.87it/s] " + "152147it [42:48, 59.24it/s] \n" ] }, { @@ -178,19 +178,12 @@ "text": [ " precision recall f1-score support\n", "\n", - " 0 1.00 1.00 1.00 7188\n", - " 1 1.00 1.00 1.00 5991\n", + " 0 1.00 1.00 1.00 45644\n", + " 1 1.00 1.00 1.00 38037\n", "\n", - " accuracy 1.00 13179\n", - " macro avg 1.00 1.00 1.00 13179\n", - "weighted avg 1.00 1.00 1.00 13179\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + " accuracy 1.00 83681\n", + " macro avg 1.00 1.00 1.00 83681\n", + "weighted avg 1.00 1.00 1.00 83681\n", "\n" ] } @@ -200,7 +193,7 @@ " data = scipy.io.loadmat(dataset_file)\n", " x, y = data['x'], data['y'].ravel()\n", " model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n", - "else:\n", + "else:8\n", " world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", " model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n", " is_save_dataset=True, model_selection='dt')\n", diff --git a/05_model_training.ipynb b/05_model_training.ipynb index d610ed1..3583fe6 100644 --- a/05_model_training.ipynb +++ b/05_model_training.ipynb @@ -15,14 +15,20 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "ename": "NotImplementedError", + "evalue": "开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mNotImplementedError\u001B[0m Traceback (most recent call last)", + "Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m\u001B[1;34m()\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpickle\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m read_envi_ascii\n\u001B[0;32m 4\u001B[0m \u001B[38;5;66;03m# from config import Config\u001B[39;00m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ManualTree\n", + "File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\utils\\__init__.py:20\u001B[0m, in \u001B[0;36m\u001B[1;34m\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m pyplot \u001B[38;5;28;01mas\u001B[39;00m plt\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mre\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mconfig\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Config\n\u001B[0;32m 23\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mnatural_sort\u001B[39m(l):\n\u001B[0;32m 24\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;124;03m 自然排序\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124;03m :param l: 待排序\u001B[39;00m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m :return:\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n", + "File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:6\u001B[0m, in \u001B[0;36m\u001B[1;34m\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 6\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mConfig\u001B[39;00m:\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 文件相关参数\u001B[39;00m\n\u001B[0;32m 8\u001B[0m nRows, nCols, nBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m256\u001B[39m, \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m22\u001B[39m\n\u001B[0;32m 9\u001B[0m nRgbRows, nRgbCols, nRgbBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m4096\u001B[39m, \u001B[38;5;241m3\u001B[39m\n", + "File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:31\u001B[0m, in \u001B[0;36mConfig\u001B[1;34m()\u001B[0m\n\u001B[0;32m 28\u001B[0m spec_size_threshold \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# rgb模型参数\u001B[39;00m\n\u001B[1;32m---> 31\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 32\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"weights/tobacco_dt_2022-08-27_14-43.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 33\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"weights/background_dt_2022-08-22_22-15.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 36\u001B[0m threshold_low, threshold_high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m230\u001B[39m\n", + "\u001B[1;31mNotImplementedError\u001B[0m: 开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行" ] } ], diff --git a/config.py b/config.py index eb11103..ce5c46a 100644 --- a/config.py +++ b/config.py @@ -20,23 +20,21 @@ class Config: # 光谱模型参数 blk_size = 4 # 必须是2的倍数,不然会出错 - raise NotImplementedError("开发时和机器上部署的路径不同,请注意选择pixel_model_path、blk_model_path后删除本行") - # pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径 + pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径 # pixel_model_path = r"/home/dt/tobacco-color/weights/pixel_2022-08-02_15-22.model" # 机器上部署的路径 - # blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径 + blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径 # blk_model_path = r"/home/dt/tobacco-color/weights/rf_4x4_c22_20_sen8_9.model" # 机器上部署的路径 spec_size_threshold = 3 # rgb模型参数 - raise NotImplementedError("开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行") - # rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径 + rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径 # rgb_tobacco_model_path = r"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model" # 机器上部署的路径 - # rgb_background_model_path = r"weights/background_dt_2022-08-22_22-15.model" # 开发时的路径 + rgb_background_model_path = r"weights/background_dt_2022-08-22_22-15.model" # 开发时的路径 # rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径 threshold_low, threshold_high = 10, 230 threshold_s = 190 # 饱和度的最高允许值 rgb_size_threshold = 4 # rgb的尺寸限制 - # ai_path = 'weights/best0827.pt' # 开发时的路径 + ai_path = 'weights/best0827.pt' # 开发时的路径 # ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径 ai_conf_threshold = 0.6 diff --git a/main.py b/main.py index e6eee8e..36fd479 100755 --- a/main.py +++ b/main.py @@ -124,6 +124,7 @@ def main(only_spec=False, only_color=False, if_merge=False, interval_time=None, mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8) masks = [mask_spec, mask_rgb] # 进行多个喷阀的合并 + masks = [utils_customized.shield_valve(mask, left_shield=8) for mask in masks] masks = [utils_customized.valve_expend(mask) for mask in masks] mask_nums = sum([np.sum(np.sum(mask)) for mask in masks]) log_time_count += 1 diff --git a/utils/__init__.py b/utils/__init__.py index ab9e5a2..be7ebc3 100755 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -178,6 +178,14 @@ def valve_expend(img: np.ndarray) -> np.ndarray: return img +def shield_valve(mask: np.ndarray, left_shield: int = -1, right_shield: int = -1) -> np.asarray: + if (left_shield < mask.shape[1]) & (left_shield > 0): + mask[:, :left_shield] = 0 + if (right_shield < mask.shape[1]) & (right_shield > 0): + mask[:, -right_shield:] = 0 + return mask + + def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray: """ 用于限制阀门同时开启个数的函数 @@ -282,8 +290,8 @@ def valve_log(log_path: pathlib.Path, valve_num: [int, str]): 将喷阀的开启次数记录到文件log_path当中。 """ valve_str = "截至 " + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + f' 喷阀使用次数: {valve_num}.' - with open(log_path, "w") as f: - f.write(str(valve_str)) + with open(log_path, "a") as f: + f.write(str(valve_str) + "\n") if __name__ == '__main__':