分离预测

This commit is contained in:
FEIJINTI 2022-07-21 17:13:44 +08:00
parent b1be6eeb49
commit f7c88d3c2a
3 changed files with 23 additions and 32 deletions

View File

@ -37,7 +37,7 @@
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹放置相同文件名的图片和label\n", "data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹放置相同文件名的图片和label\n",
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n", "dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
"\n", "\n",
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n", "color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
"# color_dict = {(0, 0, 255): \"yangeng\"}\n", "# color_dict = {(0, 0, 255): \"yangeng\"}\n",
"# color_dict = {(255, 0, 0): 'beijing'}\n", "# color_dict = {(255, 0, 0): 'beijing'}\n",
"color_dict = {(0, 255, 0): \"zibian\"}\n", "color_dict = {(0, 255, 0): \"zibian\"}\n",

View File

@ -1,4 +1,3 @@
# %%
import numpy as np import numpy as np
import scipy import scipy
from imblearn.under_sampling import RandomUnderSampler from imblearn.under_sampling import RandomUnderSampler
@ -6,12 +5,12 @@ from models import AnonymousColorDetector
from utils import read_labeled_img from utils import read_labeled_img
# %% # %%
train_from_existed = True # 是否从现有数据训练如果是的话那就从dataset_file训练否则就用data_dir里头的数据 train_from_existed = False # 是否从现有数据训练如果是的话那就从dataset_file训练否则就用data_dir里头的数据
data_dir = "data/dataset" # 数据集,文件夹下必须包含`img`和`label`两个文件夹放置相同文件名的图片和label data_dir = "data/dataset" # 数据集,文件夹下必须包含`img`和`label`两个文件夹放置相同文件名的图片和label
dataset_file = "data/dataset/dataset_2022-07-20_10-04.mat" dataset_file = "data/dataset/dataset_2022-07-20_10-04.mat"
# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'} # 颜色对应的类别 color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing', (0, 255, 0): "zibian"} # 颜色对应的类别
color_dict = {(0, 0, 255): "yangeng"} # color_dict = {(0, 0, 255): "yangeng"}
# color_dict = {(255, 0, 0): 'beijing'} # color_dict = {(255, 0, 0): 'beijing'}
# color_dict = {(0, 255, 0): "zibian"} # color_dict = {(0, 255, 0): "zibian"}
label_index = {"yangeng": 1, "beijing": 0, "zibian": 2} # 类别对应的序号 label_index = {"yangeng": 1, "beijing": 0, "zibian": 2} # 类别对应的序号
@ -32,11 +31,13 @@ if show_samples:
# %% md # %% md
## 数据平衡化 ## 数据平衡化
# %% # %%
rus = RandomUnderSampler(random_state=0) if len(dataset) > 1:
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \ rus = RandomUnderSampler(random_state=0)
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist() x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
x_resampled, y_resampled = rus.fit_resample(x_list, y_list) np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()],
dataset = {"inside": np.array(x_resampled)} axis=0).tolist()
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
dataset = {"inside": np.array(x_resampled)}
# %% md # %% md
## 模型训练 ## 模型训练
# %% # %%
@ -53,14 +54,4 @@ else:
world_boundary = np.array([0, 0, 0, 255, 255, 255]) world_boundary = np.array([0, 0, 0, 255, 255, 255])
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7, model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
is_save_dataset=True, model_selection='dt') is_save_dataset=True, model_selection='dt')
model.save() model.save()
# %%
if train_from_existed:
data = scipy.io.loadmat(dataset_file)
x, y = data['x'], data['y'].ravel()
model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
else:
world_boundary = np.array([0, 0, 0, 255, 255, 255])
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
is_save_dataset=True, model_selection='dt')
model.save()

View File

@ -14,7 +14,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 44,
"outputs": [], "outputs": [],
"source": [ "source": [
"import datetime\n", "import datetime\n",
@ -39,7 +39,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 45,
"outputs": [], "outputs": [],
"source": [ "source": [
"img_path = r\"C:\\Users\\FEIJINTI\\Desktop\\721\\zazhi\\Image_2022_0721_1351_38_946-002034.bmp\"\n", "img_path = r\"C:\\Users\\FEIJINTI\\Desktop\\721\\zazhi\\Image_2022_0721_1351_38_946-002034.bmp\"\n",
@ -54,7 +54,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 46,
"outputs": [], "outputs": [],
"source": [ "source": [
"img = cv2.imread(img_path)[:, :, ::-1]\n", "img = cv2.imread(img_path)[:, :, ::-1]\n",
@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 47,
"outputs": [], "outputs": [],
"source": [ "source": [
"t_detector = AnonymousColorDetector(file_path=model_path[1])\n", "t_detector = AnonymousColorDetector(file_path=model_path[1])\n",
@ -86,7 +86,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 48,
"outputs": [], "outputs": [],
"source": [ "source": [
"z_detector = AnonymousColorDetector(file_path=model_path[2])\n", "z_detector = AnonymousColorDetector(file_path=model_path[2])\n",
@ -101,7 +101,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 49,
"outputs": [], "outputs": [],
"source": [ "source": [
"result = 1 - (b_result | t_result | z_result)" "result = 1 - (b_result | t_result | z_result)"
@ -115,7 +115,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 50,
"outputs": [ "outputs": [
{ {
"data": { "data": {
@ -128,16 +128,16 @@
{ {
"data": { "data": {
"text/plain": "<IPython.core.display.HTML object>", "text/plain": "<IPython.core.display.HTML object>",
"text/html": "<div id='cec7c80e-e50e-4a78-858f-b58ac7b4a36e'></div>" "text/html": "<div id='71e0e832-f293-40bc-9913-301125493d2e'></div>"
}, },
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
}, },
{ {
"data": { "data": {
"text/plain": "<matplotlib.image.AxesImage at 0x24e9025e170>" "text/plain": "<matplotlib.image.AxesImage at 0x24e8faaa440>"
}, },
"execution_count": 36, "execution_count": 50,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -158,7 +158,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 50,
"outputs": [], "outputs": [],
"source": [], "source": [],
"metadata": { "metadata": {