mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
42 lines
888 B
Python
42 lines
888 B
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
# # 模型的训练
|
|
|
|
# In[16]:
|
|
|
|
|
|
import numpy as np
|
|
|
|
from models import AnonymousColorDetector
|
|
from utils import read_labeled_img
|
|
|
|
# ## 读取数据与构建数据集
|
|
|
|
# In[17]:
|
|
|
|
|
|
data_dir = "data/dataset"
|
|
color_dict = {(0, 0, 255): "yangeng"}
|
|
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
|
|
|
# ## 模型训练
|
|
|
|
# In[18]:
|
|
|
|
|
|
# 定义一些常量
|
|
threshold = 5
|
|
node_num = 20
|
|
negative_sample_num = None # None或者一个数字
|
|
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
|
# 对数据进行预处理
|
|
x = np.concatenate([v for k, v in dataset.items()], axis=0)
|
|
negative_sample_num = int(x.shape[0] * 0.7) if negative_sample_num is None else negative_sample_num
|
|
|
|
model = AnonymousColorDetector()
|
|
|
|
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)
|
|
|
|
model.save()
|