for i, (inputs, labels) inenumerate(dataloader): outputs = model(inputs) loss = criterion(outputs, labels) loss = loss / accumulation_steps # 归一化损失 loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()
方法3:清理GPU缓存
1 2 3 4 5 6 7
import torch import gc
# 在训练循环中定期清理 if batch_idx % 100 == 0: torch.cuda.empty_cache() gc.collect()
3. DataParallel vs DistributedDataParallel
**问题:**选择哪种并行方式?
DataParallel(简单但效率低):
1 2 3
if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(device)
DistributedDataParallel(推荐):
1 2 3 4 5 6 7 8 9 10 11 12
import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP
# 包装模型 model = model.to(rank) model = DDP(model, device_ids=[rank])
4. 混合精度训练
**问题:**训练速度慢,GPU利用率不高
解决方案:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from torch.cuda.amp import autocast, GradScaler
model = model.to(device) scaler = GradScaler()
for inputs, labels in dataloader: optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()