mirror of
https://github.com/NanjingForestryUniversity/supermachine-wood.git
synced 2025-11-08 10:13:53 +00:00
代码测试比较
This commit is contained in:
parent
02d222f479
commit
b37e064594
52
classifer.py
52
classifer.py
@ -14,7 +14,9 @@ import cv2
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.metrics import accuracy_score, confusion_matrix
|
||||
|
||||
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from scipy.stats import binom
|
||||
@ -31,7 +33,7 @@ import utils
|
||||
|
||||
FEATURE_INDEX = [0,1,2]
|
||||
delete_columns = 10 # 已弃用
|
||||
num_bins = 7
|
||||
num_bins = 10
|
||||
|
||||
|
||||
class WoodClass(object):
|
||||
@ -58,9 +60,10 @@ class WoodClass(object):
|
||||
self._single_pick = single_pick_mode
|
||||
self.set_purity(self.pur)
|
||||
self.change_pick_mode(single_pick_mode)
|
||||
self.model = LogisticRegression(C=1e5)
|
||||
# self.model = LogisticRegression(C=1e5)
|
||||
self.left_correct = left_correct
|
||||
# self.model = KNeighborsClassifier()
|
||||
self.model = DecisionTreeClassifier()
|
||||
else:
|
||||
self.load(load_from)
|
||||
self.isCorrect = False
|
||||
@ -125,25 +128,32 @@ class WoodClass(object):
|
||||
y_pred = self.model.predict(x_test)
|
||||
|
||||
#将y_pred中1和2转化为1
|
||||
y_pred[y_pred == 2] = 1
|
||||
y_test[y_test == 2] = 1
|
||||
# y_pred[y_pred == 2] = 1
|
||||
# y_test[y_test == 2] = 1
|
||||
|
||||
print(confusion_matrix(y_test, y_pred))
|
||||
|
||||
pre_score = accuracy_score(y_test, y_pred)
|
||||
self.log.log("Test accuracy is:" + str(pre_score * 100) + "%.")
|
||||
y_pred = self.model.predict(x_train)
|
||||
|
||||
y_pred[y_pred == 2] = 1
|
||||
y_train[y_train == 2] = 1
|
||||
# y_pred[y_pred == 2] = 1
|
||||
# y_train[y_train == 2] = 1
|
||||
|
||||
pre_score = accuracy_score(y_train, y_pred)
|
||||
self.log.log("Train accuracy is:" + str(pre_score * 100) + "%.")
|
||||
y_pred = self.model.predict(x)
|
||||
|
||||
y[y == 2] = 1
|
||||
y_pred[y_pred == 2] = 1
|
||||
# y[y == 2] = 1
|
||||
# y_pred[y_pred == 2] = 1
|
||||
|
||||
pre_score = accuracy_score(y, y_pred)
|
||||
self.log.log("Total accuracy is:" + str(pre_score * 100) + "%.")
|
||||
|
||||
#显示结果报告
|
||||
|
||||
|
||||
|
||||
return int(pre_score * 100)
|
||||
|
||||
def calculate_p1(self, x, remove_background=False):
|
||||
@ -296,11 +306,25 @@ class WoodClass(object):
|
||||
# x = x[np.argsort(x[:, 0])]
|
||||
# x = x[-self.k:, :]
|
||||
|
||||
|
||||
hist, bins = np.histogram(x[:, 0], bins=num_bins)
|
||||
hist = hist[1:]
|
||||
bins = bins[1:]
|
||||
hist_number = np.argmax(hist)
|
||||
x = x[(x[:, 0] > bins[hist_number]) & (x[:, 0] < bins[hist_number + 1]), :]
|
||||
sorted_indices = np.argsort(hist)
|
||||
hist_number = sorted_indices[-1]
|
||||
second_hist_number = sorted_indices[-2]
|
||||
x = x[((x[:, 0] > bins[hist_number]) & (x[:, 0] < bins[hist_number + 1])) | (
|
||||
(x[:, 0] > bins[second_hist_number]) & (x[:, 0] < bins[second_hist_number + 1])), :]
|
||||
|
||||
|
||||
# hist, bins = np.histogram(x[:, 0], bins=5)
|
||||
# sorted_indices = np.argsort(hist)
|
||||
# hist_number = sorted_indices[-1]
|
||||
# second_hist_number = sorted_indices[-2]
|
||||
# x = x[((x[:, 0] > bins[hist_number]) & (x[:, 0] < bins[hist_number + 1])) | (
|
||||
# (x[:, 0] > bins[second_hist_number]) & (x[:, 0] < bins[second_hist_number + 1])), :]
|
||||
|
||||
|
||||
|
||||
if debug_mode:
|
||||
# self.log.log(x)
|
||||
@ -376,7 +400,7 @@ class WoodClass(object):
|
||||
plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data)
|
||||
plt.show()
|
||||
# 尝试最合适的特征组合,保存提取出的特征的方法
|
||||
# 0: l, 1: a, 2: b, 3: var(l), 4: var(a), 5: var(s), 6: h, 7: s, 8: v, 9: var(h) 10: var(s): 11: var(v)
|
||||
# 0: l, 1: a, 2: b, 3: var(l), 4: var(a), 5: var(b), 6: h, 7: s, 8: v, 9: var(h) 10: var(s): 11: var(v)
|
||||
# 全部:0.941
|
||||
# [:, [0, 1, 2, 3, 4, 5, 6, 7]] : 0.88
|
||||
# [:, [0, 1, 2]] : 0.911
|
||||
@ -433,7 +457,7 @@ if __name__ == '__main__':
|
||||
|
||||
settings = Config()
|
||||
# 初始化wood
|
||||
wood = WoodClass(w=4096, h=1200, n=8000, p1=0.46, debug_mode=False)
|
||||
wood = WoodClass(w=4096, h=1200, n=8000, p1=0.8, debug_mode=False)
|
||||
print("色彩纯度控制量{}/{}".format(wood.k, wood.n))
|
||||
data_path = settings.data_path
|
||||
# wood.correct()
|
||||
@ -442,7 +466,7 @@ if __name__ == '__main__':
|
||||
settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path))
|
||||
|
||||
# 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的
|
||||
pic = cv2.imread(r"data/99/dark/rgb70.png")
|
||||
pic = cv2.imread(r"data/316/dark/rgb70.png")
|
||||
start_time = time.time()
|
||||
# for i in range(100):
|
||||
wood_color = wood.predict(pic)
|
||||
|
||||
22
hist.py
22
hist.py
@ -2,19 +2,33 @@ import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
img_path = 'data/data1103/dark/rgb20.png'
|
||||
img_path = 'data/99/dark/rgb59.png'
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
||||
|
||||
w = img.shape[0]
|
||||
h = img.shape[1]
|
||||
|
||||
ratio = np.sqrt(5000 / (w * h))
|
||||
ww, hh = int(ratio * w), int(ratio * h)
|
||||
img = cv2.resize(img, (hh, ww))
|
||||
x = img.reshape(img.shape[0]*img.shape[1], img.shape[2])
|
||||
|
||||
hist, bins = np.histogram(x[:, 0], bins=10)
|
||||
hist = hist[1:]
|
||||
bins = bins[1:]
|
||||
hist_number = np.argmax(hist)
|
||||
x = x[(x[:, 0] > bins[hist_number]) & (x[:, 0] < bins[hist_number + 1]), :]
|
||||
|
||||
|
||||
sorted_indices = np.argsort(hist)
|
||||
hist_number = sorted_indices[-1]
|
||||
second_hist_number = sorted_indices[-2]
|
||||
|
||||
x = x[((x[:, 0] > bins[hist_number]) & (x[:, 0] < bins[hist_number + 1]))|((x[:, 0] > bins[second_hist_number]) & (x[:, 0] < bins[second_hist_number + 1])), :]
|
||||
|
||||
mean_value = np.mean(x, axis=0).astype(np.uint8)
|
||||
y = np.zeros((img.shape[0]*img.shape[1], img.shape[2]), dtype=np.uint8)
|
||||
y[:, ] = mean_value
|
||||
#lab转rgb再保存
|
||||
y = cv2.cvtColor(y.reshape(img.shape[0], img.shape[1], img.shape[2]), cv2.COLOR_LAB2BGR)
|
||||
cv2.imwrite('5.png', y.reshape(img.shape[0], img.shape[1], img.shape[2]))
|
||||
cv2.imwrite('59.png', y.reshape(img.shape[0], img.shape[1], img.shape[2]))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user