commit 848e261137db330975fdb707862895877a4d254f Author: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Sat Sep 3 16:25:55 2022 +0800 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..a4f6466 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +data/correct/*.bmp +data/dark/*.bmp +data/light/*.bmp +data/middle/*.bmp +data1 +data2 +data3 +data4 +data5 +.idea +__pycache__ +*.pyc +test.py +*.log \ No newline at end of file diff --git a/README.md b/README.md new file mode 100755 index 0000000..ea35201 --- /dev/null +++ b/README.md @@ -0,0 +1,48 @@ +# 木地板分色识别项目 + +## 训练数据准备 + +请在使用前在同一文件夹下创建data文件夹,所需的文件目录如下: + +```bash +. +├── README.md +├── classifer.py +├── data +│ ├── dark +│ ├── light +│ └── middle +``` + +上上面所示的三个文件夹下分别放置三种不同色彩的木板图片就可以训练了。 + +当然不要放除了图片外的东西,不然程序会出错哦。 + +## 色彩提取 + +下图为色彩提取效果: + +![从木板中提取色彩](./pics/从木板中提取色彩.png) + +我们利用概率的方法以随机对抗随机,然后使用二项分布的概率分位点即可获得任意纯度下的木板色彩。 + +但是纯度过高有时反而会难以反映木板的颜色,所以关于纯度的取舍还需调整。目前的纯度要求0.99999下,该木板的分类效果好像可以有较好的表现,反正之后再调整嘛~ + +## 色彩分布 + +我们可以明显发现,在进行了色彩提纯后,色彩的区分度在lab空间下的ab平面还是很具备可分性的,所以直接用logistic regression这样的线性方法就够了。 + +![色彩分类](./pics/色彩分类.png) + +在classifer中还有6处简单的TODO, 交给老孙完成了,加油! + +![image-20201105104225975](./pics/TODO.png) + + + +最终通过选择多次特征的组合,找到了效果比较好的特征 1 2 6 7 8 或 1 2 6 7 + +在数据集上的测试准确率为: 97.05% + +![测试结果](./pics/result.png) + diff --git a/classifer.py b/classifer.py new file mode 100755 index 0000000..16b4691 --- /dev/null +++ b/classifer.py @@ -0,0 +1,363 @@ +# -*- coding: utf-8 -*- +""" +Created on Nov 3 21:18:26 2020 + +@author: l.z.y +@e-mail: li.zhenye@qq.com +""" +import sys +import numpy as np +import cv2 +from sklearn.cluster import KMeans +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score +from scipy.stats import binom +import matplotlib.pyplot as plt +import time +import pickle +import os +sys.path.append(os.getcwd()) +from root_dir import ROOT_DIR +import utils + +FEATURE_INDEX = [1, 2] + + +class WoodClass(object): + def __init__(self, load_from=None, w=2048, h=12450, n=5000, p1=0.3, pur=0.99999, single_pick_mode=False, + debug_mode=False): + """ + 初始化. + + :param w: 图像的尺寸w + :param h: 图像的尺寸h + :param p1: 木板色彩在图像中的比例p1 + :param n: 采集用于识别的样本点的个数 + :param single_pick_mode: 是否使用单点提取方案 + """ + if load_from is None: + if w is None or h is None: + print("It will damage your performance if you don't set w and h and use single_pick_mode!") + raise ValueError("w or h is None") + self.pur, self.p1, self.k = pur, p1, 1 + self.w, self.h, self.n = w, h, n + self.ww, self.hh = None, None + self.width = None + self._single_pick = single_pick_mode + self.set_purity(self.pur) + self.change_pick_mode(single_pick_mode) + self.model = LogisticRegression(C=1e5) + else: + self.load(load_from) + self.isCorrect = False + self.correct_color = None + self.log = utils.Logger(is_to_file=debug_mode) + self.debug_mode = debug_mode + self.image_num = 0 + + def change_pick_mode(self, single_pick_mode): + """ + 更改图像提取方法, + + :param single_pick_mode:若True, 则为单点随机抽取模式 + :return: None + """ + w, h, n = self.w, self.h, self.n + + if single_pick_mode: + self._single_pick = True + width = int(np.floor(np.sqrt(w * h / n))) + w0, h0 = np.arange(0, w-width, width), np.arange(0, h-width, width) + self.ww, self.hh = np.meshgrid(w0, h0) + self.width = width + else: + self._single_pick = False + ratio = np.sqrt(n / (w * h)) + self.ww, self.hh = int(ratio * w), int(ratio * h) + + def get_rand_sample(self, x): + """ + 在图像中进行随机抽取,如果single_pick_mode为True,则为真随机抽取,反之为假的抽取. + + :param x: + :return: + """ + if self._single_pick: + offset_w, offset_h = np.random.randint(0, self.width), np.random.randint(0, self.width) + sample = x[self.hh+offset_h, self.ww+offset_w, ...] + else: + sample = cv2.resize(x, (self.ww, self.hh)) + return sample + + def fit_pictures(self, data_path=ROOT_DIR): + """ + 根据给出的data_path 进行 fit.如果没有给出data目录,那么将会使用当前文件夹 + :param data_path: + :return: + """ + # 训练数据文件位置 + result = self.get_train_data(data_path) + if result is False: + return 0 + x, y = result + score = self.fit(x, y) + self.save() + return score + + def fit(self, x, y, test_size=0.1): + x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0) + self.model.fit(x_train, y_train) + y_pred = self.model.predict(x_test) + pre_score = accuracy_score(y_test, y_pred) + self.log.log("Test accuracy is:"+str(pre_score * 100) + "%.") + y_pred = self.model.predict(x_train) + pre_score = accuracy_score(y_train, y_pred) + self.log.log("Train accuracy is:"+str(pre_score * 100) + "%.") + y_pred = self.model.predict(x) + pre_score = accuracy_score(y, y_pred) + self.log.log("Total accuracy is:"+str(pre_score * 100) + "%.") + return int(pre_score*100) + + def calculate_p1(self, x, remove_background=False): + """ + + :param x: + :param remove_background: + :return: + """ + if remove_background: + x = self.remove_background(x) + kmeans = KMeans(n_clusters=2, init='k-means++') + kmeans.fit(x) + result = kmeans.predict(x) # 聚类结果 + + def predict(self, img): + """ + :param img: 输入图像 + :return: 分类值 + """ + if self.debug_mode: + cv2.imwrite(str(self.image_num) + ".bmp", img) + self.image_num += 1 + feature = self.extract_feature(img, remove_background=False, debug_mode=False) + feature = feature.reshape(1, -1)[:, [1, 2]] + if self.isCorrect: + feature = feature / (self.correct_color+1e-4) + pred_color = self.model.predict(feature) + if self.debug_mode: + self.log.log(feature) + return int(pred_color[0]) + + def correct(self, img=None, img_path=None): + """ + 记录校准值,将校准值记录到类别内 + :param img: 校准板图片 + :param img_path: 用于校准的图片路径 + :return: 0 if correct success, 1 if failed + """ + if img is None: + path = os.path.join(ROOT_DIR, "data", "correct") + utils.mkdir_if_not_exist(path) + file_list = os.listdir(path) + if len(file_list) == 0: + return 1 + if img_path is None: + file_name = os.path.join(ROOT_DIR, "data", "correct", file_list[-1]) + else: + file_name = img_path + img = cv2.imread(file_name) + + feature = self.extract_feature(img)[FEATURE_INDEX] + self.correct_color = feature + self.isCorrect = True + self.log.log("Correct Successfully!") + return 0 + + def set_purity(self, purity): + self.pur = purity + vs_pur = 1 - self.pur + for i in range(self.n): + vs_pur_i = binom.cdf(k=i, p=self.p1, n=self.n) + if vs_pur_i > vs_pur: + self.k = i + return i + + def remove_background(self, x): + # TODO: 利用色度?饱和度或者明度?亮度?去除背景 + # 去背景的方法效果不好,太慢了,所以没弄了。。。 + # x = x[2000:16000, 300:1600, :] + x = x[2000:10000, 300:1600, :] + return x + + def save(self, file_name=None): + """ + 保存当前文件下的classify.model文件模型 + save_parameters 为要保存的参数 + :return: None + """ + if file_name is None: + file_name = "model_" + time.strftime("%Y-%m-%d_%H-%M") + ".p" + file_name = os.path.join(ROOT_DIR, "models", file_name) + model_dic = {"n": self.n, "k": self.k, "p1": self.p1, "pur": self.pur, "model": self.model, + "ww": self.ww, "hh": self.hh, "width": self.width, "w": self.w, "h": self.h, + "mode": self._single_pick, "isCorrect": self.isCorrect} + with open(file_name, "wb") as f: + pickle.dump(model_dic, f) + self.log.log("Save file to '" + str(file_name) + "'") + + def load(self, path=None): + if path is None: + path = os.path.join(ROOT_DIR, "models") + utils.mkdir_if_not_exist(path) + model_files = os.listdir(path) + if len(model_files) == 0: + self.log.log("No model found!") + return 1 + self.log.log("./ Models Found:") + _ = [self.log.log("├--"+str(model_file)) for model_file in model_files] + file_times = [model_file[6:-2] for model_file in model_files] + latest_model = model_files[int(np.argmax(file_times))] + self.log.log("└--Using the latest model: "+str(latest_model)) + path = os.path.join(ROOT_DIR, "models", str(latest_model)) + with open(path, "rb") as f: + model_dic = pickle.load(f) + self.n, self.k, self.p1, self.pur = model_dic["n"], model_dic["k"], model_dic["p1"], model_dic["pur"] + self.ww, self.hh, self.width = model_dic["ww"], model_dic["hh"], model_dic["width"] + self.w, self.h, self.model = model_dic["w"], model_dic["h"], model_dic["model"] + self.isCorrect = model_dic["isCorrect"] + self._single_pick = model_dic["mode"] + self.set_purity(self.pur) + self.change_pick_mode(self._single_pick) + return 0 + + def extract_feature(self, x, correct_color=False, remove_background=False, debug_mode=False): + """ + 获取图片的特征,色彩值的mean和var【l, a, b, s_l, s_a, s_b】. + :param x: 图片 + :param correct_color: 是否进行颜色校准 + :param remove_background:是否需要移除背景 + :param debug_mode: 是否使用debug模式 + :return: + """ + if remove_background: + x = self.remove_background(x) + x = self.get_rand_sample(x) + if correct_color is True: + x = x / self.correct_color + if debug_mode: + plt.figure() + plt.subplot(211) + plt.imshow(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) + x_hsv = cv2.cvtColor(x, cv2.COLOR_BGR2HSV) + x = cv2.cvtColor(x, cv2.COLOR_BGR2LAB) + x = np.concatenate((x, x_hsv), axis=2) + x = np.reshape(x, (x.shape[0]*x.shape[1], x.shape[2])) + x = x[np.argsort(x[:, 0])] + x = x[-self.k:, :] + if debug_mode: + # self.log.log(x) + self.log.log(x.shape) + # self.log.log(self.k) + # self.log.log(x) + self.log.log(x.shape) + mean_value = np.mean(x, axis=0) + if debug_mode: + self.log.log("mean color:"+str(mean_value)) + plt.subplot(212) + color_img = np.asarray(np.ones((100, 100, 3), dtype=np.uint8) * mean_value[:3], dtype=np.uint8) + color_img = cv2.cvtColor(color_img, cv2.COLOR_LAB2RGB) + plt.imshow(color_img) + plt.show() + var_value = np.var(x, axis=0) + feature = np.hstack((mean_value, var_value)) + if debug_mode: + self.log.log("var: "+str(var_value)) + return feature + + def get_image_data(self, img_dir="./data/dark"): + """ + :param img_dir: 图像文件的路径 + :return: 图像数据 + """ + img_data = [] + utils.mkdir_if_not_exist(img_dir) + files = os.listdir(img_dir) + if len(files) == 0: + return False + for file in files: + path = os.path.join(img_dir, file) + if self.debug_mode: + self.log.log(path) + train_img = cv2.imread(path) + data = self.extract_feature(train_img) + img_data.append(data) + img_data = np.array(img_data) + return img_data + + def get_train_data(self, data_dir=None, plot_2d=True, plot_data_3d=False, save_data=False): + """ + 获取图像数据 + :return: x_data, y_data + """ + data_dir = ROOT_DIR if data_dir is None else data_dir + dark_data = self.get_image_data(img_dir=os.path.join(data_dir, "data", "dark")) + middle_data = self.get_image_data(img_dir=os.path.join(data_dir, "data", "middle")) + light_data = self.get_image_data(img_dir=os.path.join(data_dir, "data", "light")) + if (dark_data is False) or (middle_data is False) or (light_data is False): + return False + x_data = np.vstack((dark_data, middle_data, light_data)) + dark_label = np.zeros(len(dark_data)).T + middle_label = np.ones(len(middle_data)).T + light_label = 2 * np.ones(len(light_data)).T + y_data = np.hstack((dark_label, middle_label, light_label)) + x_data = x_data[:, FEATURE_INDEX] + # 进行色彩数据校正 + if self.isCorrect: + x_data = x_data / (self.correct_color+1e-4) + if plot_data_3d: + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1, projection="3d") + ax.scatter(x_data[:, 1], x_data[:, 2], x_data[:, 0], c=y_data, edgecolors="k") + ax.set_xlabel("a*") + ax.set_ylabel("b*") + ax.set_zlabel("l") + plt.show() + if plot_2d: + plt.figure() + plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data) + plt.show() + # 尝试最合适的特征组合,保存提取出的特征的方法 + # 0: l, 1: a, 2: b, 3: var(l), 4: var(a), 5: var(s), 6: h, 7: s, 8: v, 9: var(h) 10: var(s): 11: var(v) + # 全部:0.941 + # [:, [0, 1, 2, 3, 4, 5, 6, 7]] : 0.88 + # [:, [0, 1, 2]] : 0.911 + # [:, [0, 1, 2, 6, 7, 8]] : 0.941 + # [:, [1, 2, 6, 7, 8]] : 0.9705 + # [:, [1, 2, 6, 7]] : 0.9705 + # [:, [1, 2, 4, 5, 6, 7]] : 0.941 + # [:, [0, 1, 2, 6, 7]] : 0.8529 + if save_data: + with open(os.path.join("data", "data.p"), "rb") as f: + pass + return x_data, y_data + + +if __name__ == '__main__': + # 初始化wood + wood = WoodClass(w=2048, h=12450, n=5000, debug_mode=False) + print("色彩纯度控制量{}/{}".format(wood.k, wood.n)) + wood.correct() + # wood.load() + # fit 相应的文件夹 + wood.fit_pictures(data_path=r"C:\Users\Administrator.DESKTOP-K75IPPC\Desktop\data1108") + + # 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的 + pic = cv2.imread(r"./data/dark/15.bmp") + start_time = time.time() + for i in range(100): + wood_color = wood.predict(pic) + end_time = time.time() + print("time consume:"+str((end_time - start_time)/100)) + print("wood_color:"+str(wood_color)) + diff --git a/pics/TODO.png b/pics/TODO.png new file mode 100755 index 0000000..2a6bbea Binary files /dev/null and b/pics/TODO.png differ diff --git a/pics/result.png b/pics/result.png new file mode 100755 index 0000000..743e538 Binary files /dev/null and b/pics/result.png differ diff --git a/pics/从木板中提取色彩.png b/pics/从木板中提取色彩.png new file mode 100755 index 0000000..da8f63e Binary files /dev/null and b/pics/从木板中提取色彩.png differ diff --git a/pics/色彩分类.png b/pics/色彩分类.png new file mode 100755 index 0000000..521035e Binary files /dev/null and b/pics/色彩分类.png differ diff --git a/root_dir.py b/root_dir.py new file mode 100755 index 0000000..3c59a52 --- /dev/null +++ b/root_dir.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +""" +Created on Nov 3 21:18:26 2020 + +@author: l.z.y +@e-mail: li.zhenye@qq.com +""" +import os + +ROOT_DIR = r"C:\Users\Administrator.DESKTOP-K75IPPC\Desktop\wood-color" diff --git a/utils.py b/utils.py new file mode 100755 index 0000000..eec4a8f --- /dev/null +++ b/utils.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +""" +Created on Nov 3 21:18:26 2020 + +@author: l.z.y +@e-mail: li.zhenye@qq.com +""" +import os +import shutil +import time + + +def mkdir_if_not_exist(dir_name, is_delete=False): + """ + 创建文件夹 + :param dir_name: 文件夹 + :param is_delete: 是否删除 + :return: 是否成功 + """ + try: + if is_delete: + if os.path.exists(dir_name): + shutil.rmtree(dir_name) + print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name) + + if not os.path.exists(dir_name): + os.makedirs(dir_name) + print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name) + return True + except Exception as e: + print('[Exception] %s' % e) + return False + + +def create_file(file_name): + """ + 创建文件 + :param file_name: 文件名 + :return: None + """ + if os.path.exists(file_name): + print("文件存在:%s" % file_name) + return False + # os.remove(file_name) # 删除已有文件 + if not os.path.exists(file_name): + print("文件不存在,创建文件:%s" % file_name) + open(file_name, 'a').close() + return True + + +class Logger(object): + def __init__(self, is_to_file=False, path=None): + self.is_to_file = is_to_file + if path is None: + path = "wood.log" + self.path = path + create_file(path) + + def log(self, content): + if self.is_to_file: + with open(self.path, "a") as f: + print(time.strftime("[%Y-%m-%d_%H-%M-%S]:"), file=f) + print(content, file=f) + else: + print(content) + + +if __name__ == '__main__': + log = Logger(is_to_file=True) + log.log("nihao") + import numpy as np + a = np.ones((100, 100, 3)) + log.log(a.shape)