mirror of
https://github.com/NanjingForestryUniversity/tobacoo-industry.git
synced 2025-11-08 22:33:52 +00:00
First Commit
This commit is contained in:
commit
da5fde8f40
577
.gitignore
vendored
Normal file
577
.gitignore
vendored
Normal file
@ -0,0 +1,577 @@
|
|||||||
|
### Project sepecified
|
||||||
|
dataset/*
|
||||||
|
models/*
|
||||||
|
|
||||||
|
|
||||||
|
### VisualStudioCode template
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/settings.json
|
||||||
|
!.vscode/tasks.json
|
||||||
|
!.vscode/launch.json
|
||||||
|
!.vscode/extensions.json
|
||||||
|
*.code-workspace
|
||||||
|
|
||||||
|
# Local History for Visual Studio Code
|
||||||
|
.history/
|
||||||
|
|
||||||
|
### VirtualEnv template
|
||||||
|
# Virtualenv
|
||||||
|
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
||||||
|
.Python
|
||||||
|
[Bb]in
|
||||||
|
[Ii]nclude
|
||||||
|
[Ll]ib
|
||||||
|
[Ll]ib64
|
||||||
|
[Ll]ocal
|
||||||
|
[Ss]cripts
|
||||||
|
pyvenv.cfg
|
||||||
|
.venv
|
||||||
|
pip-selfcheck.json
|
||||||
|
|
||||||
|
### Example user template template
|
||||||
|
### Example user template
|
||||||
|
|
||||||
|
# IntelliJ project files
|
||||||
|
.idea
|
||||||
|
*.iml
|
||||||
|
out
|
||||||
|
gen
|
||||||
|
### VisualStudio template
|
||||||
|
## Ignore Visual Studio temporary files, build results, and
|
||||||
|
## files generated by popular Visual Studio add-ons.
|
||||||
|
##
|
||||||
|
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
||||||
|
|
||||||
|
# User-specific files
|
||||||
|
*.rsuser
|
||||||
|
*.suo
|
||||||
|
*.user
|
||||||
|
*.userosscache
|
||||||
|
*.sln.docstates
|
||||||
|
|
||||||
|
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||||
|
*.userprefs
|
||||||
|
|
||||||
|
# Mono auto generated files
|
||||||
|
mono_crash.*
|
||||||
|
|
||||||
|
# Build results
|
||||||
|
[Dd]ebug/
|
||||||
|
[Dd]ebugPublic/
|
||||||
|
[Rr]elease/
|
||||||
|
[Rr]eleases/
|
||||||
|
x64/
|
||||||
|
x86/
|
||||||
|
[Ww][Ii][Nn]32/
|
||||||
|
[Aa][Rr][Mm]/
|
||||||
|
[Aa][Rr][Mm]64/
|
||||||
|
bld/
|
||||||
|
[Bb]in/
|
||||||
|
[Oo]bj/
|
||||||
|
[Ll]og/
|
||||||
|
[Ll]ogs/
|
||||||
|
|
||||||
|
# Visual Studio 2015/2017 cache/options directory
|
||||||
|
.vs/
|
||||||
|
# Uncomment if you have tasks that create the project's static files in wwwroot
|
||||||
|
#wwwroot/
|
||||||
|
|
||||||
|
# Visual Studio 2017 auto generated files
|
||||||
|
Generated\ Files/
|
||||||
|
|
||||||
|
# MSTest test Results
|
||||||
|
[Tt]est[Rr]esult*/
|
||||||
|
[Bb]uild[Ll]og.*
|
||||||
|
|
||||||
|
# NUnit
|
||||||
|
*.VisualState.xml
|
||||||
|
TestResult.xml
|
||||||
|
nunit-*.xml
|
||||||
|
|
||||||
|
# Build Results of an ATL Project
|
||||||
|
[Dd]ebugPS/
|
||||||
|
[Rr]eleasePS/
|
||||||
|
dlldata.c
|
||||||
|
|
||||||
|
# Benchmark Results
|
||||||
|
BenchmarkDotNet.Artifacts/
|
||||||
|
|
||||||
|
# .NET Core
|
||||||
|
project.lock.json
|
||||||
|
project.fragment.lock.json
|
||||||
|
artifacts/
|
||||||
|
|
||||||
|
# ASP.NET Scaffolding
|
||||||
|
ScaffoldingReadMe.txt
|
||||||
|
|
||||||
|
# StyleCop
|
||||||
|
StyleCopReport.xml
|
||||||
|
|
||||||
|
# Files built by Visual Studio
|
||||||
|
*_i.c
|
||||||
|
*_p.c
|
||||||
|
*_h.h
|
||||||
|
*.ilk
|
||||||
|
*.meta
|
||||||
|
*.obj
|
||||||
|
*.iobj
|
||||||
|
*.pch
|
||||||
|
*.pdb
|
||||||
|
*.ipdb
|
||||||
|
*.pgc
|
||||||
|
*.pgd
|
||||||
|
*.rsp
|
||||||
|
*.sbr
|
||||||
|
*.tlb
|
||||||
|
*.tli
|
||||||
|
*.tlh
|
||||||
|
*.tmp
|
||||||
|
*.tmp_proj
|
||||||
|
*_wpftmp.csproj
|
||||||
|
*.log
|
||||||
|
*.vspscc
|
||||||
|
*.vssscc
|
||||||
|
.builds
|
||||||
|
*.pidb
|
||||||
|
*.svclog
|
||||||
|
*.scc
|
||||||
|
|
||||||
|
# Chutzpah Test files
|
||||||
|
_Chutzpah*
|
||||||
|
|
||||||
|
# Visual C++ cache files
|
||||||
|
ipch/
|
||||||
|
*.aps
|
||||||
|
*.ncb
|
||||||
|
*.opendb
|
||||||
|
*.opensdf
|
||||||
|
*.sdf
|
||||||
|
*.cachefile
|
||||||
|
*.VC.db
|
||||||
|
*.VC.VC.opendb
|
||||||
|
|
||||||
|
# Visual Studio profiler
|
||||||
|
*.psess
|
||||||
|
*.vsp
|
||||||
|
*.vspx
|
||||||
|
*.sap
|
||||||
|
|
||||||
|
# Visual Studio Trace Files
|
||||||
|
*.e2e
|
||||||
|
|
||||||
|
# TFS 2012 Local Workspace
|
||||||
|
$tf/
|
||||||
|
|
||||||
|
# Guidance Automation Toolkit
|
||||||
|
*.gpState
|
||||||
|
|
||||||
|
# ReSharper is a .NET coding add-in
|
||||||
|
_ReSharper*/
|
||||||
|
*.[Rr]e[Ss]harper
|
||||||
|
*.DotSettings.user
|
||||||
|
|
||||||
|
# TeamCity is a build add-in
|
||||||
|
_TeamCity*
|
||||||
|
|
||||||
|
# DotCover is a Code Coverage Tool
|
||||||
|
*.dotCover
|
||||||
|
|
||||||
|
# AxoCover is a Code Coverage Tool
|
||||||
|
.axoCover/*
|
||||||
|
!.axoCover/settings.json
|
||||||
|
|
||||||
|
# Coverlet is a free, cross platform Code Coverage Tool
|
||||||
|
coverage*.json
|
||||||
|
coverage*.xml
|
||||||
|
coverage*.info
|
||||||
|
|
||||||
|
# Visual Studio code coverage results
|
||||||
|
*.coverage
|
||||||
|
*.coveragexml
|
||||||
|
|
||||||
|
# NCrunch
|
||||||
|
_NCrunch_*
|
||||||
|
.*crunch*.local.xml
|
||||||
|
nCrunchTemp_*
|
||||||
|
|
||||||
|
# MightyMoose
|
||||||
|
*.mm.*
|
||||||
|
AutoTest.Net/
|
||||||
|
|
||||||
|
# Web workbench (sass)
|
||||||
|
.sass-cache/
|
||||||
|
|
||||||
|
# Installshield output folder
|
||||||
|
[Ee]xpress/
|
||||||
|
|
||||||
|
# DocProject is a documentation generator add-in
|
||||||
|
DocProject/buildhelp/
|
||||||
|
DocProject/Help/*.HxT
|
||||||
|
DocProject/Help/*.HxC
|
||||||
|
DocProject/Help/*.hhc
|
||||||
|
DocProject/Help/*.hhk
|
||||||
|
DocProject/Help/*.hhp
|
||||||
|
DocProject/Help/Html2
|
||||||
|
DocProject/Help/html
|
||||||
|
|
||||||
|
# Click-Once directory
|
||||||
|
publish/
|
||||||
|
|
||||||
|
# Publish Web Output
|
||||||
|
*.[Pp]ublish.xml
|
||||||
|
*.azurePubxml
|
||||||
|
# Note: Comment the next line if you want to checkin your web deploy settings,
|
||||||
|
# but database connection strings (with potential passwords) will be unencrypted
|
||||||
|
*.pubxml
|
||||||
|
*.publishproj
|
||||||
|
|
||||||
|
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
||||||
|
# checkin your Azure Web App publish settings, but sensitive information contained
|
||||||
|
# in these scripts will be unencrypted
|
||||||
|
PublishScripts/
|
||||||
|
|
||||||
|
# NuGet Packages
|
||||||
|
*.nupkg
|
||||||
|
# NuGet Symbol Packages
|
||||||
|
*.snupkg
|
||||||
|
# The packages folder can be ignored because of Package Restore
|
||||||
|
**/[Pp]ackages/*
|
||||||
|
# except build/, which is used as an MSBuild target.
|
||||||
|
!**/[Pp]ackages/build/
|
||||||
|
# Uncomment if necessary however generally it will be regenerated when needed
|
||||||
|
#!**/[Pp]ackages/repositories.config
|
||||||
|
# NuGet v3's project.json files produces more ignorable files
|
||||||
|
*.nuget.props
|
||||||
|
*.nuget.targets
|
||||||
|
|
||||||
|
# Microsoft Azure Build Output
|
||||||
|
csx/
|
||||||
|
*.build.csdef
|
||||||
|
|
||||||
|
# Microsoft Azure Emulator
|
||||||
|
ecf/
|
||||||
|
rcf/
|
||||||
|
|
||||||
|
# Windows Store app package directories and files
|
||||||
|
AppPackages/
|
||||||
|
BundleArtifacts/
|
||||||
|
Package.StoreAssociation.xml
|
||||||
|
_pkginfo.txt
|
||||||
|
*.appx
|
||||||
|
*.appxbundle
|
||||||
|
*.appxupload
|
||||||
|
|
||||||
|
# Visual Studio cache files
|
||||||
|
# files ending in .cache can be ignored
|
||||||
|
*.[Cc]ache
|
||||||
|
# but keep track of directories ending in .cache
|
||||||
|
!?*.[Cc]ache/
|
||||||
|
|
||||||
|
# Others
|
||||||
|
ClientBin/
|
||||||
|
~$*
|
||||||
|
*~
|
||||||
|
*.dbmdl
|
||||||
|
*.dbproj.schemaview
|
||||||
|
*.jfm
|
||||||
|
*.pfx
|
||||||
|
*.publishsettings
|
||||||
|
orleans.codegen.cs
|
||||||
|
|
||||||
|
# Including strong name files can present a security risk
|
||||||
|
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
||||||
|
#*.snk
|
||||||
|
|
||||||
|
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
||||||
|
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
||||||
|
#bower_components/
|
||||||
|
|
||||||
|
# RIA/Silverlight projects
|
||||||
|
Generated_Code/
|
||||||
|
|
||||||
|
# Backup & report files from converting an old project file
|
||||||
|
# to a newer Visual Studio version. Backup files are not needed,
|
||||||
|
# because we have git ;-)
|
||||||
|
_UpgradeReport_Files/
|
||||||
|
Backup*/
|
||||||
|
UpgradeLog*.XML
|
||||||
|
UpgradeLog*.htm
|
||||||
|
ServiceFabricBackup/
|
||||||
|
*.rptproj.bak
|
||||||
|
|
||||||
|
# SQL Server files
|
||||||
|
*.mdf
|
||||||
|
*.ldf
|
||||||
|
*.ndf
|
||||||
|
|
||||||
|
# Business Intelligence projects
|
||||||
|
*.rdl.data
|
||||||
|
*.bim.layout
|
||||||
|
*.bim_*.settings
|
||||||
|
*.rptproj.rsuser
|
||||||
|
*- [Bb]ackup.rdl
|
||||||
|
*- [Bb]ackup ([0-9]).rdl
|
||||||
|
*- [Bb]ackup ([0-9][0-9]).rdl
|
||||||
|
|
||||||
|
# Microsoft Fakes
|
||||||
|
FakesAssemblies/
|
||||||
|
|
||||||
|
# GhostDoc plugin setting file
|
||||||
|
*.GhostDoc.xml
|
||||||
|
|
||||||
|
# Node.js Tools for Visual Studio
|
||||||
|
.ntvs_analysis.dat
|
||||||
|
node_modules/
|
||||||
|
|
||||||
|
# Visual Studio 6 build log
|
||||||
|
*.plg
|
||||||
|
|
||||||
|
# Visual Studio 6 workspace options file
|
||||||
|
*.opt
|
||||||
|
|
||||||
|
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
||||||
|
*.vbw
|
||||||
|
|
||||||
|
# Visual Studio LightSwitch build output
|
||||||
|
**/*.HTMLClient/GeneratedArtifacts
|
||||||
|
**/*.DesktopClient/GeneratedArtifacts
|
||||||
|
**/*.DesktopClient/ModelManifest.xml
|
||||||
|
**/*.Server/GeneratedArtifacts
|
||||||
|
**/*.Server/ModelManifest.xml
|
||||||
|
_Pvt_Extensions
|
||||||
|
|
||||||
|
# Paket dependency manager
|
||||||
|
.paket/paket.exe
|
||||||
|
paket-files/
|
||||||
|
|
||||||
|
# FAKE - F# Make
|
||||||
|
.fake/
|
||||||
|
|
||||||
|
# CodeRush personal settings
|
||||||
|
.cr/personal
|
||||||
|
|
||||||
|
# Python Tools for Visual Studio (PTVS)
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
|
||||||
|
# Cake - Uncomment if you are using it
|
||||||
|
# tools/**
|
||||||
|
# !tools/packages.config
|
||||||
|
|
||||||
|
# Tabs Studio
|
||||||
|
*.tss
|
||||||
|
|
||||||
|
# Telerik's JustMock configuration file
|
||||||
|
*.jmconfig
|
||||||
|
|
||||||
|
# BizTalk build output
|
||||||
|
*.btp.cs
|
||||||
|
*.btm.cs
|
||||||
|
*.odx.cs
|
||||||
|
*.xsd.cs
|
||||||
|
|
||||||
|
# OpenCover UI analysis results
|
||||||
|
OpenCover/
|
||||||
|
|
||||||
|
# Azure Stream Analytics local run output
|
||||||
|
ASALocalRun/
|
||||||
|
|
||||||
|
# MSBuild Binary and Structured Log
|
||||||
|
*.binlog
|
||||||
|
|
||||||
|
# NVidia Nsight GPU debugger configuration file
|
||||||
|
*.nvuser
|
||||||
|
|
||||||
|
# MFractors (Xamarin productivity tool) working folder
|
||||||
|
.mfractor/
|
||||||
|
|
||||||
|
# Local History for Visual Studio
|
||||||
|
.localhistory/
|
||||||
|
|
||||||
|
# BeatPulse healthcheck temp database
|
||||||
|
healthchecksdb
|
||||||
|
|
||||||
|
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
||||||
|
MigrationBackup/
|
||||||
|
|
||||||
|
# Ionide (cross platform F# VS Code tools) working folder
|
||||||
|
.ionide/
|
||||||
|
|
||||||
|
# Fody - auto-generated XML schema
|
||||||
|
FodyWeavers.xsd
|
||||||
|
|
||||||
|
### JupyterNotebooks template
|
||||||
|
# gitignore template for Jupyter Notebooks
|
||||||
|
# website: http://jupyter.org/
|
||||||
|
|
||||||
|
.ipynb_checkpoints
|
||||||
|
*/.ipynb_checkpoints/*
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# Remove previous ipynb_checkpoints
|
||||||
|
# git rm -r .ipynb_checkpoints/
|
||||||
|
|
||||||
|
### macOS template
|
||||||
|
# General
|
||||||
|
.DS_Store
|
||||||
|
.AppleDouble
|
||||||
|
.LSOverride
|
||||||
|
|
||||||
|
# Icon must end with two \r
|
||||||
|
Icon
|
||||||
|
|
||||||
|
# Thumbnails
|
||||||
|
._*
|
||||||
|
|
||||||
|
# Files that might appear in the root of a volume
|
||||||
|
.DocumentRevisions-V100
|
||||||
|
.fseventsd
|
||||||
|
.Spotlight-V100
|
||||||
|
.TemporaryItems
|
||||||
|
.Trashes
|
||||||
|
.VolumeIcon.icns
|
||||||
|
.com.apple.timemachine.donotpresent
|
||||||
|
|
||||||
|
# Directories potentially created on remote AFP share
|
||||||
|
.AppleDB
|
||||||
|
.AppleDesktop
|
||||||
|
Network Trash Folder
|
||||||
|
Temporary Items
|
||||||
|
.apdisk
|
||||||
|
|
||||||
|
### Python template
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
149
01_dataset_building.ipynb
Executable file
149
01_dataset_building.ipynb
Executable file
@ -0,0 +1,149 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# 数据集的制作"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pickle\n",
|
||||||
|
"import cv2\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from utils import split_xy"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# 一些参数\n",
|
||||||
|
"blk_sz = 8\n",
|
||||||
|
"sensitivity = 8\n",
|
||||||
|
"selected_bands = [127, 201, 202, 294]\n",
|
||||||
|
"# [76, 146, 216, 367, 383, 406]\n",
|
||||||
|
"file_name, labeled_image_file = r\"F:\\zhouchao\\616\\calibrated1.raw\", \\\n",
|
||||||
|
"r\"F:\\zhouchao\\616\\label1.bmp\"\n",
|
||||||
|
"# file_name, labeled_image_file = \"./dataset/calibrated77.raw\", \"./dataset/label77.png\"\n",
|
||||||
|
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_1.p'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## 波长选择"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open(file_name, \"rb\") as f:\n",
|
||||||
|
" data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, 448, 1024)).transpose(0, 2, 1)\n",
|
||||||
|
"data = data[..., selected_bands]\n",
|
||||||
|
"label = cv2.imread(labeled_image_file)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## 块分割与数据存储"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
|
||||||
|
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
|
||||||
|
"x, y = split_xy(data, label, blk_sz, sensitivity=sensitivity, color_dict=color_dict)\n",
|
||||||
|
"with open(dataset_file, 'wb') as f:\n",
|
||||||
|
" pickle.dump((x, y), f)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.9.7 ('base')",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.7"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "7f619fc91ee8bdab81d49e7c14228037474662e3f2d607687ae505108922fa06"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
471
02_feature_design&training.ipynb
Executable file
471
02_feature_design&training.ipynb
Executable file
File diff suppressed because one or more lines are too long
301
03_data_update.ipynb
Executable file
301
03_data_update.ipynb
Executable file
@ -0,0 +1,301 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# 数据集扩充"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"虽然当前的模型已经能够达到较好的效果,但是还不够好,对于一些较老的烟梗不能够做到有效的判别,我们为此增加数据集。"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"import cv2\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"from utils import read_raw_file, split_xy, generate_tobacco_label, generate_impurity_label\n",
|
||||||
|
"from models import SpecDetector\n",
|
||||||
|
"import pickle"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# some parameters\n",
|
||||||
|
"new_spectra_file = r\"F:\\zhouchao\\618\\bomo\\calibrated7.raw\"\n",
|
||||||
|
"new_label_file = r\"F:\\zhouchao\\618\\bomo\\picture\\label7.bmp\"\n",
|
||||||
|
"\n",
|
||||||
|
"target_class = 0\n",
|
||||||
|
"target_class_left, target_class_right = 5, 4\n",
|
||||||
|
"light_threshold = 0.5\n",
|
||||||
|
"add_background = False\n",
|
||||||
|
"\n",
|
||||||
|
"split_line = 500\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"blk_sz, sensitivity = 8, 8\n",
|
||||||
|
"selected_bands = [127, 201, 202, 294]\n",
|
||||||
|
"tree_num = 185\n",
|
||||||
|
"\n",
|
||||||
|
"pic_row, pic_col= 600, 1024\n",
|
||||||
|
"\n",
|
||||||
|
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
|
||||||
|
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
|
||||||
|
"\n",
|
||||||
|
"new_dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_7.p'\n",
|
||||||
|
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_6.p'\n",
|
||||||
|
"\n",
|
||||||
|
"model_file = f'./models/rf_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_6.model'\n",
|
||||||
|
"selected_bands = None"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# if len(new_spectra_files) == 1:\n",
|
||||||
|
"data = read_raw_file(new_spectra_file, selected_bands)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## 烟梗标签生成\n",
|
||||||
|
"这会将纯烟梗图片中识别为杂质的部分提取出来"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"x_list, y_list = [], []\n",
|
||||||
|
"if (new_label_file is None) and (target_class == 1):\n",
|
||||||
|
" x_list, y_list = generate_tobacco_label(data, model_file, blk_sz, selected_bands)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## 其他类别杂质阈值分割\n",
|
||||||
|
"通过阈值分割的形式获取其他类别的杂质"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"if (new_label_file is None) and (target_class != 1):\n",
|
||||||
|
" img = generate_impurity_label(data, light_threshold, color_dict,\n",
|
||||||
|
" target_class_right=target_class_right,\n",
|
||||||
|
" target_class_left=target_class_left,\n",
|
||||||
|
" split_line=split_line)\n",
|
||||||
|
" root, _ = os.path.splitext(new_dataset_file)\n",
|
||||||
|
" cv2.imwrite(root+\"_generated.bmp\", img)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## 读取标签"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"(600, 1024, 3)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"if new_label_file is not None:\n",
|
||||||
|
" label = cv2.imread(new_label_file)\n",
|
||||||
|
" print(label.shape)\n",
|
||||||
|
" x_list, y_list = split_xy(data, label, blk_sz, sensitivity=sensitivity, color_dict=color_dict, add_background=add_background)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"148 148\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(len(x_list), len(y_list))"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## 读取旧数据合并"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open(dataset_file, 'rb') as f:\n",
|
||||||
|
" x, y = pickle.load(f)\n",
|
||||||
|
"x.extend(x_list)\n",
|
||||||
|
"y.extend(y_list)\n",
|
||||||
|
"with open(new_dataset_file, 'wb') as f:\n",
|
||||||
|
" pickle.dump((x, y), f)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## 批量数据的处理"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
398
04_performance_tune.ipynb
Executable file
398
04_performance_tune.ipynb
Executable file
File diff suppressed because one or more lines are too long
225
05_evaluation.ipynb
Executable file
225
05_evaluation.ipynb
Executable file
File diff suppressed because one or more lines are too long
34
main.py
Executable file
34
main.py
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from models import SpecDetector
|
||||||
|
from root_dir import ROOT_DIR
|
||||||
|
|
||||||
|
nrows, ncols, nbands = 600, 1024, 4
|
||||||
|
img_fifo_path = "/tmp/dkimg.fifo"
|
||||||
|
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||||
|
selected_model = "rf_8x8_c4_400_13.model"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model_path = os.path.join(ROOT_DIR, "models", selected_model)
|
||||||
|
detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
|
||||||
|
total_len = nrows * ncols * nbands * 4
|
||||||
|
if not os.access(img_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(img_fifo_path, 0o777)
|
||||||
|
if not os.access(mask_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(mask_fifo_path, 0o777)
|
||||||
|
|
||||||
|
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||||
|
print("connect to fifo")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = os.read(fd_img, total_len)
|
||||||
|
print("get img")
|
||||||
|
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
|
||||||
|
mask = detector.predict(img)
|
||||||
|
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||||
|
os.write(fd_mask, mask.tobytes())
|
||||||
|
os.close(fd_mask)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
142
models.py
Executable file
142
models.py
Executable file
@ -0,0 +1,142 @@
|
|||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
||||||
|
|
||||||
|
|
||||||
|
# def feature(x):
|
||||||
|
# x = x.reshape((x.shape[0], -1, x.shape[-1]))
|
||||||
|
# x = np.mean(x, axis=1)
|
||||||
|
# return x
|
||||||
|
|
||||||
|
def feature(x):
|
||||||
|
x = x.reshape((x.shape[0], -1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def train_rf_and_report(train_x, train_y, test_x, test_y,
|
||||||
|
tree_num, save_path=None):
|
||||||
|
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42, class_weight={0:10, 1:10})
|
||||||
|
rfc = rfc.fit(train_x, train_y)
|
||||||
|
t1 = time.time()
|
||||||
|
y_pred = rfc.predict(test_x)
|
||||||
|
y_pred_binary = np.ones_like(y_pred)
|
||||||
|
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||||
|
y_pred_binary[(y_pred >1)] = 2
|
||||||
|
test_y_binary = np.ones_like(test_y)
|
||||||
|
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||||
|
test_y_binary[(test_y >1) ] = 2
|
||||||
|
print("预测时间:", time.time() - t1)
|
||||||
|
print("RFC训练模型评分:" + str(accuracy_score(train_y, rfc.predict(train_x))))
|
||||||
|
print("RFC待测模型评分:" + str(accuracy_score(test_y, rfc.predict(test_x))))
|
||||||
|
print('RFC预测结果:' + str(y_pred))
|
||||||
|
print('---------------------------------------------------------------------------------------------------')
|
||||||
|
print('RFC分类报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||||
|
print('RFC混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
print('rfc分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||||
|
print('rfc混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
if save_path is not None:
|
||||||
|
with open(save_path, 'wb') as f:
|
||||||
|
pickle.dump(rfc, f)
|
||||||
|
return rfc
|
||||||
|
|
||||||
|
|
||||||
|
def evaluation_and_report(model, test_x, test_y):
|
||||||
|
t1 = time.time()
|
||||||
|
y_pred = model.predict(test_x)
|
||||||
|
y_pred_binary = np.ones_like(y_pred)
|
||||||
|
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||||
|
y_pred_binary[(y_pred >1)] = 2
|
||||||
|
test_y_binary = np.ones_like(test_y)
|
||||||
|
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||||
|
test_y_binary[(test_y >1) ] = 2
|
||||||
|
print("预测时间:", time.time() - t1)
|
||||||
|
print("RFC待测模型 accuracy:" + str(accuracy_score(test_y, model.predict(test_x))))
|
||||||
|
print('RFC预测结果:' + str(y_pred))
|
||||||
|
print('---------------------------------------------------------------------------------------------------')
|
||||||
|
print('RFC分类报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||||
|
print('RFC混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
print('rfc分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||||
|
print('rfc混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
|
||||||
|
|
||||||
|
def train_pca_rf(train_x, train_y, test_x, test_y, n_comp,
|
||||||
|
tree_num, save_path=None):
|
||||||
|
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42,class_weight={0:100, 1:100})
|
||||||
|
pca = PCA(n_components=0.95)
|
||||||
|
rfc = rfc.fit(train_x, train_y)
|
||||||
|
t1 = time.time()
|
||||||
|
y_pred = rfc.predict(test_x)
|
||||||
|
y_pred_binary = np.ones_like(y_pred)
|
||||||
|
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
|
||||||
|
y_pred_binary[(y_pred == 2) | (y_pred == 3) | (y_pred == 4)] = 2
|
||||||
|
test_y_binary = np.ones_like(test_y)
|
||||||
|
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
|
||||||
|
test_y_binary[(test_y == 2) | (test_y == 3) | (test_y == 4)] = 2
|
||||||
|
print("预测时间:", time.time() - t1)
|
||||||
|
print("RFC训练模型评分:" + str(accuracy_score(train_y, rfc.predict(train_x))))
|
||||||
|
print("RFC待测模型评分:" + str(accuracy_score(test_y, rfc.predict(test_x))))
|
||||||
|
print('RFC预测结果:' + str(y_pred))
|
||||||
|
print('---------------------------------------------------------------------------------------------------')
|
||||||
|
print('RFC分类报告:\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
|
||||||
|
print('RFC混淆矩阵:\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
print('rfc分类报告:\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
|
||||||
|
print('rfc混淆矩阵:\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
|
||||||
|
if save_path is not None:
|
||||||
|
with open(save_path, 'wb') as f:
|
||||||
|
pickle.dump((pca, rfc), f)
|
||||||
|
return pca, rfc
|
||||||
|
|
||||||
|
|
||||||
|
def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||||
|
"""
|
||||||
|
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||||
|
|
||||||
|
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||||
|
;param blk_sz: block size
|
||||||
|
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||||
|
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||||
|
"""
|
||||||
|
x_list = []
|
||||||
|
for i in range(0, 600 // blk_sz):
|
||||||
|
for j in range(0, 1024 // blk_sz):
|
||||||
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
|
x_list.append(block_data)
|
||||||
|
return x_list
|
||||||
|
|
||||||
|
|
||||||
|
class SpecDetector(object):
|
||||||
|
def __init__(self, model_path, blk_sz=8, channel_num=4):
|
||||||
|
self.blk_sz, self.channel_num = blk_sz, channel_num
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
with open(model_path, "rb") as model_file:
|
||||||
|
self.clf = pickle.load(model_file)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError("Model File not found")
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
blocks = split_x(data, blk_sz=self.blk_sz)
|
||||||
|
blocks = np.array(blocks)
|
||||||
|
features = feature(np.array(blocks))
|
||||||
|
y_pred = self.clf.predict(features)
|
||||||
|
y_pred_binary = np.ones_like(y_pred)
|
||||||
|
# classes merge
|
||||||
|
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
|
||||||
|
# transform to mask
|
||||||
|
mask = self.mask_transform(y_pred_binary, (1024, 600))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def mask_transform(self, result, dst_size):
|
||||||
|
mask_size = 600//self.blk_sz, 1024 // self.blk_sz
|
||||||
|
mask = np.zeros(mask_size, dtype=np.uint8)
|
||||||
|
for idx, r in enumerate(result):
|
||||||
|
row, col = idx // mask_size[1], idx % mask_size[1]
|
||||||
|
mask[row, col] = r
|
||||||
|
mask = mask.repeat(self.blk_sz, axis = 0).repeat(self.blk_sz, axis = 1)
|
||||||
|
return mask
|
||||||
10
root_dir.py
Executable file
10
root_dir.py
Executable file
@ -0,0 +1,10 @@
|
|||||||
|
# -*- codeing = utf-8 -*-
|
||||||
|
# Time : 2022/6/18 9:32
|
||||||
|
# @Auther : zhouchao
|
||||||
|
# @File: root_dir.py
|
||||||
|
# @Software:PyCharm
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
ROOT_DIR = os.path.split(sys.argv[0])[0]
|
||||||
|
|
||||||
19
test_files/models_test.py
Executable file
19
test_files/models_test.py
Executable file
@ -0,0 +1,19 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from models import feature
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTestCase(unittest.TestCase):
|
||||||
|
def test_feature(self):
|
||||||
|
x_list = [np.ones((8, 8, 6)) * i for i in range(9600)]
|
||||||
|
features = feature(x_list=x_list)
|
||||||
|
self.assertEqual(features[0][0], 0) # add assertion here
|
||||||
|
self.assertEqual(features[0][5], 0) # add assertion here
|
||||||
|
self.assertEqual(features[-1][5], 9599) # add assertion here
|
||||||
|
self.assertEqual(features[-1][0], 9599) # add assertion here
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
119
test_files/utils_test.py
Executable file
119
test_files/utils_test.py
Executable file
@ -0,0 +1,119 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from utils import determine_class, split_xy, split_x
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(unittest.TestCase):
|
||||||
|
def test_determine_class(self):
|
||||||
|
pixel_block = np.zeros((8, 8), dtype=np.uint8)
|
||||||
|
pixel_block[2: 4, 5: 6] = 2
|
||||||
|
pixel_block[5: 7, 1: 2] = 1
|
||||||
|
pixel_block[4: 6, 1: 6] = 3
|
||||||
|
cls = determine_class(pixel_block, sensitivity=8)
|
||||||
|
self.assertEqual(cls, 3)
|
||||||
|
pixel_block = np.zeros((8, 8), dtype=np.uint8)
|
||||||
|
pixel_block[2: 4, 5: 6] = 2
|
||||||
|
pixel_block[5: 7, 1: 2] = 1
|
||||||
|
pixel_block[4: 6, 1: 6] = 2
|
||||||
|
cls = determine_class(pixel_block, sensitivity=8)
|
||||||
|
self.assertEqual(cls, 2)
|
||||||
|
pixel_block = np.zeros((8, 8), dtype=np.uint8)
|
||||||
|
pixel_block[2: 4, 5: 6] = 1
|
||||||
|
pixel_block[5: 7, 1: 2] = 2
|
||||||
|
pixel_block[4: 6, 1: 6] = 1
|
||||||
|
cls = determine_class(pixel_block, sensitivity=8)
|
||||||
|
self.assertEqual(cls, 1)
|
||||||
|
|
||||||
|
def test_split_xy(self):
|
||||||
|
x = np.arange(600*1024).reshape((600, 1024))
|
||||||
|
y = np.zeros((600, 1024, 3))
|
||||||
|
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
|
||||||
|
trans_color_dict = {v: k for k, v in color_dict.items()}
|
||||||
|
# modify the first block
|
||||||
|
y[2: 4, 5: 6] = trans_color_dict[2]
|
||||||
|
y[5: 7, 1: 2] = trans_color_dict[1]
|
||||||
|
y[4: 6, 1: 6] = trans_color_dict[3]
|
||||||
|
# modify the last block
|
||||||
|
y[-4: -2, -6: -5] = trans_color_dict[1]
|
||||||
|
y[-7: -5, -2: -1] = trans_color_dict[2]
|
||||||
|
y[-6: -4, -6: -1] = trans_color_dict[1]
|
||||||
|
# modify the middle block
|
||||||
|
y[64+2: 64+4, 64+5: 64+6] = trans_color_dict[2]
|
||||||
|
y[64+5: 64+7, 64+1: 64+2] = trans_color_dict[1]
|
||||||
|
y[64+4: 64+6, 64+1: 64+6] = trans_color_dict[2]
|
||||||
|
x_list, y_list = split_xy(x, y, blk_sz=8, sensitivity=8)
|
||||||
|
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
|
||||||
|
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
|
||||||
|
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
|
||||||
|
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
|
||||||
|
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
|
||||||
|
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
|
||||||
|
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
|
||||||
|
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175],])
|
||||||
|
sum_value = np.sum(np.sum(x_list[0]-first_block))
|
||||||
|
self.assertEqual(sum_value, 0)
|
||||||
|
self.assertEqual(y_list[0], 3)
|
||||||
|
last_block = np.array(
|
||||||
|
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
|
||||||
|
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
|
||||||
|
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
|
||||||
|
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
|
||||||
|
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
|
||||||
|
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
|
||||||
|
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
|
||||||
|
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399],])
|
||||||
|
sum_value = np.sum(np.sum(x_list[-1]-last_block))
|
||||||
|
self.assertEqual(sum_value, 0)
|
||||||
|
self.assertEqual(y_list[-1], 1)
|
||||||
|
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
|
||||||
|
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
|
||||||
|
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
|
||||||
|
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
|
||||||
|
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
|
||||||
|
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
|
||||||
|
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
|
||||||
|
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
|
||||||
|
sum_value = np.sum(np.sum(x_list[1032]-middle_block))
|
||||||
|
self.assertEqual(sum_value, 0)
|
||||||
|
self.assertEqual(y_list[1032], 2)
|
||||||
|
|
||||||
|
def test_split_x(self):
|
||||||
|
x = np.arange(600 * 1024).reshape((600, 1024))
|
||||||
|
x_list = split_x(x, blk_sz=8)
|
||||||
|
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
|
||||||
|
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
|
||||||
|
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
|
||||||
|
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
|
||||||
|
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
|
||||||
|
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
|
||||||
|
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
|
||||||
|
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175], ])
|
||||||
|
sum_value = np.sum(np.sum(x_list[0] - first_block))
|
||||||
|
self.assertEqual(sum_value, 0)
|
||||||
|
last_block = np.array(
|
||||||
|
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
|
||||||
|
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
|
||||||
|
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
|
||||||
|
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
|
||||||
|
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
|
||||||
|
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
|
||||||
|
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
|
||||||
|
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399], ])
|
||||||
|
sum_value = np.sum(np.sum(x_list[-1] - last_block))
|
||||||
|
self.assertEqual(sum_value, 0)
|
||||||
|
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
|
||||||
|
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
|
||||||
|
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
|
||||||
|
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
|
||||||
|
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
|
||||||
|
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
|
||||||
|
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
|
||||||
|
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
|
||||||
|
sum_value = np.sum(np.sum(x_list[1032] - middle_block))
|
||||||
|
self.assertEqual(sum_value, 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
190
utils.py
Executable file
190
utils.py
Executable file
@ -0,0 +1,190 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from models import SpecDetector
|
||||||
|
|
||||||
|
|
||||||
|
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
|
||||||
|
"""
|
||||||
|
将label转为类别
|
||||||
|
|
||||||
|
:param pixel: 一个 n x n 的像素块
|
||||||
|
:param color_dict: 用于转化的字典 {(0, 0, 255): 1, ....} 色彩采用bgr
|
||||||
|
:return:类别白噢好
|
||||||
|
"""
|
||||||
|
# 0 表示的是背景, 1表示的是烟梗,剩下的都是杂质
|
||||||
|
if color_dict is None:
|
||||||
|
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
|
||||||
|
if (pixel[0], pixel[1], pixel[2]) in color_dict.keys():
|
||||||
|
return color_dict[(pixel[0], pixel[1], pixel[2])]
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
|
||||||
|
"""
|
||||||
|
决定像素块的类别
|
||||||
|
|
||||||
|
:param pixel_blk: 像素块
|
||||||
|
:param sensitivity: 敏感度
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
|
||||||
|
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
|
||||||
|
for cls in defect_dict.keys()}
|
||||||
|
grant_cls = {0: 0, 1: 0}
|
||||||
|
for cls, num in color_numbers.items():
|
||||||
|
grant_cls[defect_dict[cls]] += num
|
||||||
|
if grant_cls[1] >= sensitivity:
|
||||||
|
color_numbers = {cls: color_numbers[cls] for cls in [2, 3, 4, 5, 6]}
|
||||||
|
return max(color_numbers, key=color_numbers.get)
|
||||||
|
else:
|
||||||
|
if color_numbers[1] >= sensitivity:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity: int = 12,
|
||||||
|
color_dict=None, add_background=True) -> tuple:
|
||||||
|
"""
|
||||||
|
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||||
|
|
||||||
|
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||||
|
;param labeled_img: RGB labeled img with respect to the image!
|
||||||
|
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
|
||||||
|
;param blk_sz: block size
|
||||||
|
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||||
|
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||||
|
data y (block_num, ) 1 是杂质, 0是无杂质
|
||||||
|
"""
|
||||||
|
assert (data.shape[0] == labeled_img.shape[0]) and (data.shape[1] == labeled_img.shape[1])
|
||||||
|
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}\
|
||||||
|
if color_dict is None else color_dict
|
||||||
|
class_img = np.zeros((labeled_img.shape[0], labeled_img.shape[1]), dtype=int)
|
||||||
|
for color, class_idx in color_dict.items():
|
||||||
|
truth_map = np.all(labeled_img == color, axis=2)
|
||||||
|
class_img[truth_map] = class_idx
|
||||||
|
x_list, y_list = [], []
|
||||||
|
for i in range(0, 600 // blk_sz):
|
||||||
|
for j in range(0, 1024 // blk_sz):
|
||||||
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
|
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
|
block_label = determine_class(block_label, sensitivity=sensitivity)
|
||||||
|
if add_background:
|
||||||
|
y_list.append(block_label)
|
||||||
|
x_list.append(block_data)
|
||||||
|
else:
|
||||||
|
if block_label != 0:
|
||||||
|
y_list.append(block_label)
|
||||||
|
x_list.append(block_data)
|
||||||
|
return x_list, y_list
|
||||||
|
|
||||||
|
|
||||||
|
def split_x(data: np.ndarray, blk_sz: int) -> list:
|
||||||
|
"""
|
||||||
|
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
|
||||||
|
|
||||||
|
;param data: image data, shape (num_rows x 1024 x num_channels)
|
||||||
|
;param blk_sz: block size
|
||||||
|
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
|
||||||
|
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
|
||||||
|
"""
|
||||||
|
x_list = []
|
||||||
|
for i in range(0, 600 // blk_sz):
|
||||||
|
for j in range(0, 1024 // blk_sz):
|
||||||
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
|
x_list.append(block_data)
|
||||||
|
return x_list
|
||||||
|
|
||||||
|
|
||||||
|
def visualization_evaluation(detector, data_path, selected_bands=None):
|
||||||
|
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
|
||||||
|
nrows, ncols = 600, 1024
|
||||||
|
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
|
||||||
|
for idx, image_path in enumerate(image_paths):
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
data = f.read()
|
||||||
|
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)).transpose(0, 2, 1)
|
||||||
|
nbands = img.shape[2]
|
||||||
|
t1 = time.time()
|
||||||
|
mask = detector.predict(img[..., selected_bands] if nbands == 448 else img)
|
||||||
|
time_spent = time.time() - t1
|
||||||
|
if nbands == 448:
|
||||||
|
rgb_img = np.asarray(img[..., [372, 241, 169]] * 255, dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
rgb_img = np.asarray(img[..., [0, 1, 2]] * 255, dtype=np.uint8)
|
||||||
|
fig, axs = plt.subplots(1, 2)
|
||||||
|
axs[0].imshow(rgb_img)
|
||||||
|
axs[1].imshow(mask)
|
||||||
|
fig.suptitle(f"time spent {time_spent*1000:.2f} ms" + f"\n{image_path}")
|
||||||
|
plt.savefig(f"./dataset/{idx}.png", dpi=300)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def visualization_y(y_list, k_size):
|
||||||
|
mask = np.zeros((600//k_size, 1024//k_size), dtype=np.uint8)
|
||||||
|
for idx, r in enumerate(y_list):
|
||||||
|
row, col = idx // (1024 // k_size), idx % (1024 // k_size)
|
||||||
|
mask[row, col] = r
|
||||||
|
fig, axs = plt.subplots()
|
||||||
|
axs.imshow(mask)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def read_raw_file(file_name, selected_bands=None):
|
||||||
|
with open(file_name, "rb") as f:
|
||||||
|
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, 1024)).transpose(0, 2, 1)
|
||||||
|
if selected_bands is not None:
|
||||||
|
data = data[..., selected_bands]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def read_black_and_white_file(file_name):
|
||||||
|
with open(file_name, "rb") as f:
|
||||||
|
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, 1024)).transpose(0, 2, 1)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def label2pic(label, color_dict):
|
||||||
|
pic = np.zeros((label.shape[0], label.shape[1], 3))
|
||||||
|
for color, cls in color_dict.items():
|
||||||
|
pic[label == cls] = color
|
||||||
|
return pic
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
|
||||||
|
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
|
||||||
|
y_label = model.predict(data)
|
||||||
|
x_list, y_list = [], []
|
||||||
|
for i in range(0, 600 // blk_sz):
|
||||||
|
for j in range(0, 1024 // blk_sz):
|
||||||
|
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
|
||||||
|
> 0:
|
||||||
|
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
|
||||||
|
x_list.append(block_data)
|
||||||
|
y_list.append(1)
|
||||||
|
return x_list, y_list
|
||||||
|
|
||||||
|
|
||||||
|
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
|
||||||
|
target_class_left=None,):
|
||||||
|
y_label = np.zeros((data.shape[0], data.shape[1]))
|
||||||
|
for i in range(0, 600):
|
||||||
|
for j in range(0, 1024):
|
||||||
|
if np.sum(np.sum(data[i, j])) >= light_threshold:
|
||||||
|
if j > split_line:
|
||||||
|
y_label[i, j] = target_class_right
|
||||||
|
else:
|
||||||
|
y_label[i, j] = target_class_left
|
||||||
|
pic = label2pic(y_label, color_dict=color_dict)
|
||||||
|
fig, axs = plt.subplots(2, 1)
|
||||||
|
axs[0].matshow(y_label)
|
||||||
|
axs[1].matshow(data[..., 0])
|
||||||
|
plt.show()
|
||||||
|
return pic
|
||||||
Loading…
Reference in New Issue
Block a user