mirror of
https://github.com/NanjingForestryUniversity/tobacoo-industry.git
synced 2025-11-08 22:33:52 +00:00
修改灵敏度
修改灵敏度为32,删去了问题数据
This commit is contained in:
parent
b992ec150d
commit
a42fcb3438
@ -42,7 +42,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# 一些参数\n",
|
"# 一些参数\n",
|
||||||
"blk_sz = 8\n",
|
"blk_sz = 8\n",
|
||||||
"sensitivity = 8\n",
|
"sensitivity = 32\n",
|
||||||
"selected_bands = [127, 201, 202, 294]\n",
|
"selected_bands = [127, 201, 202, 294]\n",
|
||||||
"# [76, 146, 216, 367, 383, 406]\n",
|
"# [76, 146, 216, 367, 383, 406]\n",
|
||||||
"file_name, labeled_image_file = r\"F:\\zhouchao\\616\\calibrated1.raw\", \\\n",
|
"file_name, labeled_image_file = r\"F:\\zhouchao\\616\\calibrated1.raw\", \\\n",
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@ -26,7 +26,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 49,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
@ -48,12 +48,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 50,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# some parameters\n",
|
"# some parameters\n",
|
||||||
"new_spectra_file = r\"F:\\zhouchao\\618\\bomo\\calibrated7.raw\"\n",
|
"new_spectra_file = r\"F:\\zhouchao\\615\\calibrated0.raw\"\n",
|
||||||
"new_label_file = r\"F:\\zhouchao\\618\\bomo\\picture\\label7.bmp\"\n",
|
"new_label_file = r\"F:\\zhouchao\\615\\label0.bmp\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"target_class = 0\n",
|
"target_class = 0\n",
|
||||||
"target_class_left, target_class_right = 5, 4\n",
|
"target_class_left, target_class_right = 5, 4\n",
|
||||||
@ -63,7 +63,7 @@
|
|||||||
"split_line = 500\n",
|
"split_line = 500\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"blk_sz, sensitivity = 8, 8\n",
|
"blk_sz, sensitivity = 8, 32\n",
|
||||||
"selected_bands = [127, 201, 202, 294]\n",
|
"selected_bands = [127, 201, 202, 294]\n",
|
||||||
"tree_num = 185\n",
|
"tree_num = 185\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -72,11 +72,11 @@
|
|||||||
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
|
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
|
||||||
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
|
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"new_dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_7.p'\n",
|
"new_dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_4.p'\n",
|
||||||
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_6.p'\n",
|
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_3.p'\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_file = f'./models/rf_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_6.model'\n",
|
"model_file = f'./models/rf_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_sen{sensitivity}_3.model'\n",
|
||||||
"selected_bands = None"
|
"# selected_bands = None"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
@ -87,10 +87,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 51,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# if len(new_spectra_files) == 1:\n",
|
|
||||||
"data = read_raw_file(new_spectra_file, selected_bands)"
|
"data = read_raw_file(new_spectra_file, selected_bands)"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -115,7 +114,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 52,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"x_list, y_list = [], []\n",
|
"x_list, y_list = [], []\n",
|
||||||
@ -144,7 +143,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 53,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"if (new_label_file is None) and (target_class != 1):\n",
|
"if (new_label_file is None) and (target_class != 1):\n",
|
||||||
@ -176,7 +175,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 54,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -201,13 +200,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 55,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"148 148\n"
|
"301 301\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -235,7 +234,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 56,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"with open(dataset_file, 'rb') as f:\n",
|
"with open(dataset_file, 'rb') as f:\n",
|
||||||
@ -266,7 +265,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 56,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [],
|
"source": [],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
213
06_other_models.ipynb
Normal file
213
06_other_models.ipynb
Normal file
File diff suppressed because one or more lines are too long
54
models.py
54
models.py
@ -10,11 +10,6 @@ from sklearn.decomposition import PCA
|
|||||||
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
||||||
|
|
||||||
|
|
||||||
# def feature(x):
|
|
||||||
# x = x.reshape((x.shape[0], -1, x.shape[-1]))
|
|
||||||
# x = np.mean(x, axis=1)
|
|
||||||
# return x
|
|
||||||
|
|
||||||
def feature(x):
|
def feature(x):
|
||||||
x = x.reshape((x.shape[0], -1))
|
x = x.reshape((x.shape[0], -1))
|
||||||
return x
|
return x
|
||||||
@ -69,10 +64,13 @@ def evaluation_and_report(model, test_x, test_y):
|
|||||||
def train_pca_rf(train_x, train_y, test_x, test_y, n_comp,
|
def train_pca_rf(train_x, train_y, test_x, test_y, n_comp,
|
||||||
tree_num, save_path=None):
|
tree_num, save_path=None):
|
||||||
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42, class_weight={0: 100, 1: 100})
|
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42, class_weight={0: 100, 1: 100})
|
||||||
pca = PCA(n_components=0.95)
|
pca = PCA(n_components=n_comp)
|
||||||
rfc = rfc.fit(train_x, train_y)
|
pca = pca.fit(train_x)
|
||||||
|
pca_train_x = pca.transform(train_x)
|
||||||
|
rfc = rfc.fit(pca_train_x, train_y)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
y_pred = rfc.predict(test_x)
|
pca_test_x = pca.transform(test_x)
|
||||||
|
y_pred = rfc.predict(pca_test_x)
|
||||||
y_pred_binary = np.ones_like(y_pred)
|
y_pred_binary = np.ones_like(y_pred)
|
||||||
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||||
y_pred_binary[(y_pred == 2) | (y_pred == 3) | (y_pred == 4)] = 2
|
y_pred_binary[(y_pred == 2) | (y_pred == 3) | (y_pred == 4)] = 2
|
||||||
@ -80,8 +78,8 @@ def train_pca_rf(train_x, train_y, test_x, test_y, n_comp,
|
|||||||
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||||
test_y_binary[(test_y == 2) | (test_y == 3) | (test_y == 4)] = 2
|
test_y_binary[(test_y == 2) | (test_y == 3) | (test_y == 4)] = 2
|
||||||
print("预测时间:", time.time() - t1)
|
print("预测时间:", time.time() - t1)
|
||||||
print("RFC训练模型评分:" + str(accuracy_score(train_y, rfc.predict(train_x))))
|
print("RFC训练模型评分:" + str(accuracy_score(train_y, rfc.predict(pca_train_x))))
|
||||||
print("RFC待测模型评分:" + str(accuracy_score(test_y, rfc.predict(test_x))))
|
print("RFC待测模型评分:" + str(accuracy_score(test_y, rfc.predict(pca_test_x))))
|
||||||
print('RFC预测结果:' + str(y_pred))
|
print('RFC预测结果:' + str(y_pred))
|
||||||
print('---------------------------------------------------------------------------------------------------')
|
print('---------------------------------------------------------------------------------------------------')
|
||||||
print('RFC分类报告:\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
print('RFC分类报告:\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||||
@ -140,3 +138,39 @@ class SpecDetector(object):
|
|||||||
mask[row, col] = r
|
mask[row, col] = r
|
||||||
mask = mask.repeat(self.blk_sz, axis=0).repeat(self.blk_sz, axis=1)
|
mask = mask.repeat(self.blk_sz, axis=0).repeat(self.blk_sz, axis=1)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class PcaSpecDetector(object):
|
||||||
|
def __init__(self, model_path, pca_path, blk_sz=8, channel_num=4):
|
||||||
|
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
with open(model_path, "rb") as model_file:
|
||||||
|
self.clf = pickle.load(model_file)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError("Model File not found")
|
||||||
|
if os.path.exists(pca_path):
|
||||||
|
with open(pca_path, "rb") as pca_file:
|
||||||
|
self.pca = pickle.load(pca_file)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError("Pca File not found")
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
blocks = split_x(data, blk_sz=self.blk_sz)
|
||||||
|
blocks = np.array(blocks)
|
||||||
|
features = feature(np.array(blocks))
|
||||||
|
y_pred = self.clf.predict(features)
|
||||||
|
y_pred_binary = np.ones_like(y_pred)
|
||||||
|
# classes merge
|
||||||
|
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||||
|
# transform to mask
|
||||||
|
mask = self.mask_transform(y_pred_binary, (1024, 600))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def mask_transform(self, result, dst_size):
|
||||||
|
mask_size = 600 // self.blk_sz, 1024 // self.blk_sz
|
||||||
|
mask = np.zeros(mask_size, dtype=np.uint8)
|
||||||
|
for idx, r in enumerate(result):
|
||||||
|
row, col = idx // mask_size[1], idx % mask_size[1]
|
||||||
|
mask[row, col] = r
|
||||||
|
mask = mask.repeat(self.blk_sz, axis=0).repeat(self.blk_sz, axis=1)
|
||||||
|
return mask
|
||||||
|
|||||||
6
utils.py
6
utils.py
@ -119,9 +119,13 @@ def visualization_evaluation(detector, data_path, selected_bands=None):
|
|||||||
rgb_img = np.asarray(img[..., [372, 241, 169]] * 255, dtype=np.uint8)
|
rgb_img = np.asarray(img[..., [372, 241, 169]] * 255, dtype=np.uint8)
|
||||||
else:
|
else:
|
||||||
rgb_img = np.asarray(img[..., [0, 1, 2]] * 255, dtype=np.uint8)
|
rgb_img = np.asarray(img[..., [0, 1, 2]] * 255, dtype=np.uint8)
|
||||||
fig, axs = plt.subplots(1, 2)
|
mask_color = np.zeros_like(rgb_img)
|
||||||
|
mask_color[mask > 0] = (0, 0 , 255)
|
||||||
|
combine = cv2.addWeighted(rgb_img, 1, mask_color, 0.5, 0)
|
||||||
|
fig, axs = plt.subplots(1, 3)
|
||||||
axs[0].imshow(rgb_img)
|
axs[0].imshow(rgb_img)
|
||||||
axs[1].imshow(mask)
|
axs[1].imshow(mask)
|
||||||
|
axs[2].imshow(combine)
|
||||||
fig.suptitle(f"time spent {time_spent * 1000:.2f} ms" + f"\n{image_path}")
|
fig.suptitle(f"time spent {time_spent * 1000:.2f} ms" + f"\n{image_path}")
|
||||||
plt.savefig(f"./dataset/{idx}.png", dpi=300)
|
plt.savefig(f"./dataset/{idx}.png", dpi=300)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user