Computing the Fisher Information Matrix in Elastic Weight Consolidation

Hi,

I have seen many PyTorch implementations of EWC that computes the Fisher Information Matrix by taking the gradient and squaring it. Why is this the case?

# Gradients accumulated can be used to calculate fisher information matrix
for name, param in model.named_parameters():
        opt_params[name] = param.data.clone().cpu()
        fisher_matrices[name] += param.grad.data.clone().pow(2).cpu() / len(dataloader)

I have read the original paper and also [2105.04093] Elastic Weight Consolidation (EWC): Nuts and Bolts (arxiv.org), but I am still unsure of why the Hessian is calculated this way.

EWC computes the diagonal of the approximation of the Fisher Information Matrix. This sequence of approximations leads to the diagonal be estimated as the squared gradient averaged over the mini-batches for a single pass on the training set. The paper does not provide all the steps but I think you can back-track from the details they give about the approximations used.

Thanks! I didn’t realize that we are only using the diagonal of the Fisher Information Matrix.

As a follow-up question, I wonder if I can ask the community on what is wrong with this implementation of the EWC:

def train_ewc(
        model, 
        dataloader, 
        fisher_matrices,
        opt_params,
        ewc_weight,
        optimizer, 
        criterion,
        device
    ):
    """ Train one epoch of the model using Elastic Weight Consolidation strategy.
    """
    model = model.to(device)

    running_loss = 0.0
    data_size = len(dataloader)

    if not optimizer:
        # Default optimizer if one is not provided
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for data in tqdm(dataloader):
        imgs, labels = data

        optimizer.zero_grad()

        output = model(imgs.to(device))
        loss = criterion(output, labels.to(device))

        # Regularize loss with Fisher Information Matrix
        for name, param in model.named_parameters():
            fisher = fisher_matrices[name].to(device)
            opt_param = opt_params[name].to(device)
            loss += (fisher * (opt_param - param).pow(2)).sum() * ewc_weight
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    return running_loss / data_size
def ewc_update(
        model, dataloader,
        criterion=torch.nn.CrossEntropyLoss(),
        device=torch.device("cpu")
    ):

    model = model.to(device)

    fisher_matrices = {}
    opt_params = {}
    for name, param in model.named_parameters():
        fisher_matrices[name] = torch.zeros(param.data.size())

    model.eval()
    # accumulating gradients
    for data in dataloader:
        model.zero_grad()
        imgs, labels = data
        output = model(imgs.to(device))
        loss = criterion(output, labels.to(device))
        loss.backward()

        # Gradients accumulated can be used to calculate Fisher Information Matrix (FIM)
        # We only want the diagonals of the FIM which is just the square of our gradients.
        for name, param in model.named_parameters():
            opt_params[name] = param.data.clone().cpu()
            fisher_matrices[name] += param.grad.data.clone().pow(2).cpu() / len(dataloader)

    return fisher_matrices, opt_params

The whole training loop looks like this:

for task in range(trainset.num_tasks()):
    tqdm.write(f"Training on task {trainset.get_current_task()}")
    trainloader = DataLoader(
        trainset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=4
    )

    # Update fisher dict and optimal parameter dict
    fisher_matrices, opt_params = ewc_update(
                                    model, trainloader,
                                    criterion=criterion,
                                    device=device)

    # Train with EWC regularized weights
    for epoch in tqdm(range(config['epochs'])):
        loss = train_ewc(
                    model, trainloader,
                    fisher_matrices=fisher_matrices, 
                    opt_params=opt_params, 
                    ewc_weight=config['ewc_weight'],
                    optimizer=optimizer,
                    criterion=criterion,
                    device=device)

    # Evaluate error rate on current and previous tasks
    for task in range(trainset.get_current_task() + 1):
        evalloader = DataLoader(
                        evalset,
                        batch_size=config['batch_size'],
                        shuffle=True,
                        num_workers=4
                    )
        vloss, verror = validate(model, evalloader, criterion=criterion, device=device)
        tqdm.write(f"Evaluated task {task}")
        tqdm.write(
            f"Training loss: {loss: .3f}, Validation loss: {vloss: .3f}, " 
            f"Validation error: {verror: .3f}")
        evalset.next_task()

    # Progress to next task
    trainset.next_task()
    evalset.restart()

Based on the code above, my class incremental scenario still forgets (seen from getting an error rate of 1) after training on subsequent tasks.

I apologize for the code dump, but I am really not sure where this bug is.