mirror of
https://github.com/NanjingForestryUniversity/tobacoo-industry.git
synced 2025-11-08 14:23:53 +00:00
First Commit
This commit is contained in:
commit
da5fde8f40
577
.gitignore
vendored
Normal file
577
.gitignore
vendored
Normal file
@ -0,0 +1,577 @@
|
||||
### Project sepecified
|
||||
dataset/*
|
||||
models/*
|
||||
|
||||
|
||||
### VisualStudioCode template
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
*.code-workspace
|
||||
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
||||
### VirtualEnv template
|
||||
# Virtualenv
|
||||
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
||||
.Python
|
||||
[Bb]in
|
||||
[Ii]nclude
|
||||
[Ll]ib
|
||||
[Ll]ib64
|
||||
[Ll]ocal
|
||||
[Ss]cripts
|
||||
pyvenv.cfg
|
||||
.venv
|
||||
pip-selfcheck.json
|
||||
|
||||
### Example user template template
|
||||
### Example user template
|
||||
|
||||
# IntelliJ project files
|
||||
.idea
|
||||
*.iml
|
||||
out
|
||||
gen
|
||||
### VisualStudio template
|
||||
## Ignore Visual Studio temporary files, build results, and
|
||||
## files generated by popular Visual Studio add-ons.
|
||||
##
|
||||
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
||||
|
||||
# User-specific files
|
||||
*.rsuser
|
||||
*.suo
|
||||
*.user
|
||||
*.userosscache
|
||||
*.sln.docstates
|
||||
|
||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||
*.userprefs
|
||||
|
||||
# Mono auto generated files
|
||||
mono_crash.*
|
||||
|
||||
# Build results
|
||||
[Dd]ebug/
|
||||
[Dd]ebugPublic/
|
||||
[Rr]elease/
|
||||
[Rr]eleases/
|
||||
x64/
|
||||
x86/
|
||||
[Ww][Ii][Nn]32/
|
||||
[Aa][Rr][Mm]/
|
||||
[Aa][Rr][Mm]64/
|
||||
bld/
|
||||
[Bb]in/
|
||||
[Oo]bj/
|
||||
[Ll]og/
|
||||
[Ll]ogs/
|
||||
|
||||
# Visual Studio 2015/2017 cache/options directory
|
||||
.vs/
|
||||
# Uncomment if you have tasks that create the project's static files in wwwroot
|
||||
#wwwroot/
|
||||
|
||||
# Visual Studio 2017 auto generated files
|
||||
Generated\ Files/
|
||||
|
||||
# MSTest test Results
|
||||
[Tt]est[Rr]esult*/
|
||||
[Bb]uild[Ll]og.*
|
||||
|
||||
# NUnit
|
||||
*.VisualState.xml
|
||||
TestResult.xml
|
||||
nunit-*.xml
|
||||
|
||||
# Build Results of an ATL Project
|
||||
[Dd]ebugPS/
|
||||
[Rr]eleasePS/
|
||||
dlldata.c
|
||||
|
||||
# Benchmark Results
|
||||
BenchmarkDotNet.Artifacts/
|
||||
|
||||
# .NET Core
|
||||
project.lock.json
|
||||
project.fragment.lock.json
|
||||
artifacts/
|
||||
|
||||
# ASP.NET Scaffolding
|
||||
ScaffoldingReadMe.txt
|
||||
|
||||
# StyleCop
|
||||
StyleCopReport.xml
|
||||
|
||||
# Files built by Visual Studio
|
||||
*_i.c
|
||||
*_p.c
|
||||
*_h.h
|
||||
*.ilk
|
||||
*.meta
|
||||
*.obj
|
||||
*.iobj
|
||||
*.pch
|
||||
*.pdb
|
||||
*.ipdb
|
||||
*.pgc
|
||||
*.pgd
|
||||
*.rsp
|
||||
*.sbr
|
||||
*.tlb
|
||||
*.tli
|
||||
*.tlh
|
||||
*.tmp
|
||||
*.tmp_proj
|
||||
*_wpftmp.csproj
|
||||
*.log
|
||||
*.vspscc
|
||||
*.vssscc
|
||||
.builds
|
||||
*.pidb
|
||||
*.svclog
|
||||
*.scc
|
||||
|
||||
# Chutzpah Test files
|
||||
_Chutzpah*
|
||||
|
||||
# Visual C++ cache files
|
||||
ipch/
|
||||
*.aps
|
||||
*.ncb
|
||||
*.opendb
|
||||
*.opensdf
|
||||
*.sdf
|
||||
*.cachefile
|
||||
*.VC.db
|
||||
*.VC.VC.opendb
|
||||
|
||||
# Visual Studio profiler
|
||||
*.psess
|
||||
*.vsp
|
||||
*.vspx
|
||||
*.sap
|
||||
|
||||
# Visual Studio Trace Files
|
||||
*.e2e
|
||||
|
||||
# TFS 2012 Local Workspace
|
||||
$tf/
|
||||
|
||||
# Guidance Automation Toolkit
|
||||
*.gpState
|
||||
|
||||
# ReSharper is a .NET coding add-in
|
||||
_ReSharper*/
|
||||
*.[Rr]e[Ss]harper
|
||||
*.DotSettings.user
|
||||
|
||||
# TeamCity is a build add-in
|
||||
_TeamCity*
|
||||
|
||||
# DotCover is a Code Coverage Tool
|
||||
*.dotCover
|
||||
|
||||
# AxoCover is a Code Coverage Tool
|
||||
.axoCover/*
|
||||
!.axoCover/settings.json
|
||||
|
||||
# Coverlet is a free, cross platform Code Coverage Tool
|
||||
coverage*.json
|
||||
coverage*.xml
|
||||
coverage*.info
|
||||
|
||||
# Visual Studio code coverage results
|
||||
*.coverage
|
||||
*.coveragexml
|
||||
|
||||
# NCrunch
|
||||
_NCrunch_*
|
||||
.*crunch*.local.xml
|
||||
nCrunchTemp_*
|
||||
|
||||
# MightyMoose
|
||||
*.mm.*
|
||||
AutoTest.Net/
|
||||
|
||||
# Web workbench (sass)
|
||||
.sass-cache/
|
||||
|
||||
# Installshield output folder
|
||||
[Ee]xpress/
|
||||
|
||||
# DocProject is a documentation generator add-in
|
||||
DocProject/buildhelp/
|
||||
DocProject/Help/*.HxT
|
||||
DocProject/Help/*.HxC
|
||||
DocProject/Help/*.hhc
|
||||
DocProject/Help/*.hhk
|
||||
DocProject/Help/*.hhp
|
||||
DocProject/Help/Html2
|
||||
DocProject/Help/html
|
||||
|
||||
# Click-Once directory
|
||||
publish/
|
||||
|
||||
# Publish Web Output
|
||||
*.[Pp]ublish.xml
|
||||
*.azurePubxml
|
||||
# Note: Comment the next line if you want to checkin your web deploy settings,
|
||||
# but database connection strings (with potential passwords) will be unencrypted
|
||||
*.pubxml
|
||||
*.publishproj
|
||||
|
||||
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
||||
# checkin your Azure Web App publish settings, but sensitive information contained
|
||||
# in these scripts will be unencrypted
|
||||
PublishScripts/
|
||||
|
||||
# NuGet Packages
|
||||
*.nupkg
|
||||
# NuGet Symbol Packages
|
||||
*.snupkg
|
||||
# The packages folder can be ignored because of Package Restore
|
||||
**/[Pp]ackages/*
|
||||
# except build/, which is used as an MSBuild target.
|
||||
!**/[Pp]ackages/build/
|
||||
# Uncomment if necessary however generally it will be regenerated when needed
|
||||
#!**/[Pp]ackages/repositories.config
|
||||
# NuGet v3's project.json files produces more ignorable files
|
||||
*.nuget.props
|
||||
*.nuget.targets
|
||||
|
||||
# Microsoft Azure Build Output
|
||||
csx/
|
||||
*.build.csdef
|
||||
|
||||
# Microsoft Azure Emulator
|
||||
ecf/
|
||||
rcf/
|
||||
|
||||
# Windows Store app package directories and files
|
||||
AppPackages/
|
||||
BundleArtifacts/
|
||||
Package.StoreAssociation.xml
|
||||
_pkginfo.txt
|
||||
*.appx
|
||||
*.appxbundle
|
||||
*.appxupload
|
||||
|
||||
# Visual Studio cache files
|
||||
# files ending in .cache can be ignored
|
||||
*.[Cc]ache
|
||||
# but keep track of directories ending in .cache
|
||||
!?*.[Cc]ache/
|
||||
|
||||
# Others
|
||||
ClientBin/
|
||||
~$*
|
||||
*~
|
||||
*.dbmdl
|
||||
*.dbproj.schemaview
|
||||
*.jfm
|
||||
*.pfx
|
||||
*.publishsettings
|
||||
orleans.codegen.cs
|
||||
|
||||
# Including strong name files can present a security risk
|
||||
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
||||
#*.snk
|
||||
|
||||
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
||||
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
||||
#bower_components/
|
||||
|
||||
# RIA/Silverlight projects
|
||||
Generated_Code/
|
||||
|
||||
# Backup & report files from converting an old project file
|
||||
# to a newer Visual Studio version. Backup files are not needed,
|
||||
# because we have git ;-)
|
||||
_UpgradeReport_Files/
|
||||
Backup*/
|
||||
UpgradeLog*.XML
|
||||
UpgradeLog*.htm
|
||||
ServiceFabricBackup/
|
||||
*.rptproj.bak
|
||||
|
||||
# SQL Server files
|
||||
*.mdf
|
||||
*.ldf
|
||||
*.ndf
|
||||
|
||||
# Business Intelligence projects
|
||||
*.rdl.data
|
||||
*.bim.layout
|
||||
*.bim_*.settings
|
||||
*.rptproj.rsuser
|
||||
*- [Bb]ackup.rdl
|
||||
*- [Bb]ackup ([0-9]).rdl
|
||||
*- [Bb]ackup ([0-9][0-9]).rdl
|
||||
|
||||
# Microsoft Fakes
|
||||
FakesAssemblies/
|
||||
|
||||
# GhostDoc plugin setting file
|
||||
*.GhostDoc.xml
|
||||
|
||||
# Node.js Tools for Visual Studio
|
||||
.ntvs_analysis.dat
|
||||
node_modules/
|
||||
|
||||
# Visual Studio 6 build log
|
||||
*.plg
|
||||
|
||||
# Visual Studio 6 workspace options file
|
||||
*.opt
|
||||
|
||||
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
||||
*.vbw
|
||||
|
||||
# Visual Studio LightSwitch build output
|
||||
**/*.HTMLClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/ModelManifest.xml
|
||||
**/*.Server/GeneratedArtifacts
|
||||
**/*.Server/ModelManifest.xml
|
||||
_Pvt_Extensions
|
||||
|
||||
# Paket dependency manager
|
||||
.paket/paket.exe
|
||||
paket-files/
|
||||
|
||||
# FAKE - F# Make
|
||||
.fake/
|
||||
|
||||
# CodeRush personal settings
|
||||
.cr/personal
|
||||
|
||||
# Python Tools for Visual Studio (PTVS)
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Cake - Uncomment if you are using it
|
||||
# tools/**
|
||||
# !tools/packages.config
|
||||
|
||||
# Tabs Studio
|
||||
*.tss
|
||||
|
||||
# Telerik's JustMock configuration file
|
||||
*.jmconfig
|
||||
|
||||
# BizTalk build output
|
||||
*.btp.cs
|
||||
*.btm.cs
|
||||
*.odx.cs
|
||||
*.xsd.cs
|
||||
|
||||
# OpenCover UI analysis results
|
||||
OpenCover/
|
||||
|
||||
# Azure Stream Analytics local run output
|
||||
ASALocalRun/
|
||||
|
||||
# MSBuild Binary and Structured Log
|
||||
*.binlog
|
||||
|
||||
# NVidia Nsight GPU debugger configuration file
|
||||
*.nvuser
|
||||
|
||||
# MFractors (Xamarin productivity tool) working folder
|
||||
.mfractor/
|
||||
|
||||
# Local History for Visual Studio
|
||||
.localhistory/
|
||||
|
||||
# BeatPulse healthcheck temp database
|
||||
healthchecksdb
|
||||
|
||||
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
||||
MigrationBackup/
|
||||
|
||||
# Ionide (cross platform F# VS Code tools) working folder
|
||||
.ionide/
|
||||
|
||||
# Fody - auto-generated XML schema
|
||||
FodyWeavers.xsd
|
||||
|
||||
### JupyterNotebooks template
|
||||
# gitignore template for Jupyter Notebooks
|
||||
# website: http://jupyter.org/
|
||||
|
||||
.ipynb_checkpoints
|
||||
*/.ipynb_checkpoints/*
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# Remove previous ipynb_checkpoints
|
||||
# git rm -r .ipynb_checkpoints/
|
||||
|
||||
### macOS template
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### Python template
|
||||
# Byte-compiled / optimized / DLL files
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
|
||||
# IPython
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
149
01_dataset_building.ipynb
Executable file
149
01_dataset_building.ipynb
Executable file
@ -0,0 +1,149 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# 数据集的制作"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"from utils import split_xy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 一些参数\n",
|
||||
"blk_sz = 8\n",
|
||||
"sensitivity = 8\n",
|
||||
"selected_bands = [127, 201, 202, 294]\n",
|
||||
"# [76, 146, 216, 367, 383, 406]\n",
|
||||
"file_name, labeled_image_file = r\"F:\\zhouchao\\616\\calibrated1.raw\", \\\n",
|
||||
"r\"F:\\zhouchao\\616\\label1.bmp\"\n",
|
||||
"# file_name, labeled_image_file = \"./dataset/calibrated77.raw\", \"./dataset/label77.png\"\n",
|
||||
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_1.p'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"## 波长选择"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(file_name, \"rb\") as f:\n",
|
||||
" data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, 448, 1024)).transpose(0, 2, 1)\n",
|
||||
"data = data[..., selected_bands]\n",
|
||||
"label = cv2.imread(labeled_image_file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## 块分割与数据存储"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
|
||||
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
|
||||
"x, y = split_xy(data, label, blk_sz, sensitivity=sensitivity, color_dict=color_dict)\n",
|
||||
"with open(dataset_file, 'wb') as f:\n",
|
||||
" pickle.dump((x, y), f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.7 ('base')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "7f619fc91ee8bdab81d49e7c14228037474662e3f2d607687ae505108922fa06"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
471
02_feature_design&training.ipynb
Executable file
471
02_feature_design&training.ipynb
Executable file
File diff suppressed because one or more lines are too long
301
03_data_update.ipynb
Executable file
301
03_data_update.ipynb
Executable file
@ -0,0 +1,301 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# 数据集扩充"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"虽然当前的模型已经能够达到较好的效果,但是还不够好,对于一些较老的烟梗不能够做到有效的判别,我们为此增加数据集。"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import cv2\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from utils import read_raw_file, split_xy, generate_tobacco_label, generate_impurity_label\n",
|
||||
"from models import SpecDetector\n",
|
||||
"import pickle"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# some parameters\n",
|
||||
"new_spectra_file = r\"F:\\zhouchao\\618\\bomo\\calibrated7.raw\"\n",
|
||||
"new_label_file = r\"F:\\zhouchao\\618\\bomo\\picture\\label7.bmp\"\n",
|
||||
"\n",
|
||||
"target_class = 0\n",
|
||||
"target_class_left, target_class_right = 5, 4\n",
|
||||
"light_threshold = 0.5\n",
|
||||
"add_background = False\n",
|
||||
"\n",
|
||||
"split_line = 500\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"blk_sz, sensitivity = 8, 8\n",
|
||||
"selected_bands = [127, 201, 202, 294]\n",
|
||||
"tree_num = 185\n",
|
||||
"\n",
|
||||
"pic_row, pic_col= 600, 1024\n",
|
||||
"\n",
|
||||
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
|
||||
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
|
||||
"\n",
|
||||
"new_dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_7.p'\n",
|
||||
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_6.p'\n",
|
||||
"\n",
|
||||
"model_file = f'./models/rf_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_6.model'\n",
|
||||
"selected_bands = None"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# if len(new_spectra_files) == 1:\n",
|
||||
"data = read_raw_file(new_spectra_file, selected_bands)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 烟梗标签生成\n",
|
||||
"这会将纯烟梗图片中识别为杂质的部分提取出来"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_list, y_list = [], []\n",
|
||||
"if (new_label_file is None) and (target_class == 1):\n",
|
||||
" x_list, y_list = generate_tobacco_label(data, model_file, blk_sz, selected_bands)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 其他类别杂质阈值分割\n",
|
||||
"通过阈值分割的形式获取其他类别的杂质"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if (new_label_file is None) and (target_class != 1):\n",
|
||||
" img = generate_impurity_label(data, light_threshold, color_dict,\n",
|
||||
" target_class_right=target_class_right,\n",
|
||||
" target_class_left=target_class_left,\n",
|
||||
" split_line=split_line)\n",
|
||||
" root, _ = os.path.splitext(new_dataset_file)\n",
|
||||
" cv2.imwrite(root+\"_generated.bmp\", img)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 读取标签"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(600, 1024, 3)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"if new_label_file is not None:\n",
|
||||
" label = cv2.imread(new_label_file)\n",
|
||||
" print(label.shape)\n",
|
||||
" x_list, y_list = split_xy(data, label, blk_sz, sensitivity=sensitivity, color_dict=color_dict, add_background=add_background)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"148 148\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(len(x_list), len(y_list))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 读取旧数据合并"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(dataset_file, 'rb') as f:\n",
|
||||
" x, y = pickle.load(f)\n",
|
||||
"x.extend(x_list)\n",
|
||||
"y.extend(y_list)\n",
|
||||
"with open(new_dataset_file, 'wb') as f:\n",
|
||||
" pickle.dump((x, y), f)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 批量数据的处理"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
398
04_performance_tune.ipynb
Executable file
398
04_performance_tune.ipynb
Executable file
File diff suppressed because one or more lines are too long
225
05_evaluation.ipynb
Executable file
225
05_evaluation.ipynb
Executable file
File diff suppressed because one or more lines are too long
34
main.py
Executable file
34
main.py
Executable file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from models import SpecDetector
|
||||
from root_dir import ROOT_DIR
|
||||
|
||||
nrows, ncols, nbands = 600, 1024, 4
|
||||
img_fifo_path = "/tmp/dkimg.fifo"
|
||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||
selected_model = "rf_8x8_c4_400_13.model"
|
||||
|
||||
def main():
|
||||
model_path = os.path.join(ROOT_DIR, "models", selected_model)
|
||||
detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
|
||||
total_len = nrows * ncols * nbands * 4
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
print("connect to fifo")
|
||||
|
||||
while True:
|
||||
data = os.read(fd_img, total_len)
|
||||
print("get img")
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
||||
mask = detector.predict(img)
|
||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||
os.write(fd_mask, mask.tobytes())
|
||||
os.close(fd_mask)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
142
models.py
Executable file
142
models.py
Executable file
@ -0,0 +1,142 @@
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
||||
|
||||
|
||||
# def feature(x):
|
||||
# x = x.reshape((x.shape[0], -1, x.shape[-1]))
|
||||
# x = np.mean(x, axis=1)
|
||||
# return x
|
||||
|
||||
def feature(x):
|
||||
x = x.reshape((x.shape[0], -1))
|
||||
return x
|
||||
|
||||
|
||||
def train_rf_and_report(train_x, train_y, test_x, test_y,
|
||||
tree_num, save_path=None):
|
||||
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42, class_weight={0:10, 1:10})
|
||||
rfc = rfc.fit(train_x, train_y)
|
||||
t1 = time.time()
|
||||
y_pred = rfc.predict(test_x)
|
||||
y_pred_binary = np.ones_like(y_pred)
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||
y_pred_binary[(y_pred >1)] = 2
|
||||
test_y_binary = np.ones_like(test_y)
|
||||
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||
test_y_binary[(test_y >1) ] = 2
|
||||
print("预测时间:", time.time() - t1)
|
||||
print("RFC训练模型评分:" + str(accuracy_score(train_y, rfc.predict(train_x))))
|
||||
print("RFC待测模型评分:" + str(accuracy_score(test_y, rfc.predict(test_x))))
|
||||
print('RFC预测结果:' + str(y_pred))
|
||||
print('---------------------------------------------------------------------------------------------------')
|
||||
print('RFC分类报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||
print('RFC混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||
print('rfc分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||
print('rfc混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||
if save_path is not None:
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump(rfc, f)
|
||||
return rfc
|
||||
|
||||
|
||||
def evaluation_and_report(model, test_x, test_y):
|
||||
t1 = time.time()
|
||||
y_pred = model.predict(test_x)
|
||||
y_pred_binary = np.ones_like(y_pred)
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||
y_pred_binary[(y_pred >1)] = 2
|
||||
test_y_binary = np.ones_like(test_y)
|
||||
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||
test_y_binary[(test_y >1) ] = 2
|
||||
print("预测时间:", time.time() - t1)
|
||||
print("RFC待测模型 accuracy:" + str(accuracy_score(test_y, model.predict(test_x))))
|
||||
print('RFC预测结果:' + str(y_pred))
|
||||
print('---------------------------------------------------------------------------------------------------')
|
||||
print('RFC分类报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||
print('RFC混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||
print('rfc分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||
print('rfc混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||
|
||||
|
||||
def train_pca_rf(train_x, train_y, test_x, test_y, n_comp,
|
||||
tree_num, save_path=None):
|
||||
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42,class_weight={0:100, 1:100})
|
||||
pca = PCA(n_components=0.95)
|
||||
rfc = rfc.fit(train_x, train_y)
|
||||
t1 = time.time()
|
||||
y_pred = rfc.predict(test_x)
|
||||
y_pred_binary = np.ones_like(y_pred)
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||
y_pred_binary[(y_pred == 2) | (y_pred == 3) | (y_pred == 4)] = 2
|
||||
test_y_binary = np.ones_like(test_y)
|
||||
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||
test_y_binary[(test_y == 2) | (test_y == 3) | (test_y == 4)] = 2
|
||||
print("预测时间:", time.time() - t1)
|
||||
print("RFC训练模型评分:" + str(accuracy_score(train_y, rfc.predict(train_x))))
|
||||
print("RFC待测模型评分:" + str(accuracy_score(test_y, rfc.predict(test_x))))
|
||||
print('RFC预测结果:' + str(y_pred))
|
||||
print('---------------------------------------------------------------------------------------------------')
|
||||
print('RFC分类报告:\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||
print('RFC混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||
print('rfc分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||
print('rfc混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||
if save_path is not None:
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump((pca, rfc), f)
|
||||
return pca, rfc
|
||||
|
||||
|
||||
def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||
"""
|
||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||
|
||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||
;param blk_sz: block size
|
||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||
"""
|
||||
x_list = []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
x_list.append(block_data)
|
||||
return x_list
|
||||
|
||||
|
||||
class SpecDetector(object):
|
||||
def __init__(self, model_path, blk_sz=8, channel_num=4):
|
||||
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||
if os.path.exists(model_path):
|
||||
with open(model_path, "rb") as model_file:
|
||||
self.clf = pickle.load(model_file)
|
||||
else:
|
||||
raise FileNotFoundError("Model File not found")
|
||||
|
||||
def predict(self, data):
|
||||
blocks = split_x(data, blk_sz=self.blk_sz)
|
||||
blocks = np.array(blocks)
|
||||
features = feature(np.array(blocks))
|
||||
y_pred = self.clf.predict(features)
|
||||
y_pred_binary = np.ones_like(y_pred)
|
||||
# classes merge
|
||||
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||
# transform to mask
|
||||
mask = self.mask_transform(y_pred_binary, (1024, 600))
|
||||
return mask
|
||||
|
||||
def mask_transform(self, result, dst_size):
|
||||
mask_size = 600//self.blk_sz, 1024 // self.blk_sz
|
||||
mask = np.zeros(mask_size, dtype=np.uint8)
|
||||
for idx, r in enumerate(result):
|
||||
row, col = idx // mask_size[1], idx % mask_size[1]
|
||||
mask[row, col] = r
|
||||
mask = mask.repeat(self.blk_sz, axis = 0).repeat(self.blk_sz, axis = 1)
|
||||
return mask
|
||||
10
root_dir.py
Executable file
10
root_dir.py
Executable file
@ -0,0 +1,10 @@
|
||||
# -*- codeing = utf-8 -*-
|
||||
# Time : 2022/6/18 9:32
|
||||
# @Auther : zhouchao
|
||||
# @File: root_dir.py
|
||||
# @Software:PyCharm
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
ROOT_DIR = os.path.split(sys.argv[0])[0]
|
||||
|
||||
19
test_files/models_test.py
Executable file
19
test_files/models_test.py
Executable file
@ -0,0 +1,19 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from models import feature
|
||||
|
||||
|
||||
class ModelTestCase(unittest.TestCase):
|
||||
def test_feature(self):
|
||||
x_list = [np.ones((8, 8, 6)) * i for i in range(9600)]
|
||||
features = feature(x_list=x_list)
|
||||
self.assertEqual(features[0][0], 0) # add assertion here
|
||||
self.assertEqual(features[0][5], 0) # add assertion here
|
||||
self.assertEqual(features[-1][5], 9599) # add assertion here
|
||||
self.assertEqual(features[-1][0], 9599) # add assertion here
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
119
test_files/utils_test.py
Executable file
119
test_files/utils_test.py
Executable file
@ -0,0 +1,119 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils import determine_class, split_xy, split_x
|
||||
|
||||
|
||||
class DatasetTest(unittest.TestCase):
|
||||
def test_determine_class(self):
|
||||
pixel_block = np.zeros((8, 8), dtype=np.uint8)
|
||||
pixel_block[2: 4, 5: 6] = 2
|
||||
pixel_block[5: 7, 1: 2] = 1
|
||||
pixel_block[4: 6, 1: 6] = 3
|
||||
cls = determine_class(pixel_block, sensitivity=8)
|
||||
self.assertEqual(cls, 3)
|
||||
pixel_block = np.zeros((8, 8), dtype=np.uint8)
|
||||
pixel_block[2: 4, 5: 6] = 2
|
||||
pixel_block[5: 7, 1: 2] = 1
|
||||
pixel_block[4: 6, 1: 6] = 2
|
||||
cls = determine_class(pixel_block, sensitivity=8)
|
||||
self.assertEqual(cls, 2)
|
||||
pixel_block = np.zeros((8, 8), dtype=np.uint8)
|
||||
pixel_block[2: 4, 5: 6] = 1
|
||||
pixel_block[5: 7, 1: 2] = 2
|
||||
pixel_block[4: 6, 1: 6] = 1
|
||||
cls = determine_class(pixel_block, sensitivity=8)
|
||||
self.assertEqual(cls, 1)
|
||||
|
||||
def test_split_xy(self):
|
||||
x = np.arange(600*1024).reshape((600, 1024))
|
||||
y = np.zeros((600, 1024, 3))
|
||||
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
|
||||
trans_color_dict = {v: k for k, v in color_dict.items()}
|
||||
# modify the first block
|
||||
y[2: 4, 5: 6] = trans_color_dict[2]
|
||||
y[5: 7, 1: 2] = trans_color_dict[1]
|
||||
y[4: 6, 1: 6] = trans_color_dict[3]
|
||||
# modify the last block
|
||||
y[-4: -2, -6: -5] = trans_color_dict[1]
|
||||
y[-7: -5, -2: -1] = trans_color_dict[2]
|
||||
y[-6: -4, -6: -1] = trans_color_dict[1]
|
||||
# modify the middle block
|
||||
y[64+2: 64+4, 64+5: 64+6] = trans_color_dict[2]
|
||||
y[64+5: 64+7, 64+1: 64+2] = trans_color_dict[1]
|
||||
y[64+4: 64+6, 64+1: 64+6] = trans_color_dict[2]
|
||||
x_list, y_list = split_xy(x, y, blk_sz=8, sensitivity=8)
|
||||
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
|
||||
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
|
||||
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
|
||||
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
|
||||
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
|
||||
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
|
||||
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175],])
|
||||
sum_value = np.sum(np.sum(x_list[0]-first_block))
|
||||
self.assertEqual(sum_value, 0)
|
||||
self.assertEqual(y_list[0], 3)
|
||||
last_block = np.array(
|
||||
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
|
||||
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
|
||||
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
|
||||
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
|
||||
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
|
||||
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
|
||||
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
|
||||
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399],])
|
||||
sum_value = np.sum(np.sum(x_list[-1]-last_block))
|
||||
self.assertEqual(sum_value, 0)
|
||||
self.assertEqual(y_list[-1], 1)
|
||||
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
|
||||
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
|
||||
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
|
||||
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
|
||||
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
|
||||
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
|
||||
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
|
||||
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
|
||||
sum_value = np.sum(np.sum(x_list[1032]-middle_block))
|
||||
self.assertEqual(sum_value, 0)
|
||||
self.assertEqual(y_list[1032], 2)
|
||||
|
||||
def test_split_x(self):
|
||||
x = np.arange(600 * 1024).reshape((600, 1024))
|
||||
x_list = split_x(x, blk_sz=8)
|
||||
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
|
||||
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
|
||||
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
|
||||
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
|
||||
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
|
||||
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
|
||||
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175], ])
|
||||
sum_value = np.sum(np.sum(x_list[0] - first_block))
|
||||
self.assertEqual(sum_value, 0)
|
||||
last_block = np.array(
|
||||
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
|
||||
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
|
||||
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
|
||||
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
|
||||
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
|
||||
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
|
||||
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
|
||||
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399], ])
|
||||
sum_value = np.sum(np.sum(x_list[-1] - last_block))
|
||||
self.assertEqual(sum_value, 0)
|
||||
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
|
||||
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
|
||||
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
|
||||
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
|
||||
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
|
||||
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
|
||||
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
|
||||
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
|
||||
sum_value = np.sum(np.sum(x_list[1032] - middle_block))
|
||||
self.assertEqual(sum_value, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
190
utils.py
Executable file
190
utils.py
Executable file
@ -0,0 +1,190 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from models import SpecDetector
|
||||
|
||||
|
||||
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
|
||||
"""
|
||||
将label转为类别
|
||||
|
||||
:param pixel: 一个 n x n 的像素块
|
||||
:param color_dict: 用于转化的字典 {(0, 0, 255): 1, ....} 色彩采用bgr
|
||||
:return:类别白噢好
|
||||
"""
|
||||
# 0 表示的是背景, 1表示的是烟梗,剩下的都是杂质
|
||||
if color_dict is None:
|
||||
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
|
||||
if (pixel[0], pixel[1], pixel[2]) in color_dict.keys():
|
||||
return color_dict[(pixel[0], pixel[1], pixel[2])]
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
|
||||
"""
|
||||
决定像素块的类别
|
||||
|
||||
:param pixel_blk: 像素块
|
||||
:param sensitivity: 敏感度
|
||||
:return:
|
||||
"""
|
||||
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
|
||||
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
|
||||
for cls in defect_dict.keys()}
|
||||
grant_cls = {0: 0, 1: 0}
|
||||
for cls, num in color_numbers.items():
|
||||
grant_cls[defect_dict[cls]] += num
|
||||
if grant_cls[1] >= sensitivity:
|
||||
color_numbers = {cls: color_numbers[cls] for cls in [2, 3, 4, 5, 6]}
|
||||
return max(color_numbers, key=color_numbers.get)
|
||||
else:
|
||||
if color_numbers[1] >= sensitivity:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity: int = 12,
|
||||
color_dict=None, add_background=True) -> tuple:
|
||||
"""
|
||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||
|
||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||
;param labeled_img: RGB labeled img with respect to the image!
|
||||
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
|
||||
;param blk_sz: block size
|
||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||
data y (block_num, ) 1 是杂质, 0是无杂质
|
||||
"""
|
||||
assert (data.shape[0] == labeled_img.shape[0]) and (data.shape[1] == labeled_img.shape[1])
|
||||
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}\
|
||||
if color_dict is None else color_dict
|
||||
class_img = np.zeros((labeled_img.shape[0], labeled_img.shape[1]), dtype=int)
|
||||
for color, class_idx in color_dict.items():
|
||||
truth_map = np.all(labeled_img == color, axis=2)
|
||||
class_img[truth_map] = class_idx
|
||||
x_list, y_list = [], []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
block_label = determine_class(block_label, sensitivity=sensitivity)
|
||||
if add_background:
|
||||
y_list.append(block_label)
|
||||
x_list.append(block_data)
|
||||
else:
|
||||
if block_label != 0:
|
||||
y_list.append(block_label)
|
||||
x_list.append(block_data)
|
||||
return x_list, y_list
|
||||
|
||||
|
||||
def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||
"""
|
||||
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||
|
||||
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||
;param blk_sz: block size
|
||||
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||
"""
|
||||
x_list = []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
x_list.append(block_data)
|
||||
return x_list
|
||||
|
||||
|
||||
def visualization_evaluation(detector, data_path, selected_bands=None):
|
||||
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
|
||||
nrows, ncols = 600, 1024
|
||||
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
|
||||
for idx, image_path in enumerate(image_paths):
|
||||
with open(image_path, 'rb') as f:
|
||||
data = f.read()
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)).transpose(0, 2, 1)
|
||||
nbands = img.shape[2]
|
||||
t1 = time.time()
|
||||
mask = detector.predict(img[..., selected_bands] if nbands == 448 else img)
|
||||
time_spent = time.time() - t1
|
||||
if nbands == 448:
|
||||
rgb_img = np.asarray(img[..., [372, 241, 169]] * 255, dtype=np.uint8)
|
||||
else:
|
||||
rgb_img = np.asarray(img[..., [0, 1, 2]] * 255, dtype=np.uint8)
|
||||
fig, axs = plt.subplots(1, 2)
|
||||
axs[0].imshow(rgb_img)
|
||||
axs[1].imshow(mask)
|
||||
fig.suptitle(f"time spent {time_spent*1000:.2f} ms" + f"\n{image_path}")
|
||||
plt.savefig(f"./dataset/{idx}.png", dpi=300)
|
||||
plt.show()
|
||||
|
||||
|
||||
def visualization_y(y_list, k_size):
|
||||
mask = np.zeros((600//k_size, 1024//k_size), dtype=np.uint8)
|
||||
for idx, r in enumerate(y_list):
|
||||
row, col = idx // (1024 // k_size), idx % (1024 // k_size)
|
||||
mask[row, col] = r
|
||||
fig, axs = plt.subplots()
|
||||
axs.imshow(mask)
|
||||
plt.show()
|
||||
|
||||
|
||||
def read_raw_file(file_name, selected_bands=None):
|
||||
with open(file_name, "rb") as f:
|
||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, 1024)).transpose(0, 2, 1)
|
||||
if selected_bands is not None:
|
||||
data = data[..., selected_bands]
|
||||
return data
|
||||
|
||||
|
||||
def read_black_and_white_file(file_name):
|
||||
with open(file_name, "rb") as f:
|
||||
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, 1024)).transpose(0, 2, 1)
|
||||
return data
|
||||
|
||||
|
||||
def label2pic(label, color_dict):
|
||||
pic = np.zeros((label.shape[0], label.shape[1], 3))
|
||||
for color, cls in color_dict.items():
|
||||
pic[label == cls] = color
|
||||
return pic
|
||||
|
||||
|
||||
def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
|
||||
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
|
||||
y_label = model.predict(data)
|
||||
x_list, y_list = [], []
|
||||
for i in range(0, 600 // blk_sz):
|
||||
for j in range(0, 1024 // blk_sz):
|
||||
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
|
||||
> 0:
|
||||
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||
x_list.append(block_data)
|
||||
y_list.append(1)
|
||||
return x_list, y_list
|
||||
|
||||
|
||||
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
|
||||
target_class_left=None,):
|
||||
y_label = np.zeros((data.shape[0], data.shape[1]))
|
||||
for i in range(0, 600):
|
||||
for j in range(0, 1024):
|
||||
if np.sum(np.sum(data[i, j])) >= light_threshold:
|
||||
if j > split_line:
|
||||
y_label[i, j] = target_class_right
|
||||
else:
|
||||
y_label[i, j] = target_class_left
|
||||
pic = label2pic(y_label, color_dict=color_dict)
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].matshow(y_label)
|
||||
axs[1].matshow(data[..., 0])
|
||||
plt.show()
|
||||
return pic
|
||||
Loading…
Reference in New Issue
Block a user