mirror of
https://github.com/Karllzy/cotton_color.git
synced 2025-11-08 18:53:53 +00:00
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
from model_cls.models import Model as ClsModel
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[0] # YOLOv5 root directory
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
|
|
def main():
|
|
# 命令行参数解析
|
|
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
|
|
|
|
# 设置默认值,并允许用户通过命令行进行修改
|
|
parser.add_argument('--input-dir', type=str, default='dataset/', help='Directory to input images')
|
|
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
|
|
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 设置设备
|
|
device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu')
|
|
if args.gpu and not torch.cuda.is_available():
|
|
print("GPU not available, switching to CPU.")
|
|
|
|
# 加载模型
|
|
model = ClsModel(model_path=args.weights, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |