SpriteDX - Pixel Alignment - Lab Note 7

We built anti-corruption model V1 and integrated it into SpriteDX. However, we are seeing some quality issues. I want to start a Lab Note to go over the issues and plan for mitigation.
V1 Model
Architecture of V1 Model is UNet with skip connections. Given RGB pixel arts, it computes corrections needed to make it a true pixel art in RGBA. Here are the building blocks of the UNet.
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.pool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.pool_conv(x)
class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
self.conv = DoubleConv(in_channels + in_channels // 2, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
Then here is the Unet defintion:
class AntiCorruptionUNet(nn.Module):
"""
U-Net for anti-corruption: RGB (3 channels) -> RGBA (4 channels)
Input: 128x128x3 (corrupted RGB)
Output: 128x128x4 (clean RGBA)
"""
def __init__(self):
super().__init__() # (B, 3, 128, 128) Input
self.inc = DoubleConv(3, 64) # (B, 64, 128, 128) Full Res DoubleConv
self.down1 = Down(64, 128) # (B, 128, 64, 64) 2x2 Patch
self.down2 = Down(128, 256) # (B, 256, 32, 32) 4x4 Patch
self.down3 = Down(256, 512) # (B, 512, 16, 16) 8x8 Patch
self.down4 = Down(512, 512) # (B, 512, 8, 8) 16x16 Patch
self.up1 = Up(512, 256) # (B, 256, 16, 16) 8x8 Patch
self.up2 = Up(256, 128) # (B, 128, 32, 32) 4x4 Patch
self.up3 = Up(128, 64) # (B, 64, 64, 64) 2x2 Patch
self.up4 = Up(64, 64) # (B, 64, 128, 128) 1x1 Patch
self.out_rgb = nn.Conv2d(64, 3, 1) # (B, 3, 128, 128) RGB Out
self.out_alpha = nn.Conv2d(64, 1, 1) # (B, 1, 128, 128) Alpha Out
# Initialize output layers for identity-like behavior
# RGB: small weights so residual corrections are small
nn.init.normal_(self.out_rgb.weight, mean=0.0, std=0.001)
nn.init.zeros_(self.out_rgb.bias)
# Alpha: keep random initialization
nn.init.normal_(self.out_alpha.weight, mean=0.0, std=0.001)
nn.init.zeros_(self.out_alpha.bias)
def forward(self, x):
input_rgb = x # Save input for residual connection
# Encoder with skip connections
x1 = self.inc(x) # 128x128x64
x2 = self.down1(x1) # 64x64x128
x3 = self.down2(x2) # 32x32x256
x4 = self.down3(x3) # 16x16x512
x5 = self.down4(x4) # 8x8x512
# Decoder with skip connections
x = self.up1(x5, x4) # 16x16x256
x = self.up2(x, x3) # 32x32x128
x = self.up3(x, x2) # 64x64x64
x = self.up4(x, x1) # 128x128x64
# Separate output heads with residual connection for RGB
rgb_residual = self.out_rgb(x)
rgb = input_rgb + rgb_residual # Identity + learned corrections, unbounded for full gradient flow
alpha = torch.sigmoid(self.out_alpha(x))
return torch.cat([rgb, alpha], dim=1)
Dataset used to train has 909 training samples and 16 validation samples. We want to increase this however we can. These samples are true-pixel sprites and the model programmatically corrupts them (noise, blur, sub-pixel translation, downsample, etc) to produce input image, and the true-pixel sprites are used to as training targets. There are quite a bit of hyper parameters we hand-tune to make the corrupted samples look like a downsampled version from a AI generated pixel arts. We train for 200 epochs and pick the best performing one in terms of loss number from validation result. Data augmentation is heavily used to randomly crop and adjust the hue, saturations, etc. We do not use rotations yet.
Loss Function: Heavy focus is done on designing the loss function. We heavily utilize focal loss. Let’s review what exactly happens there.
def compute_weighted_loss(output, target, alpha_weight=1.0):
rgb_diff = torch.abs(output[:, :3] - target[:, :3])
alpha_mask = target[:, 3:4]
alpha_binary = (alpha_mask > 0.5).float()
dilated_mask = F.max_pool2d(alpha_binary, kernel_size=3, stride=1, padding=1)
target_rgb = target[:, :3]
grad_left = torch.abs(target_rgb[:, :, :, 1:] - target_rgb[:, :, :, :-1])
grad_right = torch.abs(target_rgb[:, :, :, :-1] - target_rgb[:, :, :, 1:])
grad_top = torch.abs(target_rgb[:, :, 1:, :] - target_rgb[:, :, :-1, :])
grad_bottom = torch.abs(target_rgb[:, :, :-1, :] - target_rgb[:, :, 1:, :])
grad_left = F.pad(grad_left, (0, 1, 0, 0))
grad_right = F.pad(grad_right, (1, 0, 0, 0))
grad_top = F.pad(grad_top, (0, 0, 0, 1))
grad_bottom = F.pad(grad_bottom, (0, 0, 1, 0))
gradient_magnitude = (grad_left + grad_right + grad_top + grad_bottom).mean(dim=1, keepdim=True)
grad_normalized = gradient_magnitude / (gradient_magnitude.mean() + 1e-8)
contrast_weight = 1.0 + 2.0 * grad_normalized
combined_weight = dilated_mask * contrast_weight
# RGB loss is high when it is opaque pixel or is adjacent to opaque
# pixel. Also, higher the local contrast level, higher will be the loss.
rgb_loss = (rgb_diff * combined_weight).mean()
whiteness = target[:, :3].mean(dim=1, keepdim=True)
white_threshold = 240.0 / 255.0
white_boost = torch.clamp((whiteness - white_threshold) / (1.0 - white_threshold), 0, 1)
white_weight = 1.0 + 1.0 * white_boost
alpha_diff = torch.abs(output[:, 3:4] - target[:, 3:4])
# Alpha loss is big when the color is white. This is called "white boost"
# we use this boost because often times the inference inputs have white
# clothes on top of white bg. Retro sprites love pure white. LOL.
alpha_loss = (alpha_diff * white_weight).mean()
# Then we do a weighted sum of these losses.
loss = rgb_loss + alpha_weight * alpha_loss
return loss, rgb_loss, alpha_loss
Metrics
After training 200 epochs, we got smallest validation loss of 0.215206 at Epoch 194.
Success Metric
If “success“ means SpriteDX can reliably turn corrupted RGB into clean pixel-art RGBA (crisp silhouette + stable colors), then a single scalar “accuracy” should reward (1) correct alpha mask, (2) correct colors where opaque, and (3) extra correctness on the boundary band (where most artifacts show up).
A good, debuggable metric bundle we could use:
Components
Mask IoU (or F1) between predicted alpha mask and target alpha mask
Opaque color accuracy: among pixels where target is opaque, fraction whose RGB error is within a small tolerance (e.g.
<= 2/255per channel orL∞ <= 2/255).Edge-band accuracy: same as (2) but on a thin band around the silhouette (dilated - eroded). Use a stricter tolerance because edge errors are very visible.
Background suppression accuracy (optional but useful): among pixels where target is transparent, fraction where predicted alpha is below a small threshold (e.g.
< 0.05).
Then combine them in to one score (weighted) so your can track a single number while still seeing why it moved.
def compute_accuracy(
pred_rgba: torch.Tensor,
tgt_rgba: torch.Tensor,
):
B, _, H, W = pred_rgba.shape
pred_rgb = pred_rgba[:, :3].clamp(0.0, 1.0)
pred_a = pred_rgba[:, 3:4].clamp(0..0, 1.0)
tgt_rgb = tgt_rgba[:, :3].clamp(0.0, 1.0)
tgt_a = tgt_rgba[:, 3:4].clamp(0.0, 1.0)
tgt_mask = (tgt_a > 0.5).float()
pred_mask = (pred_a > 0.5).float()
# IOU avoids "predict all zeros and looks good."
inter = (tgt_mask * pred_mask).sum(dim=(2,3))
union = (tgt_mask + pred_mask - tgt_mask * pred_mask).sum(dim(2,3)).clamp_min(1e-8)
iou = (inter / union)
# Dice (overall mass test)
pred_sum = pred_mask.sum(dim=(2,3))
tgt_sum = tgt_mask.sum(dim=(2,3))
dice = (2 * inter / (pred_sum + tgt_sum).clamp_min(1e-8))
# RGB L-infinity (caveat: this one is very very noisy)
rgb_tol = 2.0 / 255.0
rgb_err = (pred_rgb - tgt_rbg).abs().amax(dim=1, keepdim=True)
opaque = (tgt_a > 0.5).float()
opaque_ok = ((rgb_err <= rgb_tol).float() * opaque).sum(dim=(2,3))
opaque_den = opaque.sum(dim=(2,3)).clamp_min(1e-8)
opaque_rbg_acc = (opaque_ok / opaque_den)
# Edge Band (caveat: this also measures opaque areas)
edge_rgb_tol = 1.0/255.0
edge_band_radius = 1
k = 2 * edge_band_radius + 1
dil = F.max_pool2d(tgt_mask, kernel_size=k, stride=1, padding=edge_band_radius)
ero = -F.max_pool2d(-tgt_mask, kernel_size=k, stride=1, padding=edge_band_radius)
edge_band = (dil - ero).clamp(0.0, 1.0)
edge_ok = ((rgb_err <= edge_rgb_tol).float() * edge_band).sum(dim=(2,3))
edge_dem = edge_band.sum(dim=(2,3)).clamp_min(1e-8)
edge_rgb_acc = (edge_ok / edge_den)
# BG loss
bg_alpha_tol = 0.05
transparent = (tgt_a <= 0.5).float()
bg_ok = ((pred_a <= bg_alpha_tol).float() * transparent).sum(dim=(2,3))
bg_den = transparent.sum(dim=(2,3)).clamp_min(1e-8)
bg_alpha_acc = (bg_ok / bg_den)
accuracy = (
0.35 * iou +
0.25 * edge_rgb_acc +
0.25 * opaque_rgb_acc +
0.15 * bg_alpha_acc
)
return {
"accuracy": accuracy.mean()
}
This design is insightful but I want to come up with something on my own.
Related Papers
ToonOut: https://arxiv.org/pdf/2509.06839
BiRefNet: https://arxiv.org/pdf/2401.03407
PhotoRoom: https://www.photoroom.com/tools/background-remover
Notes
PhotoRoom provides a low-cost high-quality option to making images transparent. 1024×1024 spritesheet (2 cents) at a much better quality than Flux1.Kontext (~10 cents). I think we should a
PhotoRoom also fits into the picture of dataset generation.
BiRefNet mentions “non-salient features in image objects can be well reflected by obtaining gradient features through derivative operations on the original image.”
It also mentions, “when certain positions exhibit high similarity in color and texture to the background, the gradient features are probably too weak.” In these cases, they introduce ground-truth features for side supervision.
I think I reading and analyzing BiRefNet will provide important insights. Let’s focus on that.
I’ll read that first then come back to this task.
— Sprited Dev 🐛


![[WIP] Digital Being - Texture v1](/_next/image?url=https%3A%2F%2Fcdn.hashnode.com%2Fuploads%2Fcovers%2F682665f051e3d254b7cd5062%2F0a0b4f8e-d369-4de0-8d46-ee0d7cc55db2.webp&w=3840&q=75)

