2121
2222class TestSSIMMetric (unittest .TestCase ):
2323
24- def test2d_gaussian (self ):
24+ def test_2d_gaussian (self ):
2525 set_determinism (0 )
2626 preds = torch .abs (torch .randn (2 , 3 , 16 , 16 ))
2727 target = torch .abs (torch .randn (2 , 3 , 16 , 16 ))
@@ -32,9 +32,9 @@ def test2d_gaussian(self):
3232 metric (preds , target )
3333 result = metric .aggregate ()
3434 expected_value = 0.045415
35- self .assertTrue (expected_value - result .item () < 0.000001 )
35+ self .assertTrue (abs ( expected_value - result .item () ) < 0.000001 )
3636
37- def test2d_uniform (self ):
37+ def test_2d_uniform (self ):
3838 set_determinism (0 )
3939 preds = torch .abs (torch .randn (2 , 3 , 16 , 16 ))
4040 target = torch .abs (torch .randn (2 , 3 , 16 , 16 ))
@@ -45,9 +45,9 @@ def test2d_uniform(self):
4545 metric (preds , target )
4646 result = metric .aggregate ()
4747 expected_value = 0.050103
48- self .assertTrue (expected_value - result .item () < 0.000001 )
48+ self .assertTrue (abs ( expected_value - result .item () ) < 0.000001 )
4949
50- def test3d_gaussian (self ):
50+ def test_3d_gaussian (self ):
5151 set_determinism (0 )
5252 preds = torch .abs (torch .randn (2 , 3 , 16 , 16 , 16 ))
5353 target = torch .abs (torch .randn (2 , 3 , 16 , 16 , 16 ))
0 commit comments