From 00175ac721f79e9292f4bd5854594b22b13199c9 Mon Sep 17 00:00:00 2001 From: finn <finn@hillengass.de> Date: Tue, 27 Feb 2024 00:26:17 +0100 Subject: [PATCH] Reformat script to extract domain knowledge --- src/one-shot/extract_domain_knowledge.py | 160 +++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 src/one-shot/extract_domain_knowledge.py diff --git a/src/one-shot/extract_domain_knowledge.py b/src/one-shot/extract_domain_knowledge.py new file mode 100644 index 0000000..2469d8a --- /dev/null +++ b/src/one-shot/extract_domain_knowledge.py @@ -0,0 +1,160 @@ +import json +import os +import argparse + + +def read_json_file(file_path): + """ + Reads a JSON file and returns its content. + + Args: + file_path (str): Path to the JSON file. + + Returns: + dict: Content of the JSON file. + """ + with open(file_path, 'r', encoding='utf-8') as file: + return json.load(file) + + +def extract_dialogue_info(dialogues): + """ + Extracts and formats dialogue information from the dialogues data. + + Args: + dialogues (list): A list of dialogues, each containing multiple turns. + + Returns: + list: A list of formatted dialogues with service details. + """ + formatted_dialogues = [] + + for dialogue in dialogues: + services_involved = set() + service_details = {} + complete_dialogue = [] # Store the complete dialogue + + for turn in dialogue["turns"]: + # Add the speaker and utterance to the complete dialogue + complete_dialogue.append(f"{turn['speaker']}: {turn['utterance']}") + + if turn["speaker"] == "USER": + for frame in turn.get("frames", []): + service = frame["service"] + + if "state" in frame: + intent = frame["state"].get("active_intent", "NONE") + slot_values = frame["state"].get("slot_values", {}) + + # Skip services with no actions or empty preferences + if intent == "NONE" and not slot_values: + continue + + services_involved.add(service) + + # Consolidate information for each service type + if service not in service_details: + service_details[service] = { + "Actions": set(), + "Preferences": {} + } + + service_details[service]["Actions"].add(intent) + for key, value in slot_values.items(): + if key not in service_details[service]["Preferences"]: + service_details[service]["Preferences"][key] = set() + service_details[service]["Preferences"][key].update(value) + + # Format the details for each service + for details in service_details.values(): + details["Actions"] = list(details["Actions"]) + for key in details["Preferences"]: + details["Preferences"][key] = list(details["Preferences"][key]) + + service_info = { + "Services Involved": list(services_involved), + "Details": list(service_details.values()) + } + + dialogue_info = { + "Dialogue ID": dialogue['dialogue_id'], + "Complete Dialogue": complete_dialogue, + "Service Information": service_info # Combined services and details + } + + formatted_dialogues.append(dialogue_info) + + return formatted_dialogues + + +def write_to_file(filename, data): + """ + Writes the given data to a file in JSON format. + + Args: + filename (str): The name of the file where data will be written. + data (dict): The data to be written into the file. + """ + with open(filename, 'w', encoding='utf-8') as file: + json.dump(data, file, indent=4) + + +def process_dialogues(dialogue_count='FULL'): + """ + Processes dialogues from the MultiWOZ dataset and writes the processed information to a file. + + Args: + dialogue_count (str or int): The number of dialogues to process. 'FULL' for all dialogues, + or an integer for a specific number. + """ + base_path = '../../data/multiwoz_2.2/train' + file_template = 'dialogues_{:03d}.json' + output_base_path = '../../data/own_data/one-shot/domain_knowledge' + max_files = 16 + num_dialogues_per_file = 512 + max_dialogues = max_files * num_dialogues_per_file + + if dialogue_count == 'FULL': + files_to_process = range(1, max_files + 1) + else: + total_dialogues = min(dialogue_count, max_dialogues) + files_needed = (total_dialogues + num_dialogues_per_file - 1) // num_dialogues_per_file + files_to_process = range(1, files_needed + 1) + + all_formatted_dialogues = [] + + for file_num in files_to_process: + file_path = os.path.join(base_path, file_template.format(file_num)) + dialogues = read_json_file(file_path) + formatted_dialogues = extract_dialogue_info(dialogues) + all_formatted_dialogues.extend(formatted_dialogues) + + if dialogue_count != 'FULL' and len(all_formatted_dialogues) >= dialogue_count: + all_formatted_dialogues = all_formatted_dialogues[:dialogue_count] + break + + output_file_name = f"domain_knowledge_{dialogue_count if dialogue_count != 'FULL' else 'ALL'}.json" + output_file_path = os.path.join(output_base_path, output_file_name) + + write_to_file(output_file_path, all_formatted_dialogues) + +def main(): + parser = argparse.ArgumentParser(description="Process MultiWOZ dialogues.") + parser.add_argument('-q', '--quantity', type=str, default='FULL', + help="Number of dialogues to process. Use 'FULL' to process all. Max: 8192.") + + args = parser.parse_args() + + dialogue_count = args.quantity + if dialogue_count != 'FULL': + try: + dialogue_count = int(dialogue_count) + except ValueError: + raise ValueError("Count must be an integer or 'FULL'.") + if dialogue_count > 8192: + raise ValueError("Count cannot exceed 8192.") + + process_dialogues(dialogue_count) + +if __name__ == "__main__": + main() -- GitLab