mirror of
https://github.com/NanjingForestryUniversity/SCNet.git
synced 2025-11-08 14:24:03 +00:00
First Commit
This commit is contained in:
commit
17e8992073
219
.gitignore
vendored
Normal file
219
.gitignore
vendored
Normal file
@ -0,0 +1,219 @@
|
||||
preprocess/dataset/*
|
||||
checkpoints/*
|
||||
.idea
|
||||
### JetBrains template
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### Python template
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
!/checkpoints/
|
||||
!/preprocess/dataset/
|
||||
!/preprocess/dataset/
|
||||
21
README.md
Normal file
21
README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# SCNet: A deep learning network framework for analyzing near-infrared spectroscopy using short-cut
|
||||
## Pre-processing
|
||||
|
||||
Since the method we proposed is a regression model, the classification dataset weat kernel is not used in this work.
|
||||
|
||||
The other three dataset (corn, marzipan, soil) were preprocessed manually with Matlab and saved in the sub dictionary of `./preprocess` dir. The original dataset of these three dataset were stored in the `./preprocess/dataset/`.
|
||||
|
||||
The mango dataset is not in Matlab .m file format, so we save them with the `process.py`.
|
||||
Meanwhile, we drop the useless part and only save the data between 684 and 900 nm.
|
||||
|
||||
> The data set used in this study comprises a total of 11,691 NIR spectra (684–990 nm in 3 nm sampling with a total 103 variables) and DM measurements performed on 4675 mango fruit across 4 harvest seasons 2015, 2016, 2017 and 2018 [24].
|
||||
|
||||
The detailed preprocessing progress can be found in [./preprocess.ipynb](./preprocess.ipynb)
|
||||
|
||||
## Network Training
|
||||
|
||||
In order to show our network can prevent degration problem, we hold the experiment which contains the training loss curve of four models. The detailed information can be found in [model_training.ipynb](./model_training.ipynb).
|
||||
|
||||
## Network evaluation
|
||||
After training our model on training set, we evaluate the models on testing dataset that spared before. The evaluation is done with [model_evaluation.ipynb](model_evaluating.ipynb).
|
||||
|
||||
155
model_evaluating.ipynb
Normal file
155
model_evaluating.ipynb
Normal file
@ -0,0 +1,155 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Experiment 2: Model Evaluating"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from keras.models import load_model\n",
|
||||
"from matplotlib import ticker\n",
|
||||
"from scipy.io import loadmat\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.metrics import mean_squared_error\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"%matplotlib inline"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"In this experiment, we load model weights from the experiment1 and evaluate them on test dataset."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"shape of data:\n",
|
||||
"x_train: (5728, 1, 102), y_train: (5728, 1),\n",
|
||||
"x_val: (2455, 1, 102), y_val: (2455, 1)\n",
|
||||
"x_test: (3508, 1, 102), y_test: (3508, 1)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data = loadmat('./preprocess/dataset/mango/mango_dm_split.mat')\n",
|
||||
"x_train, y_train, x_test, y_test = data['x_train'], data['y_train'], data['x_test'], data['y_test']\n",
|
||||
"x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.3, random_state=12, shuffle=True)\n",
|
||||
"x_train, x_val, x_test = x_train[:, np.newaxis, :], x_val[:, np.newaxis, :], x_test[:, np.newaxis, :]\n",
|
||||
"print(f\"shape of data:\\n\"\n",
|
||||
" f\"x_train: {x_train.shape}, y_train: {y_train.shape},\\n\"\n",
|
||||
" f\"x_val: {x_val.shape}, y_val: {y_val.shape}\\n\"\n",
|
||||
" f\"x_test: {x_test.shape}, y_test: {y_test.shape}\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"## Build model and load weights\n",
|
||||
"plain_5, plain_11 = load_model('./checkpoints/plain5.hdf5'), load_model('./checkpoints/plain11.hdf5')\n",
|
||||
"shortcut5, shortcut11 = load_model('./checkpoints/shortcut5.hdf5'), load_model('./checkpoints/shortcut11.hdf5')\n",
|
||||
"models = {'plain 5': plain_5, 'plain 11': plain_11, 'shortcut 5': shortcut5, 'shortcut11': shortcut11}\n",
|
||||
"results = {model_name: model.predict(x_test).reshape((-1, )) for model_name, model in models.items()}\n",
|
||||
"for model_name, model_result in results.items():\n",
|
||||
" print(model_name, \" : \", mean_squared_error(y_test, model_result)*100, \"%\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"execution_count": 31,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"plain 5 : 0.2707851525589865 %\n",
|
||||
"plain 11 : 0.26240810192725905 %\n",
|
||||
"shortcut 5 : 0.28330442301217196 %\n",
|
||||
"shortcut11 : 0.25743312483685266 %\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
7124
model_training.ipynb
Normal file
7124
model_training.ipynb
Normal file
File diff suppressed because one or more lines are too long
264
models.py
Normal file
264
models.py
Normal file
@ -0,0 +1,264 @@
|
||||
import keras.callbacks
|
||||
import keras.layers as KL
|
||||
from keras import Model
|
||||
from keras.optimizers import adam_v2
|
||||
|
||||
|
||||
class Plain5(object):
|
||||
def __init__(self, model_path=None, input_shape=None):
|
||||
self.model = None
|
||||
self.input_shape = input_shape
|
||||
if model_path is not None:
|
||||
# TODO: loading from the file
|
||||
pass
|
||||
else:
|
||||
self.model = self.build_model()
|
||||
|
||||
def build_model(self):
|
||||
input_layer = KL.Input(self.input_shape, name='input')
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv1')(input_layer)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv2')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv3')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Dense(20, activation='relu', name='dense')(x)
|
||||
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||
model = Model(input_layer, x)
|
||||
return model
|
||||
|
||||
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/plain5.hdf5', monitor='val_loss',
|
||||
mode="min", save_best_only=True)
|
||||
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
|
||||
patience=1000, verbose=0, mode='auto')
|
||||
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=25, min_delta=1e-6)
|
||||
callbacks = [checkpoint, early_stop, lr_decay]
|
||||
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||
callbacks=callbacks, batch_size=batch_size)
|
||||
return history
|
||||
|
||||
|
||||
class Residual5(object):
|
||||
def __init__(self, model_path=None, input_shape=None):
|
||||
self.model = None
|
||||
self.input_shape = input_shape
|
||||
if model_path is not None:
|
||||
# TODO: loading from the file
|
||||
pass
|
||||
else:
|
||||
self.model = self.build_model()
|
||||
|
||||
def build_model(self):
|
||||
input_layer = KL.Input(self.input_shape, name='input')
|
||||
fx = KL.Conv1D(8, 3, padding='same', name='Conv1')(input_layer)
|
||||
fx = KL.BatchNormalization()(fx)
|
||||
x = KL.Activation('relu')(fx)
|
||||
|
||||
fx = KL.Conv1D(8, 3, padding='same', name='Conv2')(x)
|
||||
fx = KL.BatchNormalization()(fx)
|
||||
fx = KL.Activation('relu')(fx)
|
||||
x = fx + x
|
||||
|
||||
fx = KL.Conv1D(8, 3, padding='same', name='Conv3')(x)
|
||||
fx = KL.BatchNormalization()(fx)
|
||||
fx = KL.Activation('relu')(fx)
|
||||
x = fx + x
|
||||
|
||||
x = KL.Dense(20, activation='relu', name='dense')(x)
|
||||
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||
model = Model(input_layer, x)
|
||||
return model
|
||||
|
||||
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/res5.hdf5', monitor='val_loss',
|
||||
mode="min", save_best_only=True)
|
||||
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
|
||||
patience=1000, verbose=0, mode='auto')
|
||||
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=25, min_delta=1e-6)
|
||||
callbacks = [checkpoint, early_stop, lr_decay]
|
||||
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||
callbacks=callbacks, batch_size=batch_size)
|
||||
return history
|
||||
|
||||
|
||||
class ShortCut5(object):
|
||||
def __init__(self, model_path=None, input_shape=None):
|
||||
self.model = None
|
||||
self.input_shape = input_shape
|
||||
if model_path is not None:
|
||||
# TODO: loading from the file
|
||||
pass
|
||||
else:
|
||||
self.model = self.build_model()
|
||||
|
||||
def build_model(self):
|
||||
input_layer = KL.Input(self.input_shape, name='input')
|
||||
x_raw = KL.Conv1D(8, 3, padding='same', name='Conv1')(input_layer)
|
||||
fx1 = KL.BatchNormalization()(x_raw)
|
||||
fx1 = KL.Activation('relu')(fx1)
|
||||
|
||||
fx2 = KL.Conv1D(8, 3, padding='same', name='Conv2')(fx1)
|
||||
fx2 = KL.BatchNormalization()(fx2)
|
||||
fx2 = KL.Activation('relu')(fx2)
|
||||
|
||||
fx3 = KL.Conv1D(8, 3, padding='same', name='Conv3')(fx2)
|
||||
fx3 = KL.BatchNormalization()(fx3)
|
||||
fx3 = KL.Activation('relu')(fx3)
|
||||
x = KL.Concatenate(axis=2)([x_raw, fx1, fx2, fx3])
|
||||
|
||||
x = KL.Dense(20, activation='relu', name='dense')(x)
|
||||
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||
model = Model(input_layer, x)
|
||||
return model
|
||||
|
||||
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||
|
||||
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/shortcut5.hdf5', monitor='val_loss',
|
||||
mode="min", save_best_only=True)
|
||||
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
|
||||
patience=1000, verbose=0, mode='auto')
|
||||
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=25, min_delta=1e-6)
|
||||
callbacks = [checkpoint, early_stop, lr_decay]
|
||||
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||
callbacks=callbacks, batch_size=batch_size)
|
||||
return history
|
||||
|
||||
|
||||
class ShortCut11(object):
|
||||
def __init__(self, model_path=None, input_shape=None):
|
||||
self.model = None
|
||||
self.input_shape = input_shape
|
||||
if model_path is not None:
|
||||
# TODO: loading from the file
|
||||
pass
|
||||
else:
|
||||
self.model = self.build_model()
|
||||
|
||||
def build_model(self):
|
||||
input_layer = KL.Input(self.input_shape, name='input')
|
||||
x_raw = KL.Conv1D(8, 3, padding='same', name='Conv1_1')(input_layer)
|
||||
x = KL.BatchNormalization()(x_raw)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv1_2')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv1_3')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
fx1 = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv2_1')(fx1)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv2_2')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv2_3')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
fx2 = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv3_1')(fx2)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv3_2')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv3_3')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
fx3 = KL.Activation('relu')(x)
|
||||
x = KL.Concatenate(axis=2)([x_raw, fx1, fx2, fx3])
|
||||
|
||||
x = KL.Dense(200, activation='relu', name='dense1')(x)
|
||||
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||
model = Model(input_layer, x)
|
||||
return model
|
||||
|
||||
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/shortcut11.hdf5', monitor='val_loss',
|
||||
mode="min", save_best_only=True)
|
||||
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-6,
|
||||
patience=200, verbose=0, mode='auto')
|
||||
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,
|
||||
patience=25, min_delta=1e-6)
|
||||
callbacks = [checkpoint, early_stop, lr_decay]
|
||||
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||
callbacks=callbacks, batch_size=batch_size)
|
||||
return history
|
||||
|
||||
|
||||
class Plain11(object):
|
||||
def __init__(self, model_path=None, input_shape=None):
|
||||
self.model = None
|
||||
self.input_shape = input_shape
|
||||
if model_path is not None:
|
||||
# TODO: loading from the file
|
||||
pass
|
||||
else:
|
||||
self.model = self.build_model()
|
||||
|
||||
def build_model(self):
|
||||
input_layer = KL.Input(self.input_shape, name='input')
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv1_1')(input_layer)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv1_2')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv1_3')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv2_1')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv2_2')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv2_3')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv3_1')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv3_2')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
x = KL.Conv1D(8, 3, padding='same', name='Conv3_3')(x)
|
||||
x = KL.BatchNormalization()(x)
|
||||
x = KL.Activation('relu')(x)
|
||||
|
||||
x = KL.Dense(200, activation='relu', name='dense1')(x)
|
||||
x = KL.Dense(1, activation='sigmoid', name='output')(x)
|
||||
model = Model(input_layer, x)
|
||||
return model
|
||||
|
||||
def fit(self, x, y, x_val, y_val, epoch, batch_size):
|
||||
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||
checkpoint = keras.callbacks.ModelCheckpoint(filepath='checkpoints/plain11.hdf5', monitor='val_loss',
|
||||
mode="min", save_best_only=True)
|
||||
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-6,
|
||||
patience=200, verbose=0, mode='auto')
|
||||
lr_decay = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,
|
||||
patience=25, min_delta=1e-6)
|
||||
callbacks = [checkpoint, early_stop, lr_decay]
|
||||
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||
callbacks=callbacks, batch_size=batch_size)
|
||||
return history
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# plain5 = Plain5(model_path=None, input_shape=(1, 102))
|
||||
# plain11 = Plain11(model_path=None, input_shape=(1, 102))
|
||||
residual5 = Residual5(model_path=None, input_shape=(1, 102))
|
||||
short5 = ShortCut5(model_path=None, input_shape=(1, 102))
|
||||
127
preprocess.ipynb
Normal file
127
preprocess.ipynb
Normal file
@ -0,0 +1,127 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dd2c8c55",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Preprocessing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "716880ac",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from scipy.io import savemat, loadmat\n",
|
||||
"import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d7dc4a0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 1: \n",
|
||||
"Convert the dataset to mat format for Matlab."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "711356a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = pd.read_csv('preprocess/dataset/mango/NAnderson2020MendeleyMangoNIRData.csv')\n",
|
||||
"y = dataset.DM\n",
|
||||
"x = dataset.loc[:, '684': '990']\n",
|
||||
"savemat('preprocess/dataset/mango/mango_origin.mat', {'x': x.values, 'y': y.values})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3e41e8e6",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ea5e54fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step3:\n",
|
||||
"Data split with train test split."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6eac026e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = loadmat('preprocess/dataset/mango/mango_preprocessed.mat')\n",
|
||||
"x, y = data['x'], data['y']\n",
|
||||
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=24)\n",
|
||||
"if not os.path.exists('mango'):\n",
|
||||
" os.makedirs('mango')\n",
|
||||
"savemat('preprocess/dataset/mango/mango_dm_split.mat',{'x_train':x_train, 'y_train':y_train, 'x_test':x_test, 'y_test':y_test,\n",
|
||||
" 'max_y': data['max_y'], 'min_y': data['min_y'],\n",
|
||||
" 'min_x':data['min_x'], 'max_x':data['max_x']})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b2977dae",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 4:\n",
|
||||
"Show data with pictures\n",
|
||||
"use `draw_pics_origin` to draw original spectra\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"use `draw_pics_preprocessed.m` to draw proprecessed spectra\n",
|
||||
""
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "7f619fc91ee8bdab81d49e7c14228037474662e3f2d607687ae505108922fa06"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.7 64-bit ('base': conda)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
45
preprocess/draw_pics_origin.m
Executable file
45
preprocess/draw_pics_origin.m
Executable file
@ -0,0 +1,45 @@
|
||||
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||
f = figure;
|
||||
f.Position(3:4) = [1331 331];
|
||||
%%% draw the pic of corn spectra
|
||||
load('dataset/corn.mat');
|
||||
x = m5spec.data;
|
||||
wave_length = m5spec.axisscale{2, 1};
|
||||
subplot(1, 4, 1)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/marzipan.mat');
|
||||
x = NIRS1;
|
||||
wave_length = NIRS1_axis;
|
||||
subplot(1, 4, 2)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/soil.mat');
|
||||
x = soil.data;
|
||||
wave_length = soil.axisscale{2, 1};
|
||||
subplot(1, 4, 3)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
% draw the pic of Mango spectra
|
||||
load('dataset/mango/mango_origin.mat');
|
||||
wave_length = 684: 3: 990;
|
||||
subplot(1, 4, 4)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Signal intensity');
|
||||
clear
|
||||
48
preprocess/draw_pics_preprocessed.m
Executable file
48
preprocess/draw_pics_preprocessed.m
Executable file
@ -0,0 +1,48 @@
|
||||
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||
f = figure;
|
||||
f.Position(3:4) = [1331 331];
|
||||
%%% draw the pic of corn spectra
|
||||
load('dataset/corn.mat');
|
||||
x = m5spec.data;
|
||||
wave_length = m5spec.axisscale{2, 1};
|
||||
preprocess;
|
||||
subplot(1, 4, 1)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/marzipan.mat');
|
||||
x = NIRS1;
|
||||
wave_length = NIRS1_axis;
|
||||
preprocess;
|
||||
subplot(1, 4, 2)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/soil.mat');
|
||||
x = soil.data;
|
||||
wave_length = soil.axisscale{2, 1};
|
||||
preprocess;
|
||||
subplot(1, 4, 3)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
% draw the pic of Mango spectra
|
||||
load('dataset/mango/mango_preprocessed.mat');
|
||||
wave_length = 687: 3: 990;
|
||||
subplot(1, 4, 4)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Signal intensity');
|
||||
clear
|
||||
BIN
preprocess/pics/preprocessed.png
Normal file
BIN
preprocess/pics/preprocessed.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 89 KiB |
BIN
preprocess/pics/raw.png
Normal file
BIN
preprocess/pics/raw.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 175 KiB |
8
preprocess/preprocess.m
Executable file
8
preprocess/preprocess.m
Executable file
@ -0,0 +1,8 @@
|
||||
%% x preprocessing
|
||||
x = x';
|
||||
x = sgolayfilt(x,2,17);
|
||||
x =diff(x);
|
||||
max_x=max(max(x));
|
||||
min_x=min(min(x));
|
||||
x=(x-min_x)/(max_x-min_x);
|
||||
x = x';
|
||||
15
preprocess/preprocess_mango.m
Executable file
15
preprocess/preprocess_mango.m
Executable file
@ -0,0 +1,15 @@
|
||||
%% x preprocessing
|
||||
clear;
|
||||
load('dataset/mango/mango_origin.mat')
|
||||
x = x';
|
||||
x = sgolayfilt(x,2,17);
|
||||
x =diff(x);
|
||||
max_x=max(max(x));
|
||||
min_x=min(min(x));
|
||||
x=(x-min_x)/(max_x-min_x);
|
||||
x = x';
|
||||
y = y';
|
||||
min_y = min(min(y));
|
||||
max_y = max(max(y));
|
||||
y = (y-min_y)/(max_y-min_y);
|
||||
save('dataset/mango/mango_preprocessed.mat')
|
||||
15
preprocess/train_test_split.m
Executable file
15
preprocess/train_test_split.m
Executable file
@ -0,0 +1,15 @@
|
||||
data=[x,y];
|
||||
test_rate = 0.3;
|
||||
data_num = size(x, 1);
|
||||
train_num = round((1-test_rate) * data_num);
|
||||
idx=randperm(data_num);
|
||||
train_idx=idx(1:train_num);
|
||||
test_idx=idx(train_num+1:data_num);
|
||||
data_train=data(train_idx,:);
|
||||
x_train=data_train(:,1:size(x, 2));
|
||||
y_train=data_train(:,size(x, 2)+1);
|
||||
test_data=data(test_idx,:);
|
||||
x_test=test_data(:,1:size(x, 2));
|
||||
y_test=test_data(:,size(x, 2)+1);
|
||||
clear data_num train_num idx train_idx test_idx test_data train_data x y;
|
||||
clear data data_train test_rate;
|
||||
153
utils.py
Executable file
153
utils.py
Executable file
@ -0,0 +1,153 @@
|
||||
from scipy.io import loadmat
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
def load_data(data_path='./pine_water_cc.mat', validation_rate=0.25):
|
||||
if data_path == './pine_water_cc.mat':
|
||||
data = loadmat(data_path)
|
||||
y_train, y_test = data['value_train'], data['value_test']
|
||||
print('Value train shape: ', y_train.shape, 'Value test shape', y_test.shape)
|
||||
y_max_value, y_min_value = data['value_max'], data['value_min']
|
||||
x_train, x_test = data['DL_train'], data['DL_test']
|
||||
elif data_path == './N_100_leaf_cc.mat':
|
||||
data = loadmat(data_path)
|
||||
y_train, y_test = data['y_train'], data['y_test']
|
||||
x_train, x_test = data['x_train'], data['x_test']
|
||||
y_max_value, y_min_value = data['max_y'], data['min_y']
|
||||
x_train = np.expand_dims(x_train, axis=1)
|
||||
x_test = np.expand_dims(x_test, axis=1)
|
||||
x_validation, y_validation = x_test, y_test
|
||||
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
|
||||
else:
|
||||
data = loadmat(data_path)
|
||||
y_train, y_test = data['y_train'], data['y_test']
|
||||
x_train, x_test = data['x_train'], data['x_test']
|
||||
y_max_value, y_min_value = data['max_y'], data['min_y']
|
||||
x_train = np.expand_dims(x_train, axis=1)
|
||||
x_test = np.expand_dims(x_test, axis=1)
|
||||
print('SG17 DATA train shape: ', x_train.shape, 'SG17 DATA test shape', x_test.shape)
|
||||
|
||||
print('Mini value: %s, Max value %s.' % (y_min_value, y_max_value))
|
||||
|
||||
x_train, x_validation, y_train, y_validation = train_test_split(x_train, y_train, test_size=validation_rate,
|
||||
random_state=8)
|
||||
|
||||
return x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value
|
||||
|
||||
|
||||
def mkdir_if_not_exist(dir_name, is_delete=False):
|
||||
"""
|
||||
创建文件夹
|
||||
:param dir_name: 文件夹
|
||||
:param is_delete: 是否删除
|
||||
:return: 是否成功
|
||||
"""
|
||||
try:
|
||||
if is_delete:
|
||||
if os.path.exists(dir_name):
|
||||
shutil.rmtree(dir_name)
|
||||
print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name)
|
||||
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
print('[Exception] %s' % e)
|
||||
return False
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# 数据有关的参数
|
||||
self.validation_rate = 0.2
|
||||
# 训练有关参数
|
||||
self.train_epoch = 20000
|
||||
self.batch_size = 20
|
||||
# 是否训练的参数
|
||||
self.train_cnn = True
|
||||
self.train_ms_cnn = True
|
||||
self.train_ms_sc_cnn = True
|
||||
# 是否评估参数
|
||||
self.evaluate_cnn = True
|
||||
self.evaluate_ms_cnn = True
|
||||
self.evaluate_ms_sc_cnn = True
|
||||
# 要评估的保存好的模型列表
|
||||
self.evaluate_cnn_name_list = []
|
||||
self.evaluate_ms_cnn_name_list = []
|
||||
self.evaluate_ms_sc_cnn_name_list = []
|
||||
|
||||
# 存储训练出的模型和图片的文件夹
|
||||
self.img_dir = './pictures0331'
|
||||
self.checkpoint_dir = './check_points0331'
|
||||
|
||||
# 数据集选择
|
||||
self.data_set = './dataset_preprocess/corn/corn_mositure.mat'
|
||||
|
||||
def show_yourself(self, to_text_file=None):
|
||||
line_width = 36
|
||||
content = '\n'
|
||||
# create line
|
||||
line_text = 'Data Parameters'
|
||||
line = '='*((line_width-len(line_text))//2) + line_text + '='*((line_width-len(line_text))//2)
|
||||
line.ljust(line_width, '=')
|
||||
content += line + '\n'
|
||||
content += 'Validation Rate: ' + str(self.validation_rate) + '\n'
|
||||
# create line
|
||||
line_text = 'Training Parameters'
|
||||
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
||||
line.ljust(line_width, '=')
|
||||
content += line + '\n'
|
||||
content += 'Train CNN: ' + str(self.train_cnn) + '\n'
|
||||
content += 'Train Ms CNN: ' + str(self.train_ms_cnn) + '\n'
|
||||
content += 'Train Ms Sc CNN: ' + str(self.train_ms_sc_cnn) + '\n'
|
||||
# create line
|
||||
line_text = 'Evaluate Parameters'
|
||||
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
||||
line.ljust(line_width, '=')
|
||||
content += line + '\n'
|
||||
content += 'Train Epoch: ' + str(self.train_epoch) + '\n'
|
||||
content += 'Train Batch Size: ' + str(self.batch_size) + '\n'
|
||||
|
||||
content += 'Evaluate CNN: ' + str(self.evaluate_cnn) + '\n'
|
||||
if len(self.evaluate_cnn_name_list) >=1:
|
||||
content += 'Saved CNNs to Evaluate:\n'
|
||||
for models in self.evaluate_cnn_name_list:
|
||||
content += models + '\n'
|
||||
|
||||
content += 'Evaluate Ms CNN: ' + str(self.evaluate_ms_cnn) + '\n'
|
||||
if len(self.evaluate_ms_cnn_name_list) >= 1:
|
||||
content += 'Saved Ms CNNs to Evaluate:\n'
|
||||
for models in self.evaluate_ms_cnn_name_list:
|
||||
content += models + '\n'
|
||||
|
||||
content += 'Evaluate Ms Sc CNN: ' + str(self.evaluate_ms_cnn) + '\n'
|
||||
if len(self.evaluate_ms_sc_cnn_name_list) >= 1:
|
||||
content += 'Saved Ms Sc CNNs to Evaluate:\n'
|
||||
for models in self.evaluate_ms_sc_cnn_name_list:
|
||||
content += models + '\n'
|
||||
|
||||
# create line
|
||||
line_text = 'Saving Dir'
|
||||
line = '=' * ((line_width - len(line_text)) // 2) + line_text + '=' * ((line_width - len(line_text)) // 2)
|
||||
line.ljust(line_width, '=')
|
||||
content += line + '\n'
|
||||
content += 'Image Dir: ' + str(self.img_dir) + '\n'
|
||||
content += 'Check Point Dir: ' + str(self.img_dir) + '\n'
|
||||
print(content)
|
||||
if to_text_file:
|
||||
with open(to_text_file, 'w') as f:
|
||||
f.write(content)
|
||||
return content
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Config()
|
||||
config.show_yourself(to_text_file='name.txt')
|
||||
x_train, x_test, x_validation, y_train, y_test, y_validation, y_max_value, y_min_value = \
|
||||
load_data(data_path='./yaowan_calibrate.mat', validation_rate=0.25)
|
||||
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape, x_validation.shape, y_validation.shape,
|
||||
y_max_value, y_min_value)
|
||||
Loading…
Reference in New Issue
Block a user