# -*- coding: utf-8 -*-
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# ========== Reproducibility ==========
def set_seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
set_seed(0)
device = torch.device("cpu")
# ========== Model: Small LeNet ==========
class LeNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5) # 28 -> 24
self.pool = nn.MaxPool2d(2, 2) # 24 -> 12
self.conv2 = nn.Conv2d(6, 16, 5) # 12 -> 8
# 8 -> 4 after pool
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x) # 12x12
x = F.relu(self.conv2(x))
x = self.pool(x) # 4x4
x = torch.flatten(x, 1) # B x (16*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x) # logits
return x
# ========== Data ==========
transform = transforms.ToTensor() # pixels in [0,1]
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=2, pin_memory=False)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2, pin_memory=False)
# ========== Train / Eval ==========
def train(model, loader, epochs=5, lr=5e-4):
model.train()
opt = torch.optim.Adam(model.parameters(), lr=lr)
for ep in range(epochs):
total, correct, loss_sum = 0, 0, 0.0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
loss_sum += loss.item() * x.size(0)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += x.size(0)
print(f"Epoch {ep+1}/{epochs} - loss={loss_sum/total:.4f} acc={correct/total:.4f}")
@torch.no_grad()
def eval_clean_acc(model, loader):
model.eval()
total, correct = 0, 0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
total += x.size(0)
correct += (pred == y).sum().item()
return correct / total
# ========== Gradient + FGSM directions ==========
def grad_wrt_x(model, x, y):
model.eval()
x = x.clone().detach().to(device)
x.requires_grad_(True)
with torch.enable_grad(): # 确保构建计算图
logits = model(x)
loss = F.cross_entropy(logits, y.to(device))
model.zero_grad(set_to_none=True)
loss.backward()
g = x.grad.detach()
return g, logits.detach()
def dir_linf(g):
return g.sign()
def dir_l2(g, eps=1e-12):
g_flat = g.view(g.size(0), -1)
g_norm = g_flat.norm(p=2, dim=1).view(-1, 1, 1, 1)
return g / (g_norm + eps)
def dir_spec(g):
# per-sample SVD on 28x28
B, C, H, W = g.shape
assert (C, H, W) == (1, 28, 28), "This demo assumes MNIST 1x28x28"
d = torch.zeros_like(g)
for i in range(B):
Gi = g[i, 0]
U, S, Vh = torch.linalg.svd(Gi, full_matrices=False)
d[i, 0] = U @ Vh
return d
def get_direction(method, g):
if method == "linf":
return dir_linf(g)
elif method == "l2":
return dir_l2(g)
elif method == "spec":
return dir_spec(g)
else:
raise ValueError("Unknown method")
# ========== Attack evaluation over a loader for a list of eps ==========
@torch.no_grad()
def eval_attack_grid(model, loader, method, eps_list):
"""
Returns dict with per-eps: acc, mean_maxprob, mean_trueprob, time_sec
"""
model.eval()
eps_list = list(eps_list)
K = len(eps_list)
total = 0
correct = [0 for _ in range(K)]
sum_maxprob = [0.0 for _ in range(K)]
sum_trueprob = [0.0 for _ in range(K)]
times = [0.0 for _ in range(K)]
for x, y in loader:
x, y = x.to(device), y.to(device)
# get gradient once per batch
torch.set_grad_enabled(True)
g, _ = grad_wrt_x(model, x, y)
d = get_direction(method, g)
torch.set_grad_enabled(False)
for j, eps in enumerate(eps_list):
t0 = time.perf_counter()
x_adv = torch.clamp(x + eps * d, 0.0, 1.0)
logits = model(x_adv)
probs = logits.softmax(dim=1)
pred = probs.argmax(dim=1)
correct[j] += (pred == y).sum().item()
sum_maxprob[j] += probs.max(dim=1).values.sum().item()
sum_trueprob[j] += probs[torch.arange(y.size(0)), y].sum().item()
times[j] += (time.perf_counter() - t0)
total += x.size(0)
out = []
for j, eps in enumerate(eps_list):
out.append({
"eps": float(eps),
"acc": correct[j] / total,
"mean_maxprob": sum_maxprob[j] / total,
"mean_trueprob": sum_trueprob[j] / total,
"time_sec": times[j],
"n_total": total
})
return out
# ========== Fixed sample picker ==========
@torch.no_grad()
def pick_fixed_samples(model, dataset, k=6, seed=0):
"""
Pick k correctly-classified test samples with fixed seed; returns indices list.
"""
set_seed(seed)
idxs = list(range(len(dataset)))
random.shuffle(idxs)
chosen = []
for idx in idxs:
x, y = dataset[idx]
x_in = x.unsqueeze(0).to(device)
logits = model(x_in)
pred = logits.argmax(dim=1).item()
if pred == y:
chosen.append(idx)
if len(chosen) >= k:
break
return chosen
# ========== Build visualization figure per method ==========
@torch.no_grad()
def visualize_method(
model, dataset, method, eps_list, fixed_indices,
train_stats, test_stats, figsize_scale=2.0
):
"""
Build a big figure:
rows = len(eps_list)
cols = len(fixed_indices) + 1 (last col is metrics summary)
Each cell (sample) shows x_adv at the given eps; last col shows train/test acc, conf, time.
"""
k = len(fixed_indices)
R = len(eps_list)
C = k + 1
fig_w = max(8, int(figsize_scale * C))
fig_h = max(4, int(figsize_scale * R))
fig, axes = plt.subplots(R, C, figsize=(fig_w, fig_h))
if R == 1:
axes = np.expand_dims(axes, axis=0)
if C == 1:
axes = np.expand_dims(axes, axis=1)
# Header titles (top row)
for j, idx in enumerate(fixed_indices):
x0, y0 = dataset[idx]
# show clean label in column title
axes[0, j].set_title(f"Sample {j+1}\nidx={idx}, true={y0}", fontsize=9)
axes[0, -1].set_title("Summary (train/test acc, conf, time)", fontsize=9)
# For each eps row
for r, eps in enumerate(eps_list):
# Left side: adversarial images for fixed samples
for c, idx in enumerate(fixed_indices):
x0, y0 = dataset[idx]
x = x0.unsqueeze(0).to(device)
y = torch.tensor([y0], dtype=torch.long).to(device)
# grad & direction for this single sample
g, _ = grad_wrt_x(model, x, y)
d = get_direction(method, g)
x_adv = torch.clamp(x + eps * d, 0.0, 1.0)
logits = model(x_adv)
probs = logits.softmax(dim=1)
conf, pred = probs.max(dim=1)
ax = axes[r, c]
ax.imshow(x_adv[0, 0].cpu(), cmap="gray", vmin=0, vmax=1)
ax.set_xticks([]); ax.set_yticks([])
ax.set_xlabel(f"ε={eps:.3f}\n{pred.item()} ({conf.item()*100:.1f}%)", fontsize=8)
# Rightmost summary cell
ax_sum = axes[r, -1]
ax_sum.axis("off")
tr = train_stats[r]; te = test_stats[r]
text = (
f"Norm={method.upper()} | ε={eps:.3f}\n"
f"Train acc: {tr['acc']*100:.2f}% (N={tr['n_total']})\n"
f"Test acc: {te['acc']*100:.2f}% (N={te['n_total']})\n"
f"Test mean max prob: {te['mean_maxprob']*100:.1f}%\n"
f"Time (train/test): {tr['time_sec']:.2f}s / {te['time_sec']:.2f}s"
)
ax_sum.text(0.02, 0.5, text, va="center", ha="left", fontsize=9, family="monospace")
fig.suptitle(f"FGSM under {method.upper()} norm | rows: eps, cols: fixed samples + summary", fontsize=12)
fig.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
# ========== Main Pipeline ==========
if __name__ == "__main__":
set_seed(0)
model = LeNet().to(device)
print("Training LeNet on MNIST (CPU)...")
train(model, train_loader, epochs=5, lr=5e-4)
clean_train_acc = eval_clean_acc(model, train_loader)
clean_test_acc = eval_clean_acc(model, test_loader)
print(f"Clean acc - train={clean_train_acc:.4f}, test={clean_test_acc:.4f}")
# ---- Define epsilon grids per norm ----
eps_grid = {
"linf": [0.05, 0.10, 0.20, 0.30],
"l2": [0.50, 2.00, 3.00, 6.00],
"spec": [0.10, 0.60, 1.50, 2.20],
}
# ---- Pick fixed samples (from test set) ----
fixed_indices = pick_fixed_samples(model, test_set, k=6, seed=0)
print("Fixed sample indices (test set):", fixed_indices)
# ---- For each norm: evaluate grid on train/test, then visualize ----
for method, eps_list in eps_grid.items():
print(f"\n=== Evaluating {method.upper()} with eps list: {eps_list} ===")
train_stats = eval_attack_grid(model, train_loader, method, eps_list)
test_stats = eval_attack_grid(model, test_loader, method, eps_list)
# Console summary
print("eps | train_acc | test_acc | test_mean_max_prob | time_train(s) | time_test(s)")
for tr, te in zip(train_stats, test_stats):
print(f"{te['eps']:.3f} | {tr['acc']*100:8.2f}% | {te['acc']*100:7.2f}% | "
f"{te['mean_maxprob']*100:7.2f}% | {tr['time_sec']:.2f} | {te['time_sec']:.2f}")
# Visualization big figure
visualize_method(
model, test_set, method, eps_list, fixed_indices,
train_stats, test_stats, figsize_scale=2.0
)