Skip to content
Snippets Groups Projects
Commit fe709a4b authored by friebolin's avatar friebolin
Browse files

Update docstrings

parent 427c3c7d
No related branches found
No related tags found
No related merge requests found
......@@ -429,6 +429,17 @@ class BertModelTMix(BertPreTrainedModel):
class BertTMixEncoder(torch.nn.Module):
"""Used for Tmix. When using Tmix the only change that has to be done, is to be able to modify layers in model.
This way, we can apply the Mixup function to a batch of hidden states at a certain layer """
"""Module for Tmix that modifies the model layers to apply Mixup function to a batch of hidden states at a certain layer.
Params:
config (BertConfig): Configuration class for BertTMixEncoder.
Returns:
- If return_dict is True, returns a BaseModelOutputWithPastAndCrossAttentions object containing the last hidden state,
past key values, all hidden states, self-attention weights, and cross-attention weights (if add_cross_attention=True in the model configuration).
- If return_dict is False, returns a tuple containing the last hidden state and the last state of labels after interpolation.
Any None values are omitted.
"""
def __init__(self, config):
super().__init__()
self.config = config
......@@ -536,6 +547,20 @@ class BertTMixEncoder(torch.nn.Module):
#Moneky Patching the forward function of BertLayer for mixup -> use decorators here to call the old forward function on the newly comptued hidden_state
def forward_new(forward):
"""
Decorator function that monkey patches the `forward` method of a `BertLayer` instance to implement mixup data
augmentation during training. When `nowlayer` and `mixepoch` arguments are specified, performs mixup data
augmentation on the hidden states and labels of the layer specified by `nowlayer`. Otherwise, calls the original
`forward` method of the `BertLayer` instance on the input tensors and returns the output tensor along with the
original labels.
Params:
forward (callable): The original `forward` method of the `BertLayer` instance.
Returns:
callable: A new `forward` method for the `BertLayer` instance that performs mixup data augmentation if
specified, otherwise calls the original `forward` method.
"""
def forward_mix(self, hidden_states: torch.Tensor,
labels: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment