diff --git a/classifer.py b/classifer.py index b0f2b6b..00f4025 100755 --- a/classifer.py +++ b/classifer.py @@ -30,11 +30,13 @@ from root_dir import ROOT_DIR import utils FEATURE_INDEX = [0, 1, 2] -delete_columns = 10 # 已弃用 +delete_columns = 10 # 已弃用 num_bins = 10 + class WoodClass(object): - def __init__(self, load_from=None, w=2048, h=12450, n=5000, p1=0.3, pur=0.99999, left_correct=False, single_pick_mode=False, + def __init__(self, load_from=None, w=2048, h=12450, n=5000, p1=0.3, pur=0.99999, left_correct=False, + single_pick_mode=False, debug_mode=False): """ 初始化. @@ -79,7 +81,7 @@ class WoodClass(object): if single_pick_mode: self._single_pick = True width = int(np.floor(np.sqrt(w * h / n))) - w0, h0 = np.arange(0, w-width, width), np.arange(0, h-width, width) + w0, h0 = np.arange(0, w - width, width), np.arange(0, h - width, width) self.ww, self.hh = np.meshgrid(w0, h0) self.width = width else: @@ -96,7 +98,7 @@ class WoodClass(object): """ if self._single_pick: offset_w, offset_h = np.random.randint(0, self.width), np.random.randint(0, self.width) - sample = x[self.hh+offset_h, self.ww+offset_w, ...] + sample = x[self.hh + offset_h, self.ww + offset_w, ...] else: sample = cv2.resize(x, (self.ww, self.hh)) return sample @@ -117,19 +119,19 @@ class WoodClass(object): model_name = self.save(file_name) return model_name - def fit(self, x, y, test_size=0.1): + def fit(self, x, y, test_size=0.1): x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0) self.model.fit(x_train, y_train) y_pred = self.model.predict(x_test) pre_score = accuracy_score(y_test, y_pred) - self.log.log("Test accuracy is:"+str(pre_score * 100) + "%.") + self.log.log("Test accuracy is:" + str(pre_score * 100) + "%.") y_pred = self.model.predict(x_train) pre_score = accuracy_score(y_train, y_pred) - self.log.log("Train accuracy is:"+str(pre_score * 100) + "%.") + self.log.log("Train accuracy is:" + str(pre_score * 100) + "%.") y_pred = self.model.predict(x) pre_score = accuracy_score(y, y_pred) - self.log.log("Total accuracy is:"+str(pre_score * 100) + "%.") - return int(pre_score*100) + self.log.log("Total accuracy is:" + str(pre_score * 100) + "%.") + return int(pre_score * 100) def calculate_p1(self, x, remove_background=False): """ @@ -156,7 +158,7 @@ class WoodClass(object): feature = self.extract_feature(img, remove_background=False, debug_mode=False) feature = feature.reshape(1, -1)[:, FEATURE_INDEX] if self.isCorrect: - feature = feature / (self.correct_color+1e-4) + feature = feature / (self.correct_color + 1e-4) plt.figure() plt.scatter(feature[:, 0], feature[:, 1]) plt.show() @@ -232,10 +234,10 @@ class WoodClass(object): self.log.log("No model found!") return 1 self.log.log("./ Models Found:") - _ = [self.log.log("├--"+str(model_file)) for model_file in model_files] + _ = [self.log.log("├--" + str(model_file)) for model_file in model_files] file_times = [model_file[6:-2] for model_file in model_files] latest_model = model_files[int(np.argmax(file_times))] - self.log.log("└--Using the latest model: "+str(latest_model)) + self.log.log("└--Using the latest model: " + str(latest_model)) path = os.path.join(ROOT_DIR, "models", str(latest_model)) if not os.path.isabs(path): logging.warning('给的是相对路径') @@ -276,14 +278,14 @@ class WoodClass(object): x_hsv = cv2.cvtColor(x, cv2.COLOR_BGR2HSV) x = cv2.cvtColor(x, cv2.COLOR_BGR2LAB) x = np.concatenate((x, x_hsv), axis=2) - x = np.reshape(x, (x.shape[0]*x.shape[1], x.shape[2])) + x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2])) hist, bins = np.histogram(x[:, 0], bins=num_bins) hist = hist[1:] bins = bins[1:] # x = x[np.argsort(x[:, 0])] # x = x[-self.k:, :] hist_number = np.argmax(hist) - x = x[(x[:, 0] > bins[hist_number]) & (x[:, 0] < bins[hist_number+1]), :] + x = x[(x[:, 0] > bins[hist_number]) & (x[:, 0] < bins[hist_number + 1]), :] if debug_mode: # self.log.log(x) self.log.log(x.shape) @@ -292,7 +294,7 @@ class WoodClass(object): self.log.log(x.shape) mean_value = np.mean(x, axis=0) if debug_mode: - self.log.log("mean color:"+str(mean_value)) + self.log.log("mean color:" + str(mean_value)) plt.subplot(212) color_img = np.asarray(np.ones((100, 100, 3), dtype=np.uint8) * mean_value[:3], dtype=np.uint8) color_img = cv2.cvtColor(color_img, cv2.COLOR_LAB2RGB) @@ -301,7 +303,7 @@ class WoodClass(object): var_value = np.var(x, axis=0) feature = np.hstack((mean_value, var_value)) if debug_mode: - self.log.log("var: "+str(var_value)) + self.log.log("var: " + str(var_value)) return feature def get_image_data(self, img_dir="./data/dark"): @@ -344,7 +346,7 @@ class WoodClass(object): x_data = x_data[:, FEATURE_INDEX] # 进行色彩数据校正 if self.isCorrect: - x_data = x_data / (self.correct_color+1e-4) + x_data = x_data / (self.correct_color + 1e-4) if plot_data_3d: fig = plt.figure() ax = fig.add_subplot(1, 1, 1, projection="3d") @@ -372,7 +374,6 @@ class WoodClass(object): pass return x_data, y_data - def realtime_correction(self, img): """ 实时校正 @@ -386,7 +387,7 @@ class WoodClass(object): img = (img / img_like * standard_img)[:, delete_columns:, :] return img - def realtime_correct(self, img:np.ndarray, correct_col_num: int, cut_col_num: Optional[int] = None, + def realtime_correct(self, img: np.ndarray, correct_col_num: int, cut_col_num: Optional[int] = None, standard_color: Optional[tuple] = (255, 255, 255)) -> np.ndarray: """ 实时利用左侧边的correct_col_num列进行色彩校正 @@ -398,6 +399,7 @@ class WoodClass(object): :return: 校正后的图片 shape = (n_rows, n_cols - cut_col_num, n_channels) """ if self.left_correct: + img = img[:, 20:, :] # 按照correct_col_num列数量取出最左侧校正板区域成像结果 correct_img = img[:, :correct_col_num, :] # 校正区域进行均值化 @@ -412,6 +414,7 @@ class WoodClass(object): if __name__ == '__main__': from config import Config + settings = Config() # 初始化wood wood = WoodClass(w=4096, h=1200, n=12500, p1=0.4, debug_mode=False) @@ -428,6 +431,5 @@ if __name__ == '__main__': # for i in range(100): wood_color = wood.predict(pic) end_time = time.time() - print("time consume:"+str((end_time - start_time)/100)) - print("wood_color:"+str(wood_color)) - + print("time consume:" + str((end_time - start_time) / 100)) + print("wood_color:" + str(wood_color))