添加了dt方法

This commit is contained in:
FEIJINTI 2022-07-19 17:09:04 +08:00
parent bfba067285
commit 0f5c8dae17
4 changed files with 258 additions and 25 deletions

View File

@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 6,
"metadata": {
"pycharm": {
"name": "#%%\n"
@ -22,7 +22,8 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import scipy\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from models import AnonymousColorDetector\n",
"from utils import read_labeled_img"
]
@ -40,7 +41,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 7,
"metadata": {
"pycharm": {
"name": "#%%\n"
@ -49,8 +50,15 @@
"outputs": [],
"source": [
"data_dir = \"data/dataset\"\n",
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)"
"color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'}\n",
"label_index = {\"yangeng\": 1, \"beijing\": 0}\n",
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
"rus = RandomUnderSampler(random_state=0)\n",
"x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
" np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n",
"\n",
"x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n",
"dataset = {\"inside\": np.array(x_resampled)}"
]
},
{
@ -66,7 +74,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 8,
"metadata": {
"pycharm": {
"name": "#%%\n"
@ -77,16 +85,16 @@
"# 定义一些常量\n",
"threshold = 5\n",
"node_num = 20\n",
"negative_sample_num = None # None或者一个数字\n",
"negative_sample_num = None # None或者一个数字\n",
"world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
"# 对数据进行预处理\n",
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
"negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num"
"negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 9,
"metadata": {
"pycharm": {
"name": "#%%\n"
@ -99,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 10,
"metadata": {
"pycharm": {
"name": "#%%\n"
@ -107,21 +115,28 @@
},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"Input \u001B[1;32mIn [20]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnegative_sample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_num\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.7\u001B[39;49m\u001B[43m)\u001B[49m\n",
"File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:34\u001B[0m, in \u001B[0;36mAnonymousColorDetector.fit\u001B[1;34m(self, x, world_boundary, threshold, model_selected, negative_sample_size, train_size, **kwargs)\u001B[0m\n\u001B[0;32m 32\u001B[0m node_num \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mnode_num\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;241m10\u001B[39m)\n\u001B[0;32m 33\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel \u001B[38;5;241m=\u001B[39m ELM(input_size\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m1\u001B[39m], node_num\u001B[38;5;241m=\u001B[39mnode_num, output_num\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m---> 34\u001B[0m negative_samples \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgenerate_negative_samples\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 35\u001B[0m \u001B[43m \u001B[49m\u001B[43msample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 36\u001B[0m data_x, data_y \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mconcatenate([x, negative_samples], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m), \\\n\u001B[0;32m 37\u001B[0m np\u001B[38;5;241m.\u001B[39mconcatenate([np\u001B[38;5;241m.\u001B[39mones(x\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m),\n\u001B[0;32m 38\u001B[0m np\u001B[38;5;241m.\u001B[39mzeros(negative_samples\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m)], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m)\n\u001B[0;32m 39\u001B[0m x_train, x_val, y_train, y_val \u001B[38;5;241m=\u001B[39m train_test_split(data_x, data_y, train_size\u001B[38;5;241m=\u001B[39mtrain_size, shuffle\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n",
"File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:66\u001B[0m, in \u001B[0;36mAnonymousColorDetector.generate_negative_samples\u001B[1;34m(x, world_boundary, threshold, sample_size)\u001B[0m\n\u001B[0;32m 64\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m sample_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(generated_data\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]):\n\u001B[0;32m 65\u001B[0m sample \u001B[38;5;241m=\u001B[39m generated_data[sample_idx, :]\n\u001B[1;32m---> 66\u001B[0m in_threshold \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39many(np\u001B[38;5;241m.\u001B[39msum(\u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpower\u001B[49m\u001B[43m(\u001B[49m\u001B[43msample\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m<\u001B[39m threshold)\n\u001B[0;32m 67\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m in_threshold:\n\u001B[0;32m 68\u001B[0m negative_samples[sample_idx, :] \u001B[38;5;241m=\u001B[39m sample\n",
"\u001B[1;31mKeyboardInterrupt\u001B[0m: "
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 0.99 0.99 0.99 26314\n",
" 1.0 0.99 0.99 0.99 24492\n",
"\n",
" accuracy 0.99 50806\n",
" macro avg 0.99 0.99 0.99 50806\n",
"weighted avg 0.99 0.99 0.99 50806\n",
"\n"
]
}
],
"source": [
"model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)"
"# model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n",
"# is_save_dataset=True, model_selection='dt')\n",
"data = scipy.io.loadmat('data/dataset/dataset_2022-07-19_17-06.mat')\n",
"x, y = data['x'], data['y'].ravel()\n",
"model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n",
"model.save()"
]
}
],

212
03_data_update.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -31,7 +31,7 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
else:
raise TypeError("test img should be np.ndarray or str")
t1 = time.time()
img = cv2.resize(img, (1024, 256))
# img = cv2.resize(img, (1024, 256))
t2 = time.time()
result = 1 - detector.predict(img)
t3 = time.time()
@ -55,5 +55,7 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
if __name__ == '__main__':
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
detector = AnonymousColorDetector(file_path='dt_2022-07-19_17-07.model')
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)

View File

@ -86,7 +86,7 @@ class AnonymousColorDetector(Detector):
y_predict = self.model.predict(x_val)
print(classification_report(y_true=y_val, y_pred=y_predict))
def predict(self, x):
def predict(self, x, threshold_low=10, threshold_high=170):
"""
输入rgb彩色图像
@ -95,7 +95,11 @@ class AnonymousColorDetector(Detector):
"""
w, h = x.shape[1], x.shape[0]
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
result = self.model.predict(x.reshape(w * h, -1))
x = x.reshape(w * h, -1)
mask = (threshold_low < x[:, 0]) & (x[:, 0] < threshold_high)
mask_result = self.model.predict(x[mask])
result = np.ones((w * h,))
result[mask] = mask_result
return result.reshape(h, w)
@staticmethod