在深度学习落地应用日益广泛的今天,我们通常训练模型的方式是:收集所有数据,一次性训练出一个“全能”模型。但在现实世界中,数据往往是分批次、持续不断产生的。例如,一个智能客服系统需要不断学习新出现的业务词汇,或者自动驾驶系统需要适应不同城市的交通标志。
这就引入了增量学习的概念。然而,增量学习面临着一个巨大的挑战——灾难性遗忘。本文将深入探讨这一现象的成因、表现以及目前主流的应对策略,并配合PyTorch代码进行演示。
关键词:增量学习、灾难性遗忘、神经网络、持续学习、EWC
1. 什么是灾难性遗忘?
用一句通俗的话来概括:“学了新知识,忘了旧知识”。
在人工神经网络中,当我们在新任务上对预训练好的模型进行微调时,模型会为了适应新任务的数据分布,大幅修改网络权重。这种修改往往会覆盖掉之前学习旧任务所需的关键权重,导致模型在旧任务上的性能断崖式下跌。
1.1 现象对比
- 人类学习:学会了骑自行车,再学开汽车,不会忘记怎么骑自行车。甚至两者技能还能互相促进。
- 神经网络:在任务A上训练好的模型,去学习任务B。训练结束后,模型在任务B上表现完美,但在任务A上的准确率可能降为随机猜测水平。
1.2 直观展示
假设我们有两个任务:
- Task A:MNIST手写数字识别(0-4)
- Task B:MNIST手写数字识别(5-9)
如果直接进行顺序训练,准确率变化通常如下表所示:
| 阶段 | Task A 准确率 | Task B 准确率 |
|---|---|---|
| 训练 Task A 后 | 98% | – |
| 训练 Task B 后 | 20% (骤降) | 97% |
2. 为什么会发生遗忘?
从数学和优化角度来看,灾难性遗忘的根源在于参数共享与优化目标的冲突。
2.1 权重重叠
深度神经网络通常采用共享权重的架构。处理旧任务(Task A)用到了某些神经元,处理新任务(Task B)也可能用到同样的神经元。当我们用梯度下降更新参数以最小化 Task B 的 Loss 时,算法并不“在乎”这些参数对 Task A 是否重要,它只关注降低当前的 Loss。
2.2 损失地形
如下图所示(脑补示意图),Task A 的最优解位于一个狭长的山谷中。当模型为了 Task B 移动参数位置时,很容易走出 Task A 的“最优区域”,滑向 Task A 的高误差区域。
3. 缓解灾难性遗忘的主流方法
为了解决这一问题,学术界提出了多种方案,主要可以归纳为以下三类:
3.1 正则化方法
核心思想:在更新权重时,给重要的权重加一个“紧箍咒”。
代表算法:EWC (Elastic Weight Consolidation)
EWC 借鉴了贝叶斯理论,它计算每个权重对旧任务的重要性(通过 Fisher 信息矩阵)。在训练新任务时,如果某个权重对旧任务很重要,就限制它的修改幅度。
- Loss公式: ������=����(�)+�∑���(��−��,�∗)2 其中,�� 是权重的重要性,��,�∗ 是旧任务的最优权重。
3.2 回放方法
核心思想:温故而知新。
- 经验回放:保存一小部分旧任务的数据,在训练新任务时,将旧数据和新数据混合训练。
- 伪回放:如果不允许保存旧数据(隐私限制),则使用生成对抗网络(GAN)生成假的旧数据进行回放。
3.3 架构方法
核心思想:由于任务不同,索性把网络结构分开。
- PackNet:为每个任务划分网络的一部分神经元,通过剪枝技术预留空间给未来任务。
- 动态扩展网络:当新任务太难时,自动增加网络容量(增加神经元或层数)。
4. 代码实战:用 PyTorch 演示遗忘与 EWC 的对抗
为了让大家更直观地理解,我们用一个简单的全连接网络来演示。
4.1 环境准备
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
# 定义简单的全连接网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(28*28, 256)
self.fc2 = nn.Linear(256, 10) # 假设最终分类为10类
def forward(self, x):
x = x.view(-1, 28*28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
4.2 构造任务与普通训练(演示灾难性遗忘)
我们将MNIST数据集划分为两个任务:
- Task A: 数字 0-4
- Task B: 数字 5-9
# 数据预处理
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
# 划分任务索引
idx_task_a = [i for i, (img, label) in enumerate(train_dataset) if label < 5]
idx_task_b = [i for i, (img, label) in enumerate(train_dataset) if label >= 5]
loader_task_a = DataLoader(Subset(train_dataset, idx_task_a), batch_size=64, shuffle=True)
loader_task_b = DataLoader(Subset(train_dataset, idx_task_b), batch_size=64, shuffle=True)
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# --- 阶段1: 训练 Task A ---
print("Training on Task A (Digits 0-4)...")
for epoch in range(2): # 简单训练2个epoch
for data, target in loader_task_a:
optimizer.zero_grad()
output = model(data)
# 注意:这里为了演示,label保持原样,实际应用可能需要重映射label到0-4
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 测试 Task A 性能 (这里简化处理,实际应编写测试函数)
print("Task A training finished.")
4.3 引入 EWC (简化版实现)
下面实现一个简化的 EWC 类,用于计算并存储 Fisher 信息。
class EWC:
def __init__(self, model, dataloader, device):
self.model = model
self.dataloader = dataloader
self.device = device
self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
self._means = {} # 存储旧任务的最优参数
self._precision_matrices = {} # 存储 Fisher 信息矩阵 (重要性)
# 初始化
for n, p in self.params.items():
self._means[n] = p.clone().detach()
# 计算 Fisher 信息矩阵
self._calculate_fisher()
def _calculate_fisher(self):
precision_matrices = {}
for n, p in self.params.items():
precision_matrices[n] = p.clone().detach().fill_(0)
self.model.eval()
for data, target in self.dataloader:
self.model.zero_grad()
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = criterion(output, target)
loss.backward()
for n, p in self.params.items():
# 简化:使用梯度的平方作为重要性的近似
precision_matrices[n].data += p.grad.data ** 2 / len(self.dataloader)
self._precision_matrices = precision_matrices
def penalty(self, model):
loss = 0
for n, p in model.named_parameters():
if n in self._precision_matrices:
_loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
loss += _loss.sum()
return loss
# --- 阶段2: 训练 Task B (使用 EWC) ---
print("\nTraining on Task B (Digits 5-9) with EWC...")
# 先计算 Task A 的 EWC 信息
ewc = EWC(model, loader_task_a, device='cpu')
optimizer = optim.SGD(model.parameters(), lr=0.01)
for epoch in range(2):
for data, target in loader_task_b:
optimizer.zero_grad()
output = model(data)
# 核心区别:Loss = 新任务Loss + EWC惩罚项
loss = criterion(output, target) + 1000 * ewc.penalty(model)
loss.backward()
optimizer.step()
print("Task B training finished with EWC protection.")
4.4 结果对比
如果你运行上述代码,你会发现:
- 普通训练:Task B 训练完后,模型在 Task A 数据上的预测基本是瞎猜。
- EWC训练:虽然 Task B 的收敛速度可能稍慢,但模型在 Task A 上依然能保持较高的准确率,因为权重被约束在旧任务最优解附近。
5. 总结与展望
灾难性遗忘是增量学习走向实用的最大绊脚石。本文介绍了其原理及经典的 EWC 解决方案。
| 方法 | 优点 | 缺点 |
|---|---|---|
| 正则化 (EWC) | 不需要存储旧数据,内存开销小 | 当任务数量很多时,约束过多,难以学习新任务 |
| 回放 | 效果通常最好 | 存在隐私问题,需要存储旧数据 |
| 架构 | 彻底避免遗忘 | 模型参数量随任务线性增长,计算资源消耗大 |
未来的研究方向主要集中在:
- 动态网络结构:如何更高效地分配网络容量。
- 无回放学习:如何像人类大脑一样,通过知识蒸馏等方式压缩旧知识,而不依赖原始数据。
增量学习是一个充满挑战但也极具潜力的领域,希望本文能为你打开这扇大门!
参考文献:
- Kirkpatrick, J., et al. “Overcoming catastrophic forgetting in neural networks.” PNAS 2017.
- McCloskey, M., & Cohen, N. J. “Catastrophic interference in connectionist networks: The sequential learning problem.” Psychology of Learning and Motivation, 1989.