|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""ROI-aware preprocessing: denoise outside driving corridor to save bits for encoder.""" |
| 3 | +import argparse |
| 4 | +import sys |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +import av |
| 8 | +import torch |
| 9 | +import torch.nn.functional as F |
| 10 | +from PIL import Image, ImageDraw, ImageFilter |
| 11 | + |
| 12 | +ROOT = Path(__file__).resolve().parents[2] |
| 13 | +if str(ROOT) not in sys.path: |
| 14 | + sys.path.insert(0, str(ROOT)) |
| 15 | + |
| 16 | +from frame_utils import yuv420_to_rgb |
| 17 | + |
| 18 | + |
| 19 | +def collapse_chroma(x: torch.Tensor, mode: str) -> torch.Tensor: |
| 20 | + if mode == "normal": |
| 21 | + return x |
| 22 | + k = {"soft": 1, "medium": 2, "strong": 4}[mode] |
| 23 | + uv = x[:, 1:3] |
| 24 | + uv = F.avg_pool2d(uv, kernel_size=k * 2 + 1, stride=1, padding=k) |
| 25 | + x[:, 1:3] = uv |
| 26 | + return x |
| 27 | + |
| 28 | + |
| 29 | +def apply_luma_denoise(x: torch.Tensor, strength: float) -> torch.Tensor: |
| 30 | + if strength <= 0: |
| 31 | + return x |
| 32 | + kernel_size = 3 if strength <= 2.0 else 5 |
| 33 | + sigma = max(0.1, strength * 0.35) |
| 34 | + coords = torch.arange(kernel_size, device=x.device) - kernel_size // 2 |
| 35 | + g = torch.exp(-(coords ** 2) / (2 * sigma * sigma)) |
| 36 | + kernel_1d = (g / g.sum()).float() |
| 37 | + kernel_2d = torch.outer(kernel_1d, kernel_1d).view(1, 1, kernel_size, kernel_size) |
| 38 | + y = x[:, 0:1] |
| 39 | + y_blur = F.conv2d(y, kernel_2d, padding=kernel_size // 2) |
| 40 | + blend = min(0.9, strength / 3.0) |
| 41 | + x[:, 0:1] = (1 - blend) * y + blend * y_blur |
| 42 | + return x |
| 43 | + |
| 44 | + |
| 45 | +def rgb_to_yuv(rgb: torch.Tensor) -> torch.Tensor: |
| 46 | + r, g, b = rgb[:, 0:1], rgb[:, 1:2], rgb[:, 2:3] |
| 47 | + y = 0.299 * r + 0.587 * g + 0.114 * b |
| 48 | + u = (b - y) / 1.772 + 128.0 |
| 49 | + v = (r - y) / 1.402 + 128.0 |
| 50 | + return torch.cat([y, u, v], dim=1) |
| 51 | + |
| 52 | + |
| 53 | +def yuv_to_rgb(yuv: torch.Tensor) -> torch.Tensor: |
| 54 | + y = yuv[:, 0:1] |
| 55 | + u, v = yuv[:, 1:2] - 128.0, yuv[:, 2:3] - 128.0 |
| 56 | + r = y + 1.402 * v |
| 57 | + g = y - 0.344136 * u - 0.714136 * v |
| 58 | + b = y + 1.772 * u |
| 59 | + return torch.cat([r, g, b], dim=1) |
| 60 | + |
| 61 | + |
| 62 | +def segment_polygon(frame_idx: int, width: int, height: int) -> list[tuple[float, float]]: |
| 63 | + segments = [ |
| 64 | + (0, 299, [(0.14, 0.52), (0.82, 0.48), (0.98, 1.00), (0.05, 1.00)]), |
| 65 | + (300, 599, [(0.10, 0.50), (0.76, 0.47), (0.92, 1.00), (0.00, 1.00)]), |
| 66 | + (600, 899, [(0.18, 0.50), (0.84, 0.47), (0.98, 1.00), (0.06, 1.00)]), |
| 67 | + (900, 1199, [(0.22, 0.52), (0.90, 0.49), (1.00, 1.00), (0.10, 1.00)]), |
| 68 | + ] |
| 69 | + for start, end, poly in segments: |
| 70 | + if start <= frame_idx <= end: |
| 71 | + return [(x * width, y * height) for x, y in poly] |
| 72 | + return [(0.15 * width, 0.52 * height), (0.85 * width, 0.48 * height), (width, height), (0, height)] |
| 73 | + |
| 74 | + |
| 75 | +def build_mask(frame_idx: int, width: int, height: int, feather_radius: int) -> torch.Tensor: |
| 76 | + img = Image.new("L", (width, height), 0) |
| 77 | + draw = ImageDraw.Draw(img) |
| 78 | + draw.polygon(segment_polygon(frame_idx, width, height), fill=255) |
| 79 | + if feather_radius > 0: |
| 80 | + img = img.filter(ImageFilter.GaussianBlur(radius=feather_radius)) |
| 81 | + mask = torch.frombuffer(memoryview(img.tobytes()), dtype=torch.uint8).clone().view(height, width).float() / 255.0 |
| 82 | + return mask.unsqueeze(0).unsqueeze(0) |
| 83 | + |
| 84 | + |
| 85 | +def process_frame( |
| 86 | + frame_rgb: torch.Tensor, |
| 87 | + frame_idx: int, |
| 88 | + outside_luma_denoise: float, |
| 89 | + outside_chroma_mode: str, |
| 90 | + feather_radius: int, |
| 91 | + outside_blend: float, |
| 92 | +) -> torch.Tensor: |
| 93 | + chw = frame_rgb.permute(2, 0, 1).float().unsqueeze(0) |
| 94 | + mask = build_mask(frame_idx, chw.shape[-1], chw.shape[-2], feather_radius).to(chw.device) |
| 95 | + yuv = rgb_to_yuv(chw) |
| 96 | + processed = yuv.clone() |
| 97 | + processed = apply_luma_denoise(processed, outside_luma_denoise) |
| 98 | + processed = collapse_chroma(processed, outside_chroma_mode) |
| 99 | + processed_rgb = yuv_to_rgb(processed) |
| 100 | + outside_alpha = (1.0 - mask) * outside_blend |
| 101 | + mixed = chw * (1.0 - outside_alpha) + processed_rgb * outside_alpha |
| 102 | + return mixed.clamp(0, 255).round().to(torch.uint8).squeeze(0).permute(1, 2, 0) |
| 103 | + |
| 104 | + |
| 105 | +def main() -> None: |
| 106 | + parser = argparse.ArgumentParser() |
| 107 | + parser.add_argument("--input", type=Path, required=True) |
| 108 | + parser.add_argument("--output", type=Path, required=True) |
| 109 | + parser.add_argument("--outside-luma-denoise", type=float, default=2.5) |
| 110 | + parser.add_argument("--outside-chroma-mode", type=str, default="medium") |
| 111 | + parser.add_argument("--feather-radius", type=int, default=24) |
| 112 | + parser.add_argument("--outside-blend", type=float, default=0.60) |
| 113 | + args = parser.parse_args() |
| 114 | + |
| 115 | + in_container = av.open(str(args.input)) |
| 116 | + in_stream = in_container.streams.video[0] |
| 117 | + width, height = in_stream.width, in_stream.height |
| 118 | + |
| 119 | + out_container = av.open(str(args.output), mode="w") |
| 120 | + out_stream = out_container.add_stream("ffv1", rate=20) |
| 121 | + out_stream.width = width |
| 122 | + out_stream.height = height |
| 123 | + out_stream.pix_fmt = "yuv420p" |
| 124 | + |
| 125 | + for frame_idx, frame in enumerate(in_container.decode(in_stream)): |
| 126 | + rgb = yuv420_to_rgb(frame) |
| 127 | + out_rgb = process_frame( |
| 128 | + rgb, frame_idx, |
| 129 | + outside_luma_denoise=args.outside_luma_denoise, |
| 130 | + outside_chroma_mode=args.outside_chroma_mode, |
| 131 | + feather_radius=args.feather_radius, |
| 132 | + outside_blend=args.outside_blend, |
| 133 | + ) |
| 134 | + video_frame = av.VideoFrame.from_ndarray(out_rgb.cpu().numpy(), format="rgb24") |
| 135 | + for packet in out_stream.encode(video_frame): |
| 136 | + out_container.mux(packet) |
| 137 | + |
| 138 | + for packet in out_stream.encode(): |
| 139 | + out_container.mux(packet) |
| 140 | + |
| 141 | + out_container.close() |
| 142 | + in_container.close() |
| 143 | + |
| 144 | + |
| 145 | +if __name__ == "__main__": |
| 146 | + main() |
0 commit comments