2121from monai .networks .blocks .feature_pyramid_network import FeaturePyramidNetwork
2222from monai .networks .nets .resnet import resnet50
2323from monai .utils import optional_import
24- from tests .test_utils import test_export_save
24+ from tests .test_utils import test_script_save
2525
2626_ , has_torchvision = optional_import ("torchvision" )
2727
@@ -55,13 +55,13 @@ def test_fpn_block(self, input_param, input_shape, expected_shape):
5555 self .assertEqual (result ["feat1" ].shape , expected_shape [1 ])
5656
5757 @parameterized .expand (TEST_CASES )
58- def test_export (self , input_param , input_shape , expected_shape ):
59- # test whether support torch.export
58+ def test_script (self , input_param , input_shape , expected_shape ):
59+ # test whether support torchscript
6060 net = FeaturePyramidNetwork (** input_param )
6161 data = OrderedDict ()
6262 data ["feat0" ] = torch .rand (input_shape [0 ])
6363 data ["feat1" ] = torch .rand (input_shape [1 ])
64- test_export_save (net , data )
64+ test_script_save (net , data )
6565
6666
6767@unittest .skipUnless (has_torchvision , "Requires torchvision" )
@@ -75,11 +75,11 @@ def test_fpn(self, input_param, input_shape, expected_shape):
7575 self .assertEqual (result ["pool" ].shape , expected_shape [1 ])
7676
7777 @parameterized .expand (TEST_CASES2 )
78- def test_export (self , input_param , input_shape , expected_shape ):
79- # test whether support torch.export
78+ def test_script (self , input_param , input_shape , expected_shape ):
79+ # test whether support torchscript
8080 net = _resnet_fpn_extractor (backbone = resnet50 (), spatial_dims = input_param ["spatial_dims" ], returned_layers = [1 ])
8181 data = torch .rand (input_shape )
82- test_export_save (net , data )
82+ test_script_save (net , data )
8383
8484
8585if __name__ == "__main__" :
0 commit comments