@@ -722,7 +722,7 @@ def test_prediction_shape(
722722
723723 @parameterized .expand (LATENT_CNDM_TEST_CASES )
724724 @skipUnless (has_einops , "Requires einops" )
725- def test_sample_shape (
725+ def test_pred_shape (
726726 self ,
727727 ae_model_type ,
728728 autoencoder_params ,
@@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat(
11651165
11661166 @parameterized .expand (LATENT_CNDM_TEST_CASES_DIFF_SHAPES )
11671167 @skipUnless (has_einops , "Requires einops" )
1168- def test_sample_shape_different_latents (
1168+ def test_shape_different_latents (
11691169 self ,
11701170 ae_model_type ,
11711171 autoencoder_params ,
@@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents(
12421242 )
12431243 self .assertEqual (prediction .shape , latent_shape )
12441244
1245+ @parameterized .expand (LATENT_CNDM_TEST_CASES_DIFF_SHAPES )
1246+ @skipUnless (has_einops , "Requires einops" )
1247+ def test_sample_shape_different_latents (
1248+ self ,
1249+ ae_model_type ,
1250+ autoencoder_params ,
1251+ dm_model_type ,
1252+ stage_2_params ,
1253+ controlnet_params ,
1254+ input_shape ,
1255+ latent_shape ,
1256+ ):
1257+ stage_1 = None
1258+
1259+ if ae_model_type == "AutoencoderKL" :
1260+ stage_1 = AutoencoderKL (** autoencoder_params )
1261+ if ae_model_type == "VQVAE" :
1262+ stage_1 = VQVAE (** autoencoder_params )
1263+ if ae_model_type == "SPADEAutoencoderKL" :
1264+ stage_1 = SPADEAutoencoderKL (** autoencoder_params )
1265+ if dm_model_type == "SPADEDiffusionModelUNet" :
1266+ stage_2 = SPADEDiffusionModelUNet (** stage_2_params )
1267+ else :
1268+ stage_2 = DiffusionModelUNet (** stage_2_params )
1269+ controlnet = ControlNet (** controlnet_params )
1270+
1271+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
1272+ stage_1 .to (device )
1273+ stage_2 .to (device )
1274+ controlnet .to (device )
1275+ stage_1 .eval ()
1276+ stage_2 .eval ()
1277+ controlnet .eval ()
1278+
1279+ noise = torch .randn (latent_shape ).to (device )
1280+ mask = torch .randn (input_shape ).to (device )
1281+ scheduler = DDPMScheduler (num_train_timesteps = 10 )
1282+ # We infer the VAE shape
1283+ if ae_model_type == "VQVAE" :
1284+ autoencoder_latent_shape = [i // (2 ** (len (autoencoder_params ["channels" ]))) for i in input_shape [2 :]]
1285+ else :
1286+ autoencoder_latent_shape = [i // (2 ** (len (autoencoder_params ["channels" ]) - 1 )) for i in input_shape [2 :]]
1287+
1288+ inferer = ControlNetLatentDiffusionInferer (
1289+ scheduler = scheduler ,
1290+ scale_factor = 1.0 ,
1291+ ldm_latent_shape = list (latent_shape [2 :]),
1292+ autoencoder_latent_shape = autoencoder_latent_shape ,
1293+ )
1294+ scheduler .set_timesteps (num_inference_steps = 10 )
1295+
1296+ if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL" :
1297+ input_shape_seg = list (input_shape )
1298+ if "label_nc" in stage_2_params .keys ():
1299+ input_shape_seg [1 ] = stage_2_params ["label_nc" ]
1300+ else :
1301+ input_shape_seg [1 ] = autoencoder_params ["label_nc" ]
1302+ input_seg = torch .randn (input_shape_seg ).to (device )
1303+ prediction , _ = inferer .sample (
1304+ autoencoder_model = stage_1 ,
1305+ diffusion_model = stage_2 ,
1306+ controlnet = controlnet ,
1307+ cn_cond = mask ,
1308+ input_noise = noise ,
1309+ seg = input_seg ,
1310+ save_intermediates = True ,
1311+ )
1312+ else :
1313+ prediction = inferer .sample (
1314+ autoencoder_model = stage_1 ,
1315+ diffusion_model = stage_2 ,
1316+ input_noise = noise ,
1317+ controlnet = controlnet ,
1318+ cn_cond = mask ,
1319+ save_intermediates = False ,
1320+ )
1321+ self .assertEqual (prediction .shape , input_shape )
1322+
12451323 @skipUnless (has_einops , "Requires einops" )
12461324 def test_incompatible_spade_setup (self ):
12471325 stage_1 = SPADEAutoencoderKL (
0 commit comments