supermachine--tomato-passio.../classfier.py

165 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import sys
from typing import Optional
import numpy as np
import cv2
from sklearn.tree import DecisionTreeClassifier
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from scipy import ndimage
from sklearn.tree import DecisionTreeClassifier
from sklearn.feature_selection import SelectKBest
from sklearn.neighbors import KNeighborsClassifier
from scipy.stats import binom
import matplotlib.pyplot as plt
import time
import pickle
import os
import utils
from root_dir import ROOT_DIR
class Astragalin(object):
def __init__(self, load_from=None, debug_mode=False, class_weight=None):
if load_from is None:
self.model = DecisionTreeClassifier(random_state=65, class_weight=class_weight)
else:
self.load(load_from)
self.log = utils.Logger(is_to_file=debug_mode)
self.debug_mode = debug_mode
def load(self, path=None):
if path is None:
path = os.path.join(ROOT_DIR, 'models')
model_files = os.listdir(path)
if len(model_files) == 0:
self.log.log("No model found!")
return 1
self.log.log("./ Models Found:")
_ = [self.log.log("├--" + str(model_file)) for model_file in model_files]
file_times = [model_file[6:-2] for model_file in model_files]
latest_model = model_files[int(np.argmax(file_times))]
self.log.log("└--Using the latest model: " + str(latest_model))
path = os.path.join(ROOT_DIR, "models", str(latest_model))
if not os.path.isabs(path):
logging.warning('给的是相对路径')
return -1
if not os.path.exists(path):
logging.warning('文件不存在')
return -1
with open(path, 'rb') as f:
model_dic = pickle.load(f)
self.model = model_dic['model']
return 0
def fit(self, data_x, data_y):
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.3, random_state=65)
self.model.fit(x_train, y_train)
y_pred = self.model.predict(x_test)
print(confusion_matrix(y_test, y_pred))
pre_score = accuracy_score(y_test, y_pred)
self.log.log("Test accuracy is:" + str(pre_score * 100) + "%.")
y_pred = self.model.predict(x_train)
pre_score = accuracy_score(y_train, y_pred)
self.log.log("Train accuracy is:" + str(pre_score * 100) + "%.")
y_pred = self.model.predict(data_x)
pre_score = accuracy_score(data_y, y_pred)
self.log.log("Total accuracy is:" + str(pre_score * 100) + "%.")
return int(pre_score * 100)
def fit_value(self, file_name=None, data_path='data/1.txt', select_bands=[91, 92, 93, 94, 95, 96, 97, 98, 99, 100]):
data_x, data_y = self.data_construction(data_path, select_bands)
score = self.fit(data_x, data_y)
print('score:', score)
model_name = self.save(file_name=file_name)
return score, model_name
def save(self, file_name):
# 保存模型
if file_name is None:
file_name = "model_" + time.strftime("%Y-%m-%d_%H-%M") + ".p"
file_name = os.path.join(ROOT_DIR, "models", file_name)
model_dic = {'model': self.model}
with open(file_name, 'wb') as f:
pickle.dump(model_dic, f)
self.log.log("Model saved to '" + str(file_name) + "'.")
return file_name
# def data_construction(self, data_path, select_bands):
# data = utils.read_envi_ascii(data_path)
# beijing = data['beijing'][:, select_bands]
# zazhi1 = data['zazhi1'][:, select_bands]
# # zazhi2 = data['zazhi2'][:, select_bands]
# huangqi = data['huangqi'][:, select_bands]
# gancaopian = data['gancaopian'][:, select_bands]
# # hongqi = data['hongqi'][:, select_bands]
# beijing_y = np.zeros(beijing.shape[0])
# zazhi1_y = np.ones(zazhi1.shape[0]) * 3
# # zazhi2_y = np.ones(zazhi2.shape[0]) * 2
# huangqi_y = np.ones(huangqi.shape[0]) * 1
# gancaopian_y = np.ones(gancaopian.shape[0]) * 4
# # hongqi_y = np.ones(hongqi.shape[0]) * 5
# data_x = np.concatenate((beijing, zazhi1, huangqi, gancaopian), axis=0)
# data_y = np.concatenate((beijing_y, zazhi1_y, huangqi_y, gancaopian_y), axis=0)
# return data_x, data_y
def data_construction(self, data_path='data/1.txt', select_bands=[91, 92, 93, 94, 95, 96, 97, 98, 99, 100],
type=['beijing', 'zazhi1', 'huangqi', 'gancaopian']):
'''
:param data_path: 数据文件路径
:param select_bands: 选择的波段
:param type: 选择的类型
:return: data_x, data_y
'''
data = utils.read_envi_ascii(data_path)
# 判断读取的txt文件内是否有beijing和haungqi类型的数据
if 'beijing' not in data or 'huangqi' not in data:
logging.error("数据文件中缺少'beijing''huangqi'类型标签")
raise ValueError("数据文件中缺少'beijing''huangqi'类型标签")
data_x = np.concatenate([data[key][:, select_bands] for key in type], axis=0)
data_y = np.concatenate([np.zeros(data[key].shape[0]) if key == 'beijing' else np.ones(data[key].shape[0])
if key == 'huangqi' else np.ones(data[key].shape[0]) * (i + 2) for i, key in enumerate(type)], axis=0)
return data_x, data_y
def predict(self, data_x):
'''
对数据进行预测
:param data_x: 波段选择后的数据
:return: 预测结果二值化后的数据0为背景1为黄芪,2为杂质23为杂质14为甘草片5为红芪
'''
data_x_shape = data_x.shape
data_x = data_x.reshape(-1, data_x.shape[2])
data_y = self.model.predict(data_x)
data_y = data_y.reshape(data_x_shape[0], data_x_shape[1]).astype(np.uint8)
data_y, centers, categories = self.connect_space(data_y)
result = {'data_y': data_y, 'centers': centers, 'categories': categories}
return result
def connect_space(self, data_y):
# 连通域处理离散点
labels, num_features = ndimage.label(data_y)
centers = []
categories = []
for i in range(1, num_features + 1):
mask = (labels == i)
counts = np.bincount(data_y[mask])
category = np.argmax(counts)
data_y[mask] = category
center = ndimage.measurements.center_of_mass(data_y, labels, [i])
center = list(center)
center = np.array(center).astype(int)
centers.append(center)
categories.append(category)
return data_y, centers, categories
if __name__ == '__main__':
detector = Astragalin()
detector.fit_value(file_name="astragalin.p", data_path="data/1.txt")