Technical Roadmap
Stage 1 (backbone). Given partial observations \(u_{obs} = M \odot u_0\) and a binary mask \(M\), split the observed support into a context mask \(M_{ctx}\) (model input) and a query mask \(M_{qry}\) (loss region). Train a diffusion denoiser \(u_\phi\) to predict clean data from context only.
Stage 2a (Work I). Sample \(M_{ctx}\) with the same structural pattern as \(p_{\text{mask}}(M)\) so that every observed dimension can be queried with positive probability. At inference, ensemble over context masks to bridge the train/test conditioning mismatch.
Stage 2b (Work II). Pretrain a Bayesian Flow Network on \(p(M)\), then form \(M_{ctx} = M_1 \odot M_2\) from two i.i.d. mask draws. Intersection yields strict positivity for any spatial topology; observation-aligned guidance anchors generation to real occlusions.
1. Problem Setup
Complete data \(u_0 \in \mathbb{R}^d \sim p_{\text{data}}\), binary observation mask \(M \in \{0,1\}^d \sim p_{\text{mask}}(M)\) with \(p_{\text{mask}}(M \mid u_0) = p_{\text{mask}}(M)\). Only partial observations
\[u_{obs} = M \odot u_0\]are available — no complete sample ever appears in training. Goal: learn \(p_\theta(u_0 \mid u_{obs}, M)\).
2. Unified Context-Query Backbone
Hierarchical masking
Treat \(u_{obs}\) as “complete within its support”, then split it into a context mask \(M_{ctx} \subseteq M\) (model input) and a query mask \(M_{qry} \subseteq M\) (loss region). With the noisy state
\[u_{obs,t} = M \odot (\alpha_t u_{obs} + \sigma_t \epsilon),\]train \(u_\phi\) to predict clean data from the context only:
\[\mathcal{L}(t, u_{obs}, M_{ctx}, M_{qry}) = \big\| M_{qry} \odot \big( u_\phi(t,\, M_{ctx} \odot u_{obs,t},\, M_{ctx}) - u_{obs} \big) \big\|^2\]The model sees only \(M_{ctx}\) and the masked input — it must infer the withheld region, not memorize it.
Core result
Minimizing the loss gives, per dimension \(i\),
\[\big(u_\phi(t, M_{ctx} \odot u_{obs,t}, M_{ctx})\big)_i = \begin{cases} \mathbb{E}\!\left[(u_0)_i \mid M_{ctx} \odot u_{obs,t},\, M_{ctx}\right], & P((M_{qry})_i = 1 \mid M_{ctx}) > 0 \\[4pt] \text{arbitrary}, & P((M_{qry})_i = 1 \mid M_{ctx}) = 0 \end{cases}\]and if the union of possible \(M_{qry}\) covers all dimensions,
\[u_\phi = \mathbb{E}[u_0 \mid M_{ctx} \odot u_{obs,t}, M_{ctx}].\]The expected squared gradient and the parameter-update frequency for dimension \(i\) both scale with \(p_i := P((M_{qry})_i = 1 \mid M_{ctx})\) — a zero-query dimension simply never gets a gradient.
Design principle
The whole framework reduces to one requirement (plus a balance condition):
\[P((M_{qry})_i = 1 \mid M_{ctx}) > 0 \quad \forall\, i:\ (M_{ctx})_i = 0, \qquad P((M_{qry})_i = 1 \mid M_{ctx}) \approx P((M_{qry})_j = 1 \mid M_{ctx})\]The two works below are two answers to “how do we guarantee this strict positivity?” — one heuristic, one generative.
3. Work I — Distribution-Preserving Partitioning
Partitioning
Decompose the query probability by the law of total probability:
\[P((M_{qry})_i = 1 \mid M_{ctx}) = \sum_{M} P((M_{qry})_i = 1 \mid M_{ctx}, M)\, P(M \mid M_{ctx})\]Since \(M_{ctx} \subseteq M\), sampling \(M_{ctx}\) with the same structural pattern as \(p_{\text{mask}}(M)\) (i.i.d. pixels → drop pixels independently; block-structured → drop whole blocks) makes multiple \(M\) compatible with a given \(M_{ctx}\), so both factors can be simultaneously positive → positivity holds for every observed dimension.
Too few context points → large information gap, high variance, slow convergence; too many → \(p_i\) tiny, infrequent updates. A moderate ratio is optimal.
Ensemble inference
Inference is a train/test mismatch: the model gives \(\mathbb{E}[u_0 \mid M_{ctx} \odot u_{obs,t}, M_{ctx}]\) but we want \(\mathbb{E}[u_0 \mid u_{obs,t}, M]\). Bridge it by ensembling over context masks.
Single-step sampling. Apply minimal noise \(u_\delta = \alpha_\delta u_{obs} + \sigma_\delta \epsilon\), \(0 < \delta \ll 1\) (so \(M \odot u_\delta \approx u_{obs}\)), then average \(K\) context masks:
\[\hat u^* = \mathbb{E}[u_0 \mid u_{obs}, M] \approx \frac{1}{K} \sum_{k=1}^{K} u_\phi\!\left(\delta,\, M_{ctx}^{(k)} \odot u_{obs,\delta},\, M_{ctx}^{(k)}\right)\]Why it works. With \(u_\phi(t, \text{ctx}) = \mathbb{E}[u_0 \mid \text{ctx}] + b(\text{ctx}) + \epsilon_{\text{bias}}(\text{ctx})\) and \(\hat\mu_K = \frac{1}{K} \sum_k u_\phi(t, \text{ctx}^{(k)})\):
\[\mathbb{E}\big[\|\hat\mu_K - \mathbb{E}[u_0 \mid \text{obs}]\|^2\big] = \big\|\underbrace{\mathbb{E}[\mathbb{E}[u_0 \mid \text{ctx}]] - \mathbb{E}[u_0 \mid \text{obs}]}_{\text{information gap}} + \underbrace{\mathbb{E}[b(\text{ctx})]}_{\text{model bias}}\big\|^2 + \frac{1}{K}\big(\underbrace{\mathrm{Var}[\mathbb{E}[u_0 \mid \text{ctx}]]}_{\text{data variance}} + \underbrace{\mathrm{Var}[b(\text{ctx})] + \mathrm{Var}[\epsilon_{\text{bias}}]}_{\text{model variance}}\big)\] \[\lim_{K \to \infty} \mathbb{E}\big[\|\hat\mu_K - \mathbb{E}[u_0 \mid \text{obs}]\|^2\big] = \big\|\mathbb{E}[\mathbb{E}[u_0 \mid \text{ctx}]] - \mathbb{E}[u_0 \mid \text{obs}] + \mathbb{E}[b(\text{ctx})]\big\|^2\]Averaging kills all variance terms; residual error = information gap + systematic bias only.
Multi-step sampling. When the posterior is not near-deterministic, replace each denoising step’s denoiser with a weighted combination of two ensemble estimates:
\[\hat u_\phi(t, u_t, u_{obs}, M) \approx \omega_t\, \mathbb{E}[u_0 \mid u_t] + (1 - \omega_t)\, \mathbb{E}[u_0 \mid u_{obs}, M]\]with \(\omega_t\) monotonically increasing \(0 \to 1\). Each term is a Monte-Carlo average over masks (random masks for the first, context masks \(\subseteq M\) for the second). Sampling follows a RePaint-style scheme: estimate noise from the model on unobserved entries, compute it directly from known clean values on observed entries, merge via
\[\epsilon_{\text{full}} = M \odot \epsilon_{obs} + (1 - M) \odot \epsilon_{\text{unobs}},\]and step the diffusion ODE — keeping observations consistent throughout.
4. Work II — Generative-Prior Partitioning
Motivation
The distribution-preserving rule must be hand-crafted per observation pattern (pixel dropout, block occlusion, …), introduces pattern-specific hyperparameters, and is mathematically hard to universally guarantee positivity over complex real spatial dependencies. Goal: a single mechanism valid for any spatial topology.
Intersection gives positivity for free
Model the true mask distribution \(p(M)\) with a generative prior; draw two i.i.d. masks \(M_1, M_2 \sim p(M)\) and set
\[M_{ctx} = M_1 \odot M_2, \qquad M_{qry} = M_1 \odot (1 - M_{ctx})\]Theorem (strict positivity via intersection): for any valid \(M_{ctx} = m\) and any \(i\) with \(m_i = 0\),
\[P((M_{qry})_i = 1 \mid M_{ctx} = m) > 0.\]This removes all heuristic design — but now requires sampling the complex discrete \(p(M)\).
Modeling p(M) with a BFN
A Bayesian Flow Network handles discrete binary masks, recast into a continuous diffusion-style form so we can intervene by gradients in latent space:
- Scaled-logit target: encode class \(c\) as \(u_0 = Ke_c\), since \(\mathrm{softmax}(Ke_c) \to e_c\) for large \(K\) — discrete generation becomes continuous regression.
- Forward process: \(x_t = \alpha_t x_0 + \sigma_t \epsilon\), \(\epsilon \sim \mathcal{N}(0, I)\).
- Shift invariance: the score inherits invariance to adding a constant across logits, enforced architecturally by feeding \(\mathrm{softmax}(x_t)\) as input.
- Discrete data-matching objective:
- Sampling: integrate the probability-flow ODE backward via Tweedie,
decode \(e_c = \arg\max \frac{1}{K}(x_0 + 1)\).
Observation-aligned conditioning
Unconditional intersection can overlap the real \(M\) too little → context too sparse. Anchor the generation to the actual observations via classifier guidance:
- Stochastic anchor: \(y_i = \mathbf{1}[r_i < \rho] \cdot M_i\), \(r_i \sim \mathrm{Uniform}(0,1)\) — randomly retaining a fraction \(\rho\) of observed points injects diversity (full anchoring collapses to one deterministic mask).
- Globally-normalized guidance loss:
Global normalization stabilizes the gradient across very different sparsity levels.
- Latent intervention:
Theorem (positivity preserved under guidance): under the ratio-guided intersection constraint, for any observed dimension \(i\) (\(M_i = 1\)), \(P((M_{qry})_i = 1 \mid C_k) > 0\) — conditioning does not break the guarantee.
Training
Same context-query objective as the backbone, with the generated partition:
\[M_{ctx} = \hat M \odot M, \quad M_{qry} = M \odot (1 - M_{ctx}), \qquad \mathcal{L} = \mathbb{E}\big[\| M_{qry} \odot (u_\phi(M_{ctx} \odot u_{obs}, M_{ctx}) - u_{obs}) \|^2 \big]\]5. Summary
| Route | How positivity is guaranteed | Trade-off |
|---|---|---|
| Work I (distribution-preserving) | Sample \(M_{ctx}\) with the same structure as \(p_{\text{mask}}\); ensemble at inference | Pattern-specific heuristics; hard to cover complex spatial dependencies |
| Work II (generative-prior) | Intersection of two i.i.d. masks from \(p(M)\); observation-aligned guidance | Requires learning and sampling \(p(M)\), but holds for any topology |
Positivity now holds by construction for every valid topology, so the model transfers reconstruction from synthetic training masks to the genuinely-missing regions at test time — without ever seeing a complete field.