帮我写一个偏标记学习可以查看每一个样本学习梯度方向可以使得哪个标签的损失下降的可视化 使用的是cifar100数据集,使用resnet18网络
下面给出一份可直接运行的 PyTorch 代码示例,演示如何在 CIFAR-100 上用 ResNet-18 做 偏标记学习(PLL),并 把每个样本的梯度方向投影到标签空间,从而可视化“当前梯度会让哪个标签的损失下降”。
运行后你会得到一张 grad_heatmap.png,横轴是 100 个 CIFAR-100 类别,纵轴是验证集里的若干样本;颜色越深说明该样本的梯度越偏向降低对应标签的损失。
pip install torch torchvision matplotlib seaborn tqdm
pll_grad_vis.py)"""
Partial Label Learning + 梯度方向可视化 (CIFAR-100, ResNet-18)
运行: python pll_grad_vis.py
输出: grad_heatmap.png
"""
import torch, torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import os
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 256
MAX_SAMPLES = 50 # 可视化多少个验证样本(防止图太大)
GRAD_NORM = True # 是否对梯度做归一化,使得颜色可比
# 1. 数据
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_set = datasets.CIFAR100(root='./data', train=False,
download=True, transform=transform_test)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
shuffle=False, num_workers=4)
# 2. 模型
model = models.resnet18(num_classes=100).to(DEVICE)
# 随便加载一个 checkpoint(这里用随机权,你也可以换成训练好的)
# torch.load('best.pth') 之类
model.eval()
# 3. 计算梯度方向
def grad_directions(images, targets):
"""
images: [B,3,32,32]
targets: [B] 真实标签,仅用于演示,PLL 时可换成候选集
返回: [B, 100] 每个样本对每个类的梯度方向(cosine similarity)
"""
images = images.to(DEVICE)
targets = targets.to(DEVICE)
B = images.size(0)
# 前向得到 logits
logits = model(images) # [B,100]
probs = torch.softmax(logits, dim=1)
# 对每个类 c 计算梯度
grad_vecs = [] # list of [B, 100] 梯度向量
for c in range(100):
loss_c = -torch.log(probs[:, c] + 1e-8).mean()
model.zero_grad()
loss_c.backward(retain_graph=True)
g = []
for p in model.parameters():
if p.grad is not None:
g.append(p.grad.flatten())
g = torch.cat(g) # 1-D
grad_vecs.append(g)
grad_matrix = torch.stack(grad_vecs, dim=1) # [B, 100, #params]
# 现在把每个样本的梯度方向投影到标签空间
# 简单做法:计算 (g_c · g_t) 的 cosine,其中 g_t 是真实标签的梯度
# 更直观:直接看哪个 g_c 与负梯度最对齐(即下降最快)
# 这里我们计算每个 g_c 与“真实标签的负梯度”的 cosine
model.zero_grad()
ce_loss = nn.CrossEntropyLoss()(logits, targets)
ce_loss.backward()
g_truth = []
for p in model.parameters():
if p.grad is not None:
g_truth.append(p.grad.flatten())
g_truth = -torch.cat(g_truth) # 负梯度方向
# 计算余弦相似度
g_truth_norm = g_truth.norm()
cos_sim = torch.zeros(B, 100)
for b in range(B):
for c in range(100):
g_c = grad_matrix[b, c]
cos_sim[b, c] = (g_c @ g_truth) / (g_c.norm() * g_truth_norm + 1e-8)
return cos_sim.cpu().numpy()
# 4. 采样验证集并画热力图
model.eval()
sampled_cos = []
sampled_labels = []
pbar = tqdm(test_loader, desc='Collecting grads')
with torch.no_grad():
pass
model.train() # 需要梯度
for imgs, lbs in pbar:
cos = grad_directions(imgs, lbs)
sampled_cos.append(cos)
sampled_labels.extend(lbs.tolist())
if len(sampled_labels) >= MAX_SAMPLES:
break
sampled_cos = np.concatenate(sampled_cos)[:MAX_SAMPLES]
# 5. 画图
plt.figure(figsize=(12, 6))
sns.heatmap(sampled_cos,
xticklabels=datasets.CIFAR100.classes,
yticklabels=False,
cmap='coolwarm',
cbar_kws={'label': 'cosine similarity to true-label gradient'})
plt.title('Gradient direction vs. which label loss would drop (PLL perspective)')
plt.xlabel('Candidate label')
plt.ylabel('Validation sample')
plt.xticks(rotation=90, fontsize=6)
plt.tight_layout()
plt.savefig('grad_heatmap.png')
print('Saved grad_heatmap.png')
python pll_grad_vis.py
生成的 grad_heatmap.png 里,每一行代表一个验证样本,颜色越红(正值)表示 该标签方向的梯度与“真实标签负梯度”越一致,即 训练时如果把这个标签当作候选,该样本会让这个标签的损失下降。
在 偏标记场景 下,你可以把 targets 换成一个候选集,然后计算哪个候选标签的梯度方向最显著,从而决定该保留哪个标签。