SCNet/utils.py
karllzy 19d8a7c588 Add visualization
add visualization for our network
2023-03-28 10:49:37 +08:00

154 lines
6.2 KiB
Python

from scipy.io import loadmat
import numpy as np
from sklearn.model_selection import train_test_split
import os
import shutil
def load_data(data_path='./pine_water_cc.mat', validation_rate=0.25):
if data_path == './pine_water_cc.mat':
data = loadmat(data_path)
y_train, y_test = data['value_train'], data['value_test']
print('Value train shape: ', y_train.shape, 'Value test shape', y_test.shape)
y_max_value, y_min_value = data['value_max'], data['value_min']
x_train, x_test = data['DL_train'], data['DL_test']
elif data_path == './N_100_leaf_cc.mat':
data = loadmat(data_path)
y_train, y_test = data['y_train'], data['y_test']
x_train, x_test = data['x_train'], data['x_test']
y_max_value, y_min_value = data['max_y'], data['min_y']
x_train = np.expand_dims(x_train, axis=1)
x_test = np.expand_dims(x_test, axis=1)
x_validation, y_validation = x_test, y_test
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
else:
data = loadmat(data_path)
y_train, y_test = data['y_train'], data['y_test']
x_train, x_test = data['x_train'], data['x_test']
y_max_value, y_min_value = data['max_y'], data['min_y']
x_train = np.expand_dims(x_train, axis=1)
x_test = np.expand_dims(x_test, axis=1)
print('SG17 DATA train shape: ', x_train.shape, 'SG17 DATA test shape', x_test.shape)
print('Mini value: %s, Max value %s.' % (y_min_value, y_max_value))
x_train, x_validation, y_train, y_validation = train_test_split(x_train, y_train, test_size=validation_rate,
random_state=8)
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
def mkdir_if_not_exist(dir_name, is_delete=False):
"""
创建文件夹
:param dir_name: 文件夹
:param is_delete: 是否删除
:return: 是否成功
"""
try:
if is_delete:
if os.path.exists(dir_name):
shutil.rmtree(dir_name)
print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name)
return True
except Exception as e:
print('[Exception] %s' % e)
return False
class Config:
def __init__(self):
# 数据有关的参数
self.validation_rate = 0.2
# 训练有关参数
self.train_epoch = 20000
self.batch_size = 20
# 是否训练的参数
self.train_cnn = True
self.train_ms_cnn = True
self.train_ms_sc_cnn = True
# 是否评估参数
self.evaluate_cnn = True
self.evaluate_ms_cnn = True
self.evaluate_ms_sc_cnn = True
# 要评估的保存好的模型列表
self.evaluate_cnn_name_list = []
self.evaluate_ms_cnn_name_list = []
self.evaluate_ms_sc_cnn_name_list = []
# 存储训练出的模型和图片的文件夹
self.img_dir = './pictures0331'
self.checkpoint_dir = './check_points0331'
# 数据集选择
self.data_set = './dataset_preprocess/corn/corn_mositure.mat'
def show_yourself(self, to_text_file=None):
line_width = 36
content = '\n'
# create line
line_text = 'Data Parameters'
line = '='*((line_width-len(line_text))//2) + line_text + '='*((line_width-len(line_text))//2)
line.ljust(line_width, '=')
content += line + '\n'
content += 'Validation Rate: ' + str(self.validation_rate) + '\n'
# create line
line_text = 'Training Parameters'
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
line.ljust(line_width, '=')
content += line + '\n'
content += 'Train CNN: ' + str(self.train_cnn) + '\n'
content += 'Train Ms CNN: ' + str(self.train_ms_cnn) + '\n'
content += 'Train Ms Sc CNN: ' + str(self.train_ms_sc_cnn) + '\n'
# create line
line_text = 'Evaluate Parameters'
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
line.ljust(line_width, '=')
content += line + '\n'
content += 'Train Epoch: ' + str(self.train_epoch) + '\n'
content += 'Train Batch Size: ' + str(self.batch_size) + '\n'
content += 'Evaluate CNN: ' + str(self.evaluate_cnn) + '\n'
if len(self.evaluate_cnn_name_list) >=1:
content += 'Saved CNNs to Evaluate:\n'
for models in self.evaluate_cnn_name_list:
content += models + '\n'
content += 'Evaluate Ms CNN: ' + str(self.evaluate_ms_cnn) + '\n'
if len(self.evaluate_ms_cnn_name_list) >= 1:
content += 'Saved Ms CNNs to Evaluate:\n'
for models in self.evaluate_ms_cnn_name_list:
content += models + '\n'
content += 'Evaluate Ms Sc CNN: ' + str(self.evaluate_ms_cnn) + '\n'
if len(self.evaluate_ms_sc_cnn_name_list) >= 1:
content += 'Saved Ms Sc CNNs to Evaluate:\n'
for models in self.evaluate_ms_sc_cnn_name_list:
content += models + '\n'
# create line
line_text = 'Saving Dir'
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
line.ljust(line_width, '=')
content += line + '\n'
content += 'Image Dir: ' + str(self.img_dir) + '\n'
content += 'Check Point Dir: ' + str(self.img_dir) + '\n'
print(content)
if to_text_file:
with open(to_text_file, 'w') as f:
f.write(content)
return content
if __name__ == '__main__':
config = Config()
config.show_yourself(to_text_file='name.txt')
x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value = \
load_data(data_path='./yaowan_calibrate.mat', validation_rate=0.25)
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape, x_validation.shape, y_validation.shape,
y_max_value, y_min_value)