11import os
22os .environ ['TL_BACKEND' ] = 'tensorflow'
3+ # os.environ['TL_BACKEND'] = 'torch'
34from tensorlayerx .nn import Module
45import tensorlayerx as tlx
56from tensorlayerx .nn import Conv2d , BatchNorm2d , Elementwise , SubpixelConv2d , UpSampling2d , Flatten , Sequential
@@ -31,24 +32,24 @@ def __init__(self):
3132 super (ResidualBlock , self ).__init__ ()
3233 self .conv1 = Conv2d (
3334 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
34- data_format = 'channels_last ' , b_init = None
35+ data_format = 'channels_first ' , b_init = None
3536 )
36- self .bn1 = BatchNorm2d (num_features = 64 , act = tlx .ReLU , gamma_init = G_init , data_format = 'channels_last ' )
37+ self .bn1 = BatchNorm2d (num_features = 64 , act = tlx .ReLU , gamma_init = G_init , data_format = 'channels_first ' )
3738 self .conv2 = Conv2d (
3839 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
39- data_format = 'channels_last ' , b_init = None
40+ data_format = 'channels_first ' , b_init = None
4041 )
41- self .bn2 = BatchNorm2d (num_features = 64 , act = None , gamma_init = G_init , data_format = 'channels_last ' )
42+ self .bn2 = BatchNorm2d (num_features = 64 , act = None , gamma_init = G_init , data_format = 'channels_first ' )
4243 self .add = Add ()
4344
4445 def forward (self , x ):
45- temp = x
4646 z = self .conv1 (x )
4747 z = self .bn1 (z )
4848 z = self .conv2 (z )
4949 z = self .bn2 (z )
50- out = self .add (temp ,z )
51- return out
50+ x = self .add (x , z )
51+ return x
52+
5253
5354class SRGAN_g (Module ):
5455 """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
@@ -59,19 +60,19 @@ def __init__(self):
5960 super (SRGAN_g , self ).__init__ ()
6061 self .conv1 = Conv2d (
6162 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = tlx .ReLU , padding = 'SAME' , W_init = W_init ,
62- data_format = 'channels_last '
63+ data_format = 'channels_first '
6364 )
6465 self .residual_block = self .make_layer ()
6566 self .conv2 = Conv2d (
6667 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init ,
67- data_format = 'channels_last ' , b_init = None
68+ data_format = 'channels_first ' , b_init = None
6869 )
69- self .bn1 = BatchNorm2d (num_features = 64 , act = None , gamma_init = G_init , data_format = 'channels_last ' )
70- self .conv3 = Conv2d (out_channels = 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init , data_format = 'channels_last ' )
71- self .subpiexlconv1 = SubpixelConv2d (data_format = 'channels_last ' , scale = 2 , act = tlx .ReLU )
72- self .conv4 = Conv2d (out_channels = 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init , data_format = 'channels_last ' )
73- self .subpiexlconv2 = SubpixelConv2d (data_format = 'channels_last ' , scale = 2 , act = tlx .ReLU )
74- self .conv5 = Conv2d (3 , kernel_size = (1 , 1 ), stride = (1 , 1 ), act = tlx .Tanh , padding = 'SAME' , W_init = W_init , data_format = 'channels_last ' )
70+ self .bn1 = BatchNorm2d (num_features = 64 , act = None , gamma_init = G_init , data_format = 'channels_first ' )
71+ self .conv3 = Conv2d (out_channels = 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init , data_format = 'channels_first ' )
72+ self .subpiexlconv1 = SubpixelConv2d (data_format = 'channels_first ' , scale = 2 , act = tlx .ReLU )
73+ self .conv4 = Conv2d (out_channels = 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init , data_format = 'channels_first ' )
74+ self .subpiexlconv2 = SubpixelConv2d (data_format = 'channels_first ' , scale = 2 , act = tlx .ReLU )
75+ self .conv5 = Conv2d (3 , kernel_size = (1 , 1 ), stride = (1 , 1 ), act = tlx .Tanh , padding = 'SAME' , W_init = W_init , data_format = 'channels_first ' )
7576 self .add = Add ()
7677
7778 def make_layer (self ):
@@ -87,7 +88,6 @@ def forward(self, x):
8788 x = self .conv2 (x )
8889 x = self .bn1 (x )
8990 x = self .add (x , temp )
90- # x = x + temp
9191 x = self .conv3 (x )
9292 x = self .subpiexlconv1 (x )
9393 x = self .conv4 (x )
@@ -108,26 +108,26 @@ def __init__(self):
108108 super (SRGAN_g2 , self ).__init__ ()
109109 self .conv1 = Conv2d (
110110 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
111- data_format = 'channels_last '
111+ data_format = 'channels_first '
112112 )
113113 self .residual_block = self .make_layer ()
114114 self .conv2 = Conv2d (
115115 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init ,
116- data_format = 'channels_last ' , b_init = None
116+ data_format = 'channels_first ' , b_init = None
117117 )
118- self .bn1 = BatchNorm2d (act = None , gamma_init = G_init , data_format = 'channels_last ' )
119- self .upsample1 = UpSampling2d (data_format = 'channels_last ' , scale = (2 , 2 ), method = 'bilinear' )
118+ self .bn1 = BatchNorm2d (act = None , gamma_init = G_init , data_format = 'channels_first ' )
119+ self .upsample1 = UpSampling2d (data_format = 'channels_first ' , scale = (2 , 2 ), method = 'bilinear' )
120120 self .conv3 = Conv2d (
121121 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init ,
122- data_format = 'channels_last ' , b_init = None
122+ data_format = 'channels_first ' , b_init = None
123123 )
124- self .bn2 = BatchNorm2d (act = tlx .ReLU , gamma_init = G_init , data_format = 'channels_last ' )
125- self .upsample2 = UpSampling2d (data_format = 'channels_last ' , scale = (4 , 4 ), method = 'bilinear' )
124+ self .bn2 = BatchNorm2d (act = tlx .ReLU , gamma_init = G_init , data_format = 'channels_first ' )
125+ self .upsample2 = UpSampling2d (data_format = 'channels_first ' , scale = (4 , 4 ), method = 'bilinear' )
126126 self .conv4 = Conv2d (
127127 out_channels = 32 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = 'SAME' , W_init = W_init ,
128- data_format = 'channels_last ' , b_init = None
128+ data_format = 'channels_first ' , b_init = None
129129 )
130- self .bn3 = BatchNorm2d (act = tlx .ReLU , gamma_init = G_init , data_format = 'channels_last ' )
130+ self .bn3 = BatchNorm2d (act = tlx .ReLU , gamma_init = G_init , data_format = 'channels_first ' )
131131 self .conv5 = Conv2d (
132132 out_channels = 3 , kernel_size = (1 , 1 ), stride = (1 , 1 ), act = tlx .Tanh , padding = 'SAME' , W_init = W_init
133133 )
@@ -164,43 +164,43 @@ def __init__(self, ):
164164 super (SRGAN_d2 , self ).__init__ ()
165165 self .conv1 = Conv2d (
166166 out_channels = 64 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
167- W_init = W_init , data_format = 'channels_last '
167+ W_init = W_init , data_format = 'channels_first '
168168 )
169169 self .conv2 = Conv2d (
170170 out_channels = 64 , kernel_size = (3 , 3 ), stride = (2 , 2 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
171- W_init = W_init , data_format = 'channels_last ' , b_init = None
171+ W_init = W_init , data_format = 'channels_first ' , b_init = None
172172 )
173- self .bn1 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_last ' )
173+ self .bn1 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_first ' )
174174 self .conv3 = Conv2d (
175175 out_channels = 128 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
176- W_init = W_init , data_format = 'channels_last ' , b_init = None
176+ W_init = W_init , data_format = 'channels_first ' , b_init = None
177177 )
178- self .bn2 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_last ' )
178+ self .bn2 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_first ' )
179179 self .conv4 = Conv2d (
180180 out_channels = 128 , kernel_size = (3 , 3 ), stride = (2 , 2 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
181- W_init = W_init , data_format = 'channels_last ' , b_init = None
181+ W_init = W_init , data_format = 'channels_first ' , b_init = None
182182 )
183- self .bn3 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_last ' )
183+ self .bn3 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_first ' )
184184 self .conv5 = Conv2d (
185185 out_channels = 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
186- W_init = W_init , data_format = 'channels_last ' , b_init = None
186+ W_init = W_init , data_format = 'channels_first ' , b_init = None
187187 )
188- self .bn4 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_last ' )
188+ self .bn4 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_first ' )
189189 self .conv6 = Conv2d (
190190 out_channels = 256 , kernel_size = (3 , 3 ), stride = (2 , 2 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
191- W_init = W_init , data_format = 'channels_last ' , b_init = None
191+ W_init = W_init , data_format = 'channels_first ' , b_init = None
192192 )
193- self .bn5 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_last ' )
193+ self .bn5 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_first ' )
194194 self .conv7 = Conv2d (
195195 out_channels = 512 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
196- W_init = W_init , data_format = 'channels_last ' , b_init = None
196+ W_init = W_init , data_format = 'channels_first ' , b_init = None
197197 )
198- self .bn6 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_last ' )
198+ self .bn6 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_first ' )
199199 self .conv8 = Conv2d (
200200 out_channels = 512 , kernel_size = (3 , 3 ), stride = (2 , 2 ), act = tlx .LeakyReLU (negative_slope = 0.2 ), padding = 'SAME' ,
201- W_init = W_init , data_format = 'channels_last ' , b_init = None
201+ W_init = W_init , data_format = 'channels_first ' , b_init = None
202202 )
203- self .bn7 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_last ' )
203+ self .bn7 = BatchNorm2d (gamma_init = G_init , data_format = 'channels_first ' )
204204 self .flat = Flatten ()
205205 self .dense1 = Linear (out_features = 1024 , act = tlx .LeakyReLU (negative_slope = 0.2 ))
206206 self .dense2 = Linear (out_features = 1 )
@@ -235,58 +235,58 @@ def __init__(self, dim=64):
235235 super (SRGAN_d , self ).__init__ ()
236236 self .conv1 = Conv2d (
237237 out_channels = dim , kernel_size = (4 , 4 ), stride = (2 , 2 ), act = tlx .LeakyReLU , padding = 'SAME' , W_init = W_init ,
238- data_format = 'channels_last '
238+ data_format = 'channels_first '
239239 )
240240 self .conv2 = Conv2d (
241241 out_channels = dim * 2 , kernel_size = (4 , 4 ), stride = (2 , 2 ), act = None , padding = 'SAME' , W_init = W_init ,
242- data_format = 'channels_last ' , b_init = None
242+ data_format = 'channels_first ' , b_init = None
243243 )
244- self .bn1 = BatchNorm2d (num_features = dim * 2 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
244+ self .bn1 = BatchNorm2d (num_features = dim * 2 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
245245 self .conv3 = Conv2d (
246246 out_channels = dim * 4 , kernel_size = (4 , 4 ), stride = (2 , 2 ), act = None , padding = 'SAME' , W_init = W_init ,
247- data_format = 'channels_last ' , b_init = None
247+ data_format = 'channels_first ' , b_init = None
248248 )
249- self .bn2 = BatchNorm2d (num_features = dim * 4 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
249+ self .bn2 = BatchNorm2d (num_features = dim * 4 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
250250 self .conv4 = Conv2d (
251251 out_channels = dim * 8 , kernel_size = (4 , 4 ), stride = (2 , 2 ), act = None , padding = 'SAME' , W_init = W_init ,
252- data_format = 'channels_last ' , b_init = None
252+ data_format = 'channels_first ' , b_init = None
253253 )
254- self .bn3 = BatchNorm2d (num_features = dim * 8 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
254+ self .bn3 = BatchNorm2d (num_features = dim * 8 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
255255 self .conv5 = Conv2d (
256256 out_channels = dim * 16 , kernel_size = (4 , 4 ), stride = (2 , 2 ), act = None , padding = 'SAME' , W_init = W_init ,
257- data_format = 'channels_last ' , b_init = None
257+ data_format = 'channels_first ' , b_init = None
258258 )
259- self .bn4 = BatchNorm2d (num_features = dim * 16 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
259+ self .bn4 = BatchNorm2d (num_features = dim * 16 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
260260 self .conv6 = Conv2d (
261261 out_channels = dim * 32 , kernel_size = (4 , 4 ), stride = (2 , 2 ), act = None , padding = 'SAME' , W_init = W_init ,
262- data_format = 'channels_last ' , b_init = None
262+ data_format = 'channels_first ' , b_init = None
263263 )
264- self .bn5 = BatchNorm2d (num_features = dim * 32 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
264+ self .bn5 = BatchNorm2d (num_features = dim * 32 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
265265 self .conv7 = Conv2d (
266266 out_channels = dim * 16 , kernel_size = (1 , 1 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
267- data_format = 'channels_last ' , b_init = None
267+ data_format = 'channels_first ' , b_init = None
268268 )
269- self .bn6 = BatchNorm2d (num_features = dim * 16 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
269+ self .bn6 = BatchNorm2d (num_features = dim * 16 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
270270 self .conv8 = Conv2d (
271271 out_channels = dim * 8 , kernel_size = (1 , 1 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
272- data_format = 'channels_last ' , b_init = None
272+ data_format = 'channels_first ' , b_init = None
273273 )
274- self .bn7 = BatchNorm2d (num_features = dim * 8 , act = None , gamma_init = G_init , data_format = 'channels_last ' )
274+ self .bn7 = BatchNorm2d (num_features = dim * 8 , act = None , gamma_init = G_init , data_format = 'channels_first ' )
275275 self .conv9 = Conv2d (
276276 out_channels = dim * 2 , kernel_size = (1 , 1 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
277- data_format = 'channels_last ' , b_init = None
277+ data_format = 'channels_first ' , b_init = None
278278 )
279- self .bn8 = BatchNorm2d (num_features = dim * 2 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
279+ self .bn8 = BatchNorm2d (num_features = dim * 2 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
280280 self .conv10 = Conv2d (
281281 out_channels = dim * 2 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
282- data_format = 'channels_last ' , b_init = None
282+ data_format = 'channels_first ' , b_init = None
283283 )
284- self .bn9 = BatchNorm2d (num_features = dim * 2 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_last ' )
284+ self .bn9 = BatchNorm2d (num_features = dim * 2 , act = tlx .LeakyReLU , gamma_init = G_init , data_format = 'channels_first ' )
285285 self .conv11 = Conv2d (
286286 out_channels = dim * 8 , kernel_size = (3 , 3 ), stride = (1 , 1 ), act = None , padding = 'SAME' , W_init = W_init ,
287- data_format = 'channels_last ' , b_init = None
287+ data_format = 'channels_first ' , b_init = None
288288 )
289- self .bn10 = BatchNorm2d (num_features = dim * 8 , gamma_init = G_init , data_format = 'channels_last ' )
289+ self .bn10 = BatchNorm2d (num_features = dim * 8 , gamma_init = G_init , data_format = 'channels_first ' )
290290 self .add = Elementwise (combine_fn = tlx .add , act = tlx .LeakyReLU )
291291 self .flat = Flatten ()
292292 self .dense = Linear (out_features = 1 , W_init = W_init )
@@ -390,35 +390,41 @@ def forward(self, x):
390390
391391
392392net = SRGAN_g ()
393- net .init_build (tlx .nn .Input (shape = (1 , 96 , 96 , 3 )))
393+ net .set_eval ()
394+ net .init_build (tlx .nn .Input (shape = (1 , 3 , 96 , 96 )))
394395net .load_weights ('model/g.npz' , format = 'npz_dict' )
395- input = tlx .nn .Input (shape = (1 , 96 , 96 , 3 ))
396-
397- onnx_model = export (net , input_spec = input , path = 'srgan.onnx' )
398-
396+ input = tlx .nn .Input (shape = (1 , 3 , 96 , 96 ))
397+ onnx_model = export (net , input_spec = input , path = 'srgan.onnx' , dynamic_axes = [2 ,3 ])
399398sess = onnxruntime .InferenceSession ('srgan.onnx' )
400399
401400input_name = sess .get_inputs ()[0 ].name
402401output_name = sess .get_outputs ()[0 ].name
403402
404- # TODO DYNAMIC INPUT SHAPE
405403valid_hr_img = tlx .vision .load_image ('data/0882.png' )
406404valid_lr_img = np .asarray (valid_hr_img )
407405hr_size = [valid_hr_img .shape [0 ], valid_hr_img .shape [1 ]]
408406valid_lr_img = cv2 .resize (valid_lr_img , dsize = (hr_size [1 ]// 4 , hr_size [0 ]// 4 ))
409407lr_size = [valid_hr_img .shape [0 ]// 4 , valid_hr_img .shape [1 ]// 4 ]
410408input = (valid_lr_img / 127.5 ) - 1
411409input = np .asarray (input , dtype = np .float32 )
410+ input = np .transpose (input , axes = [2 , 0 , 1 ])
412411input = input [np .newaxis , :, :, :]
413412output = sess .run ([output_name ], {input_name : input })
414- output = np .asarray ((output + 1 ) * 127.5 , dtype = np .uint8 )
415-
413+ output = output [0 ]
414+
415+ output1 = net (tlx .convert_to_tensor (input ))
416+ output1 = tlx .convert_to_numpy (output1 )
417+ output = np .asarray ((output [0 ] + 1 ) * 127.5 , dtype = np .uint8 )
418+ output1 = np .asarray ((output1 [0 ] + 1 ) * 127.5 , dtype = np .uint8 )
419+ output = np .transpose (output , axes = [1 , 2 , 0 ])
420+ output1 = np .transpose (output1 , axes = [1 , 2 , 0 ])
416421plt .figure ()
417- plt .subplot (1 ,3 ,1 )
418- plt .plot (valid_hr_img )
422+ plt .subplot (1 ,2 ,1 )
423+ plt .title ("ONNX" )
424+ plt .imshow (output )
419425
420- plt .subplot (1 ,3 ,2 )
421- plt .plot (valid_lr_img )
426+ plt .subplot (1 ,2 ,2 )
427+ plt .title (tlx .BACKEND )
428+ plt .imshow (output1 )
422429
423- plt .subplot (output )
424430plt .show ()
0 commit comments