Skip to content
11 changes: 5 additions & 6 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,12 @@ def __init__(
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.proj_type == "perceptron":
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)}
# for 3d: "b c (h 16) (w 16) (d 16) -> b (h w d) (16 16 16 c)"
dim_names = ("h", "w", "d")[:spatial_dims]
from_chars = "b c " + " ".join(f"({name} {psize})" for name, psize in zip(dim_names, patch_size))
to_chars = f"b ({' '.join(dim_names)}) ({' '.join(str(p) for p in patch_size)} c)"
self.patch_embeddings = nn.Sequential(
Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)
Rearrange(f"{from_chars} -> {to_chars}"), nn.Linear(self.patch_dim, hidden_size)
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
Expand Down
Loading