Skip to content
Snippets Groups Projects
Commit 2634b50a authored by holzinger's avatar holzinger
Browse files

added Dataloader module

parent 756704ed
No related branches found
No related tags found
No related merge requests found
# imports
import os
import re
import string
import tensorflow as tf
import tensorflow_datasets as tfds
class Data:
def __init__(self, name, batch_size=32, vectorize=True) -> None:
self.name = name
self.bs = batch_size
self.num_labels = 1
self.vocab_size = int(1e6)
self.sequence_length = 512
self.vectorize = vectorize
self.path = "/home/students/holzinger/tmp/arxiv"
def load(self):
def vectorize(text, label):
text = tf.expand_dims(text, -1)
return self.vectorize_data(text), label
def transform_data(dictionary):
text = dictionary['sentence']
label = dictionary['label']
return text, label
if self.name == "imdb":
self.sequence_length = 3072
self.vectorize_data = tf.keras.layers.TextVectorization(
standardize=custom_standardization,
max_tokens=self.vocab_size,
output_mode='int',
output_sequence_length=self.sequence_length)
# load data
train_ds = tfds.load('imdb_reviews', split='train[:80%]', batch_size=self.bs, as_supervised=True)
val_ds = tfds.load('imdb_reviews', split='train[80%:]', batch_size=self.bs, as_supervised=True)
test_ds = tfds.load('imdb_reviews', split='test', batch_size=self.bs, as_supervised=True)
if self.vectorize:
# vectorize data
train_text = train_ds.map(lambda text, labels: text)
self.vectorize_data.adapt(train_text)
train_ds = train_ds.map(vectorize)
val_ds = val_ds.map(vectorize)
test_ds = test_ds.map(vectorize)
# configure datasets for performance
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)
test_ds = configure_dataset(test_ds)
elif self.name == 'sst2':
self.sequence_length = 64
self.vectorize_data = tf.keras.layers.TextVectorization(
standardize=custom_standardization,
max_tokens=self.vocab_size,
output_mode='int',
output_sequence_length=self.sequence_length)
# load data
train_ds = tfds.load('glue/sst2', split='train[:80%]', batch_size=self.bs)
val_ds = tfds.load('glue/sst2', split='train[80%:]', batch_size=self.bs)
test_ds = tfds.load('glue/sst2', split='validation', batch_size=self.bs)
train_ds = train_ds.map(transform_data)
val_ds = val_ds.map(transform_data)
test_ds = test_ds.map(transform_data)
if self.vectorize:
# vectorize data
train_text = train_ds.map(lambda text, labels: text)
self.vectorize_data.adapt(train_text)
train_ds = train_ds.map(vectorize)
val_ds = val_ds.map(vectorize)
test_ds = test_ds.map(vectorize)
# configure datasets for performance
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)
test_ds = configure_dataset(test_ds)
elif self.name == 'arxiv':
self.sequence_length = 8192
self.vectorize_data = tf.keras.layers.TextVectorization(
standardize=custom_standardization,
max_tokens=self.vocab_size,
output_mode='int',
output_sequence_length=self.sequence_length)
# load data
dataset = tf.keras.utils.text_dataset_from_directory(self.path)
train_ds, val_ds, test_ds = get_dataset_partitions_tf(dataset, len(dataset))
if self.vectorize:
# vectorize data
train_text = train_ds.map(lambda text, labels: text)
self.vectorize_data.adapt(train_text)
train_ds = train_ds.map(vectorize)
val_ds = val_ds.map(vectorize)
test_ds = test_ds.map(vectorize)
# configure datasets for performance
train_ds = configure_dataset(train_ds)
val_ds = configure_dataset(val_ds)
test_ds = configure_dataset(test_ds)
#self.num_labels = len(dataset.class_names)
return {'train': train_ds,
'val': val_ds,
'test': test_ds}
def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
return tf.strings.regex_replace(stripped_html, '[%s]' % re.escape(string.punctuation), '')
def configure_dataset(dataset):
return dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
def get_dataset_partitions_tf(ds, ds_size, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=1000):
assert (train_split + test_split + val_split) == 1
if shuffle:
# Specify seed to always have the same split distribution between runs
ds = ds.shuffle(shuffle_size, seed=12)
train_size = int(train_split * ds_size)
val_size = int(val_split * ds_size)
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size).take(val_size)
test_ds = ds.skip(train_size).skip(val_size)
return train_ds, val_ds, test_ds
def main():
pass
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment