mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
添加了dt方法
This commit is contained in:
parent
bfba067285
commit
0f5c8dae17
@ -13,7 +13,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 6,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%%\n"
|
"name": "#%%\n"
|
||||||
@ -22,7 +22,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"\n",
|
"import scipy\n",
|
||||||
|
"from imblearn.under_sampling import RandomUnderSampler\n",
|
||||||
"from models import AnonymousColorDetector\n",
|
"from models import AnonymousColorDetector\n",
|
||||||
"from utils import read_labeled_img"
|
"from utils import read_labeled_img"
|
||||||
]
|
]
|
||||||
@ -40,7 +41,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 7,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%%\n"
|
"name": "#%%\n"
|
||||||
@ -49,8 +50,15 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"data_dir = \"data/dataset\"\n",
|
"data_dir = \"data/dataset\"\n",
|
||||||
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
"color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'}\n",
|
||||||
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)"
|
"label_index = {\"yangeng\": 1, \"beijing\": 0}\n",
|
||||||
|
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
||||||
|
"rus = RandomUnderSampler(random_state=0)\n",
|
||||||
|
"x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
|
||||||
|
" np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n",
|
||||||
|
"\n",
|
||||||
|
"x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n",
|
||||||
|
"dataset = {\"inside\": np.array(x_resampled)}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -66,7 +74,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 8,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%%\n"
|
"name": "#%%\n"
|
||||||
@ -81,12 +89,12 @@
|
|||||||
"world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
|
"world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
|
||||||
"# 对数据进行预处理\n",
|
"# 对数据进行预处理\n",
|
||||||
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
|
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
|
||||||
"negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num"
|
"negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 9,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%%\n"
|
"name": "#%%\n"
|
||||||
@ -99,7 +107,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 10,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%%\n"
|
"name": "#%%\n"
|
||||||
@ -107,21 +115,28 @@
|
|||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"ename": "KeyboardInterrupt",
|
"name": "stdout",
|
||||||
"evalue": "",
|
"output_type": "stream",
|
||||||
"output_type": "error",
|
"text": [
|
||||||
"traceback": [
|
" precision recall f1-score support\n",
|
||||||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
"\n",
|
||||||
"\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
|
" 0.0 0.99 0.99 0.99 26314\n",
|
||||||
"Input \u001B[1;32mIn [20]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnegative_sample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_num\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.7\u001B[39;49m\u001B[43m)\u001B[49m\n",
|
" 1.0 0.99 0.99 0.99 24492\n",
|
||||||
"File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:34\u001B[0m, in \u001B[0;36mAnonymousColorDetector.fit\u001B[1;34m(self, x, world_boundary, threshold, model_selected, negative_sample_size, train_size, **kwargs)\u001B[0m\n\u001B[0;32m 32\u001B[0m node_num \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mnode_num\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;241m10\u001B[39m)\n\u001B[0;32m 33\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel \u001B[38;5;241m=\u001B[39m ELM(input_size\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m1\u001B[39m], node_num\u001B[38;5;241m=\u001B[39mnode_num, output_num\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m---> 34\u001B[0m negative_samples \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgenerate_negative_samples\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 35\u001B[0m \u001B[43m \u001B[49m\u001B[43msample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 36\u001B[0m data_x, data_y \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mconcatenate([x, negative_samples], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m), \\\n\u001B[0;32m 37\u001B[0m np\u001B[38;5;241m.\u001B[39mconcatenate([np\u001B[38;5;241m.\u001B[39mones(x\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m),\n\u001B[0;32m 38\u001B[0m np\u001B[38;5;241m.\u001B[39mzeros(negative_samples\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m)], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m)\n\u001B[0;32m 39\u001B[0m x_train, x_val, y_train, y_val \u001B[38;5;241m=\u001B[39m train_test_split(data_x, data_y, train_size\u001B[38;5;241m=\u001B[39mtrain_size, shuffle\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n",
|
"\n",
|
||||||
"File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:66\u001B[0m, in \u001B[0;36mAnonymousColorDetector.generate_negative_samples\u001B[1;34m(x, world_boundary, threshold, sample_size)\u001B[0m\n\u001B[0;32m 64\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m sample_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(generated_data\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]):\n\u001B[0;32m 65\u001B[0m sample \u001B[38;5;241m=\u001B[39m generated_data[sample_idx, :]\n\u001B[1;32m---> 66\u001B[0m in_threshold \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39many(np\u001B[38;5;241m.\u001B[39msum(\u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpower\u001B[49m\u001B[43m(\u001B[49m\u001B[43msample\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m<\u001B[39m threshold)\n\u001B[0;32m 67\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m in_threshold:\n\u001B[0;32m 68\u001B[0m negative_samples[sample_idx, :] \u001B[38;5;241m=\u001B[39m sample\n",
|
" accuracy 0.99 50806\n",
|
||||||
"\u001B[1;31mKeyboardInterrupt\u001B[0m: "
|
" macro avg 0.99 0.99 0.99 50806\n",
|
||||||
|
"weighted avg 0.99 0.99 0.99 50806\n",
|
||||||
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)"
|
"# model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n",
|
||||||
|
"# is_save_dataset=True, model_selection='dt')\n",
|
||||||
|
"data = scipy.io.loadmat('data/dataset/dataset_2022-07-19_17-06.mat')\n",
|
||||||
|
"x, y = data['x'], data['y'].ravel()\n",
|
||||||
|
"model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n",
|
||||||
|
"model.save()"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
212
03_data_update.ipynb
Normal file
212
03_data_update.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -31,7 +31,7 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
|
|||||||
else:
|
else:
|
||||||
raise TypeError("test img should be np.ndarray or str")
|
raise TypeError("test img should be np.ndarray or str")
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
img = cv2.resize(img, (1024, 256))
|
# img = cv2.resize(img, (1024, 256))
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
result = 1 - detector.predict(img)
|
result = 1 - detector.predict(img)
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
@ -55,5 +55,7 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
detector = AnonymousColorDetector(file_path='dt_2022-07-19_17-07.model')
|
||||||
|
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)
|
||||||
|
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)
|
||||||
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)
|
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)
|
||||||
|
|||||||
@ -86,7 +86,7 @@ class AnonymousColorDetector(Detector):
|
|||||||
y_predict = self.model.predict(x_val)
|
y_predict = self.model.predict(x_val)
|
||||||
print(classification_report(y_true=y_val, y_pred=y_predict))
|
print(classification_report(y_true=y_val, y_pred=y_predict))
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x, threshold_low=10, threshold_high=170):
|
||||||
"""
|
"""
|
||||||
输入rgb彩色图像
|
输入rgb彩色图像
|
||||||
|
|
||||||
@ -95,7 +95,11 @@ class AnonymousColorDetector(Detector):
|
|||||||
"""
|
"""
|
||||||
w, h = x.shape[1], x.shape[0]
|
w, h = x.shape[1], x.shape[0]
|
||||||
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
|
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
|
||||||
result = self.model.predict(x.reshape(w * h, -1))
|
x = x.reshape(w * h, -1)
|
||||||
|
mask = (threshold_low < x[:, 0]) & (x[:, 0] < threshold_high)
|
||||||
|
mask_result = self.model.predict(x[mask])
|
||||||
|
result = np.ones((w * h,))
|
||||||
|
result[mask] = mask_result
|
||||||
return result.reshape(h, w)
|
return result.reshape(h, w)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user