Note
Click here to download the full example code
Example of robust training on CIFAR10.ΒΆ

Out:
Files already downloaded and verified
Files already downloaded and verified
Training on Linf ball(0.03137254901960784).
Train Accuracy: 32.6%
Train Adv Accuracy: 22.9%
Test Accuracy: 30.7%
Test Adv Accuracy: 21.2%
Train Accuracy: 46.3%
Train Adv Accuracy: 27.6%
Test Accuracy: 44.9%
Test Adv Accuracy: 27.0%
Train Accuracy: 54.6%
Train Adv Accuracy: 29.2%
Test Accuracy: 38.1%
Test Adv Accuracy: 27.2%
Train Accuracy: 59.5%
Train Adv Accuracy: 30.7%
Test Accuracy: 53.8%
Test Adv Accuracy: 27.7%
Train Accuracy: 63.4%
Train Adv Accuracy: 31.9%
Test Accuracy: 52.7%
Test Adv Accuracy: 26.3%
Train Accuracy: 65.8%
Train Adv Accuracy: 31.9%
Test Accuracy: 44.2%
Test Adv Accuracy: 28.0%
Train Accuracy: 67.4%
Train Adv Accuracy: 32.4%
Test Accuracy: 57.8%
Test Adv Accuracy: 32.5%
Train Accuracy: 69.1%
Train Adv Accuracy: 32.1%
Test Accuracy: 59.3%
Test Adv Accuracy: 26.5%
Train Accuracy: 70.2%
Train Adv Accuracy: 33.0%
Test Accuracy: 58.7%
Test Adv Accuracy: 29.1%
Train Accuracy: 71.2%
Train Adv Accuracy: 33.3%
Test Accuracy: 62.5%
Test Adv Accuracy: 31.9%
Train Accuracy: 71.6%
Train Adv Accuracy: 33.1%
Test Accuracy: 60.6%
Test Adv Accuracy: 28.7%
Train Accuracy: 72.2%
Train Adv Accuracy: 33.3%
Test Accuracy: 60.7%
Test Adv Accuracy: 30.0%
Train Accuracy: 72.6%
Train Adv Accuracy: 33.1%
Test Accuracy: 66.1%
Test Adv Accuracy: 23.6%
Train Accuracy: 73.2%
Train Adv Accuracy: 33.6%
Test Accuracy: 60.9%
Test Adv Accuracy: 28.9%
Train Accuracy: 73.6%
Train Adv Accuracy: 34.0%
Test Accuracy: 60.7%
Test Adv Accuracy: 31.1%
Train Accuracy: 74.0%
Train Adv Accuracy: 34.0%
Test Accuracy: 63.8%
Test Adv Accuracy: 27.6%
Train Accuracy: 74.4%
Train Adv Accuracy: 34.1%
Test Accuracy: 63.6%
Test Adv Accuracy: 28.6%
Train Accuracy: 74.7%
Train Adv Accuracy: 33.8%
Test Accuracy: 62.9%
Test Adv Accuracy: 26.4%
Train Accuracy: 74.9%
Train Adv Accuracy: 33.9%
Test Accuracy: 61.2%
Test Adv Accuracy: 28.0%
Train Accuracy: 75.1%
Train Adv Accuracy: 34.2%
Test Accuracy: 61.8%
Test Adv Accuracy: 28.1%
Train Accuracy: 75.0%
Train Adv Accuracy: 34.1%
Test Accuracy: 63.5%
Test Adv Accuracy: 32.7%
Train Accuracy: 75.4%
Train Adv Accuracy: 34.2%
Test Accuracy: 62.5%
Test Adv Accuracy: 28.4%
Train Accuracy: 75.3%
Train Adv Accuracy: 34.1%
Test Accuracy: 62.2%
Test Adv Accuracy: 30.5%
Train Accuracy: 75.8%
Train Adv Accuracy: 34.2%
Test Accuracy: 56.4%
Test Adv Accuracy: 29.4%
Train Accuracy: 76.0%
Train Adv Accuracy: 33.9%
Test Accuracy: 62.2%
Test Adv Accuracy: 27.7%
Train Accuracy: 76.0%
Train Adv Accuracy: 34.6%
Test Accuracy: 63.8%
Test Adv Accuracy: 28.6%
Train Accuracy: 76.1%
Train Adv Accuracy: 34.3%
Test Accuracy: 58.4%
Test Adv Accuracy: 30.3%
Train Accuracy: 76.1%
Train Adv Accuracy: 34.6%
Test Accuracy: 66.1%
Test Adv Accuracy: 30.0%
Train Accuracy: 75.9%
Train Adv Accuracy: 34.6%
Test Accuracy: 64.3%
Test Adv Accuracy: 31.4%
Train Accuracy: 76.3%
Train Adv Accuracy: 34.3%
Test Accuracy: 60.1%
Test Adv Accuracy: 31.6%
Train Accuracy: 76.4%
Train Adv Accuracy: 34.4%
Test Accuracy: 51.9%
Test Adv Accuracy: 31.9%
Train Accuracy: 76.6%
Train Adv Accuracy: 34.8%
Test Accuracy: 65.5%
Test Adv Accuracy: 33.0%
Train Accuracy: 76.5%
Train Adv Accuracy: 34.5%
Test Accuracy: 56.4%
Test Adv Accuracy: 31.8%
Train Accuracy: 76.5%
Train Adv Accuracy: 34.2%
Test Accuracy: 65.2%
Test Adv Accuracy: 29.7%
Train Accuracy: 76.6%
Train Adv Accuracy: 34.6%
Test Accuracy: 62.9%
Test Adv Accuracy: 31.8%
Train Accuracy: 76.5%
Train Adv Accuracy: 34.5%
Test Accuracy: 58.2%
Test Adv Accuracy: 25.9%
Train Accuracy: 76.4%
Train Adv Accuracy: 34.7%
Test Accuracy: 63.8%
Test Adv Accuracy: 29.7%
Train Accuracy: 77.0%
Train Adv Accuracy: 34.7%
Test Accuracy: 57.9%
Test Adv Accuracy: 30.2%
Train Accuracy: 77.0%
Train Adv Accuracy: 34.0%
Test Accuracy: 64.6%
Test Adv Accuracy: 29.4%
Train Accuracy: 76.7%
Train Adv Accuracy: 34.5%
Test Accuracy: 62.4%
Test Adv Accuracy: 28.2%
Train Accuracy: 76.8%
Train Adv Accuracy: 34.7%
Test Accuracy: 59.6%
Test Adv Accuracy: 26.4%
Train Accuracy: 77.3%
Train Adv Accuracy: 34.7%
Test Accuracy: 64.6%
Test Adv Accuracy: 27.7%
Train Accuracy: 77.0%
Train Adv Accuracy: 35.0%
Test Accuracy: 57.5%
Test Adv Accuracy: 30.7%
Train Accuracy: 76.8%
Train Adv Accuracy: 34.8%
Test Accuracy: 65.8%
Test Adv Accuracy: 32.9%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.1%
Test Accuracy: 61.3%
Test Adv Accuracy: 27.5%
Train Accuracy: 77.4%
Train Adv Accuracy: 34.6%
Test Accuracy: 62.4%
Test Adv Accuracy: 31.8%
Train Accuracy: 77.0%
Train Adv Accuracy: 34.6%
Test Accuracy: 57.0%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.2%
Train Adv Accuracy: 35.2%
Test Accuracy: 57.7%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.2%
Train Adv Accuracy: 35.0%
Test Accuracy: 65.6%
Test Adv Accuracy: 24.2%
Train Accuracy: 77.6%
Train Adv Accuracy: 34.8%
Test Accuracy: 63.5%
Test Adv Accuracy: 30.3%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 63.7%
Test Adv Accuracy: 30.8%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.0%
Test Accuracy: 63.5%
Test Adv Accuracy: 30.8%
Train Accuracy: 77.4%
Train Adv Accuracy: 35.1%
Test Accuracy: 59.3%
Test Adv Accuracy: 32.5%
Train Accuracy: 77.4%
Train Adv Accuracy: 34.8%
Test Accuracy: 56.7%
Test Adv Accuracy: 30.5%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.3%
Test Accuracy: 67.7%
Test Adv Accuracy: 32.0%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 61.4%
Test Adv Accuracy: 32.5%
Train Accuracy: 77.6%
Train Adv Accuracy: 34.9%
Test Accuracy: 63.1%
Test Adv Accuracy: 31.6%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.0%
Test Accuracy: 63.5%
Test Adv Accuracy: 29.2%
Train Accuracy: 77.2%
Train Adv Accuracy: 35.5%
Test Accuracy: 61.5%
Test Adv Accuracy: 33.4%
Train Accuracy: 77.6%
Train Adv Accuracy: 35.2%
Test Accuracy: 66.1%
Test Adv Accuracy: 27.1%
Train Accuracy: 77.6%
Train Adv Accuracy: 34.9%
Test Accuracy: 63.4%
Test Adv Accuracy: 31.2%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 62.4%
Test Adv Accuracy: 31.3%
Train Accuracy: 77.3%
Train Adv Accuracy: 35.0%
Test Accuracy: 65.2%
Test Adv Accuracy: 28.5%
Train Accuracy: 77.9%
Train Adv Accuracy: 34.8%
Test Accuracy: 63.9%
Test Adv Accuracy: 27.9%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.1%
Test Accuracy: 63.7%
Test Adv Accuracy: 28.0%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.5%
Test Accuracy: 58.6%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.1%
Test Accuracy: 63.4%
Test Adv Accuracy: 31.5%
Train Accuracy: 77.9%
Train Adv Accuracy: 34.8%
Test Accuracy: 57.1%
Test Adv Accuracy: 31.2%
Train Accuracy: 77.7%
Train Adv Accuracy: 35.2%
Test Accuracy: 62.8%
Test Adv Accuracy: 31.0%
Train Accuracy: 77.1%
Train Adv Accuracy: 34.7%
Test Accuracy: 57.0%
Test Adv Accuracy: 32.3%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.1%
Test Accuracy: 64.8%
Test Adv Accuracy: 29.7%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.3%
Test Accuracy: 65.1%
Test Adv Accuracy: 29.5%
Train Accuracy: 78.0%
Train Adv Accuracy: 34.9%
Test Accuracy: 64.1%
Test Adv Accuracy: 31.1%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.7%
Test Accuracy: 58.1%
Test Adv Accuracy: 33.8%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.2%
Test Accuracy: 66.6%
Test Adv Accuracy: 30.8%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.4%
Test Accuracy: 57.5%
Test Adv Accuracy: 27.9%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.1%
Test Accuracy: 60.9%
Test Adv Accuracy: 29.9%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.1%
Test Accuracy: 64.3%
Test Adv Accuracy: 26.8%
Train Accuracy: 77.8%
Train Adv Accuracy: 34.9%
Test Accuracy: 54.2%
Test Adv Accuracy: 30.5%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.0%
Test Accuracy: 54.6%
Test Adv Accuracy: 31.0%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.2%
Test Accuracy: 62.4%
Test Adv Accuracy: 31.7%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.4%
Test Accuracy: 61.1%
Test Adv Accuracy: 34.1%
Train Accuracy: 78.1%
Train Adv Accuracy: 34.8%
Test Accuracy: 51.0%
Test Adv Accuracy: 31.5%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.0%
Test Accuracy: 64.0%
Test Adv Accuracy: 26.5%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.1%
Test Accuracy: 54.2%
Test Adv Accuracy: 33.4%
Train Accuracy: 77.6%
Train Adv Accuracy: 35.3%
Test Accuracy: 63.6%
Test Adv Accuracy: 33.1%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.4%
Test Accuracy: 65.6%
Test Adv Accuracy: 29.6%
Train Accuracy: 78.0%
Train Adv Accuracy: 35.1%
Test Accuracy: 65.7%
Test Adv Accuracy: 27.6%
Train Accuracy: 77.8%
Train Adv Accuracy: 35.5%
Test Accuracy: 57.6%
Test Adv Accuracy: 32.7%
Train Accuracy: 78.0%
Train Adv Accuracy: 35.0%
Test Accuracy: 50.4%
Test Adv Accuracy: 30.7%
Train Accuracy: 78.3%
Train Adv Accuracy: 35.0%
Test Accuracy: 59.2%
Test Adv Accuracy: 27.7%
Train Accuracy: 78.0%
Train Adv Accuracy: 34.5%
Test Accuracy: 59.7%
Test Adv Accuracy: 32.5%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.4%
Test Accuracy: 57.2%
Test Adv Accuracy: 29.3%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.3%
Test Accuracy: 53.1%
Test Adv Accuracy: 30.8%
Train Accuracy: 77.7%
Train Adv Accuracy: 34.9%
Test Accuracy: 59.7%
Test Adv Accuracy: 29.9%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.3%
Test Accuracy: 56.8%
Test Adv Accuracy: 28.0%
Train Accuracy: 77.9%
Train Adv Accuracy: 35.0%
Test Accuracy: 57.8%
Test Adv Accuracy: 31.8%
Train Accuracy: 78.1%
Train Adv Accuracy: 35.0%
Test Accuracy: 55.6%
Test Adv Accuracy: 30.0%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.1%
Test Accuracy: 63.7%
Test Adv Accuracy: 30.3%
Train Accuracy: 78.2%
Train Adv Accuracy: 35.6%
Test Accuracy: 63.1%
Test Adv Accuracy: 35.0%
import matplotlib.pyplot as plt
from chop.adversary import Adversary
import torch
from tqdm import tqdm
from easydict import EasyDict
import chop
from torch.optim import SGD
from torchvision import models
device = torch.device('cuda' if torch.cuda.is_available()
else 'cpu')
n_epochs = 100
batch_size = 128
batch_size_test = 100
loaders = chop.data.load_cifar10(train_batch_size=batch_size,
test_batch_size=batch_size_test,
data_dir='~/datasets',
augment_train=True)
trainloader, testloader = loaders.train, loaders.test
n_train = len(trainloader.dataset)
n_test = len(testloader.dataset)
model = models.resnet18(pretrained=False)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=.1, momentum=.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
# Define the perturbation constraint set
max_iter_train = 7
max_iter_test = 20
alpha = 8. / 255
constraint = chop.constraints.LinfBall(alpha)
criterion_adv = torch.nn.CrossEntropyLoss(reduction='none')
print(f"Training on L{constraint.p} ball({alpha}).")
adversary = Adversary(chop.optim.minimize_pgd_madry)
results = EasyDict(train_acc=[], test_acc=[],
train_acc_adv=[], test_acc_adv=[],
train_adv_loss=[],
test_adv_loss=[])
for _ in range(n_epochs):
# Train
n_correct = 0
n_correct_adv = 0
model.train()
for k, (data, target) in enumerate(trainloader):
data = data.to(device)
target = target.to(device)
@torch.no_grad()
def image_constraint_prox(delta, step_size=None):
"""Projects perturbation delta
so that 0. <= data + delta <= 1."""
adv_img = torch.clamp(data + delta, 0, 1)
delta = adv_img - data
return delta
@torch.no_grad()
def prox(delta, step_size=None):
delta = constraint.prox(delta, step_size)
delta = image_constraint_prox(delta, step_size)
return delta
_, delta = adversary.perturb(data, target, model,
criterion_adv,
prox=prox,
lmo=constraint.lmo,
step=2. / max_iter_train,
max_iter=max_iter_train)
optimizer.zero_grad()
output = model(data)
output_adv = model(data + delta)
loss = criterion(output, target)
loss.backward()
optimizer.step()
pred = torch.argmax(output, dim=-1)
pred_adv = torch.argmax(output_adv, dim=-1)
n_correct += (pred == target).sum().item()
n_correct_adv += (pred_adv == target).sum().item()
results.train_acc.append(100. * n_correct / n_train)
results.train_acc_adv.append(100. * n_correct_adv / n_train)
print(f"Train Accuracy: {results.train_acc[-1] :.1f}%")
print(f"Train Adv Accuracy: {results.train_acc_adv[-1]:.1f}%")
# Test
n_correct = 0
n_correct_adv = 0
model.eval()
for k, (data, target) in enumerate(testloader):
data = data.to(device)
target = target.to(device)
@torch.no_grad()
def image_constraint_prox(delta, step_size=None):
"""Projects perturbation delta
so that 0. <= data + delta <= 1."""
adv_img = torch.clamp(data + delta, 0, 1)
delta = adv_img - data
return delta
@torch.no_grad()
def prox(delta, step_size=None):
delta = constraint.prox(delta, step_size)
delta = image_constraint_prox(delta, step_size)
return delta
_, delta = adversary.perturb(data, target, model,
criterion_adv,
prox=prox,
lmo=constraint.lmo,
step=2. / max_iter_test,
max_iter=max_iter_test)
with torch.no_grad():
output = model(data)
output_adv = model(data + delta)
pred = torch.argmax(output, dim=-1)
pred_adv = torch.argmax(output_adv, dim=-1)
n_correct += (pred == target).sum().item()
n_correct_adv += (pred_adv == target).sum().item()
results.test_acc.append(100. * n_correct / n_test)
results.test_acc_adv.append(100. * n_correct_adv / n_test)
print(f"Test Accuracy: {results.test_acc[-1]:.1f}%")
print(f"Test Adv Accuracy: {results.test_acc_adv[-1]:.1f}%")
fig, ax = plt.subplots(nrows=2, sharex=True)
ax[0].set_title("Clean data accuracies")
ax[0].plot(results.train_acc, label='Train Acc')
ax[0].plot(results.test_acc, label='Test Acc')
ax[1].set_title("Adversarial data accuracies")
ax[1].plot(results.train_acc_adv, label='Train Acc Adv')
ax[1].plot(results.test_acc_adv, label='Test Acc Adv')
plt.legend()
plt.show()
Total running time of the script: ( 624 minutes 28.935 seconds)
Estimated memory usage: 2425 MB