BST 2019 算法 PyTorch 复现3 层 MLP 与 1 层 Transformer 的 CTR 实战推荐系统作为连接用户与内容的关键桥梁其核心挑战在于如何精准捕捉用户兴趣。传统深度学习方法往往忽视了用户行为序列中蕴含的丰富时序信息而阿里2019年提出的Behavior Sequence Transformer(BST)通过引入Transformer结构开创性地解决了这一问题。本文将带您从零实现BST模型重点剖析PyTorch实现中的工程细节与调优技巧。1. 环境准备与数据模拟1.1 依赖安装与配置推荐使用Python 3.8和PyTorch 1.12环境以下为关键依赖pip install torch1.12.1 torchvision0.13.1 pip install pandas numpy tqdm1.2 模拟数据生成由于真实用户行为数据涉及隐私我们设计了一个符合BST论文特性的数据生成器import numpy as np import pandas as pd def generate_synthetic_data(num_users10000, seq_len10): user_features { user_id: np.arange(num_users), age: np.random.randint(18, 65, sizenum_users), gender: np.random.choice([0, 1], sizenum_users) } item_features { item_id: np.arange(500), category: np.random.randint(0, 20, size500), price: np.round(np.random.uniform(10, 500, size500), 2) } # 生成行为序列 behavior_seq [] for uid in range(num_users): seq np.random.choice(item_features[item_id], sizeseq_len, replaceFalse) timestamps np.sort(np.random.randint(0, 30, sizeseq_len)) behavior_seq.append(list(zip(seq, timestamps))) return pd.DataFrame(user_features), pd.DataFrame(item_features), behavior_seq提示实际应用中应替换为真实业务数据特别注意时间戳的连续性和行为序列的密度分布2. 模型架构实现2.1 Embedding层设计BST的Embedding层需要处理三类特征特征类型包含字段处理方式用户特征user_id, age, gender分别嵌入后拼接物品特征item_id, category共享嵌入矩阵位置特征时间差正弦位置编码import torch.nn as nn class BSTEmbedding(nn.Module): def __init__(self, user_feat_size, item_feat_size, embed_dim64): super().__init__() self.user_embeddings nn.ModuleDict({ user_id: nn.Embedding(100000, embed_dim), age: nn.Embedding(100, embed_dim//2), gender: nn.Embedding(2, embed_dim//2) }) self.item_embeddings nn.ModuleDict({ item_id: nn.Embedding(100000, embed_dim), category: nn.Embedding(1000, embed_dim//2) }) def forward(self, user_features, item_features, behavior_seq): # 用户特征处理 user_emb torch.cat([ self.user_embeddings[user_id](user_features[:, 0]), self.user_embeddings[age](user_features[:, 1]), self.user_embeddings[gender](user_features[:, 2]) ], dim1) # 物品序列处理 item_embs [] for item in behavior_seq: item_emb torch.cat([ self.item_embeddings[item_id](item[:, 0]), self.item_embeddings[category](item[:, 1]) ], dim1) item_embs.append(item_emb) return user_emb, torch.stack(item_embs)2.2 Transformer层实现采用单层Transformer编码器关键实现细节相对位置编码使用时间差计算位置权重多头注意力4个头维度64前馈网络LeakyReLU激活class BSTTransformer(nn.Module): def __init__(self, embed_dim64, num_heads4): super().__init__() self.attention nn.MultiheadAttention(embed_dim, num_heads) self.ffn nn.Sequential( nn.Linear(embed_dim, 4*embed_dim), nn.LeakyReLU(), nn.Linear(4*embed_dim, embed_dim) ) self.norm1 nn.LayerNorm(embed_dim) self.norm2 nn.LayerNorm(embed_dim) def forward(self, x, time_deltas): # 相对位置编码 pos_weights 1.0 / (1.0 time_deltas.abs().sqrt()) attn_output, _ self.attention(x, x, x, attn_maskpos_weights) x self.norm1(x attn_output) # 前馈网络 ffn_output self.ffn(x) return self.norm2(x ffn_output)3. 训练Pipeline构建3.1 数据加载与批处理使用自定义DataLoader处理变长序列from torch.utils.data import Dataset, DataLoader class BSTDataset(Dataset): def __init__(self, user_df, item_df, behavior_seq, labels): self.users user_df.values self.items item_df.values self.seq behavior_seq self.labels labels def __getitem__(self, idx): return { user: torch.LongTensor(self.users[idx]), item: torch.LongTensor(self.items[idx]), seq: torch.LongTensor(self.seq[idx]), label: torch.FloatTensor([self.labels[idx]]) } def collate_fn(batch): # 处理变长序列的padding seq_lens [len(item[seq]) for item in batch] max_len max(seq_lens) padded_seq torch.zeros(len(batch), max_len, 2) for i, item in enumerate(batch): padded_seq[i, :len(item[seq])] item[seq] return { user: torch.stack([item[user] for item in batch]), item: torch.stack([item[item] for item in batch]), seq: padded_seq, label: torch.stack([item[label] for item in batch]) }3.2 训练循环实现采用动态学习率与早停策略def train_model(model, dataloader, epochs50): optimizer torch.optim.Adam(model.parameters(), lr1e-3) scheduler torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience3, verboseTrue) criterion nn.BCELoss() best_loss float(inf) for epoch in range(epochs): model.train() total_loss 0 for batch in dataloader: optimizer.zero_grad() outputs model(batch) loss criterion(outputs, batch[label]) loss.backward() optimizer.step() total_loss loss.item() avg_loss total_loss / len(dataloader) scheduler.step(avg_loss) if avg_loss best_loss: best_loss avg_loss torch.save(model.state_dict(), best_model.pt)4. 模型优化与部署技巧4.1 关键性能优化点序列长度处理对长序列进行分段采样注意力计算优化使用Flash Attention加速混合精度训练减少显存占用# Flash Attention示例需安装相关库 from flash_attn import flash_attention class OptimizedAttention(nn.Module): def forward(self, q, k, v): return flash_attention(q, k, v)4.2 实际部署建议使用TorchScript进行模型序列化对Embedding层进行量化压缩实现异步预测接口# 模型量化示例 quantized_model torch.quantization.quantize_dynamic( model, {nn.Linear}, dtypetorch.qint8)在淘宝推荐场景中BST模型相比传统WDL模型点击率提升约8.7%这主要得益于Transformer对用户行为序列的深度挖掘。工程实现时特别需要注意行为序列的时效性处理——过久的历史行为往往会产生噪声而非有效信号。