两种用于模型训练的损失函数:nn.CrossEntropyLoss
和 nn.TripletMarginLoss
。它们在对比学习和分类任务中各自扮演不同的角色。接下来是对这两种损失函数的详细介绍。
1. nn.CrossEntropyLoss
nn.CrossEntropyLoss
是 PyTorch 提供的交叉熵损失函数,通常用于多分类任务中。它结合了 softmax 激活函数和负对数似然损失(Negative Log Likelihood Loss, NLLLoss),计算模型预测与真实标签之间的差距。
工作原理:
- 输入:模型的输出(logits)和真实的类别标签。
- 输出:一个标量值,表示预测分布与真实分布之间的差异。
交叉熵损失函数通过以下公式计算:
[ \text{Loss} = -\sum_{i} y_i \log(p_i) ]
其中 ( y_i ) 是真实的标签(one-hot 编码),而 ( p_i ) 是预测的概率分布(softmax 后的输出)。
- 使用场景:多分类任务中,比如文本分类、图像分类。
- 应用:在这段代码中,
CrossEntropyLoss
用于计算模型输出与目标标签之间的分类损失,适合用来处理分类任务的目标。
示例:
假设有三个类别的分类任务,模型输出的 logits 是:
[ [2.0, 1.0, 0.1] ]
真实标签是类别 0,交叉熵损失函数将计算该类别对应的 softmax 概率,并与真实标签对比,得出损失。
2. nn.TripletMarginLoss
nn.TripletMarginLoss
是用于对比学习(contrastive learning)的损失函数,尤其适用于度量学习(metric learning)任务。它处理 三元组(triplet) 数据,即由锚点(anchor)、正样本(positive)和负样本(negative)组成的三元组。
工作原理:
- 输入:三组嵌入向量(特征表示):锚点(anchor)、正样本(positive)和负样本(negative)。
- 输出:一个标量值,表示正样本与锚点之间的距离与负样本与锚点之间的距离之间的差距。
公式为:
[ \text{Loss} = \max(0, d(a, p) - d(a, n) + \text{margin}) ]
其中:
-
( a ) 表示锚点的特征向量,
-
( p ) 表示正样本的特征向量,
-
( n ) 表示负样本的特征向量,
-
( d(·,·) ) 是锚点与正样本/负样本之间的距离(通常是欧氏距离或余弦距离),
-
margin
是一个超参数,确保正样本与锚点的距离比负样本更近,且最小差值为margin
。 -
使用场景:主要用于度量学习和对比学习,如面部识别、文本匹配、图像检索等任务。它确保模型在嵌入空间中将相似的样本拉近,将不相似的样本推远。
-
应用:在这段代码中,
TripletMarginLoss
用于处理对比学习中的三元组损失,确保正样本和锚点的表示比负样本更接近。
示例:
假设锚点、正样本、负样本的嵌入向量分别为:
- Anchor:
[1.0, 2.0]
- Positive:
[1.1, 2.1]
- Negative:
[3.0, 4.0]
如果 margin 设为 1.0,TripletMarginLoss
将确保锚点与正样本的距离比锚点与负样本的距离更小,并且至少差 1.0。
总结
CrossEntropyLoss
适用于分类任务,用于衡量模型输出的类别分布与真实标签之间的差距。TripletMarginLoss
适用于对比学习,通过比较锚点、正样本和负样本的嵌入向量,确保正样本更接近锚点,负样本远离锚点。
在训练流程中,两者结合使用,以同时优化分类任务和对比学习任务中的关系表示。