Thanks! I didn’t realize that we are only using the diagonal of the Fisher Information Matrix.
As a follow-up question, I wonder if I can ask the community on what is wrong with this implementation of the EWC:
def train_ewc(
model,
dataloader,
fisher_matrices,
opt_params,
ewc_weight,
optimizer,
criterion,
device
):
""" Train one epoch of the model using Elastic Weight Consolidation strategy.
"""
model = model.to(device)
running_loss = 0.0
data_size = len(dataloader)
if not optimizer:
# Default optimizer if one is not provided
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for data in tqdm(dataloader):
imgs, labels = data
optimizer.zero_grad()
output = model(imgs.to(device))
loss = criterion(output, labels.to(device))
# Regularize loss with Fisher Information Matrix
for name, param in model.named_parameters():
fisher = fisher_matrices[name].to(device)
opt_param = opt_params[name].to(device)
loss += (fisher * (opt_param - param).pow(2)).sum() * ewc_weight
loss.backward()
optimizer.step()
running_loss += loss.item()
return running_loss / data_size
def ewc_update(
model, dataloader,
criterion=torch.nn.CrossEntropyLoss(),
device=torch.device("cpu")
):
model = model.to(device)
fisher_matrices = {}
opt_params = {}
for name, param in model.named_parameters():
fisher_matrices[name] = torch.zeros(param.data.size())
model.eval()
# accumulating gradients
for data in dataloader:
model.zero_grad()
imgs, labels = data
output = model(imgs.to(device))
loss = criterion(output, labels.to(device))
loss.backward()
# Gradients accumulated can be used to calculate Fisher Information Matrix (FIM)
# We only want the diagonals of the FIM which is just the square of our gradients.
for name, param in model.named_parameters():
opt_params[name] = param.data.clone().cpu()
fisher_matrices[name] += param.grad.data.clone().pow(2).cpu() / len(dataloader)
return fisher_matrices, opt_params
The whole training loop looks like this:
for task in range(trainset.num_tasks()):
tqdm.write(f"Training on task {trainset.get_current_task()}")
trainloader = DataLoader(
trainset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=4
)
# Update fisher dict and optimal parameter dict
fisher_matrices, opt_params = ewc_update(
model, trainloader,
criterion=criterion,
device=device)
# Train with EWC regularized weights
for epoch in tqdm(range(config['epochs'])):
loss = train_ewc(
model, trainloader,
fisher_matrices=fisher_matrices,
opt_params=opt_params,
ewc_weight=config['ewc_weight'],
optimizer=optimizer,
criterion=criterion,
device=device)
# Evaluate error rate on current and previous tasks
for task in range(trainset.get_current_task() + 1):
evalloader = DataLoader(
evalset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=4
)
vloss, verror = validate(model, evalloader, criterion=criterion, device=device)
tqdm.write(f"Evaluated task {task}")
tqdm.write(
f"Training loss: {loss: .3f}, Validation loss: {vloss: .3f}, "
f"Validation error: {verror: .3f}")
evalset.next_task()
# Progress to next task
trainset.next_task()
evalset.restart()
Based on the code above, my class incremental scenario still forgets (seen from getting an error rate of 1) after training on subsequent tasks.
I apologize for the code dump, but I am really not sure where this bug is.