Skip to content
Snippets Groups Projects
Commit 00175ac7 authored by finn's avatar finn
Browse files

Reformat script to extract domain knowledge

parent 3f9a4fb7
No related branches found
No related tags found
No related merge requests found
import json
import os
import argparse
def read_json_file(file_path):
"""
Reads a JSON file and returns its content.
Args:
file_path (str): Path to the JSON file.
Returns:
dict: Content of the JSON file.
"""
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
def extract_dialogue_info(dialogues):
"""
Extracts and formats dialogue information from the dialogues data.
Args:
dialogues (list): A list of dialogues, each containing multiple turns.
Returns:
list: A list of formatted dialogues with service details.
"""
formatted_dialogues = []
for dialogue in dialogues:
services_involved = set()
service_details = {}
complete_dialogue = [] # Store the complete dialogue
for turn in dialogue["turns"]:
# Add the speaker and utterance to the complete dialogue
complete_dialogue.append(f"{turn['speaker']}: {turn['utterance']}")
if turn["speaker"] == "USER":
for frame in turn.get("frames", []):
service = frame["service"]
if "state" in frame:
intent = frame["state"].get("active_intent", "NONE")
slot_values = frame["state"].get("slot_values", {})
# Skip services with no actions or empty preferences
if intent == "NONE" and not slot_values:
continue
services_involved.add(service)
# Consolidate information for each service type
if service not in service_details:
service_details[service] = {
"Actions": set(),
"Preferences": {}
}
service_details[service]["Actions"].add(intent)
for key, value in slot_values.items():
if key not in service_details[service]["Preferences"]:
service_details[service]["Preferences"][key] = set()
service_details[service]["Preferences"][key].update(value)
# Format the details for each service
for details in service_details.values():
details["Actions"] = list(details["Actions"])
for key in details["Preferences"]:
details["Preferences"][key] = list(details["Preferences"][key])
service_info = {
"Services Involved": list(services_involved),
"Details": list(service_details.values())
}
dialogue_info = {
"Dialogue ID": dialogue['dialogue_id'],
"Complete Dialogue": complete_dialogue,
"Service Information": service_info # Combined services and details
}
formatted_dialogues.append(dialogue_info)
return formatted_dialogues
def write_to_file(filename, data):
"""
Writes the given data to a file in JSON format.
Args:
filename (str): The name of the file where data will be written.
data (dict): The data to be written into the file.
"""
with open(filename, 'w', encoding='utf-8') as file:
json.dump(data, file, indent=4)
def process_dialogues(dialogue_count='FULL'):
"""
Processes dialogues from the MultiWOZ dataset and writes the processed information to a file.
Args:
dialogue_count (str or int): The number of dialogues to process. 'FULL' for all dialogues,
or an integer for a specific number.
"""
base_path = '../../data/multiwoz_2.2/train'
file_template = 'dialogues_{:03d}.json'
output_base_path = '../../data/own_data/one-shot/domain_knowledge'
max_files = 16
num_dialogues_per_file = 512
max_dialogues = max_files * num_dialogues_per_file
if dialogue_count == 'FULL':
files_to_process = range(1, max_files + 1)
else:
total_dialogues = min(dialogue_count, max_dialogues)
files_needed = (total_dialogues + num_dialogues_per_file - 1) // num_dialogues_per_file
files_to_process = range(1, files_needed + 1)
all_formatted_dialogues = []
for file_num in files_to_process:
file_path = os.path.join(base_path, file_template.format(file_num))
dialogues = read_json_file(file_path)
formatted_dialogues = extract_dialogue_info(dialogues)
all_formatted_dialogues.extend(formatted_dialogues)
if dialogue_count != 'FULL' and len(all_formatted_dialogues) >= dialogue_count:
all_formatted_dialogues = all_formatted_dialogues[:dialogue_count]
break
output_file_name = f"domain_knowledge_{dialogue_count if dialogue_count != 'FULL' else 'ALL'}.json"
output_file_path = os.path.join(output_base_path, output_file_name)
write_to_file(output_file_path, all_formatted_dialogues)
def main():
parser = argparse.ArgumentParser(description="Process MultiWOZ dialogues.")
parser.add_argument('-q', '--quantity', type=str, default='FULL',
help="Number of dialogues to process. Use 'FULL' to process all. Max: 8192.")
args = parser.parse_args()
dialogue_count = args.quantity
if dialogue_count != 'FULL':
try:
dialogue_count = int(dialogue_count)
except ValueError:
raise ValueError("Count must be an integer or 'FULL'.")
if dialogue_count > 8192:
raise ValueError("Count cannot exceed 8192.")
process_dialogues(dialogue_count)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment