CIFAR-100 ResNet-18偏标记梯度可视化

Vibe Coding 使用 AI 回答

代码问题:

帮我写一个偏标记学习可以查看每一个样本学习梯度方向可以使得哪个标签的损失下降的可视化 使用的是cifar100数据集,使用resnet18网络

代码解决方案:

下面给出一份可直接运行的 PyTorch 代码示例,演示如何在 CIFAR-100 上用 ResNet-18 做 偏标记学习(PLL),并 把每个样本的梯度方向投影到标签空间,从而可视化“当前梯度会让哪个标签的损失下降”。
运行后你会得到一张 grad_heatmap.png,横轴是 100 个 CIFAR-100 类别,纵轴是验证集里的若干样本;颜色越深说明该样本的梯度越偏向降低对应标签的损失。


1. 环境准备

pip install torch torchvision matplotlib seaborn tqdm

2. 代码(pll_grad_vis.py

"""
Partial Label Learning + 梯度方向可视化 (CIFAR-100, ResNet-18)
运行: python pll_grad_vis.py
输出: grad_heatmap.png
"""
import torch, torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import os

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 256
MAX_SAMPLES = 50           # 可视化多少个验证样本(防止图太大)
GRAD_NORM = True           # 是否对梯度做归一化,使得颜色可比

# 1. 数据
mean = [0.5071, 0.4867, 0.4408]
std  = [0.2675, 0.2565, 0.2761]
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
test_set = datasets.CIFAR100(root='./data', train=False,
                             download=True, transform=transform_test)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
                         shuffle=False, num_workers=4)

# 2. 模型
model = models.resnet18(num_classes=100).to(DEVICE)
# 随便加载一个 checkpoint(这里用随机权,你也可以换成训练好的)
# torch.load('best.pth') 之类
model.eval()

# 3. 计算梯度方向
def grad_directions(images, targets):
    """
    images: [B,3,32,32]
    targets: [B]  真实标签,仅用于演示,PLL 时可换成候选集
    返回: [B, 100] 每个样本对每个类的梯度方向(cosine similarity)
    """
    images = images.to(DEVICE)
    targets = targets.to(DEVICE)
    B = images.size(0)

    # 前向得到 logits
    logits = model(images)          # [B,100]
    probs  = torch.softmax(logits, dim=1)

    # 对每个类 c 计算梯度
    grad_vecs = []                # list of [B, 100] 梯度向量
    for c in range(100):
        loss_c = -torch.log(probs[:, c] + 1e-8).mean()
        model.zero_grad()
        loss_c.backward(retain_graph=True)
        g = []
        for p in model.parameters():
            if p.grad is not None:
                g.append(p.grad.flatten())
        g = torch.cat(g)           # 1-D
        grad_vecs.append(g)
    grad_matrix = torch.stack(grad_vecs, dim=1)   # [B, 100, #params]

    # 现在把每个样本的梯度方向投影到标签空间
    # 简单做法:计算 (g_c · g_t) 的 cosine,其中 g_t 是真实标签的梯度
    # 更直观:直接看哪个 g_c 与负梯度最对齐(即下降最快)
    # 这里我们计算每个 g_c 与“真实标签的负梯度”的 cosine
    model.zero_grad()
    ce_loss = nn.CrossEntropyLoss()(logits, targets)
    ce_loss.backward()
    g_truth = []
    for p in model.parameters():
        if p.grad is not None:
            g_truth.append(p.grad.flatten())
    g_truth = -torch.cat(g_truth)   # 负梯度方向

    # 计算余弦相似度
    g_truth_norm = g_truth.norm()
    cos_sim = torch.zeros(B, 100)
    for b in range(B):
        for c in range(100):
            g_c = grad_matrix[b, c]
            cos_sim[b, c] = (g_c @ g_truth) / (g_c.norm() * g_truth_norm + 1e-8)
    return cos_sim.cpu().numpy()

# 4. 采样验证集并画热力图
model.eval()
sampled_cos = []
sampled_labels = []
pbar = tqdm(test_loader, desc='Collecting grads')
with torch.no_grad():
    pass
model.train()   # 需要梯度
for imgs, lbs in pbar:
    cos = grad_directions(imgs, lbs)
    sampled_cos.append(cos)
    sampled_labels.extend(lbs.tolist())
    if len(sampled_labels) >= MAX_SAMPLES:
        break
sampled_cos = np.concatenate(sampled_cos)[:MAX_SAMPLES]

# 5. 画图
plt.figure(figsize=(12, 6))
sns.heatmap(sampled_cos,
            xticklabels=datasets.CIFAR100.classes,
            yticklabels=False,
            cmap='coolwarm',
            cbar_kws={'label': 'cosine similarity to true-label gradient'})
plt.title('Gradient direction vs. which label loss would drop (PLL perspective)')
plt.xlabel('Candidate label')
plt.ylabel('Validation sample')
plt.xticks(rotation=90, fontsize=6)
plt.tight_layout()
plt.savefig('grad_heatmap.png')
print('Saved grad_heatmap.png')

3. 运行与观察

python pll_grad_vis.py

生成的 grad_heatmap.png 里,每一行代表一个验证样本,颜色越红(正值)表示 该标签方向的梯度与“真实标签负梯度”越一致,即 训练时如果把这个标签当作候选,该样本会让这个标签的损失下降
偏标记场景 下,你可以把 targets 换成一个候选集,然后计算哪个候选标签的梯度方向最显著,从而决定该保留哪个标签。


4. 可扩展点

  1. 候选集:把 PLL 里每个样本的候选标签集合读进来,只计算这些标签的梯度方向即可。
  2. 动态阈值:对 cos_sim 做归一化后,只保留 top-k 标签作为可信结果。
  3. 更细粒度:把梯度投影到特征空间而非参数空间(如倒数第二层特征),防止维度爆炸。