First Commit

This commit is contained in:
karllzy 2022-06-18 19:13:39 +08:00
commit da5fde8f40
12 changed files with 2635 additions and 0 deletions

577
.gitignore vendored Normal file
View File

@ -0,0 +1,577 @@
### Project sepecified
dataset/*
models/*
### VisualStudioCode template
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
# Local History for Visual Studio Code
.history/
### VirtualEnv template
# Virtualenv
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
.Python
[Bb]in
[Ii]nclude
[Ll]ib
[Ll]ib64
[Ll]ocal
[Ss]cripts
pyvenv.cfg
.venv
pip-selfcheck.json
### Example user template template
### Example user template
# IntelliJ project files
.idea
*.iml
out
gen
### VisualStudio template
## Ignore Visual Studio temporary files, build results, and
## files generated by popular Visual Studio add-ons.
##
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
# User-specific files
*.rsuser
*.suo
*.user
*.userosscache
*.sln.docstates
# User-specific files (MonoDevelop/Xamarin Studio)
*.userprefs
# Mono auto generated files
mono_crash.*
# Build results
[Dd]ebug/
[Dd]ebugPublic/
[Rr]elease/
[Rr]eleases/
x64/
x86/
[Ww][Ii][Nn]32/
[Aa][Rr][Mm]/
[Aa][Rr][Mm]64/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
[Ll]ogs/
# Visual Studio 2015/2017 cache/options directory
.vs/
# Uncomment if you have tasks that create the project's static files in wwwroot
#wwwroot/
# Visual Studio 2017 auto generated files
Generated\ Files/
# MSTest test Results
[Tt]est[Rr]esult*/
[Bb]uild[Ll]og.*
# NUnit
*.VisualState.xml
TestResult.xml
nunit-*.xml
# Build Results of an ATL Project
[Dd]ebugPS/
[Rr]eleasePS/
dlldata.c
# Benchmark Results
BenchmarkDotNet.Artifacts/
# .NET Core
project.lock.json
project.fragment.lock.json
artifacts/
# ASP.NET Scaffolding
ScaffoldingReadMe.txt
# StyleCop
StyleCopReport.xml
# Files built by Visual Studio
*_i.c
*_p.c
*_h.h
*.ilk
*.meta
*.obj
*.iobj
*.pch
*.pdb
*.ipdb
*.pgc
*.pgd
*.rsp
*.sbr
*.tlb
*.tli
*.tlh
*.tmp
*.tmp_proj
*_wpftmp.csproj
*.log
*.vspscc
*.vssscc
.builds
*.pidb
*.svclog
*.scc
# Chutzpah Test files
_Chutzpah*
# Visual C++ cache files
ipch/
*.aps
*.ncb
*.opendb
*.opensdf
*.sdf
*.cachefile
*.VC.db
*.VC.VC.opendb
# Visual Studio profiler
*.psess
*.vsp
*.vspx
*.sap
# Visual Studio Trace Files
*.e2e
# TFS 2012 Local Workspace
$tf/
# Guidance Automation Toolkit
*.gpState
# ReSharper is a .NET coding add-in
_ReSharper*/
*.[Rr]e[Ss]harper
*.DotSettings.user
# TeamCity is a build add-in
_TeamCity*
# DotCover is a Code Coverage Tool
*.dotCover
# AxoCover is a Code Coverage Tool
.axoCover/*
!.axoCover/settings.json
# Coverlet is a free, cross platform Code Coverage Tool
coverage*.json
coverage*.xml
coverage*.info
# Visual Studio code coverage results
*.coverage
*.coveragexml
# NCrunch
_NCrunch_*
.*crunch*.local.xml
nCrunchTemp_*
# MightyMoose
*.mm.*
AutoTest.Net/
# Web workbench (sass)
.sass-cache/
# Installshield output folder
[Ee]xpress/
# DocProject is a documentation generator add-in
DocProject/buildhelp/
DocProject/Help/*.HxT
DocProject/Help/*.HxC
DocProject/Help/*.hhc
DocProject/Help/*.hhk
DocProject/Help/*.hhp
DocProject/Help/Html2
DocProject/Help/html
# Click-Once directory
publish/
# Publish Web Output
*.[Pp]ublish.xml
*.azurePubxml
# Note: Comment the next line if you want to checkin your web deploy settings,
# but database connection strings (with potential passwords) will be unencrypted
*.pubxml
*.publishproj
# Microsoft Azure Web App publish settings. Comment the next line if you want to
# checkin your Azure Web App publish settings, but sensitive information contained
# in these scripts will be unencrypted
PublishScripts/
# NuGet Packages
*.nupkg
# NuGet Symbol Packages
*.snupkg
# The packages folder can be ignored because of Package Restore
**/[Pp]ackages/*
# except build/, which is used as an MSBuild target.
!**/[Pp]ackages/build/
# Uncomment if necessary however generally it will be regenerated when needed
#!**/[Pp]ackages/repositories.config
# NuGet v3's project.json files produces more ignorable files
*.nuget.props
*.nuget.targets
# Microsoft Azure Build Output
csx/
*.build.csdef
# Microsoft Azure Emulator
ecf/
rcf/
# Windows Store app package directories and files
AppPackages/
BundleArtifacts/
Package.StoreAssociation.xml
_pkginfo.txt
*.appx
*.appxbundle
*.appxupload
# Visual Studio cache files
# files ending in .cache can be ignored
*.[Cc]ache
# but keep track of directories ending in .cache
!?*.[Cc]ache/
# Others
ClientBin/
~$*
*~
*.dbmdl
*.dbproj.schemaview
*.jfm
*.pfx
*.publishsettings
orleans.codegen.cs
# Including strong name files can present a security risk
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
#*.snk
# Since there are multiple workflows, uncomment next line to ignore bower_components
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
#bower_components/
# RIA/Silverlight projects
Generated_Code/
# Backup & report files from converting an old project file
# to a newer Visual Studio version. Backup files are not needed,
# because we have git ;-)
_UpgradeReport_Files/
Backup*/
UpgradeLog*.XML
UpgradeLog*.htm
ServiceFabricBackup/
*.rptproj.bak
# SQL Server files
*.mdf
*.ldf
*.ndf
# Business Intelligence projects
*.rdl.data
*.bim.layout
*.bim_*.settings
*.rptproj.rsuser
*- [Bb]ackup.rdl
*- [Bb]ackup ([0-9]).rdl
*- [Bb]ackup ([0-9][0-9]).rdl
# Microsoft Fakes
FakesAssemblies/
# GhostDoc plugin setting file
*.GhostDoc.xml
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
node_modules/
# Visual Studio 6 build log
*.plg
# Visual Studio 6 workspace options file
*.opt
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
*.vbw
# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
**/*.DesktopClient/ModelManifest.xml
**/*.Server/GeneratedArtifacts
**/*.Server/ModelManifest.xml
_Pvt_Extensions
# Paket dependency manager
.paket/paket.exe
paket-files/
# FAKE - F# Make
.fake/
# CodeRush personal settings
.cr/personal
# Python Tools for Visual Studio (PTVS)
__pycache__/
*.pyc
# Cake - Uncomment if you are using it
# tools/**
# !tools/packages.config
# Tabs Studio
*.tss
# Telerik's JustMock configuration file
*.jmconfig
# BizTalk build output
*.btp.cs
*.btm.cs
*.odx.cs
*.xsd.cs
# OpenCover UI analysis results
OpenCover/
# Azure Stream Analytics local run output
ASALocalRun/
# MSBuild Binary and Structured Log
*.binlog
# NVidia Nsight GPU debugger configuration file
*.nvuser
# MFractors (Xamarin productivity tool) working folder
.mfractor/
# Local History for Visual Studio
.localhistory/
# BeatPulse healthcheck temp database
healthchecksdb
# Backup folder for Package Reference Convert tool in Visual Studio 2017
MigrationBackup/
# Ionide (cross platform F# VS Code tools) working folder
.ionide/
# Fody - auto-generated XML schema
FodyWeavers.xsd
### JupyterNotebooks template
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/
.ipynb_checkpoints
*/.ipynb_checkpoints/*
# IPython
profile_default/
ipython_config.py
# Remove previous ipynb_checkpoints
# git rm -r .ipynb_checkpoints/
### macOS template
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### Python template
# Byte-compiled / optimized / DLL files
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
# IPython
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

149
01_dataset_building.ipynb Executable file
View File

@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# 数据集的制作"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import pickle\n",
"import cv2\n",
"import numpy as np\n",
"from utils import split_xy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# 一些参数\n",
"blk_sz = 8\n",
"sensitivity = 8\n",
"selected_bands = [127, 201, 202, 294]\n",
"# [76, 146, 216, 367, 383, 406]\n",
"file_name, labeled_image_file = r\"F:\\zhouchao\\616\\calibrated1.raw\", \\\n",
"r\"F:\\zhouchao\\616\\label1.bmp\"\n",
"# file_name, labeled_image_file = \"./dataset/calibrated77.raw\", \"./dataset/label77.png\"\n",
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_1.p'"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 波长选择"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"with open(file_name, \"rb\") as f:\n",
" data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, 448, 1024)).transpose(0, 2, 1)\n",
"data = data[..., selected_bands]\n",
"label = cv2.imread(labeled_image_file)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## 块分割与数据存储"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
"x, y = split_xy(data, label, blk_sz, sensitivity=sensitivity, color_dict=color_dict)\n",
"with open(dataset_file, 'wb') as f:\n",
" pickle.dump((x, y), f)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"vscode": {
"interpreter": {
"hash": "7f619fc91ee8bdab81d49e7c14228037474662e3f2d607687ae505108922fa06"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}

471
02_feature_design&training.ipynb Executable file

File diff suppressed because one or more lines are too long

301
03_data_update.ipynb Executable file
View File

@ -0,0 +1,301 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# 数据集扩充"
]
},
{
"cell_type": "markdown",
"source": [
"虽然当前的模型已经能够达到较好的效果,但是还不够好,对于一些较老的烟梗不能够做到有效的判别,我们为此增加数据集。"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import os\n",
"\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from utils import read_raw_file, split_xy, generate_tobacco_label, generate_impurity_label\n",
"from models import SpecDetector\n",
"import pickle"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# some parameters\n",
"new_spectra_file = r\"F:\\zhouchao\\618\\bomo\\calibrated7.raw\"\n",
"new_label_file = r\"F:\\zhouchao\\618\\bomo\\picture\\label7.bmp\"\n",
"\n",
"target_class = 0\n",
"target_class_left, target_class_right = 5, 4\n",
"light_threshold = 0.5\n",
"add_background = False\n",
"\n",
"split_line = 500\n",
"\n",
"\n",
"blk_sz, sensitivity = 8, 8\n",
"selected_bands = [127, 201, 202, 294]\n",
"tree_num = 185\n",
"\n",
"pic_row, pic_col= 600, 1024\n",
"\n",
"color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n",
" (255, 255, 0): 5, (255, 0, 255): 6}\n",
"\n",
"new_dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_7.p'\n",
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_6.p'\n",
"\n",
"model_file = f'./models/rf_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_6.model'\n",
"selected_bands = None"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# if len(new_spectra_files) == 1:\n",
"data = read_raw_file(new_spectra_file, selected_bands)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 烟梗标签生成\n",
"这会将纯烟梗图片中识别为杂质的部分提取出来"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"x_list, y_list = [], []\n",
"if (new_label_file is None) and (target_class == 1):\n",
" x_list, y_list = generate_tobacco_label(data, model_file, blk_sz, selected_bands)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 其他类别杂质阈值分割\n",
"通过阈值分割的形式获取其他类别的杂质"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"if (new_label_file is None) and (target_class != 1):\n",
" img = generate_impurity_label(data, light_threshold, color_dict,\n",
" target_class_right=target_class_right,\n",
" target_class_left=target_class_left,\n",
" split_line=split_line)\n",
" root, _ = os.path.splitext(new_dataset_file)\n",
" cv2.imwrite(root+\"_generated.bmp\", img)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 读取标签"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(600, 1024, 3)\n"
]
}
],
"source": [
"if new_label_file is not None:\n",
" label = cv2.imread(new_label_file)\n",
" print(label.shape)\n",
" x_list, y_list = split_xy(data, label, blk_sz, sensitivity=sensitivity, color_dict=color_dict, add_background=add_background)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"148 148\n"
]
}
],
"source": [
"print(len(x_list), len(y_list))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 读取旧数据合并"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"with open(dataset_file, 'rb') as f:\n",
" x, y = pickle.load(f)\n",
"x.extend(x_list)\n",
"y.extend(y_list)\n",
"with open(new_dataset_file, 'wb') as f:\n",
" pickle.dump((x, y), f)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 批量数据的处理"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

398
04_performance_tune.ipynb Executable file

File diff suppressed because one or more lines are too long

225
05_evaluation.ipynb Executable file

File diff suppressed because one or more lines are too long

34
main.py Executable file
View File

@ -0,0 +1,34 @@
import os
import numpy as np
from models import SpecDetector
from root_dir import ROOT_DIR
nrows, ncols, nbands = 600, 1024, 4
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
selected_model = "rf_8x8_c4_400_13.model"
def main():
model_path = os.path.join(ROOT_DIR, "models", selected_model)
detector = SpecDetector(model_path, blk_sz=8, channel_num=4)
total_len = nrows * ncols * nbands * 4
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
fd_img = os.open(img_fifo_path, os.O_RDONLY)
print("connect to fifo")
while True:
data = os.read(fd_img, total_len)
print("get img")
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, nbands, -1)).transpose(0, 2, 1)
mask = detector.predict(img)
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask.tobytes())
os.close(fd_mask)
if __name__ == '__main__':
main()

142
models.py Executable file
View File

@ -0,0 +1,142 @@
import os
import pickle
import time
import cv2
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# def feature(x):
# x = x.reshape((x.shape[0], -1, x.shape[-1]))
# x = np.mean(x, axis=1)
# return x
def feature(x):
x = x.reshape((x.shape[0], -1))
return x
def train_rf_and_report(train_x, train_y, test_x, test_y,
tree_num, save_path=None):
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42, class_weight={0:10, 1:10})
rfc = rfc.fit(train_x, train_y)
t1 = time.time()
y_pred = rfc.predict(test_x)
y_pred_binary = np.ones_like(y_pred)
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
y_pred_binary[(y_pred >1)] = 2
test_y_binary = np.ones_like(test_y)
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
test_y_binary[(test_y >1) ] = 2
print("预测时间:", time.time() - t1)
print("RFC训练模型评分" + str(accuracy_score(train_y, rfc.predict(train_x))))
print("RFC待测模型评分" + str(accuracy_score(test_y, rfc.predict(test_x))))
print('RFC预测结果' + str(y_pred))
print('---------------------------------------------------------------------------------------------------')
print('RFC分类报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
print('RFC混淆矩阵\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
print('rfc分类报告\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
print('rfc混淆矩阵\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
if save_path is not None:
with open(save_path, 'wb') as f:
pickle.dump(rfc, f)
return rfc
def evaluation_and_report(model, test_x, test_y):
t1 = time.time()
y_pred = model.predict(test_x)
y_pred_binary = np.ones_like(y_pred)
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
y_pred_binary[(y_pred >1)] = 2
test_y_binary = np.ones_like(test_y)
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
test_y_binary[(test_y >1) ] = 2
print("预测时间:", time.time() - t1)
print("RFC待测模型 accuracy" + str(accuracy_score(test_y, model.predict(test_x))))
print('RFC预测结果' + str(y_pred))
print('---------------------------------------------------------------------------------------------------')
print('RFC分类报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
print('RFC混淆矩阵\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
print('rfc分类报告\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
print('rfc混淆矩阵\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
def train_pca_rf(train_x, train_y, test_x, test_y, n_comp,
tree_num, save_path=None):
rfc = RandomForestClassifier(n_estimators=tree_num, random_state=42,class_weight={0:100, 1:100})
pca = PCA(n_components=0.95)
rfc = rfc.fit(train_x, train_y)
t1 = time.time()
y_pred = rfc.predict(test_x)
y_pred_binary = np.ones_like(y_pred)
y_pred_binary[(y_pred == 0) | (y_pred == 1)] = 0
y_pred_binary[(y_pred == 2) | (y_pred == 3) | (y_pred == 4)] = 2
test_y_binary = np.ones_like(test_y)
test_y_binary[(test_y == 0) | (test_y == 1)] = 0
test_y_binary[(test_y == 2) | (test_y == 3) | (test_y == 4)] = 2
print("预测时间:", time.time() - t1)
print("RFC训练模型评分" + str(accuracy_score(train_y, rfc.predict(train_x))))
print("RFC待测模型评分" + str(accuracy_score(test_y, rfc.predict(test_x))))
print('RFC预测结果' + str(y_pred))
print('---------------------------------------------------------------------------------------------------')
print('RFC分类报告\n' + str(classification_report(test_y, y_pred))) # 生成一个小报告呀
print('RFC混淆矩阵\n' + str(confusion_matrix(test_y, y_pred))) # 这个也是,生成的矩阵的意思是有多少
print('rfc分类报告\n' + str(classification_report(test_y_binary, y_pred_binary))) # 生成一个小报告呀
print('rfc混淆矩阵\n' + str(confusion_matrix(test_y_binary, y_pred_binary))) # 这个也是,生成的矩阵的意思是有多少
if save_path is not None:
with open(save_path, 'wb') as f:
pickle.dump((pca, rfc), f)
return pca, rfc
def split_x(data: np.ndarray, blk_sz: int) -> list:
"""
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels)
;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
"""
x_list = []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data)
return x_list
class SpecDetector(object):
def __init__(self, model_path, blk_sz=8, channel_num=4):
self.blk_sz, self.channel_num = blk_sz, channel_num
if os.path.exists(model_path):
with open(model_path, "rb") as model_file:
self.clf = pickle.load(model_file)
else:
raise FileNotFoundError("Model File not found")
def predict(self, data):
blocks = split_x(data, blk_sz=self.blk_sz)
blocks = np.array(blocks)
features = feature(np.array(blocks))
y_pred = self.clf.predict(features)
y_pred_binary = np.ones_like(y_pred)
# classes merge
y_pred_binary[(y_pred == 0) | (y_pred == 1) | (y_pred == 3)] = 0
# transform to mask
mask = self.mask_transform(y_pred_binary, (1024, 600))
return mask
def mask_transform(self, result, dst_size):
mask_size = 600//self.blk_sz, 1024 // self.blk_sz
mask = np.zeros(mask_size, dtype=np.uint8)
for idx, r in enumerate(result):
row, col = idx // mask_size[1], idx % mask_size[1]
mask[row, col] = r
mask = mask.repeat(self.blk_sz, axis = 0).repeat(self.blk_sz, axis = 1)
return mask

10
root_dir.py Executable file
View File

@ -0,0 +1,10 @@
# -*- codeing = utf-8 -*-
# Time : 2022/6/18 9:32
# @Auther : zhouchao
# @File: root_dir.py
# @Software:PyCharm
import os.path
import sys
ROOT_DIR = os.path.split(sys.argv[0])[0]

19
test_files/models_test.py Executable file
View File

@ -0,0 +1,19 @@
import unittest
import numpy as np
from models import feature
class ModelTestCase(unittest.TestCase):
def test_feature(self):
x_list = [np.ones((8, 8, 6)) * i for i in range(9600)]
features = feature(x_list=x_list)
self.assertEqual(features[0][0], 0) # add assertion here
self.assertEqual(features[0][5], 0) # add assertion here
self.assertEqual(features[-1][5], 9599) # add assertion here
self.assertEqual(features[-1][0], 9599) # add assertion here
if __name__ == '__main__':
unittest.main()

119
test_files/utils_test.py Executable file
View File

@ -0,0 +1,119 @@
import unittest
import numpy as np
from utils import determine_class, split_xy, split_x
class DatasetTest(unittest.TestCase):
def test_determine_class(self):
pixel_block = np.zeros((8, 8), dtype=np.uint8)
pixel_block[2: 4, 5: 6] = 2
pixel_block[5: 7, 1: 2] = 1
pixel_block[4: 6, 1: 6] = 3
cls = determine_class(pixel_block, sensitivity=8)
self.assertEqual(cls, 3)
pixel_block = np.zeros((8, 8), dtype=np.uint8)
pixel_block[2: 4, 5: 6] = 2
pixel_block[5: 7, 1: 2] = 1
pixel_block[4: 6, 1: 6] = 2
cls = determine_class(pixel_block, sensitivity=8)
self.assertEqual(cls, 2)
pixel_block = np.zeros((8, 8), dtype=np.uint8)
pixel_block[2: 4, 5: 6] = 1
pixel_block[5: 7, 1: 2] = 2
pixel_block[4: 6, 1: 6] = 1
cls = determine_class(pixel_block, sensitivity=8)
self.assertEqual(cls, 1)
def test_split_xy(self):
x = np.arange(600*1024).reshape((600, 1024))
y = np.zeros((600, 1024, 3))
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
trans_color_dict = {v: k for k, v in color_dict.items()}
# modify the first block
y[2: 4, 5: 6] = trans_color_dict[2]
y[5: 7, 1: 2] = trans_color_dict[1]
y[4: 6, 1: 6] = trans_color_dict[3]
# modify the last block
y[-4: -2, -6: -5] = trans_color_dict[1]
y[-7: -5, -2: -1] = trans_color_dict[2]
y[-6: -4, -6: -1] = trans_color_dict[1]
# modify the middle block
y[64+2: 64+4, 64+5: 64+6] = trans_color_dict[2]
y[64+5: 64+7, 64+1: 64+2] = trans_color_dict[1]
y[64+4: 64+6, 64+1: 64+6] = trans_color_dict[2]
x_list, y_list = split_xy(x, y, blk_sz=8, sensitivity=8)
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175],])
sum_value = np.sum(np.sum(x_list[0]-first_block))
self.assertEqual(sum_value, 0)
self.assertEqual(y_list[0], 3)
last_block = np.array(
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399],])
sum_value = np.sum(np.sum(x_list[-1]-last_block))
self.assertEqual(sum_value, 0)
self.assertEqual(y_list[-1], 1)
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
sum_value = np.sum(np.sum(x_list[1032]-middle_block))
self.assertEqual(sum_value, 0)
self.assertEqual(y_list[1032], 2)
def test_split_x(self):
x = np.arange(600 * 1024).reshape((600, 1024))
x_list = split_x(x, blk_sz=8)
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175], ])
sum_value = np.sum(np.sum(x_list[0] - first_block))
self.assertEqual(sum_value, 0)
last_block = np.array(
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399], ])
sum_value = np.sum(np.sum(x_list[-1] - last_block))
self.assertEqual(sum_value, 0)
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
sum_value = np.sum(np.sum(x_list[1032] - middle_block))
self.assertEqual(sum_value, 0)
if __name__ == '__main__':
unittest.main()

190
utils.py Executable file
View File

@ -0,0 +1,190 @@
import cv2
import numpy as np
import glob
import os
import time
import matplotlib.pyplot as plt
from models import SpecDetector
def trans_color(pixel: np.ndarray, color_dict: dict = None) -> int:
"""
将label转为类别
:param pixel: 一个 n x n 的像素块
:param color_dict: 用于转化的字典 {(0, 0, 255): 1, ....} 色彩采用bgr
:return:类别白噢好
"""
# 0 表示的是背景, 1表示的是烟梗剩下的都是杂质
if color_dict is None:
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
if (pixel[0], pixel[1], pixel[2]) in color_dict.keys():
return color_dict[(pixel[0], pixel[1], pixel[2])]
else:
return -1
def determine_class(pixel_blk: np.ndarray, sensitivity=8) -> int:
"""
决定像素块的类别
:param pixel_blk: 像素块
:param sensitivity: 敏感度
:return:
"""
defect_dict = {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1}
color_numbers = {cls: pixel_blk.shape[0] ** 2 - np.count_nonzero(pixel_blk - cls)
for cls in defect_dict.keys()}
grant_cls = {0: 0, 1: 0}
for cls, num in color_numbers.items():
grant_cls[defect_dict[cls]] += num
if grant_cls[1] >= sensitivity:
color_numbers = {cls: color_numbers[cls] for cls in [2, 3, 4, 5, 6]}
return max(color_numbers, key=color_numbers.get)
else:
if color_numbers[1] >= sensitivity:
return 1
return 0
def split_xy(data: np.ndarray, labeled_img: np.ndarray, blk_sz: int, sensitivity: int = 12,
color_dict=None, add_background=True) -> tuple:
"""
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels)
;param labeled_img: RGB labeled img with respect to the image!
make sure that the defect is (255, 0, 0) and background is (255, 255, 255)
;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
data y (block_num, ) 1 是杂质 0是无杂质
"""
assert (data.shape[0] == labeled_img.shape[0]) and (data.shape[1] == labeled_img.shape[1])
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}\
if color_dict is None else color_dict
class_img = np.zeros((labeled_img.shape[0], labeled_img.shape[1]), dtype=int)
for color, class_idx in color_dict.items():
truth_map = np.all(labeled_img == color, axis=2)
class_img[truth_map] = class_idx
x_list, y_list = [], []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
block_label = class_img[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
block_label = determine_class(block_label, sensitivity=sensitivity)
if add_background:
y_list.append(block_label)
x_list.append(block_data)
else:
if block_label != 0:
y_list.append(block_label)
x_list.append(block_data)
return x_list, y_list
def split_x(data: np.ndarray, blk_sz: int) -> list:
"""
Split the data into slices for classification.将数据划分为多个像素块,便于后续识别.
;param data: image data, shape (num_rows x 1024 x num_channels)
;param blk_sz: block size
;param sensitivity: 最少有多少个杂物点能够被认为是杂物
;return data_x, data_y: sliced data x (block_num x num_charnnels x blk_sz x blk_sz)
"""
x_list = []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data)
return x_list
def visualization_evaluation(detector, data_path, selected_bands=None):
selected_bands = [76, 146, 216, 367, 383, 406] if selected_bands is None else selected_bands
nrows, ncols = 600, 1024
image_paths = glob.glob(os.path.join(data_path, "calibrated*.raw"))
for idx, image_path in enumerate(image_paths):
with open(image_path, 'rb') as f:
data = f.read()
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)).transpose(0, 2, 1)
nbands = img.shape[2]
t1 = time.time()
mask = detector.predict(img[..., selected_bands] if nbands == 448 else img)
time_spent = time.time() - t1
if nbands == 448:
rgb_img = np.asarray(img[..., [372, 241, 169]] * 255, dtype=np.uint8)
else:
rgb_img = np.asarray(img[..., [0, 1, 2]] * 255, dtype=np.uint8)
fig, axs = plt.subplots(1, 2)
axs[0].imshow(rgb_img)
axs[1].imshow(mask)
fig.suptitle(f"time spent {time_spent*1000:.2f} ms" + f"\n{image_path}")
plt.savefig(f"./dataset/{idx}.png", dpi=300)
plt.show()
def visualization_y(y_list, k_size):
mask = np.zeros((600//k_size, 1024//k_size), dtype=np.uint8)
for idx, r in enumerate(y_list):
row, col = idx // (1024 // k_size), idx % (1024 // k_size)
mask[row, col] = r
fig, axs = plt.subplots()
axs.imshow(mask)
plt.show()
def read_raw_file(file_name, selected_bands=None):
with open(file_name, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.float32).reshape((600, -1, 1024)).transpose(0, 2, 1)
if selected_bands is not None:
data = data[..., selected_bands]
return data
def read_black_and_white_file(file_name):
with open(file_name, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.float32).reshape((1, 448, 1024)).transpose(0, 2, 1)
return data
def label2pic(label, color_dict):
pic = np.zeros((label.shape[0], label.shape[1], 3))
for color, cls in color_dict.items():
pic[label == cls] = color
return pic
def generate_tobacco_label(data, model_file, blk_sz, selected_bands):
model = SpecDetector(model_path=model_file, blk_sz=blk_sz, channel_num=len(selected_bands))
y_label = model.predict(data)
x_list, y_list = [], []
for i in range(0, 600 // blk_sz):
for j in range(0, 1024 // blk_sz):
if np.sum(np.sum(y_label[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...])) \
> 0:
block_data = data[i * blk_sz: (i + 1) * blk_sz, j * blk_sz: (j + 1) * blk_sz, ...]
x_list.append(block_data)
y_list.append(1)
return x_list, y_list
def generate_impurity_label(data, light_threshold, color_dict, split_line=0, target_class_right=None,
target_class_left=None,):
y_label = np.zeros((data.shape[0], data.shape[1]))
for i in range(0, 600):
for j in range(0, 1024):
if np.sum(np.sum(data[i, j])) >= light_threshold:
if j > split_line:
y_label[i, j] = target_class_right
else:
y_label[i, j] = target_class_left
pic = label2pic(y_label, color_dict=color_dict)
fig, axs = plt.subplots(2, 1)
axs[0].matshow(y_label)
axs[1].matshow(data[..., 0])
plt.show()
return pic