[D-75] Understanding background.
Understanding the Causal LLM Head
Language Modeling Head: This is the final layer(s) of a transformer model that takes the hidden states from the encoder (or decoder) and projects them onto the size of the vocabulary (the total number of possible tokens). The output is called the logits, which represent the model’s prediction of the next token in the sequence.
Causal (Autoregressive) Prediction: In a Causal Language Model (CLM), the model is trained to predict the next token based only on the preceding tokens in the sequence. This is typically done by shifting the input labels one position to the right relative to the input sequence and using a standard Cross-Entropy Loss.
The current Masked Language Model (MLM) Head: Your file already contains a head for Masked Language Modeling (MLM), which is used for fill-in-the-blank type tasks:
1
2
3
4
5
6
7
8
class BiGSOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BiGSLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores