tobacoo-industry/test_files/utils_test.py
2022-06-18 19:13:39 +08:00

120 lines
6.3 KiB
Python
Executable File

import unittest
import numpy as np
from utils import determine_class, split_xy, split_x
class DatasetTest(unittest.TestCase):
def test_determine_class(self):
pixel_block = np.zeros((8, 8), dtype=np.uint8)
pixel_block[2: 4, 5: 6] = 2
pixel_block[5: 7, 1: 2] = 1
pixel_block[4: 6, 1: 6] = 3
cls = determine_class(pixel_block, sensitivity=8)
self.assertEqual(cls, 3)
pixel_block = np.zeros((8, 8), dtype=np.uint8)
pixel_block[2: 4, 5: 6] = 2
pixel_block[5: 7, 1: 2] = 1
pixel_block[4: 6, 1: 6] = 2
cls = determine_class(pixel_block, sensitivity=8)
self.assertEqual(cls, 2)
pixel_block = np.zeros((8, 8), dtype=np.uint8)
pixel_block[2: 4, 5: 6] = 1
pixel_block[5: 7, 1: 2] = 2
pixel_block[4: 6, 1: 6] = 1
cls = determine_class(pixel_block, sensitivity=8)
self.assertEqual(cls, 1)
def test_split_xy(self):
x = np.arange(600*1024).reshape((600, 1024))
y = np.zeros((600, 1024, 3))
color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 255, 0): 3, (0, 255, 255): 4}
trans_color_dict = {v: k for k, v in color_dict.items()}
# modify the first block
y[2: 4, 5: 6] = trans_color_dict[2]
y[5: 7, 1: 2] = trans_color_dict[1]
y[4: 6, 1: 6] = trans_color_dict[3]
# modify the last block
y[-4: -2, -6: -5] = trans_color_dict[1]
y[-7: -5, -2: -1] = trans_color_dict[2]
y[-6: -4, -6: -1] = trans_color_dict[1]
# modify the middle block
y[64+2: 64+4, 64+5: 64+6] = trans_color_dict[2]
y[64+5: 64+7, 64+1: 64+2] = trans_color_dict[1]
y[64+4: 64+6, 64+1: 64+6] = trans_color_dict[2]
x_list, y_list = split_xy(x, y, blk_sz=8, sensitivity=8)
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175],])
sum_value = np.sum(np.sum(x_list[0]-first_block))
self.assertEqual(sum_value, 0)
self.assertEqual(y_list[0], 3)
last_block = np.array(
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399],])
sum_value = np.sum(np.sum(x_list[-1]-last_block))
self.assertEqual(sum_value, 0)
self.assertEqual(y_list[-1], 1)
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
sum_value = np.sum(np.sum(x_list[1032]-middle_block))
self.assertEqual(sum_value, 0)
self.assertEqual(y_list[1032], 2)
def test_split_x(self):
x = np.arange(600 * 1024).reshape((600, 1024))
x_list = split_x(x, blk_sz=8)
first_block = np.array([[0, 1, 2, 3, 4, 5, 6, 7],
[1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031],
[2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055],
[3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079],
[4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103],
[5120, 5121, 5122, 5123, 5124, 5125, 5126, 5127],
[6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151],
[7168, 7169, 7170, 7171, 7172, 7173, 7174, 7175], ])
sum_value = np.sum(np.sum(x_list[0] - first_block))
self.assertEqual(sum_value, 0)
last_block = np.array(
[[607224, 607225, 607226, 607227, 607228, 607229, 607230, 607231],
[608248, 608249, 608250, 608251, 608252, 608253, 608254, 608255],
[609272, 609273, 609274, 609275, 609276, 609277, 609278, 609279],
[610296, 610297, 610298, 610299, 610300, 610301, 610302, 610303],
[611320, 611321, 611322, 611323, 611324, 611325, 611326, 611327],
[612344, 612345, 612346, 612347, 612348, 612349, 612350, 612351],
[613368, 613369, 613370, 613371, 613372, 613373, 613374, 613375],
[614392, 614393, 614394, 614395, 614396, 614397, 614398, 614399], ])
sum_value = np.sum(np.sum(x_list[-1] - last_block))
self.assertEqual(sum_value, 0)
middle_block = np.array([[65600, 65601, 65602, 65603, 65604, 65605, 65606, 65607],
[66624, 66625, 66626, 66627, 66628, 66629, 66630, 66631],
[67648, 67649, 67650, 67651, 67652, 67653, 67654, 67655],
[68672, 68673, 68674, 68675, 68676, 68677, 68678, 68679],
[69696, 69697, 69698, 69699, 69700, 69701, 69702, 69703],
[70720, 70721, 70722, 70723, 70724, 70725, 70726, 70727],
[71744, 71745, 71746, 71747, 71748, 71749, 71750, 71751],
[72768, 72769, 72770, 72771, 72772, 72773, 72774, 72775]])
sum_value = np.sum(np.sum(x_list[1032] - middle_block))
self.assertEqual(sum_value, 0)
if __name__ == '__main__':
unittest.main()