分离预测

This commit is contained in:
FEIJINTI 2022-07-21 17:13:44 +08:00
parent b1be6eeb49
commit f7c88d3c2a
3 changed files with 23 additions and 32 deletions

View File

@ -37,7 +37,7 @@
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹放置相同文件名的图片和label\n",
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
"\n",
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n",
"color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
"# color_dict = {(0, 0, 255): \"yangeng\"}\n",
"# color_dict = {(255, 0, 0): 'beijing'}\n",
"color_dict = {(0, 255, 0): \"zibian\"}\n",

View File

@ -1,4 +1,3 @@
# %%
import numpy as np
import scipy
from imblearn.under_sampling import RandomUnderSampler
@ -6,12 +5,12 @@ from models import AnonymousColorDetector
from utils import read_labeled_img
# %%
train_from_existed = True # 是否从现有数据训练如果是的话那就从dataset_file训练否则就用data_dir里头的数据
train_from_existed = False # 是否从现有数据训练如果是的话那就从dataset_file训练否则就用data_dir里头的数据
data_dir = "data/dataset" # 数据集,文件夹下必须包含`img`和`label`两个文件夹放置相同文件名的图片和label
dataset_file = "data/dataset/dataset_2022-07-20_10-04.mat"
# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'} # 颜色对应的类别
color_dict = {(0, 0, 255): "yangeng"}
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing', (0, 255, 0): "zibian"} # 颜色对应的类别
# color_dict = {(0, 0, 255): "yangeng"}
# color_dict = {(255, 0, 0): 'beijing'}
# color_dict = {(0, 255, 0): "zibian"}
label_index = {"yangeng": 1, "beijing": 0, "zibian": 2} # 类别对应的序号
@ -32,9 +31,11 @@ if show_samples:
# %% md
## 数据平衡化
# %%
if len(dataset) > 1:
rus = RandomUnderSampler(random_state=0)
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()],
axis=0).tolist()
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
dataset = {"inside": np.array(x_resampled)}
# %% md
@ -54,13 +55,3 @@ else:
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
is_save_dataset=True, model_selection='dt')
model.save()
# %%
if train_from_existed:
data = scipy.io.loadmat(dataset_file)
x, y = data['x'], data['y'].ravel()
model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
else:
world_boundary = np.array([0, 0, 0, 255, 255, 255])
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
is_save_dataset=True, model_selection='dt')
model.save()

View File

@ -14,7 +14,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 44,
"outputs": [],
"source": [
"import datetime\n",
@ -39,7 +39,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 45,
"outputs": [],
"source": [
"img_path = r\"C:\\Users\\FEIJINTI\\Desktop\\721\\zazhi\\Image_2022_0721_1351_38_946-002034.bmp\"\n",
@ -54,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 46,
"outputs": [],
"source": [
"img = cv2.imread(img_path)[:, :, ::-1]\n",
@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 47,
"outputs": [],
"source": [
"t_detector = AnonymousColorDetector(file_path=model_path[1])\n",
@ -86,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 48,
"outputs": [],
"source": [
"z_detector = AnonymousColorDetector(file_path=model_path[2])\n",
@ -101,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 49,
"outputs": [],
"source": [
"result = 1 - (b_result | t_result | z_result)"
@ -115,7 +115,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 50,
"outputs": [
{
"data": {
@ -128,16 +128,16 @@
{
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "<div id='cec7c80e-e50e-4a78-858f-b58ac7b4a36e'></div>"
"text/html": "<div id='71e0e832-f293-40bc-9913-301125493d2e'></div>"
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "<matplotlib.image.AxesImage at 0x24e9025e170>"
"text/plain": "<matplotlib.image.AxesImage at 0x24e8faaa440>"
},
"execution_count": 36,
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
@ -158,7 +158,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 50,
"outputs": [],
"source": [],
"metadata": {