mirror of
https://github.com/NanjingForestryUniversity/supermachine--tomato-passion_fruit.git
synced 2025-11-09 06:44:02 +00:00
165 lines
6.9 KiB
Python
165 lines
6.9 KiB
Python
import logging
|
||
import sys
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import cv2
|
||
from sklearn.tree import DecisionTreeClassifier
|
||
from sklearn.cluster import KMeans
|
||
from sklearn.linear_model import LogisticRegression
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.metrics import accuracy_score, confusion_matrix
|
||
from scipy import ndimage
|
||
from sklearn.tree import DecisionTreeClassifier
|
||
from sklearn.feature_selection import SelectKBest
|
||
from sklearn.neighbors import KNeighborsClassifier
|
||
from scipy.stats import binom
|
||
import matplotlib.pyplot as plt
|
||
import time
|
||
import pickle
|
||
import os
|
||
import utils
|
||
from root_dir import ROOT_DIR
|
||
|
||
|
||
class Astragalin(object):
|
||
def __init__(self, load_from=None, debug_mode=False, class_weight=None):
|
||
if load_from is None:
|
||
self.model = DecisionTreeClassifier(random_state=65, class_weight=class_weight)
|
||
else:
|
||
self.load(load_from)
|
||
self.log = utils.Logger(is_to_file=debug_mode)
|
||
self.debug_mode = debug_mode
|
||
|
||
def load(self, path=None):
|
||
if path is None:
|
||
path = os.path.join(ROOT_DIR, 'models')
|
||
model_files = os.listdir(path)
|
||
if len(model_files) == 0:
|
||
self.log.log("No model found!")
|
||
return 1
|
||
self.log.log("./ Models Found:")
|
||
_ = [self.log.log("├--" + str(model_file)) for model_file in model_files]
|
||
file_times = [model_file[6:-2] for model_file in model_files]
|
||
latest_model = model_files[int(np.argmax(file_times))]
|
||
self.log.log("└--Using the latest model: " + str(latest_model))
|
||
path = os.path.join(ROOT_DIR, "models", str(latest_model))
|
||
if not os.path.isabs(path):
|
||
logging.warning('给的是相对路径')
|
||
return -1
|
||
if not os.path.exists(path):
|
||
logging.warning('文件不存在')
|
||
return -1
|
||
with open(path, 'rb') as f:
|
||
model_dic = pickle.load(f)
|
||
self.model = model_dic['model']
|
||
return 0
|
||
|
||
def fit(self, data_x, data_y):
|
||
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.3, random_state=65)
|
||
self.model.fit(x_train, y_train)
|
||
y_pred = self.model.predict(x_test)
|
||
print(confusion_matrix(y_test, y_pred))
|
||
|
||
pre_score = accuracy_score(y_test, y_pred)
|
||
self.log.log("Test accuracy is:" + str(pre_score * 100) + "%.")
|
||
y_pred = self.model.predict(x_train)
|
||
|
||
pre_score = accuracy_score(y_train, y_pred)
|
||
self.log.log("Train accuracy is:" + str(pre_score * 100) + "%.")
|
||
y_pred = self.model.predict(data_x)
|
||
|
||
pre_score = accuracy_score(data_y, y_pred)
|
||
self.log.log("Total accuracy is:" + str(pre_score * 100) + "%.")
|
||
|
||
return int(pre_score * 100)
|
||
|
||
def fit_value(self, file_name=None, data_path='data/1.txt', select_bands=[91, 92, 93, 94, 95, 96, 97, 98, 99, 100]):
|
||
data_x, data_y = self.data_construction(data_path, select_bands)
|
||
score = self.fit(data_x, data_y)
|
||
print('score:', score)
|
||
model_name = self.save(file_name=file_name)
|
||
return score, model_name
|
||
|
||
def save(self, file_name):
|
||
# 保存模型
|
||
if file_name is None:
|
||
file_name = "model_" + time.strftime("%Y-%m-%d_%H-%M") + ".p"
|
||
file_name = os.path.join(ROOT_DIR, "models", file_name)
|
||
model_dic = {'model': self.model}
|
||
with open(file_name, 'wb') as f:
|
||
pickle.dump(model_dic, f)
|
||
self.log.log("Model saved to '" + str(file_name) + "'.")
|
||
return file_name
|
||
|
||
# def data_construction(self, data_path, select_bands):
|
||
# data = utils.read_envi_ascii(data_path)
|
||
# beijing = data['beijing'][:, select_bands]
|
||
# zazhi1 = data['zazhi1'][:, select_bands]
|
||
# # zazhi2 = data['zazhi2'][:, select_bands]
|
||
# huangqi = data['huangqi'][:, select_bands]
|
||
# gancaopian = data['gancaopian'][:, select_bands]
|
||
# # hongqi = data['hongqi'][:, select_bands]
|
||
# beijing_y = np.zeros(beijing.shape[0])
|
||
# zazhi1_y = np.ones(zazhi1.shape[0]) * 3
|
||
# # zazhi2_y = np.ones(zazhi2.shape[0]) * 2
|
||
# huangqi_y = np.ones(huangqi.shape[0]) * 1
|
||
# gancaopian_y = np.ones(gancaopian.shape[0]) * 4
|
||
# # hongqi_y = np.ones(hongqi.shape[0]) * 5
|
||
# data_x = np.concatenate((beijing, zazhi1, huangqi, gancaopian), axis=0)
|
||
# data_y = np.concatenate((beijing_y, zazhi1_y, huangqi_y, gancaopian_y), axis=0)
|
||
# return data_x, data_y
|
||
|
||
def data_construction(self, data_path='data/1.txt', select_bands=[91, 92, 93, 94, 95, 96, 97, 98, 99, 100],
|
||
type=['beijing', 'zazhi1', 'huangqi', 'gancaopian']):
|
||
'''
|
||
:param data_path: 数据文件路径
|
||
:param select_bands: 选择的波段
|
||
:param type: 选择的类型
|
||
:return: data_x, data_y
|
||
'''
|
||
data = utils.read_envi_ascii(data_path)
|
||
# 判断读取的txt文件内是否有beijing和haungqi类型的数据
|
||
if 'beijing' not in data or 'huangqi' not in data:
|
||
logging.error("数据文件中缺少'beijing'或'huangqi'类型标签")
|
||
raise ValueError("数据文件中缺少'beijing'或'huangqi'类型标签")
|
||
data_x = np.concatenate([data[key][:, select_bands] for key in type], axis=0)
|
||
data_y = np.concatenate([np.zeros(data[key].shape[0]) if key == 'beijing' else np.ones(data[key].shape[0])
|
||
if key == 'huangqi' else np.ones(data[key].shape[0]) * (i + 2) for i, key in enumerate(type)], axis=0)
|
||
return data_x, data_y
|
||
|
||
def predict(self, data_x):
|
||
'''
|
||
对数据进行预测
|
||
:param data_x: 波段选择后的数据
|
||
:return: 预测结果二值化后的数据,0为背景,1为黄芪,2为杂质2,3为杂质1,4为甘草片,5为红芪
|
||
'''
|
||
data_x_shape = data_x.shape
|
||
data_x = data_x.reshape(-1, data_x.shape[2])
|
||
data_y = self.model.predict(data_x)
|
||
data_y = data_y.reshape(data_x_shape[0], data_x_shape[1]).astype(np.uint8)
|
||
data_y, centers, categories = self.connect_space(data_y)
|
||
result = {'data_y': data_y, 'centers': centers, 'categories': categories}
|
||
return result
|
||
|
||
def connect_space(self, data_y):
|
||
# 连通域处理离散点
|
||
labels, num_features = ndimage.label(data_y)
|
||
centers = []
|
||
categories = []
|
||
for i in range(1, num_features + 1):
|
||
mask = (labels == i)
|
||
counts = np.bincount(data_y[mask])
|
||
category = np.argmax(counts)
|
||
data_y[mask] = category
|
||
center = ndimage.measurements.center_of_mass(data_y, labels, [i])
|
||
center = list(center)
|
||
center = np.array(center).astype(int)
|
||
centers.append(center)
|
||
categories.append(category)
|
||
return data_y, centers, categories
|
||
|
||
|
||
if __name__ == '__main__':
|
||
detector = Astragalin()
|
||
detector.fit_value(file_name="astragalin.p", data_path="data/1.txt") |