@@ -429,6 +429,17 @@ class BertModelTMix(BertPreTrainedModel):
classBertTMixEncoder(torch.nn.Module):
"""Used for Tmix. When using Tmix the only change that has to be done, is to be able to modify layers in model.
This way, we can apply the Mixup function to a batch of hidden states at a certain layer """
"""Module for Tmix that modifies the model layers to apply Mixup function to a batch of hidden states at a certain layer.
Params:
config (BertConfig): Configuration class for BertTMixEncoder.
Returns:
- If return_dict is True, returns a BaseModelOutputWithPastAndCrossAttentions object containing the last hidden state,
past key values, all hidden states, self-attention weights, and cross-attention weights (if add_cross_attention=True in the model configuration).
- If return_dict is False, returns a tuple containing the last hidden state and the last state of labels after interpolation.
Any None values are omitted.
"""
def__init__(self,config):
super().__init__()
self.config=config
...
...
@@ -536,6 +547,20 @@ class BertTMixEncoder(torch.nn.Module):
#Moneky Patching the forward function of BertLayer for mixup -> use decorators here to call the old forward function on the newly comptued hidden_state
defforward_new(forward):
"""
Decorator function that monkey patches the `forward` method of a `BertLayer` instance to implement mixup data
augmentation during training. When `nowlayer` and `mixepoch` arguments are specified, performs mixup data
augmentation on the hidden states and labels of the layer specified by `nowlayer`. Otherwise, calls the original
`forward` method of the `BertLayer` instance on the input tensors and returns the output tensor along with the
original labels.
Params:
forward (callable): The original `forward` method of the `BertLayer` instance.
Returns:
callable: A new `forward` method for the `BertLayer` instance that performs mixup data augmentation if
specified, otherwise calls the original `forward` method.