mirror of
https://github.com/NanjingForestryUniversity/SCNet.git
synced 2025-11-08 14:24:03 +00:00
154 lines
6.2 KiB
Python
154 lines
6.2 KiB
Python
from scipy.io import loadmat
|
|
import numpy as np
|
|
from sklearn.model_selection import train_test_split
|
|
import os
|
|
import shutil
|
|
|
|
|
|
def load_data(data_path='./pine_water_cc.mat', validation_rate=0.25):
|
|
if data_path == './pine_water_cc.mat':
|
|
data = loadmat(data_path)
|
|
y_train, y_test = data['value_train'], data['value_test']
|
|
print('Value train shape: ', y_train.shape, 'Value test shape', y_test.shape)
|
|
y_max_value, y_min_value = data['value_max'], data['value_min']
|
|
x_train, x_test = data['DL_train'], data['DL_test']
|
|
elif data_path == './N_100_leaf_cc.mat':
|
|
data = loadmat(data_path)
|
|
y_train, y_test = data['y_train'], data['y_test']
|
|
x_train, x_test = data['x_train'], data['x_test']
|
|
y_max_value, y_min_value = data['max_y'], data['min_y']
|
|
x_train = np.expand_dims(x_train, axis=1)
|
|
x_test = np.expand_dims(x_test, axis=1)
|
|
x_validation, y_validation = x_test, y_test
|
|
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
|
|
else:
|
|
data = loadmat(data_path)
|
|
y_train, y_test = data['y_train'], data['y_test']
|
|
x_train, x_test = data['x_train'], data['x_test']
|
|
y_max_value, y_min_value = data['max_y'], data['min_y']
|
|
x_train = np.expand_dims(x_train, axis=1)
|
|
x_test = np.expand_dims(x_test, axis=1)
|
|
print('SG17 DATA train shape: ', x_train.shape, 'SG17 DATA test shape', x_test.shape)
|
|
|
|
print('Mini value: %s, Max value %s.' % (y_min_value, y_max_value))
|
|
|
|
x_train, x_validation, y_train, y_validation = train_test_split(x_train, y_train, test_size=validation_rate,
|
|
random_state=8)
|
|
|
|
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
|
|
|
|
|
|
def mkdir_if_not_exist(dir_name, is_delete=False):
|
|
"""
|
|
创建文件夹
|
|
:param dir_name: 文件夹
|
|
:param is_delete: 是否删除
|
|
:return: 是否成功
|
|
"""
|
|
try:
|
|
if is_delete:
|
|
if os.path.exists(dir_name):
|
|
shutil.rmtree(dir_name)
|
|
print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name)
|
|
|
|
if not os.path.exists(dir_name):
|
|
os.makedirs(dir_name)
|
|
print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name)
|
|
return True
|
|
except Exception as e:
|
|
print('[Exception] %s' % e)
|
|
return False
|
|
|
|
|
|
class Config:
|
|
def __init__(self):
|
|
# 数据有关的参数
|
|
self.validation_rate = 0.2
|
|
# 训练有关参数
|
|
self.train_epoch = 20000
|
|
self.batch_size = 20
|
|
# 是否训练的参数
|
|
self.train_cnn = True
|
|
self.train_ms_cnn = True
|
|
self.train_ms_sc_cnn = True
|
|
# 是否评估参数
|
|
self.evaluate_cnn = True
|
|
self.evaluate_ms_cnn = True
|
|
self.evaluate_ms_sc_cnn = True
|
|
# 要评估的保存好的模型列表
|
|
self.evaluate_cnn_name_list = []
|
|
self.evaluate_ms_cnn_name_list = []
|
|
self.evaluate_ms_sc_cnn_name_list = []
|
|
|
|
# 存储训练出的模型和图片的文件夹
|
|
self.img_dir = './pictures0331'
|
|
self.checkpoint_dir = './check_points0331'
|
|
|
|
# 数据集选择
|
|
self.data_set = './dataset_preprocess/corn/corn_mositure.mat'
|
|
|
|
def show_yourself(self, to_text_file=None):
|
|
line_width = 36
|
|
content = '\n'
|
|
# create line
|
|
line_text = 'Data Parameters'
|
|
line = '='*((line_width-len(line_text))//2) + line_text + '='*((line_width-len(line_text))//2)
|
|
line.ljust(line_width, '=')
|
|
content += line + '\n'
|
|
content += 'Validation Rate: ' + str(self.validation_rate) + '\n'
|
|
# create line
|
|
line_text = 'Training Parameters'
|
|
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
|
line.ljust(line_width, '=')
|
|
content += line + '\n'
|
|
content += 'Train CNN: ' + str(self.train_cnn) + '\n'
|
|
content += 'Train Ms CNN: ' + str(self.train_ms_cnn) + '\n'
|
|
content += 'Train Ms Sc CNN: ' + str(self.train_ms_sc_cnn) + '\n'
|
|
# create line
|
|
line_text = 'Evaluate Parameters'
|
|
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
|
line.ljust(line_width, '=')
|
|
content += line + '\n'
|
|
content += 'Train Epoch: ' + str(self.train_epoch) + '\n'
|
|
content += 'Train Batch Size: ' + str(self.batch_size) + '\n'
|
|
|
|
content += 'Evaluate CNN: ' + str(self.evaluate_cnn) + '\n'
|
|
if len(self.evaluate_cnn_name_list) >=1:
|
|
content += 'Saved CNNs to Evaluate:\n'
|
|
for models in self.evaluate_cnn_name_list:
|
|
content += models + '\n'
|
|
|
|
content += 'Evaluate Ms CNN: ' + str(self.evaluate_ms_cnn) + '\n'
|
|
if len(self.evaluate_ms_cnn_name_list) >= 1:
|
|
content += 'Saved Ms CNNs to Evaluate:\n'
|
|
for models in self.evaluate_ms_cnn_name_list:
|
|
content += models + '\n'
|
|
|
|
content += 'Evaluate Ms Sc CNN: ' + str(self.evaluate_ms_cnn) + '\n'
|
|
if len(self.evaluate_ms_sc_cnn_name_list) >= 1:
|
|
content += 'Saved Ms Sc CNNs to Evaluate:\n'
|
|
for models in self.evaluate_ms_sc_cnn_name_list:
|
|
content += models + '\n'
|
|
|
|
# create line
|
|
line_text = 'Saving Dir'
|
|
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
|
line.ljust(line_width, '=')
|
|
content += line + '\n'
|
|
content += 'Image Dir: ' + str(self.img_dir) + '\n'
|
|
content += 'Check Point Dir: ' + str(self.img_dir) + '\n'
|
|
print(content)
|
|
if to_text_file:
|
|
with open(to_text_file, 'w') as f:
|
|
f.write(content)
|
|
return content
|
|
|
|
|
|
if __name__ == '__main__':
|
|
config = Config()
|
|
config.show_yourself(to_text_file='name.txt')
|
|
x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value = \
|
|
load_data(data_path='./yaowan_calibrate.mat', validation_rate=0.25)
|
|
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape, x_validation.shape, y_validation.shape,
|
|
y_max_value, y_min_value)
|