之前在网上找到了一个文本匹配实现仓库,但是没有提供DSSM的代码,我就根据那个代码实现以下DSSM。数据集采用的是蚂蚁金服的数据集。也参考过别人的代码,但是总感觉怪怪的,DSSM原文中,一个query有对应的正样本和负样本,因此在实现的时候分别计算query与正负样本的余弦相似度,最后拼接再接softmax,但是蚂蚁金服数据集中每一个样本都已一个query和doc,对应一个label,并没有成对的正负样本,因此在实现中遇到了困难,因此最后我索性直接将余弦值作为网络输出,貌似还取得了不错的效果,那么代码会有些许不同。
第一,损失函数采用了二分类损失函数:
class torch.nn.BCELoss(weight=None, size_average=True)
第二,判断类别时:
def correct_predictions(output_probabilities, targets):
"""
Compute the number of predictions that match some target classes in the
output of a model.
Args:
output_probabilities: A tensor of probabilities for different output
classes.
targets: The indices of the actual target classes.
Returns:
The number of correct predictions in 'output_probabilities'.
"""
# _, out_classes = output_probabilities.max(dim=1)
out_classes = output_probabilities.ge(0.5).byte().float()
correct = (out_classes == targets).sum()
return correct.item()
第三,网络结构设计如下:
class DSSM(nn.Module):
def __init__(self, dropout=0.2,device="gpu"):
super(DSSM, self).__init__()
self.device = device
self.embed = nn.Embedding(7901, 100)
self.fc1 = nn.Linear(100, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512,256)
self.dropout = nn.Dropout(dropout)
self.Sigmoid = nn.Sigmoid() #method1
self.relu = nn.ReLU()
def forward(self, a, b):
a = self.embed(a).sum(1)
b = self.embed(b).sum(1)
a = self.relu(self.fc1(a)) #torch.tanh
# a = self.dropout(a)
a = self.relu(self.fc2(a))
# a = self.dropout(a)
a = self.relu(self.fc3(a))
# a = self.dropout(a)
b = self.relu(self.fc1(b))
# b = self.dropout(b)
b = self.relu(self.fc2(b))
# b = self.dropout(b)
b = self.relu(self.fc3(b))
# b = self.dropout(b)
cosine = torch.cosine_similarity(a, b, dim=1, eps=1e-8) #计算两个句子的余弦相似度
# cosine = self.Sigmoid(cosine-0.5)
cosine = self.relu(cosine)
cosine = torch.clamp(cosine,0,1)
return cosine
这样在蚂蚁金服测试集的准确率可以达到77以上,如果cosine后面不接relu,我跑到了78以上,但是总感觉出现了过拟合现象。此外,加入dropout效果反而不好,可能这个网络本身就不复杂吧。
其他的训练代码我参考了:https://github.com/zhaogaofeng611/TextMatch
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)