diff --git a/DPL/dgd_class/predict.py b/DPL/dgd_class/predict.py index 2d72872..679f0fe 100644 --- a/DPL/dgd_class/predict.py +++ b/DPL/dgd_class/predict.py @@ -8,7 +8,7 @@ from torchvision import datasets, models, transforms import os import matplotlib.pyplot as plt -from test import device, class_names, imshow +from train import device, class_names, imshow # 加载已训练的 ONNX 模型