mirror of
https://github.com/NanjingForestryUniversity/SCNet.git
synced 2025-11-09 06:44:05 +00:00
Add visualization
add visualization for our network
This commit is contained in:
parent
ec4296de9b
commit
19d8a7c588
File diff suppressed because one or more lines are too long
@ -3,7 +3,6 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
@ -15,6 +14,15 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -49,17 +57,20 @@
|
||||
" f\"x_train: {x_train.shape}, y_train: {y_train.shape},\\n\"\n",
|
||||
" f\"x_val: {x_val.shape}, y_val: {y_val.shape}\\n\"\n",
|
||||
" f\"x_test: {x_test.shape}, y_test: {y_test.shape}\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
@ -363,20 +374,20 @@
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
||||
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
|
||||
"\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_68464/326725923.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mi\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mrange\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;36m2\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m1000\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mmodel\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mShortCut11\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mnetwork_parameter\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mi\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0minput_shape\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m102\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 6\u001B[0;31m \u001B[0mhistory_shortcut_11\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfit\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx_train\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_train\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mx_val\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_val\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mepoch\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mepoch\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mbatch_size\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mbatch_size\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0msave\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"/tmp/temp.hdf5\"\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 7\u001B[0m \u001B[0mmodel\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mload_model\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m\"/tmp/temp.hdf5\"\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 8\u001B[0m \u001B[0my_pred\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mpredict\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx_test\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mreshape\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/PycharmProjects/sccnn/models.py\u001B[0m in \u001B[0;36mfit\u001B[0;34m(self, x, y, x_val, y_val, epoch, batch_size, save)\u001B[0m\n\u001B[1;32m 197\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 198\u001B[0m history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,\n\u001B[0;32m--> 199\u001B[0;31m callbacks=callbacks, batch_size=batch_size)\n\u001B[0m\u001B[1;32m 200\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0mhistory\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 201\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/keras/utils/traceback_utils.py\u001B[0m in \u001B[0;36merror_handler\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 62\u001B[0m \u001B[0mfiltered_tb\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 63\u001B[0m \u001B[0;32mtry\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 64\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mfn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 65\u001B[0m \u001B[0;32mexcept\u001B[0m \u001B[0mException\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0;31m# pylint: disable=broad-except\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 66\u001B[0m \u001B[0mfiltered_tb\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0m_process_traceback_frames\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0me\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__traceback__\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/keras/engine/training.py\u001B[0m in \u001B[0;36mfit\u001B[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001B[0m\n\u001B[1;32m 1214\u001B[0m _r=1):\n\u001B[1;32m 1215\u001B[0m \u001B[0mcallbacks\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mon_train_batch_begin\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mstep\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m-> 1216\u001B[0;31m \u001B[0mtmp_logs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrain_function\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0miterator\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 1217\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mdata_handler\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshould_sync\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 1218\u001B[0m \u001B[0mcontext\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0masync_wait\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\u001B[0m in \u001B[0;36merror_handler\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 148\u001B[0m \u001B[0mfiltered_tb\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 149\u001B[0m \u001B[0;32mtry\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 150\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mfn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 151\u001B[0m \u001B[0;32mexcept\u001B[0m \u001B[0mException\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 152\u001B[0m \u001B[0mfiltered_tb\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0m_process_traceback_frames\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0me\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__traceback__\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py\u001B[0m in \u001B[0;36m__call__\u001B[0;34m(self, *args, **kwds)\u001B[0m\n\u001B[1;32m 908\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 909\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0mOptionalXlaContext\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_jit_compile\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 910\u001B[0;31m \u001B[0mresult\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_call\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwds\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 911\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 912\u001B[0m \u001B[0mnew_tracing_count\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mexperimental_get_tracing_count\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py\u001B[0m in \u001B[0;36m_call\u001B[0;34m(self, *args, **kwds)\u001B[0m\n\u001B[1;32m 940\u001B[0m \u001B[0;31m# In this case we have created variables on the first call, so we run the\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 941\u001B[0m \u001B[0;31m# defunned version which is guaranteed to never create variables.\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 942\u001B[0;31m \u001B[0;32mreturn\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_stateless_fn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m**\u001B[0m\u001B[0mkwds\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;31m# pylint: disable=not-callable\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 943\u001B[0m \u001B[0;32melif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_stateful_fn\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 944\u001B[0m \u001B[0;31m# Release the lock early so that multiple threads can perform the call\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/function.py\u001B[0m in \u001B[0;36m__call__\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 3128\u001B[0m (graph_function,\n\u001B[1;32m 3129\u001B[0m filtered_flat_args) = self._maybe_define_function(args, kwargs)\n\u001B[0;32m-> 3130\u001B[0;31m return graph_function._call_flat(\n\u001B[0m\u001B[1;32m 3131\u001B[0m filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access\n\u001B[1;32m 3132\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/function.py\u001B[0m in \u001B[0;36m_call_flat\u001B[0;34m(self, args, captured_inputs, cancellation_manager)\u001B[0m\n\u001B[1;32m 1957\u001B[0m and executing_eagerly):\n\u001B[1;32m 1958\u001B[0m \u001B[0;31m# No tape is watching; skip to running the function.\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m-> 1959\u001B[0;31m return self._build_call_outputs(self._inference_function.call(\n\u001B[0m\u001B[1;32m 1960\u001B[0m ctx, args, cancellation_manager=cancellation_manager))\n\u001B[1;32m 1961\u001B[0m forward_backward = self._select_forward_and_backward_functions(\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/function.py\u001B[0m in \u001B[0;36mcall\u001B[0;34m(self, ctx, args, cancellation_manager)\u001B[0m\n\u001B[1;32m 596\u001B[0m \u001B[0;32mwith\u001B[0m \u001B[0m_InterpolateFunctionError\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 597\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mcancellation_manager\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 598\u001B[0;31m outputs = execute.execute(\n\u001B[0m\u001B[1;32m 599\u001B[0m \u001B[0mstr\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msignature\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mname\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 600\u001B[0m \u001B[0mnum_outputs\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_num_outputs\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/execute.py\u001B[0m in \u001B[0;36mquick_execute\u001B[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001B[0m\n\u001B[1;32m 56\u001B[0m \u001B[0;32mtry\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 57\u001B[0m \u001B[0mctx\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mensure_initialized\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 58\u001B[0;31m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001B[0m\u001B[1;32m 59\u001B[0m inputs, attrs, num_outputs)\n\u001B[1;32m 60\u001B[0m \u001B[0;32mexcept\u001B[0m \u001B[0mcore\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_NotOkStatusException\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_68464/326725923.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mShortCut11\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork_parameter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m102\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mhistory_shortcut_11\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"/tmp/temp.hdf5\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/tmp/temp.hdf5\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/PycharmProjects/sccnn/models.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, x_val, y_val, epoch, batch_size, save)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,\n\u001b[0;32m--> 199\u001b[0;31m callbacks=callbacks, batch_size=batch_size)\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhistory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 65\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m 1214\u001b[0m _r=1):\n\u001b[1;32m 1215\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1216\u001b[0;31m \u001b[0mtmp_logs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1217\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshould_sync\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1218\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masync_wait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 150\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 151\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m 908\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 909\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mOptionalXlaContext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_compile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 910\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 911\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 912\u001b[0m \u001b[0mnew_tracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[0;31m# In this case we have created variables on the first call, so we run the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 941\u001b[0m \u001b[0;31m# defunned version which is guaranteed to never create variables.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 942\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stateless_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# pylint: disable=not-callable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 943\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stateful_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 944\u001b[0m \u001b[0;31m# Release the lock early so that multiple threads can perform the call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3128\u001b[0m (graph_function,\n\u001b[1;32m 3129\u001b[0m filtered_flat_args) = self._maybe_define_function(args, kwargs)\n\u001b[0;32m-> 3130\u001b[0;31m return graph_function._call_flat(\n\u001b[0m\u001b[1;32m 3131\u001b[0m filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access\n\u001b[1;32m 3132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[0;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[1;32m 1957\u001b[0m and executing_eagerly):\n\u001b[1;32m 1958\u001b[0m \u001b[0;31m# No tape is watching; skip to running the function.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1959\u001b[0;31m return self._build_call_outputs(self._inference_function.call(\n\u001b[0m\u001b[1;32m 1960\u001b[0m ctx, args, cancellation_manager=cancellation_manager))\n\u001b[1;32m 1961\u001b[0m forward_backward = self._select_forward_and_backward_functions(\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[1;32m 596\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0m_InterpolateFunctionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcancellation_manager\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 598\u001b[0;31m outputs = execute.execute(\n\u001b[0m\u001b[1;32m 599\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0mnum_outputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_outputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniforge3/lib/python3.9/site-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0m\u001b[1;32m 59\u001b[0m inputs, attrs, num_outputs)\n\u001b[1;32m 60\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -386,53 +397,51 @@
|
||||
"\n",
|
||||
"for i in range(2, 500):\n",
|
||||
" model = ShortCut11(network_parameter=i, input_shape=(1, 102))\n",
|
||||
" history_shortcut_11 = model.fit(x_train, y_train, x_val, y_val, epoch=epoch, batch_size=batch_size, save=\"/tmp/temp.hdf5\")\n",
|
||||
" history_shortcut_11 = model.fit(x_train, y_train, x_val, y_val, epoch=epoch, batch_size=batch_size, save=\"/tmp/temp.hdf5\", is_show=False)\n",
|
||||
" model = load_model(\"/tmp/temp.hdf5\")\n",
|
||||
" y_pred = model.predict(x_test).reshape((-1, ))\n",
|
||||
" model_parameter_optimization['neuron num'].append(i)\n",
|
||||
" model_parameter_optimization['r2'].append(r2_score(y_test, y_pred))\n",
|
||||
" model_parameter_optimization['rmse'].append(mean_squared_error(y_test, y_pred))\n",
|
||||
" print(f\"model with parameter {i}: r2: {model_parameter_optimization['r2'][-1]}, rmse: {model_parameter_optimization['rmse'][-1]}\")\n",
|
||||
"pd.DataFrame(model_parameter_optimization).to_csv(\"./dataset/test_result.csv\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Shades",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "shades"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
54
05_network_parameter_optimization.py
Normal file
54
05_network_parameter_optimization.py
Normal file
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# # Network Parameter Optimization
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from keras.models import load_model
|
||||
from sklearn.metrics import r2_score, mean_squared_error
|
||||
from sklearn.model_selection import train_test_split
|
||||
from scipy.io import loadmat
|
||||
from models import ShortCut11
|
||||
from numpy.random import seed
|
||||
import tensorflow
|
||||
import time
|
||||
seed(4750)
|
||||
tensorflow.random.set_seed(4750)
|
||||
time1 = time.time()
|
||||
data = loadmat('./dataset/mango/mango_dm_split.mat')
|
||||
x_train, y_train, x_test, y_test = data['x_train'], data['y_train'], data['x_test'], data['y_test']
|
||||
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.3, random_state=12, shuffle=True)
|
||||
x_train, x_val, x_test = x_train[:, np.newaxis, :], x_val[:, np.newaxis, :], x_test[:, np.newaxis, :]
|
||||
print(f"shape of data:\n"
|
||||
f"x_train: {x_train.shape}, y_train: {y_train.shape},\n"
|
||||
f"x_val: {x_val.shape}, y_val: {y_val.shape}\n"
|
||||
f"x_test: {x_test.shape}, y_test: {y_test.shape}")
|
||||
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
model_parameter_optimization = {"neuron num":[], "r2":[], "rmse":[]}
|
||||
epoch, batch_size = 1024, 64
|
||||
|
||||
for i in range(2, 500):
|
||||
model = ShortCut11(network_parameter=i, input_shape=(1, 102))
|
||||
history_shortcut_11 = model.fit(x_train, y_train, x_val, y_val, epoch=epoch, batch_size=batch_size, save="/tmp/temp.hdf5", is_show=False)
|
||||
model = load_model("/tmp/temp.hdf5")
|
||||
y_pred = model.predict(x_test).reshape((-1, ))
|
||||
model_parameter_optimization['neuron num'].append(i)
|
||||
model_parameter_optimization['r2'].append(r2_score(y_test, y_pred))
|
||||
model_parameter_optimization['rmse'].append(mean_squared_error(y_test, y_pred))
|
||||
print(f"model with parameter {i}: r2: {model_parameter_optimization['r2'][-1]}, rmse: {model_parameter_optimization['rmse'][-1]}")
|
||||
pd.DataFrame(model_parameter_optimization).to_csv("./dataset/test_result.csv")
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
|
||||
|
||||
BIN
assets/shortcut5.png
Normal file
BIN
assets/shortcut5.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 127 KiB |
@ -182,7 +182,7 @@ class ShortCut11(object):
|
||||
model = Model(input_layer, x)
|
||||
return model
|
||||
|
||||
def fit(self, x, y, x_val, y_val, epoch, batch_size, save='checkpoints/shortcut11.hdf5'):
|
||||
def fit(self, x, y, x_val, y_val, epoch, batch_size, save='checkpoints/shortcut11.hdf5', is_show=True):
|
||||
self.model.compile(loss='mse', optimizer=adam_v2.Adam(learning_rate=0.01 * (batch_size / 256)))
|
||||
callbacks = []
|
||||
checkpoint = keras.callbacks.ModelCheckpoint(filepath=save, monitor='val_loss',
|
||||
@ -194,8 +194,8 @@ class ShortCut11(object):
|
||||
patience=25, min_delta=1e-6)
|
||||
callbacks.append(early_stop)
|
||||
callbacks.append(lr_decay)
|
||||
|
||||
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=1,
|
||||
verbose_num = 1 if is_show else 0
|
||||
history = self.model.fit(x, y, validation_data=(x_val, y_val), epochs=epoch, verbose=verbose_num,
|
||||
callbacks=callbacks, batch_size=batch_size)
|
||||
return history
|
||||
|
||||
|
||||
88
preprocess/draw_pics_origin.m
Executable file → Normal file
88
preprocess/draw_pics_origin.m
Executable file → Normal file
@ -1,45 +1,45 @@
|
||||
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||
f = figure;
|
||||
f.Position(3:4) = [1331 331];
|
||||
%%% draw the pic of corn spectra
|
||||
load('dataset/corn.mat');
|
||||
x = m5spec.data;
|
||||
wave_length = m5spec.axisscale{2, 1};
|
||||
subplot(1, 4, 1)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/marzipan.mat');
|
||||
x = NIRS1;
|
||||
wave_length = NIRS1_axis;
|
||||
subplot(1, 4, 2)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/soil.mat');
|
||||
x = soil.data;
|
||||
wave_length = soil.axisscale{2, 1};
|
||||
subplot(1, 4, 3)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
% draw the pic of Mango spectra
|
||||
load('dataset/mango/mango_origin.mat');
|
||||
wave_length = 684: 3: 990;
|
||||
subplot(1, 4, 4)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Signal intensity');
|
||||
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||
f = figure;
|
||||
f.Position(3:4) = [1331 331];
|
||||
%%% draw the pic of corn spectra
|
||||
load('dataset/corn.mat');
|
||||
x = m5spec.data;
|
||||
wave_length = m5spec.axisscale{2, 1};
|
||||
subplot(1, 4, 1)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/marzipan.mat');
|
||||
x = NIRS1;
|
||||
wave_length = NIRS1_axis;
|
||||
subplot(1, 4, 2)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/soil.mat');
|
||||
x = soil.data;
|
||||
wave_length = soil.axisscale{2, 1};
|
||||
subplot(1, 4, 3)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
% draw the pic of Mango spectra
|
||||
load('dataset/mango/mango_origin.mat');
|
||||
wave_length = 684: 3: 990;
|
||||
subplot(1, 4, 4)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Signal intensity');
|
||||
clear
|
||||
94
preprocess/draw_pics_preprocessed.m
Executable file → Normal file
94
preprocess/draw_pics_preprocessed.m
Executable file → Normal file
@ -1,48 +1,48 @@
|
||||
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||
f = figure;
|
||||
f.Position(3:4) = [1331 331];
|
||||
%%% draw the pic of corn spectra
|
||||
load('dataset/corn.mat');
|
||||
x = m5spec.data;
|
||||
wave_length = m5spec.axisscale{2, 1};
|
||||
preprocess;
|
||||
subplot(1, 4, 1)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/marzipan.mat');
|
||||
x = NIRS1;
|
||||
wave_length = NIRS1_axis;
|
||||
preprocess;
|
||||
subplot(1, 4, 2)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/soil.mat');
|
||||
x = soil.data;
|
||||
wave_length = soil.axisscale{2, 1};
|
||||
preprocess;
|
||||
subplot(1, 4, 3)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
% draw the pic of Mango spectra
|
||||
load('dataset/mango/mango_preprocessed.mat');
|
||||
wave_length = 687: 3: 990;
|
||||
subplot(1, 4, 4)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Signal intensity');
|
||||
set(gca,'LooseInset',get(gca,'TightInset'))
|
||||
f = figure;
|
||||
f.Position(3:4) = [1331 331];
|
||||
%%% draw the pic of corn spectra
|
||||
load('dataset/corn.mat');
|
||||
x = m5spec.data;
|
||||
wave_length = m5spec.axisscale{2, 1};
|
||||
preprocess;
|
||||
subplot(1, 4, 1)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/marzipan.mat');
|
||||
x = NIRS1;
|
||||
wave_length = NIRS1_axis;
|
||||
preprocess;
|
||||
subplot(1, 4, 2)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
%%% draw the pic of Marzipan spectra
|
||||
load('dataset/soil.mat');
|
||||
x = soil.data;
|
||||
wave_length = soil.axisscale{2, 1};
|
||||
preprocess;
|
||||
subplot(1, 4, 3)
|
||||
plot(wave_length(1, 1:end-1), x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Absorbance');
|
||||
clear
|
||||
|
||||
% draw the pic of Mango spectra
|
||||
load('dataset/mango/mango_preprocessed.mat');
|
||||
wave_length = 687: 3: 990;
|
||||
subplot(1, 4, 4)
|
||||
plot(wave_length, x');
|
||||
xlim([wave_length(1) wave_length(end)]);
|
||||
xlabel('Wavelength(nm)');
|
||||
ylabel('Signal intensity');
|
||||
clear
|
||||
16
preprocess/preprocess.m
Executable file → Normal file
16
preprocess/preprocess.m
Executable file → Normal file
@ -1,8 +1,8 @@
|
||||
%% x preprocessing
|
||||
x = x';
|
||||
x = sgolayfilt(x,2,17);
|
||||
x =diff(x);
|
||||
max_x=max(max(x));
|
||||
min_x=min(min(x));
|
||||
x=(x-min_x)/(max_x-min_x);
|
||||
x = x';
|
||||
%% x preprocessing
|
||||
x = x';
|
||||
x = sgolayfilt(x,2,17);
|
||||
x =diff(x);
|
||||
max_x=max(max(x));
|
||||
min_x=min(min(x));
|
||||
x=(x-min_x)/(max_x-min_x);
|
||||
x = x';
|
||||
|
||||
30
preprocess/preprocess_mango.m
Executable file → Normal file
30
preprocess/preprocess_mango.m
Executable file → Normal file
@ -1,15 +1,15 @@
|
||||
%% x preprocessing
|
||||
clear;
|
||||
load('dataset/mango/mango_origin.mat')
|
||||
x = x';
|
||||
x = sgolayfilt(x,2,17);
|
||||
x =diff(x);
|
||||
max_x=max(max(x));
|
||||
min_x=min(min(x));
|
||||
x=(x-min_x)/(max_x-min_x);
|
||||
x = x';
|
||||
y = y';
|
||||
min_y = min(min(y));
|
||||
max_y = max(max(y));
|
||||
y = (y-min_y)/(max_y-min_y);
|
||||
save('dataset/mango/mango_preprocessed.mat')
|
||||
%% x preprocessing
|
||||
clear;
|
||||
load('dataset/mango/mango_origin.mat')
|
||||
x = x';
|
||||
x = sgolayfilt(x,2,17);
|
||||
x =diff(x);
|
||||
max_x=max(max(x));
|
||||
min_x=min(min(x));
|
||||
x=(x-min_x)/(max_x-min_x);
|
||||
x = x';
|
||||
y = y';
|
||||
min_y = min(min(y));
|
||||
max_y = max(max(y));
|
||||
y = (y-min_y)/(max_y-min_y);
|
||||
save('dataset/mango/mango_preprocessed.mat')
|
||||
|
||||
30
preprocess/train_test_split.m
Executable file → Normal file
30
preprocess/train_test_split.m
Executable file → Normal file
@ -1,15 +1,15 @@
|
||||
data=[x,y];
|
||||
test_rate = 0.3;
|
||||
data_num = size(x, 1);
|
||||
train_num = round((1-test_rate) * data_num);
|
||||
idx=randperm(data_num);
|
||||
train_idx=idx(1:train_num);
|
||||
test_idx=idx(train_num+1:data_num);
|
||||
data_train=data(train_idx,:);
|
||||
x_train=data_train(:,1:size(x, 2));
|
||||
y_train=data_train(:,size(x, 2)+1);
|
||||
test_data=data(test_idx,:);
|
||||
x_test=test_data(:,1:size(x, 2));
|
||||
y_test=test_data(:,size(x, 2)+1);
|
||||
clear data_num train_num idx train_idx test_idx test_data train_data x y;
|
||||
clear data data_train test_rate;
|
||||
data=[x,y];
|
||||
test_rate = 0.3;
|
||||
data_num = size(x, 1);
|
||||
train_num = round((1-test_rate) * data_num);
|
||||
idx=randperm(data_num);
|
||||
train_idx=idx(1:train_num);
|
||||
test_idx=idx(train_num+1:data_num);
|
||||
data_train=data(train_idx,:);
|
||||
x_train=data_train(:,1:size(x, 2));
|
||||
y_train=data_train(:,size(x, 2)+1);
|
||||
test_data=data(test_idx,:);
|
||||
x_test=test_data(:,1:size(x, 2));
|
||||
y_test=test_data(:,size(x, 2)+1);
|
||||
clear data_num train_num idx train_idx test_idx test_data train_data x y;
|
||||
clear data data_train test_rate;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user