mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
分离预测
This commit is contained in:
parent
b1be6eeb49
commit
f7c88d3c2a
@ -37,7 +37,7 @@
|
||||
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label\n",
|
||||
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
||||
"\n",
|
||||
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n",
|
||||
"color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
|
||||
"# color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||
"# color_dict = {(255, 0, 0): 'beijing'}\n",
|
||||
"color_dict = {(0, 255, 0): \"zibian\"}\n",
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# %%
|
||||
import numpy as np
|
||||
import scipy
|
||||
from imblearn.under_sampling import RandomUnderSampler
|
||||
@ -6,12 +5,12 @@ from models import AnonymousColorDetector
|
||||
from utils import read_labeled_img
|
||||
|
||||
# %%
|
||||
train_from_existed = True # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据
|
||||
train_from_existed = False # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据
|
||||
data_dir = "data/dataset" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label
|
||||
dataset_file = "data/dataset/dataset_2022-07-20_10-04.mat"
|
||||
|
||||
# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'} # 颜色对应的类别
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing', (0, 255, 0): "zibian"} # 颜色对应的类别
|
||||
# color_dict = {(0, 0, 255): "yangeng"}
|
||||
# color_dict = {(255, 0, 0): 'beijing'}
|
||||
# color_dict = {(0, 255, 0): "zibian"}
|
||||
label_index = {"yangeng": 1, "beijing": 0, "zibian": 2} # 类别对应的序号
|
||||
@ -32,11 +31,13 @@ if show_samples:
|
||||
# %% md
|
||||
## 数据平衡化
|
||||
# %%
|
||||
rus = RandomUnderSampler(random_state=0)
|
||||
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
|
||||
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()
|
||||
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
|
||||
dataset = {"inside": np.array(x_resampled)}
|
||||
if len(dataset) > 1:
|
||||
rus = RandomUnderSampler(random_state=0)
|
||||
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
|
||||
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()],
|
||||
axis=0).tolist()
|
||||
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
|
||||
dataset = {"inside": np.array(x_resampled)}
|
||||
# %% md
|
||||
## 模型训练
|
||||
# %%
|
||||
@ -53,14 +54,4 @@ else:
|
||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
|
||||
is_save_dataset=True, model_selection='dt')
|
||||
model.save()
|
||||
# %%
|
||||
if train_from_existed:
|
||||
data = scipy.io.loadmat(dataset_file)
|
||||
x, y = data['x'], data['y'].ravel()
|
||||
model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
|
||||
else:
|
||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
|
||||
is_save_dataset=True, model_selection='dt')
|
||||
model.save()
|
||||
model.save()
|
||||
@ -14,7 +14,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 44,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import datetime\n",
|
||||
@ -39,7 +39,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 45,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img_path = r\"C:\\Users\\FEIJINTI\\Desktop\\721\\zazhi\\Image_2022_0721_1351_38_946-002034.bmp\"\n",
|
||||
@ -54,7 +54,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 46,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = cv2.imread(img_path)[:, :, ::-1]\n",
|
||||
@ -71,7 +71,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 47,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"t_detector = AnonymousColorDetector(file_path=model_path[1])\n",
|
||||
@ -86,7 +86,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 48,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"z_detector = AnonymousColorDetector(file_path=model_path[2])\n",
|
||||
@ -101,7 +101,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 49,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = 1 - (b_result | t_result | z_result)"
|
||||
@ -115,7 +115,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 50,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -128,16 +128,16 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "<IPython.core.display.HTML object>",
|
||||
"text/html": "<div id='cec7c80e-e50e-4a78-858f-b58ac7b4a36e'></div>"
|
||||
"text/html": "<div id='71e0e832-f293-40bc-9913-301125493d2e'></div>"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "<matplotlib.image.AxesImage at 0x24e9025e170>"
|
||||
"text/plain": "<matplotlib.image.AxesImage at 0x24e8faaa440>"
|
||||
},
|
||||
"execution_count": 36,
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -158,7 +158,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 50,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user