mirror of
https://github.com/NanjingForestryUniversity/SCNet.git
synced 2025-11-08 14:24:03 +00:00
First Commit
This commit is contained in:
commit
17e8992073
219
.gitignore
vendored
Normal file
219
.gitignore
vendored
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
preprocess/dataset/*
|
||||||
|
checkpoints/*
|
||||||
|
.idea
|
||||||
|
### JetBrains template
|
||||||
|
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||||
|
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||||
|
|
||||||
|
# User-specific stuff
|
||||||
|
.idea/**/workspace.xml
|
||||||
|
.idea/**/tasks.xml
|
||||||
|
.idea/**/usage.statistics.xml
|
||||||
|
.idea/**/dictionaries
|
||||||
|
.idea/**/shelf
|
||||||
|
|
||||||
|
# Generated files
|
||||||
|
.idea/**/contentModel.xml
|
||||||
|
|
||||||
|
# Sensitive or high-churn files
|
||||||
|
.idea/**/dataSources/
|
||||||
|
.idea/**/dataSources.ids
|
||||||
|
.idea/**/dataSources.local.xml
|
||||||
|
.idea/**/sqlDataSources.xml
|
||||||
|
.idea/**/dynamic.xml
|
||||||
|
.idea/**/uiDesigner.xml
|
||||||
|
.idea/**/dbnavigator.xml
|
||||||
|
|
||||||
|
# Gradle
|
||||||
|
.idea/**/gradle.xml
|
||||||
|
.idea/**/libraries
|
||||||
|
|
||||||
|
# Gradle and Maven with auto-import
|
||||||
|
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||||
|
# since they will be recreated, and may cause churn. Uncomment if using
|
||||||
|
# auto-import.
|
||||||
|
# .idea/artifacts
|
||||||
|
# .idea/compiler.xml
|
||||||
|
# .idea/jarRepositories.xml
|
||||||
|
# .idea/modules.xml
|
||||||
|
# .idea/*.iml
|
||||||
|
# .idea/modules
|
||||||
|
# *.iml
|
||||||
|
# *.ipr
|
||||||
|
|
||||||
|
# CMake
|
||||||
|
cmake-build-*/
|
||||||
|
|
||||||
|
# Mongo Explorer plugin
|
||||||
|
.idea/**/mongoSettings.xml
|
||||||
|
|
||||||
|
# File-based project format
|
||||||
|
*.iws
|
||||||
|
|
||||||
|
# IntelliJ
|
||||||
|
out/
|
||||||
|
|
||||||
|
# mpeltonen/sbt-idea plugin
|
||||||
|
.idea_modules/
|
||||||
|
|
||||||
|
# JIRA plugin
|
||||||
|
atlassian-ide-plugin.xml
|
||||||
|
|
||||||
|
# Cursive Clojure plugin
|
||||||
|
.idea/replstate.xml
|
||||||
|
|
||||||
|
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||||
|
com_crashlytics_export_strings.xml
|
||||||
|
crashlytics.properties
|
||||||
|
crashlytics-build.properties
|
||||||
|
fabric.properties
|
||||||
|
|
||||||
|
# Editor-based Rest Client
|
||||||
|
.idea/httpRequests
|
||||||
|
|
||||||
|
# Android studio 3.1+ serialized cache file
|
||||||
|
.idea/caches/build_file_checksums.ser
|
||||||
|
|
||||||
|
### Python template
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
!/checkpoints/
|
||||||
|
!/preprocess/dataset/
|
||||||
|
!/preprocess/dataset/
|
||||||
21
README.md
Normal file
21
README.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# SCNet: A deep learning network framework for analyzing near-infrared spectroscopy using short-cut
|
||||||
|
## Pre-processing
|
||||||
|
|
||||||
|
Since the method we proposed is a regression model, the classification dataset weat kernel is not used in this work.
|
||||||
|
|
||||||
|
The other three dataset (corn, marzipan, soil) were preprocessed manually with Matlab and saved in the sub dictionary of `./preprocess` dir. The original dataset of these three dataset were stored in the `./preprocess/dataset/`.
|
||||||
|
|
||||||
|
The mango dataset is not in Matlab .m file format, so we save them with the `process.py`.
|
||||||
|
Meanwhile, we drop the useless part and only save the data between 684 and 900 nm.
|
||||||
|
|
||||||
|
> The data set used in this study comprises a total of 11,691 NIR spectra (684–990 nm in 3 nm sampling with a total 103 variables) and DM measurements performed on 4675 mango fruit across 4 harvest seasons 2015, 2016, 2017 and 2018 [24].
|
||||||
|
|
||||||
|
The detailed preprocessing progress can be found in [./preprocess.ipynb](./preprocess.ipynb)
|
||||||
|
|
||||||
|
## Network Training
|
||||||
|
|
||||||
|
In order to show our network can prevent degration problem, we hold the experiment which contains the training loss curve of four models. The detailed information can be found in [model_training.ipynb](./model_training.ipynb).
|
||||||
|
|
||||||
|
## Network evaluation
|
||||||
|
After training our model on training set, we evaluate the models on testing dataset that spared before. The evaluation is done with [model_evaluation.ipynb](model_evaluating.ipynb).
|
||||||
|
|
||||||
155
model_evaluating.ipynb
Normal file
155
model_evaluating.ipynb
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Experiment 2: Model Evaluating"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from keras.models import load_model\n",
|
||||||
|
"from matplotlib import ticker\n",
|
||||||
|
"from scipy.io import loadmat\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from sklearn.metrics import mean_squared_error\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"%matplotlib inline"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"In this experiment, we load model weights from the experiment1 and evaluate them on test dataset."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 30,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"shape of data:\n",
|
||||||
|
"x_train: (5728, 1, 102), y_train: (5728, 1),\n",
|
||||||
|
"x_val: (2455, 1, 102), y_val: (2455, 1)\n",
|
||||||
|
"x_test: (3508, 1, 102), y_test: (3508, 1)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"data = loadmat('./preprocess/dataset/mango/mango_dm_split.mat')\n",
|
||||||
|
"x_train, y_train, x_test, y_test = data['x_train'], data['y_train'], data['x_test'], data['y_test']\n",
|
||||||
|
"x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.3, random_state=12, shuffle=True)\n",
|
||||||
|
"x_train, x_val, x_test = x_train[:, np.newaxis, :], x_val[:, np.newaxis, :], x_test[:, np.newaxis, :]\n",
|
||||||
|
"print(f\"shape of data:\\n\"\n",
|
||||||
|
" f\"x_train: {x_train.shape}, y_train: {y_train.shape},\\n\"\n",
|
||||||
|
" f\"x_val: {x_val.shape}, y_val: {y_val.shape}\\n\"\n",
|
||||||
|
" f\"x_test: {x_test.shape}, y_test: {y_test.shape}\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"## Build model and load weights\n",
|
||||||
|
"plain_5, plain_11 = load_model('./checkpoints/plain5.hdf5'), load_model('./checkpoints/plain11.hdf5')\n",
|
||||||
|
"shortcut5, shortcut11 = load_model('./checkpoints/shortcut5.hdf5'), load_model('./checkpoints/shortcut11.hdf5')\n",
|
||||||
|
"models = {'plain 5': plain_5, 'plain 11': plain_11, 'shortcut 5': shortcut5, 'shortcut11': shortcut11}\n",
|
||||||
|
"results = {model_name: model.predict(x_test).reshape((-1, )) for model_name, model in models.items()}\n",
|
||||||
|
"for model_name, model_result in results.items():\n",
|
||||||
|
" print(model_name, \" : \", mean_squared_error(y_test, model_result)*100, \"%\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"execution_count": 31,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"plain 5 : 0.2707851525589865 %\n",
|
||||||
|
"plain 11 : 0.26240810192725905 %\n",
|
||||||
|
"shortcut 5 : 0.28330442301217196 %\n",
|
||||||
|
"shortcut11 : 0.25743312483685266 %\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 31,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
7124
model_training.ipynb
Normal file
7124
model_training.ipynb
Normal file
File diff suppressed because one or more lines are too long
264
models.py
Normal file
264
models.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
import keras.callbacks
|
||||||
|
import keras.layers as KL
|
||||||
|
from keras import Model
|
||||||
|
from keras.optimizers import adam_v2
|
||||||
|
|
||||||
|
|
||||||
|
class Plain5(object):
|
||||||
|
def __init__(self, model_path=None, input_shape=None):
|
||||||
|
self.model = None
|
||||||
|
self.input_shape = input_shape
|
||||||
|
if model_path is not None:
|
||||||
|
# TODO: loading from the file
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.model = self.build_model()
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
input_layer = KL.Input(self.input_shape, name='input')
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv1')(input_layer)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv2')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv3')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Dense(20, activation='relu', name='dense')(x)
|
||||||
|
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||||
|
model = Model(input_layer, x)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||||
|
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||||
|
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/plain5.hdf5', monitor='val_loss',
|
||||||
|
mode="min", save_best_only=True)
|
||||||
|
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
|
||||||
|
patience=1000, verbose=0, mode='auto')
|
||||||
|
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=25, min_delta=1e-6)
|
||||||
|
callbacks = [checkpoint, early_stop, lr_decay]
|
||||||
|
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||||
|
callbacks=callbacks, batch_size=batch_size)
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
class Residual5(object):
|
||||||
|
def __init__(self, model_path=None, input_shape=None):
|
||||||
|
self.model = None
|
||||||
|
self.input_shape = input_shape
|
||||||
|
if model_path is not None:
|
||||||
|
# TODO: loading from the file
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.model = self.build_model()
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
input_layer = KL.Input(self.input_shape, name='input')
|
||||||
|
fx = KL.Conv1D(8, 3, padding='same', name='Conv1')(input_layer)
|
||||||
|
fx = KL.BatchNormalization()(fx)
|
||||||
|
x = KL.Activation('relu')(fx)
|
||||||
|
|
||||||
|
fx = KL.Conv1D(8, 3, padding='same', name='Conv2')(x)
|
||||||
|
fx = KL.BatchNormalization()(fx)
|
||||||
|
fx = KL.Activation('relu')(fx)
|
||||||
|
x = fx + x
|
||||||
|
|
||||||
|
fx = KL.Conv1D(8, 3, padding='same', name='Conv3')(x)
|
||||||
|
fx = KL.BatchNormalization()(fx)
|
||||||
|
fx = KL.Activation('relu')(fx)
|
||||||
|
x = fx + x
|
||||||
|
|
||||||
|
x = KL.Dense(20, activation='relu', name='dense')(x)
|
||||||
|
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||||
|
model = Model(input_layer, x)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||||
|
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||||
|
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/res5.hdf5', monitor='val_loss',
|
||||||
|
mode="min", save_best_only=True)
|
||||||
|
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
|
||||||
|
patience=1000, verbose=0, mode='auto')
|
||||||
|
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=25, min_delta=1e-6)
|
||||||
|
callbacks = [checkpoint, early_stop, lr_decay]
|
||||||
|
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||||
|
callbacks=callbacks, batch_size=batch_size)
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
class ShortCut5(object):
|
||||||
|
def __init__(self, model_path=None, input_shape=None):
|
||||||
|
self.model = None
|
||||||
|
self.input_shape = input_shape
|
||||||
|
if model_path is not None:
|
||||||
|
# TODO: loading from the file
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.model = self.build_model()
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
input_layer = KL.Input(self.input_shape, name='input')
|
||||||
|
x_raw = KL.Conv1D(8, 3, padding='same', name='Conv1')(input_layer)
|
||||||
|
fx1 = KL.BatchNormalization()(x_raw)
|
||||||
|
fx1 = KL.Activation('relu')(fx1)
|
||||||
|
|
||||||
|
fx2 = KL.Conv1D(8, 3, padding='same', name='Conv2')(fx1)
|
||||||
|
fx2 = KL.BatchNormalization()(fx2)
|
||||||
|
fx2 = KL.Activation('relu')(fx2)
|
||||||
|
|
||||||
|
fx3 = KL.Conv1D(8, 3, padding='same', name='Conv3')(fx2)
|
||||||
|
fx3 = KL.BatchNormalization()(fx3)
|
||||||
|
fx3 = KL.Activation('relu')(fx3)
|
||||||
|
x = KL.Concatenate(axis=2)([x_raw, fx1, fx2, fx3])
|
||||||
|
|
||||||
|
x = KL.Dense(20, activation='relu', name='dense')(x)
|
||||||
|
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||||
|
model = Model(input_layer, x)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||||
|
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||||
|
|
||||||
|
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/shortcut5.hdf5', monitor='val_loss',
|
||||||
|
mode="min", save_best_only=True)
|
||||||
|
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
|
||||||
|
patience=1000, verbose=0, mode='auto')
|
||||||
|
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=25, min_delta=1e-6)
|
||||||
|
callbacks = [checkpoint, early_stop, lr_decay]
|
||||||
|
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||||
|
callbacks=callbacks, batch_size=batch_size)
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
class ShortCut11(object):
|
||||||
|
def __init__(self, model_path=None, input_shape=None):
|
||||||
|
self.model = None
|
||||||
|
self.input_shape = input_shape
|
||||||
|
if model_path is not None:
|
||||||
|
# TODO: loading from the file
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.model = self.build_model()
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
input_layer = KL.Input(self.input_shape, name='input')
|
||||||
|
x_raw = KL.Conv1D(8, 3, padding='same', name='Conv1_1')(input_layer)
|
||||||
|
x = KL.BatchNormalization()(x_raw)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv1_2')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv1_3')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
fx1 = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv2_1')(fx1)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv2_2')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv2_3')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
fx2 = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv3_1')(fx2)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv3_2')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv3_3')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
fx3 = KL.Activation('relu')(x)
|
||||||
|
x = KL.Concatenate(axis=2)([x_raw, fx1, fx2, fx3])
|
||||||
|
|
||||||
|
x = KL.Dense(200, activation='relu', name='dense1')(x)
|
||||||
|
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||||
|
model = Model(input_layer, x)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||||
|
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||||
|
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/shortcut11.hdf5', monitor='val_loss',
|
||||||
|
mode="min", save_best_only=True)
|
||||||
|
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-6,
|
||||||
|
patience=200, verbose=0, mode='auto')
|
||||||
|
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,
|
||||||
|
patience=25, min_delta=1e-6)
|
||||||
|
callbacks = [checkpoint, early_stop, lr_decay]
|
||||||
|
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||||
|
callbacks=callbacks, batch_size=batch_size)
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
class Plain11(object):
|
||||||
|
def __init__(self, model_path=None, input_shape=None):
|
||||||
|
self.model = None
|
||||||
|
self.input_shape = input_shape
|
||||||
|
if model_path is not None:
|
||||||
|
# TODO: loading from the file
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.model = self.build_model()
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
input_layer = KL.Input(self.input_shape, name='input')
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv1_1')(input_layer)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv1_2')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv1_3')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv2_1')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv2_2')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv2_3')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv3_1')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv3_2')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
x = KL.Conv1D(8, 3, padding='same', name='Conv3_3')(x)
|
||||||
|
x = KL.BatchNormalization()(x)
|
||||||
|
x = KL.Activation('relu')(x)
|
||||||
|
|
||||||
|
x = KL.Dense(200, activation='relu', name='dense1')(x)
|
||||||
|
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||||
|
model = Model(input_layer, x)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||||
|
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||||
|
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/plain11.hdf5', monitor='val_loss',
|
||||||
|
mode="min", save_best_only=True)
|
||||||
|
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-6,
|
||||||
|
patience=200, verbose=0, mode='auto')
|
||||||
|
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,
|
||||||
|
patience=25, min_delta=1e-6)
|
||||||
|
callbacks = [checkpoint, early_stop, lr_decay]
|
||||||
|
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||||
|
callbacks=callbacks, batch_size=batch_size)
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# plain5 = Plain5(model_path=None, input_shape=(1, 102))
|
||||||
|
# plain11 = Plain11(model_path=None, input_shape=(1, 102))
|
||||||
|
residual5 = Residual5(model_path=None, input_shape=(1, 102))
|
||||||
|
short5 = ShortCut5(model_path=None, input_shape=(1, 102))
|
||||||
127
preprocess.ipynb
Normal file
127
preprocess.ipynb
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "dd2c8c55",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Preprocessing"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "716880ac",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from scipy.io import savemat, loadmat\n",
|
||||||
|
"import os"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4d7dc4a0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Step 1: \n",
|
||||||
|
"Convert the dataset to mat format for Matlab."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "711356a2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = pd.read_csv('preprocess/dataset/mango/NAnderson2020MendeleyMangoNIRData.csv')\n",
|
||||||
|
"y = dataset.DM\n",
|
||||||
|
"x = dataset.loc[:, '684': '990']\n",
|
||||||
|
"savemat('preprocess/dataset/mango/mango_origin.mat', {'x': x.values, 'y': y.values})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3e41e8e6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ea5e54fd",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Step3:\n",
|
||||||
|
"Data split with train test split."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "6eac026e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"data = loadmat('preprocess/dataset/mango/mango_preprocessed.mat')\n",
|
||||||
|
"x, y = data['x'], data['y']\n",
|
||||||
|
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=24)\n",
|
||||||
|
"if not os.path.exists('mango'):\n",
|
||||||
|
" os.makedirs('mango')\n",
|
||||||
|
"savemat('preprocess/dataset/mango/mango_dm_split.mat',{'x_train':x_train, 'y_train':y_train, 'x_test':x_test, 'y_test':y_test,\n",
|
||||||
|
" 'max_y': data['max_y'], 'min_y': data['min_y'],\n",
|
||||||
|
" 'min_x':data['min_x'], 'max_x':data['max_x']})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b2977dae",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Step 4:\n",
|
||||||
|
"Show data with pictures\n",
|
||||||
|
"use `draw_pics_origin` to draw original spectra\n",
|
||||||
|
""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"use `draw_pics_preprocessed.m` to draw proprecessed spectra\n",
|
||||||
|
""
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "7f619fc91ee8bdab81d49e7c14228037474662e3f2d607687ae505108922fa06"
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.9.7 64-bit ('base': conda)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
45
preprocess/draw_pics_origin.m
Executable file
45
preprocess/draw_pics_origin.m
Executable file
@ -0,0 +1,45 @@
|
|||||||
|
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||||
|
f = figure;
|
||||||
|
f.Position(3:4) = [1331 331];
|
||||||
|
%%% draw the pic of corn spectra
|
||||||
|
load('dataset/corn.mat');
|
||||||
|
x = m5spec.data;
|
||||||
|
wave_length = m5spec.axisscale{2, 1};
|
||||||
|
subplot(1, 4, 1)
|
||||||
|
plot(wave_length, x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Absorbance');
|
||||||
|
clear
|
||||||
|
|
||||||
|
%%% draw the pic of Marzipan spectra
|
||||||
|
load('dataset/marzipan.mat');
|
||||||
|
x = NIRS1;
|
||||||
|
wave_length = NIRS1_axis;
|
||||||
|
subplot(1, 4, 2)
|
||||||
|
plot(wave_length, x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Absorbance');
|
||||||
|
clear
|
||||||
|
|
||||||
|
%%% draw the pic of Marzipan spectra
|
||||||
|
load('dataset/soil.mat');
|
||||||
|
x = soil.data;
|
||||||
|
wave_length = soil.axisscale{2, 1};
|
||||||
|
subplot(1, 4, 3)
|
||||||
|
plot(wave_length, x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Absorbance');
|
||||||
|
clear
|
||||||
|
|
||||||
|
% draw the pic of Mango spectra
|
||||||
|
load('dataset/mango/mango_origin.mat');
|
||||||
|
wave_length = 684: 3: 990;
|
||||||
|
subplot(1, 4, 4)
|
||||||
|
plot(wave_length, x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Signal intensity');
|
||||||
|
clear
|
||||||
48
preprocess/draw_pics_preprocessed.m
Executable file
48
preprocess/draw_pics_preprocessed.m
Executable file
@ -0,0 +1,48 @@
|
|||||||
|
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||||
|
f = figure;
|
||||||
|
f.Position(3:4) = [1331 331];
|
||||||
|
%%% draw the pic of corn spectra
|
||||||
|
load('dataset/corn.mat');
|
||||||
|
x = m5spec.data;
|
||||||
|
wave_length = m5spec.axisscale{2, 1};
|
||||||
|
preprocess;
|
||||||
|
subplot(1, 4, 1)
|
||||||
|
plot(wave_length(1, 1:end-1), x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Absorbance');
|
||||||
|
clear
|
||||||
|
|
||||||
|
%%% draw the pic of Marzipan spectra
|
||||||
|
load('dataset/marzipan.mat');
|
||||||
|
x = NIRS1;
|
||||||
|
wave_length = NIRS1_axis;
|
||||||
|
preprocess;
|
||||||
|
subplot(1, 4, 2)
|
||||||
|
plot(wave_length(1, 1:end-1), x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Absorbance');
|
||||||
|
clear
|
||||||
|
|
||||||
|
%%% draw the pic of Marzipan spectra
|
||||||
|
load('dataset/soil.mat');
|
||||||
|
x = soil.data;
|
||||||
|
wave_length = soil.axisscale{2, 1};
|
||||||
|
preprocess;
|
||||||
|
subplot(1, 4, 3)
|
||||||
|
plot(wave_length(1, 1:end-1), x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Absorbance');
|
||||||
|
clear
|
||||||
|
|
||||||
|
% draw the pic of Mango spectra
|
||||||
|
load('dataset/mango/mango_preprocessed.mat');
|
||||||
|
wave_length = 687: 3: 990;
|
||||||
|
subplot(1, 4, 4)
|
||||||
|
plot(wave_length, x');
|
||||||
|
xlim([wave_length(1) wave_length(end)]);
|
||||||
|
xlabel('Wavelength(nm)');
|
||||||
|
ylabel('Signal intensity');
|
||||||
|
clear
|
||||||
BIN
preprocess/pics/preprocessed.png
Normal file
BIN
preprocess/pics/preprocessed.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 89 KiB |
BIN
preprocess/pics/raw.png
Normal file
BIN
preprocess/pics/raw.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 175 KiB |
8
preprocess/preprocess.m
Executable file
8
preprocess/preprocess.m
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
%% x preprocessing
|
||||||
|
x = x';
|
||||||
|
x = sgolayfilt(x,2,17);
|
||||||
|
x =diff(x);
|
||||||
|
max_x=max(max(x));
|
||||||
|
min_x=min(min(x));
|
||||||
|
x=(x-min_x)/(max_x-min_x);
|
||||||
|
x = x';
|
||||||
15
preprocess/preprocess_mango.m
Executable file
15
preprocess/preprocess_mango.m
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
%% x preprocessing
|
||||||
|
clear;
|
||||||
|
load('dataset/mango/mango_origin.mat')
|
||||||
|
x = x';
|
||||||
|
x = sgolayfilt(x,2,17);
|
||||||
|
x =diff(x);
|
||||||
|
max_x=max(max(x));
|
||||||
|
min_x=min(min(x));
|
||||||
|
x=(x-min_x)/(max_x-min_x);
|
||||||
|
x = x';
|
||||||
|
y = y';
|
||||||
|
min_y = min(min(y));
|
||||||
|
max_y = max(max(y));
|
||||||
|
y = (y-min_y)/(max_y-min_y);
|
||||||
|
save('dataset/mango/mango_preprocessed.mat')
|
||||||
15
preprocess/train_test_split.m
Executable file
15
preprocess/train_test_split.m
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
data=[x,y];
|
||||||
|
test_rate = 0.3;
|
||||||
|
data_num = size(x, 1);
|
||||||
|
train_num = round((1-test_rate) * data_num);
|
||||||
|
idx=randperm(data_num);
|
||||||
|
train_idx=idx(1:train_num);
|
||||||
|
test_idx=idx(train_num+1:data_num);
|
||||||
|
data_train=data(train_idx,:);
|
||||||
|
x_train=data_train(:,1:size(x, 2));
|
||||||
|
y_train=data_train(:,size(x, 2)+1);
|
||||||
|
test_data=data(test_idx,:);
|
||||||
|
x_test=test_data(:,1:size(x, 2));
|
||||||
|
y_test=test_data(:,size(x, 2)+1);
|
||||||
|
clear data_num train_num idx train_idx test_idx test_data train_data x y;
|
||||||
|
clear data data_train test_rate;
|
||||||
153
utils.py
Executable file
153
utils.py
Executable file
@ -0,0 +1,153 @@
|
|||||||
|
from scipy.io import loadmat
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_path='./pine_water_cc.mat', validation_rate=0.25):
|
||||||
|
if data_path == './pine_water_cc.mat':
|
||||||
|
data = loadmat(data_path)
|
||||||
|
y_train, y_test = data['value_train'], data['value_test']
|
||||||
|
print('Value train shape: ', y_train.shape, 'Value test shape', y_test.shape)
|
||||||
|
y_max_value, y_min_value = data['value_max'], data['value_min']
|
||||||
|
x_train, x_test = data['DL_train'], data['DL_test']
|
||||||
|
elif data_path == './N_100_leaf_cc.mat':
|
||||||
|
data = loadmat(data_path)
|
||||||
|
y_train, y_test = data['y_train'], data['y_test']
|
||||||
|
x_train, x_test = data['x_train'], data['x_test']
|
||||||
|
y_max_value, y_min_value = data['max_y'], data['min_y']
|
||||||
|
x_train = np.expand_dims(x_train, axis=1)
|
||||||
|
x_test = np.expand_dims(x_test, axis=1)
|
||||||
|
x_validation, y_validation = x_test, y_test
|
||||||
|
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
|
||||||
|
else:
|
||||||
|
data = loadmat(data_path)
|
||||||
|
y_train, y_test = data['y_train'], data['y_test']
|
||||||
|
x_train, x_test = data['x_train'], data['x_test']
|
||||||
|
y_max_value, y_min_value = data['max_y'], data['min_y']
|
||||||
|
x_train = np.expand_dims(x_train, axis=1)
|
||||||
|
x_test = np.expand_dims(x_test, axis=1)
|
||||||
|
print('SG17 DATA train shape: ', x_train.shape, 'SG17 DATA test shape', x_test.shape)
|
||||||
|
|
||||||
|
print('Mini value: %s, Max value %s.' % (y_min_value, y_max_value))
|
||||||
|
|
||||||
|
x_train, x_validation, y_train, y_validation = train_test_split(x_train, y_train, test_size=validation_rate,
|
||||||
|
random_state=8)
|
||||||
|
|
||||||
|
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir_if_not_exist(dir_name, is_delete=False):
|
||||||
|
"""
|
||||||
|
创建文件夹
|
||||||
|
:param dir_name: 文件夹
|
||||||
|
:param is_delete: 是否删除
|
||||||
|
:return: 是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if is_delete:
|
||||||
|
if os.path.exists(dir_name):
|
||||||
|
shutil.rmtree(dir_name)
|
||||||
|
print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name)
|
||||||
|
|
||||||
|
if not os.path.exists(dir_name):
|
||||||
|
os.makedirs(dir_name)
|
||||||
|
print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print('[Exception] %s' % e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
def __init__(self):
|
||||||
|
# 数据有关的参数
|
||||||
|
self.validation_rate = 0.2
|
||||||
|
# 训练有关参数
|
||||||
|
self.train_epoch = 20000
|
||||||
|
self.batch_size = 20
|
||||||
|
# 是否训练的参数
|
||||||
|
self.train_cnn = True
|
||||||
|
self.train_ms_cnn = True
|
||||||
|
self.train_ms_sc_cnn = True
|
||||||
|
# 是否评估参数
|
||||||
|
self.evaluate_cnn = True
|
||||||
|
self.evaluate_ms_cnn = True
|
||||||
|
self.evaluate_ms_sc_cnn = True
|
||||||
|
# 要评估的保存好的模型列表
|
||||||
|
self.evaluate_cnn_name_list = []
|
||||||
|
self.evaluate_ms_cnn_name_list = []
|
||||||
|
self.evaluate_ms_sc_cnn_name_list = []
|
||||||
|
|
||||||
|
# 存储训练出的模型和图片的文件夹
|
||||||
|
self.img_dir = './pictures0331'
|
||||||
|
self.checkpoint_dir = './check_points0331'
|
||||||
|
|
||||||
|
# 数据集选择
|
||||||
|
self.data_set = './dataset_preprocess/corn/corn_mositure.mat'
|
||||||
|
|
||||||
|
def show_yourself(self, to_text_file=None):
|
||||||
|
line_width = 36
|
||||||
|
content = '\n'
|
||||||
|
# create line
|
||||||
|
line_text = 'Data Parameters'
|
||||||
|
line = '='*((line_width-len(line_text))//2) + line_text + '='*((line_width-len(line_text))//2)
|
||||||
|
line.ljust(line_width, '=')
|
||||||
|
content += line + '\n'
|
||||||
|
content += 'Validation Rate: ' + str(self.validation_rate) + '\n'
|
||||||
|
# create line
|
||||||
|
line_text = 'Training Parameters'
|
||||||
|
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
||||||
|
line.ljust(line_width, '=')
|
||||||
|
content += line + '\n'
|
||||||
|
content += 'Train CNN: ' + str(self.train_cnn) + '\n'
|
||||||
|
content += 'Train Ms CNN: ' + str(self.train_ms_cnn) + '\n'
|
||||||
|
content += 'Train Ms Sc CNN: ' + str(self.train_ms_sc_cnn) + '\n'
|
||||||
|
# create line
|
||||||
|
line_text = 'Evaluate Parameters'
|
||||||
|
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
||||||
|
line.ljust(line_width, '=')
|
||||||
|
content += line + '\n'
|
||||||
|
content += 'Train Epoch: ' + str(self.train_epoch) + '\n'
|
||||||
|
content += 'Train Batch Size: ' + str(self.batch_size) + '\n'
|
||||||
|
|
||||||
|
content += 'Evaluate CNN: ' + str(self.evaluate_cnn) + '\n'
|
||||||
|
if len(self.evaluate_cnn_name_list) >=1:
|
||||||
|
content += 'Saved CNNs to Evaluate:\n'
|
||||||
|
for models in self.evaluate_cnn_name_list:
|
||||||
|
content += models + '\n'
|
||||||
|
|
||||||
|
content += 'Evaluate Ms CNN: ' + str(self.evaluate_ms_cnn) + '\n'
|
||||||
|
if len(self.evaluate_ms_cnn_name_list) >= 1:
|
||||||
|
content += 'Saved Ms CNNs to Evaluate:\n'
|
||||||
|
for models in self.evaluate_ms_cnn_name_list:
|
||||||
|
content += models + '\n'
|
||||||
|
|
||||||
|
content += 'Evaluate Ms Sc CNN: ' + str(self.evaluate_ms_cnn) + '\n'
|
||||||
|
if len(self.evaluate_ms_sc_cnn_name_list) >= 1:
|
||||||
|
content += 'Saved Ms Sc CNNs to Evaluate:\n'
|
||||||
|
for models in self.evaluate_ms_sc_cnn_name_list:
|
||||||
|
content += models + '\n'
|
||||||
|
|
||||||
|
# create line
|
||||||
|
line_text = 'Saving Dir'
|
||||||
|
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
||||||
|
line.ljust(line_width, '=')
|
||||||
|
content += line + '\n'
|
||||||
|
content += 'Image Dir: ' + str(self.img_dir) + '\n'
|
||||||
|
content += 'Check Point Dir: ' + str(self.img_dir) + '\n'
|
||||||
|
print(content)
|
||||||
|
if to_text_file:
|
||||||
|
with open(to_text_file, 'w') as f:
|
||||||
|
f.write(content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config()
|
||||||
|
config.show_yourself(to_text_file='name.txt')
|
||||||
|
x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value = \
|
||||||
|
load_data(data_path='./yaowan_calibrate.mat', validation_rate=0.25)
|
||||||
|
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape, x_validation.shape, y_validation.shape,
|
||||||
|
y_max_value, y_min_value)
|
||||||
Loading…
Reference in New Issue
Block a user