diff --git a/main_test.py b/main_test.py index fc3a084..c8e895a 100644 --- a/main_test.py +++ b/main_test.py @@ -168,5 +168,5 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='Run image test or ') tester = TestMain() tester.pony_run(test_path=r'/home/lzy/2022.7.30/tobacco_v1_0/saved_img/', - test_rgb=False, test_spectra=False, get_delta=False) + test_rgb=True, test_spectra=True, get_delta=False) diff --git a/models.py b/models.py index f0acf56..9af462f 100755 --- a/models.py +++ b/models.py @@ -260,7 +260,7 @@ class PixelModelML: self.dt = pickle.load(f) def predict(self, feature): - pixel_result_array = self.dt.predict(feature) + pixel_result_array = self.dt.predict_bin(feature) return pixel_result_array @@ -409,7 +409,7 @@ class SpecDetector(Detector): if x_yellow.shape[0] == 0: return non_yellow_things else: - tobacco = self.pixel_model_ml.predict_bin(x_yellow) < 0.5 + tobacco = self.pixel_model_ml.predict(x_yellow) < 0.5 non_yellow_things[yellow_things] = ~tobacco # 杂质mask中将背景赋值为0,将杂质赋值为1