Skip to content

SANA-Denoiser - Our Architecture Design

Repurposing SANA 1.6B DiT as an image restoration model. Combines efficient linear attention with Paired Training for Restoration and Temporal Tiling via Block Causal Linear Attention.

Why SANA for Restoration

Property SANA 1.6B Step1X-Edit (RealRestorer) FLUX-dev
Params 1.6B ~15B 12B
Attention Linear O(N) Quadratic O(N^2) Quadratic O(N^2)
VAE compression 32x (DC AE) 8x 8x
Tokens at 1024px 1024 16384 4096
Tokens at 4K 16384 262144 (!) 65536
Speed (1024px) 1.2s ~15s 23s

SANA is 10x smaller, 4x fewer tokens, linear complexity. For restoration where we need high-res processing, this is decisive.

Architecture Changes (Minimal)

1. Input Conditioning: Channel Concat

degraded → DC-AE.encode → condition_latents [B, 32, H, W]
target   → DC-AE.encode → latents           [B, 32, H, W]

x_t = (1-σ)*noise + σ*latents               [B, 32, H, W]
model_input = concat([x_t, condition_latents], dim=1)  [B, 64, H, W]

projection = Conv2d(64, 32, 1)               # 1x1 conv, ~1K params
# Identity init for noise channels, zero init for condition channels
# At step 0: model = pretrained T2I behavior
# Condition signal learned gradually during fine-tuning

model(projection(model_input), timestep, text_embeddings)

Total new parameters: 1,024 (32 x 32 x 1 x 1 conv kernel). Compare: ControlNet = ~800M.

2. Text Conditioning for Degradation Type

Prompt describes what to restore: - "Remove gaussian noise, restore sharp details" - "Remove JPEG compression artifacts" - "Enhance this low-light image" - "Clean and restore this image"

Leverages SANA's Gemma-2-2B text encoder for degradation-type understanding.

3. Temporal Tiling for High-Resolution

For images > training resolution (e.g., 4K product photos):

4096x4096 image
  ↓ split into overlapping 1024px tiles (raster scan)
  ↓ each tile: DC-AE encode → 32x32x32 latent
  ↓ denoise with Block Causal Linear Attention
  ↓    (running sum S, Z from previous tiles = global context)
  ↓ stitch latents with linear blending in overlap
  ↓ DC-AE decode full stitched latent
4096x4096 restored image

Memory: constant O(D^2) cache + one tile latent. Processes any resolution.

Training Strategy

Phase 1: LoRA (fast iteration)

  • Rank 32, target: attn.to_q/k/v/out + input projection conv
  • 512px, 10K steps, DIV2K + Flickr2K synthetic degradation
  • Evaluate: does it learn to denoise at all?

Phase 2: Full Fine-Tune (if LoRA insufficient)

  • Unfreeze all transformer params + projection
  • VAE stays frozen
  • Gradient checkpointing for memory
  • Curriculum: 512px → 1024px

Phase 3: Temporal Tiling (inference-only first)

  • No retraining needed - causal attention is native to linear attention
  • Just implement the tile loop + S, Z accumulation
  • If quality insufficient: fine-tune with multi-tile samples

Dataset

Source: DIV2K (800) + Flickr2K (2650) = 3450 clean images Degradation: 5-8 variants per image = 17K-28K pairs

Degradation Params Prompt
Gaussian noise σ=10,15,25,35,50 "Remove gaussian noise sigma {σ}"
JPEG q=15,25,40 "Remove JPEG artifacts quality {q}"
Blur k=3,5,7,9 "Remove blur, restore sharpness"
Downscale 2x,3x,4x "Upscale and restore details"
Combined 2-3 random "Restore this degraded image"

Evaluation Targets

Benchmark Metric Target SOTA Reference
SIDD val PSNR > 38 dB NAFNet: 40.3
SIDD val SSIM > 0.95 NAFNet: 0.96
DIV2K (σ=25) PSNR > 30 dB SwinIR: 30.9
Urban100 (σ=25) PSNR > 29 dB SwinIR: 29.5
Temporal tiling Seam PSNR > 40 dB MultiDiffusion baseline

Project Files

happyin-research/
├── sana-fm/
│   ├── data/paired_dataset.py      ← paired loader
│   ├── data/degradation.py         ← degradation functions
│   ├── configs/img2img_denoise.yaml
│   └── train_flowmatching.py       ← modified compute_loss
├── sana-denoiser/
│   ├── prepare_dataset.py          ← DIV2K + Flickr2K + degradations
│   ├── train.py                    ← wrapper
│   ├── temporal_tiling.py          ← tile-as-sequence inference
│   └── eval/benchmark.py           ← vs SwinIR, NAFNet

Risk Assessment

Risk Likelihood Mitigation
DC-AE 32x compression loses fine details Medium Compare DC-AE reconstruction vs 8x VAE on jewelry textures
Linear attention insufficient for restoration Low SANA matches quadratic models on generation; restoration is simpler
Temporal tiling adds latency High Acceptable: quality > speed for product photography
1.6B too small for complex degradations Medium Scale to 4.8B if needed; depth-pruning from 4.8B as fallback