Skip to content

Commit cddc00f

Browse files
committed
add dynamic_axes
1 parent f4da04b commit cddc00f

3 files changed

Lines changed: 108 additions & 86 deletions

File tree

tests/export_application/test_export_srgan.py

Lines changed: 80 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
os.environ['TL_BACKEND'] = 'tensorflow'
3+
# os.environ['TL_BACKEND'] = 'torch'
34
from tensorlayerx.nn import Module
45
import tensorlayerx as tlx
56
from tensorlayerx.nn import Conv2d, BatchNorm2d, Elementwise, SubpixelConv2d, UpSampling2d, Flatten, Sequential
@@ -31,24 +32,24 @@ def __init__(self):
3132
super(ResidualBlock, self).__init__()
3233
self.conv1 = Conv2d(
3334
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
34-
data_format='channels_last', b_init=None
35+
data_format='channels_first', b_init=None
3536
)
36-
self.bn1 = BatchNorm2d(num_features=64, act=tlx.ReLU, gamma_init=G_init, data_format='channels_last')
37+
self.bn1 = BatchNorm2d(num_features=64, act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
3738
self.conv2 = Conv2d(
3839
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
39-
data_format='channels_last', b_init=None
40+
data_format='channels_first', b_init=None
4041
)
41-
self.bn2 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_last')
42+
self.bn2 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')
4243
self.add = Add()
4344

4445
def forward(self, x):
45-
temp = x
4646
z = self.conv1(x)
4747
z = self.bn1(z)
4848
z = self.conv2(z)
4949
z = self.bn2(z)
50-
out = self.add(temp,z)
51-
return out
50+
x = self.add(x, z)
51+
return x
52+
5253

5354
class SRGAN_g(Module):
5455
""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
@@ -59,19 +60,19 @@ def __init__(self):
5960
super(SRGAN_g, self).__init__()
6061
self.conv1 = Conv2d(
6162
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME', W_init=W_init,
62-
data_format='channels_last'
63+
data_format='channels_first'
6364
)
6465
self.residual_block = self.make_layer()
6566
self.conv2 = Conv2d(
6667
out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
67-
data_format='channels_last', b_init=None
68+
data_format='channels_first', b_init=None
6869
)
69-
self.bn1 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_last')
70-
self.conv3 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_last')
71-
self.subpiexlconv1 = SubpixelConv2d(data_format='channels_last', scale=2, act=tlx.ReLU)
72-
self.conv4 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_last')
73-
self.subpiexlconv2 = SubpixelConv2d(data_format='channels_last', scale=2, act=tlx.ReLU)
74-
self.conv5 = Conv2d(3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init, data_format='channels_last')
70+
self.bn1 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')
71+
self.conv3 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
72+
self.subpiexlconv1 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
73+
self.conv4 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
74+
self.subpiexlconv2 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
75+
self.conv5 = Conv2d(3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init, data_format='channels_first')
7576
self.add = Add()
7677

7778
def make_layer(self):
@@ -87,7 +88,6 @@ def forward(self, x):
8788
x = self.conv2(x)
8889
x = self.bn1(x)
8990
x = self.add(x, temp)
90-
# x = x + temp
9191
x = self.conv3(x)
9292
x = self.subpiexlconv1(x)
9393
x = self.conv4(x)
@@ -108,26 +108,26 @@ def __init__(self):
108108
super(SRGAN_g2, self).__init__()
109109
self.conv1 = Conv2d(
110110
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
111-
data_format='channels_last'
111+
data_format='channels_first'
112112
)
113113
self.residual_block = self.make_layer()
114114
self.conv2 = Conv2d(
115115
out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
116-
data_format='channels_last', b_init=None
116+
data_format='channels_first', b_init=None
117117
)
118-
self.bn1 = BatchNorm2d(act=None, gamma_init=G_init, data_format='channels_last')
119-
self.upsample1 = UpSampling2d(data_format='channels_last', scale=(2, 2), method='bilinear')
118+
self.bn1 = BatchNorm2d(act=None, gamma_init=G_init, data_format='channels_first')
119+
self.upsample1 = UpSampling2d(data_format='channels_first', scale=(2, 2), method='bilinear')
120120
self.conv3 = Conv2d(
121121
out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
122-
data_format='channels_last', b_init=None
122+
data_format='channels_first', b_init=None
123123
)
124-
self.bn2 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_last')
125-
self.upsample2 = UpSampling2d(data_format='channels_last', scale=(4, 4), method='bilinear')
124+
self.bn2 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
125+
self.upsample2 = UpSampling2d(data_format='channels_first', scale=(4, 4), method='bilinear')
126126
self.conv4 = Conv2d(
127127
out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
128-
data_format='channels_last', b_init=None
128+
data_format='channels_first', b_init=None
129129
)
130-
self.bn3 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_last')
130+
self.bn3 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
131131
self.conv5 = Conv2d(
132132
out_channels=3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init
133133
)
@@ -164,43 +164,43 @@ def __init__(self, ):
164164
super(SRGAN_d2, self).__init__()
165165
self.conv1 = Conv2d(
166166
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
167-
W_init=W_init, data_format='channels_last'
167+
W_init=W_init, data_format='channels_first'
168168
)
169169
self.conv2 = Conv2d(
170170
out_channels=64, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
171-
W_init=W_init, data_format='channels_last', b_init=None
171+
W_init=W_init, data_format='channels_first', b_init=None
172172
)
173-
self.bn1 = BatchNorm2d(gamma_init=G_init, data_format='channels_last')
173+
self.bn1 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
174174
self.conv3 = Conv2d(
175175
out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
176-
W_init=W_init, data_format='channels_last', b_init=None
176+
W_init=W_init, data_format='channels_first', b_init=None
177177
)
178-
self.bn2 = BatchNorm2d(gamma_init=G_init, data_format='channels_last')
178+
self.bn2 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
179179
self.conv4 = Conv2d(
180180
out_channels=128, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
181-
W_init=W_init, data_format='channels_last', b_init=None
181+
W_init=W_init, data_format='channels_first', b_init=None
182182
)
183-
self.bn3 = BatchNorm2d(gamma_init=G_init, data_format='channels_last')
183+
self.bn3 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
184184
self.conv5 = Conv2d(
185185
out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
186-
W_init=W_init, data_format='channels_last', b_init=None
186+
W_init=W_init, data_format='channels_first', b_init=None
187187
)
188-
self.bn4 = BatchNorm2d(gamma_init=G_init, data_format='channels_last')
188+
self.bn4 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
189189
self.conv6 = Conv2d(
190190
out_channels=256, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
191-
W_init=W_init, data_format='channels_last', b_init=None
191+
W_init=W_init, data_format='channels_first', b_init=None
192192
)
193-
self.bn5 = BatchNorm2d(gamma_init=G_init, data_format='channels_last')
193+
self.bn5 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
194194
self.conv7 = Conv2d(
195195
out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
196-
W_init=W_init, data_format='channels_last', b_init=None
196+
W_init=W_init, data_format='channels_first', b_init=None
197197
)
198-
self.bn6 = BatchNorm2d(gamma_init=G_init, data_format='channels_last')
198+
self.bn6 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
199199
self.conv8 = Conv2d(
200200
out_channels=512, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
201-
W_init=W_init, data_format='channels_last', b_init=None
201+
W_init=W_init, data_format='channels_first', b_init=None
202202
)
203-
self.bn7 = BatchNorm2d(gamma_init=G_init, data_format='channels_last')
203+
self.bn7 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
204204
self.flat = Flatten()
205205
self.dense1 = Linear(out_features=1024, act=tlx.LeakyReLU(negative_slope=0.2))
206206
self.dense2 = Linear(out_features=1)
@@ -235,58 +235,58 @@ def __init__(self, dim=64):
235235
super(SRGAN_d, self).__init__()
236236
self.conv1 = Conv2d(
237237
out_channels=dim, kernel_size=(4, 4), stride=(2, 2), act=tlx.LeakyReLU, padding='SAME', W_init=W_init,
238-
data_format='channels_last'
238+
data_format='channels_first'
239239
)
240240
self.conv2 = Conv2d(
241241
out_channels=dim * 2, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
242-
data_format='channels_last', b_init=None
242+
data_format='channels_first', b_init=None
243243
)
244-
self.bn1 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
244+
self.bn1 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
245245
self.conv3 = Conv2d(
246246
out_channels=dim * 4, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
247-
data_format='channels_last', b_init=None
247+
data_format='channels_first', b_init=None
248248
)
249-
self.bn2 = BatchNorm2d(num_features=dim * 4, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
249+
self.bn2 = BatchNorm2d(num_features=dim * 4, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
250250
self.conv4 = Conv2d(
251251
out_channels=dim * 8, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
252-
data_format='channels_last', b_init=None
252+
data_format='channels_first', b_init=None
253253
)
254-
self.bn3 = BatchNorm2d(num_features=dim * 8, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
254+
self.bn3 = BatchNorm2d(num_features=dim * 8, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
255255
self.conv5 = Conv2d(
256256
out_channels=dim * 16, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
257-
data_format='channels_last', b_init=None
257+
data_format='channels_first', b_init=None
258258
)
259-
self.bn4 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
259+
self.bn4 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
260260
self.conv6 = Conv2d(
261261
out_channels=dim * 32, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
262-
data_format='channels_last', b_init=None
262+
data_format='channels_first', b_init=None
263263
)
264-
self.bn5 = BatchNorm2d(num_features=dim * 32, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
264+
self.bn5 = BatchNorm2d(num_features=dim * 32, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
265265
self.conv7 = Conv2d(
266266
out_channels=dim * 16, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
267-
data_format='channels_last', b_init=None
267+
data_format='channels_first', b_init=None
268268
)
269-
self.bn6 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
269+
self.bn6 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
270270
self.conv8 = Conv2d(
271271
out_channels=dim * 8, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
272-
data_format='channels_last', b_init=None
272+
data_format='channels_first', b_init=None
273273
)
274-
self.bn7 = BatchNorm2d(num_features=dim * 8, act=None, gamma_init=G_init, data_format='channels_last')
274+
self.bn7 = BatchNorm2d(num_features=dim * 8, act=None, gamma_init=G_init, data_format='channels_first')
275275
self.conv9 = Conv2d(
276276
out_channels=dim * 2, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
277-
data_format='channels_last', b_init=None
277+
data_format='channels_first', b_init=None
278278
)
279-
self.bn8 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
279+
self.bn8 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
280280
self.conv10 = Conv2d(
281281
out_channels=dim * 2, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
282-
data_format='channels_last', b_init=None
282+
data_format='channels_first', b_init=None
283283
)
284-
self.bn9 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_last')
284+
self.bn9 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
285285
self.conv11 = Conv2d(
286286
out_channels=dim * 8, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
287-
data_format='channels_last', b_init=None
287+
data_format='channels_first', b_init=None
288288
)
289-
self.bn10 = BatchNorm2d(num_features=dim * 8, gamma_init=G_init, data_format='channels_last')
289+
self.bn10 = BatchNorm2d(num_features=dim * 8, gamma_init=G_init, data_format='channels_first')
290290
self.add = Elementwise(combine_fn=tlx.add, act=tlx.LeakyReLU)
291291
self.flat = Flatten()
292292
self.dense = Linear(out_features=1, W_init=W_init)
@@ -390,35 +390,41 @@ def forward(self, x):
390390

391391

392392
net = SRGAN_g()
393-
net.init_build(tlx.nn.Input(shape=(1, 96, 96, 3)))
393+
net.set_eval()
394+
net.init_build(tlx.nn.Input(shape=(1, 3, 96, 96)))
394395
net.load_weights('model/g.npz', format='npz_dict')
395-
input = tlx.nn.Input(shape=(1, 96, 96, 3))
396-
397-
onnx_model = export(net, input_spec=input, path='srgan.onnx')
398-
396+
input = tlx.nn.Input(shape=(1, 3, 96, 96))
397+
onnx_model = export(net, input_spec=input, path='srgan.onnx', dynamic_axes=[2,3])
399398
sess = onnxruntime.InferenceSession('srgan.onnx')
400399

401400
input_name = sess.get_inputs()[0].name
402401
output_name = sess.get_outputs()[0].name
403402

404-
# TODO DYNAMIC INPUT SHAPE
405403
valid_hr_img = tlx.vision.load_image('data/0882.png')
406404
valid_lr_img = np.asarray(valid_hr_img)
407405
hr_size = [valid_hr_img.shape[0], valid_hr_img.shape[1]]
408406
valid_lr_img = cv2.resize(valid_lr_img, dsize=(hr_size[1]//4, hr_size[0]//4))
409407
lr_size = [valid_hr_img.shape[0]//4, valid_hr_img.shape[1]//4]
410408
input = (valid_lr_img / 127.5) - 1
411409
input = np.asarray(input, dtype=np.float32)
410+
input = np.transpose(input, axes=[2, 0, 1])
412411
input = input[np.newaxis, :, :, :]
413412
output = sess.run([output_name], {input_name : input})
414-
output = np.asarray((output + 1) * 127.5, dtype=np.uint8)
415-
413+
output = output[0]
414+
415+
output1 = net(tlx.convert_to_tensor(input))
416+
output1 = tlx.convert_to_numpy(output1)
417+
output = np.asarray((output[0] + 1) * 127.5, dtype=np.uint8)
418+
output1 = np.asarray((output1[0] + 1) * 127.5, dtype=np.uint8)
419+
output = np.transpose(output, axes=[1, 2, 0])
420+
output1 = np.transpose(output1, axes=[1, 2, 0])
416421
plt.figure()
417-
plt.subplot(1,3,1)
418-
plt.plot(valid_hr_img)
422+
plt.subplot(1,2,1)
423+
plt.title("ONNX")
424+
plt.imshow(output)
419425

420-
plt.subplot(1,3,2)
421-
plt.plot(valid_lr_img)
426+
plt.subplot(1,2,2)
427+
plt.title(tlx.BACKEND)
428+
plt.imshow(output1)
422429

423-
plt.subplot(output)
424430
plt.show()

tests/test_batchnorm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# -*- coding: utf-8 -*-
33

44
import os
5-
os.environ["TL_BACKEND"] = 'tensorflow'
5+
# os.environ["TL_BACKEND"] = 'tensorflow'
6+
os.environ['TL_BACKEND'] = 'torch'
67
import tensorlayerx as tlx
78
from tensorlayerx.nn import Module
89
from tensorlayerx.nn import Conv2d, BatchNorm2d
@@ -40,4 +41,4 @@ def forward(self, x):
4041
input_data = np.array(input_data, dtype=np.float32)
4142

4243
result = sess.run([output_name], {input_name: input_data})
43-
print("onnx out", result)
44+
print("onnx out", result)

0 commit comments

Comments
 (0)