From f2c614dbcf15857b1c47d7c684afe5d999a8769f Mon Sep 17 00:00:00 2001 From: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Fri, 17 Mar 2023 13:25:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B1=8F=E8=94=BD=E4=BA=86=E5=8F=B3=E4=BE=A710?= =?UTF-8?q?=E4=B8=AA=E5=96=B7=E9=98=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 02_classification.ipynb | 26 +++++--------------------- main.py | 2 +- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/02_classification.ipynb b/02_classification.ipynb index 52dd0ee..71acc99 100644 --- a/02_classification.ipynb +++ b/02_classification.ipynb @@ -12,24 +12,8 @@ }, { "cell_type": "code", - "execution_count": 1, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\FEIJINTI\\miniconda3\\envs\\cv\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training env\n" - ] - } - ], + "execution_count": 7, + "outputs": [], "source": [ "import numpy as np\n", "import scipy\n", @@ -46,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "outputs": [], "source": [ "train_from_existed = False # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n", @@ -86,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "outputs": [], "source": [ "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n", @@ -193,7 +177,7 @@ " data = scipy.io.loadmat(dataset_file)\n", " x, y = data['x'], data['y'].ravel()\n", " model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n", - "else:8\n", + "else:\n", " world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", " model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n", " is_save_dataset=True, model_selection='dt')\n", diff --git a/main.py b/main.py index 15437b8..b4cd416 100755 --- a/main.py +++ b/main.py @@ -124,7 +124,7 @@ def main(only_spec=False, only_color=False, if_merge=False, interval_time=None, mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8) masks = [mask_spec, mask_rgb] # 进行多个喷阀的合并 - masks = [utils_customized.shield_valve(mask, left_shield=10) for mask in masks] + masks = [utils_customized.shield_valve(mask, left_shield=10, right_shield=10) for mask in masks] masks = [utils_customized.valve_expend(mask) for mask in masks] mask_nums = sum([np.sum(np.sum(mask)) for mask in masks]) log_time_count += 1