From 00175ac721f79e9292f4bd5854594b22b13199c9 Mon Sep 17 00:00:00 2001
From: finn <finn@hillengass.de>
Date: Tue, 27 Feb 2024 00:26:17 +0100
Subject: [PATCH] Reformat script to extract domain knowledge

---
 src/one-shot/extract_domain_knowledge.py | 160 +++++++++++++++++++++++
 1 file changed, 160 insertions(+)
 create mode 100644 src/one-shot/extract_domain_knowledge.py

diff --git a/src/one-shot/extract_domain_knowledge.py b/src/one-shot/extract_domain_knowledge.py
new file mode 100644
index 0000000..2469d8a
--- /dev/null
+++ b/src/one-shot/extract_domain_knowledge.py
@@ -0,0 +1,160 @@
+import json
+import os
+import argparse
+
+
+def read_json_file(file_path):
+    """
+    Reads a JSON file and returns its content.
+
+    Args:
+    file_path (str): Path to the JSON file.
+
+    Returns:
+    dict: Content of the JSON file.
+    """
+    with open(file_path, 'r', encoding='utf-8') as file:
+        return json.load(file)
+
+
+def extract_dialogue_info(dialogues):
+    """
+    Extracts and formats dialogue information from the dialogues data.
+
+    Args:
+    dialogues (list): A list of dialogues, each containing multiple turns.
+
+    Returns:
+    list: A list of formatted dialogues with service details.
+    """
+    formatted_dialogues = []
+
+    for dialogue in dialogues:
+        services_involved = set()
+        service_details = {}
+        complete_dialogue = []  # Store the complete dialogue
+
+        for turn in dialogue["turns"]:
+            # Add the speaker and utterance to the complete dialogue
+            complete_dialogue.append(f"{turn['speaker']}: {turn['utterance']}")
+
+            if turn["speaker"] == "USER":
+                for frame in turn.get("frames", []):
+                    service = frame["service"]
+
+                    if "state" in frame:
+                        intent = frame["state"].get("active_intent", "NONE")
+                        slot_values = frame["state"].get("slot_values", {})
+
+                        # Skip services with no actions or empty preferences
+                        if intent == "NONE" and not slot_values:
+                            continue
+
+                        services_involved.add(service)
+
+                        # Consolidate information for each service type
+                        if service not in service_details:
+                            service_details[service] = {
+                                "Actions": set(),
+                                "Preferences": {}
+                            }
+                        
+                        service_details[service]["Actions"].add(intent)
+                        for key, value in slot_values.items():
+                            if key not in service_details[service]["Preferences"]:
+                                service_details[service]["Preferences"][key] = set()
+                            service_details[service]["Preferences"][key].update(value)
+
+        # Format the details for each service
+        for details in service_details.values():
+            details["Actions"] = list(details["Actions"])
+            for key in details["Preferences"]:
+                details["Preferences"][key] = list(details["Preferences"][key])
+
+        service_info = {
+            "Services Involved": list(services_involved),
+            "Details": list(service_details.values())
+        }
+
+        dialogue_info = {
+            "Dialogue ID": dialogue['dialogue_id'],
+            "Complete Dialogue": complete_dialogue,
+            "Service Information": service_info  # Combined services and details
+        }
+
+        formatted_dialogues.append(dialogue_info)
+
+    return formatted_dialogues
+
+
+def write_to_file(filename, data):
+    """
+    Writes the given data to a file in JSON format.
+
+    Args:
+    filename (str): The name of the file where data will be written.
+    data (dict): The data to be written into the file.
+    """
+    with open(filename, 'w', encoding='utf-8') as file:
+        json.dump(data, file, indent=4)
+
+
+def process_dialogues(dialogue_count='FULL'):
+    """
+    Processes dialogues from the MultiWOZ dataset and writes the processed information to a file.
+
+    Args:
+    dialogue_count (str or int): The number of dialogues to process. 'FULL' for all dialogues, 
+                                 or an integer for a specific number.
+    """
+    base_path = '../../data/multiwoz_2.2/train'
+    file_template = 'dialogues_{:03d}.json'
+    output_base_path = '../../data/own_data/one-shot/domain_knowledge'
+    max_files = 16
+    num_dialogues_per_file = 512
+    max_dialogues = max_files * num_dialogues_per_file
+
+    if dialogue_count == 'FULL':
+        files_to_process = range(1, max_files + 1)
+    else:
+        total_dialogues = min(dialogue_count, max_dialogues)
+        files_needed = (total_dialogues + num_dialogues_per_file - 1) // num_dialogues_per_file
+        files_to_process = range(1, files_needed + 1)
+
+    all_formatted_dialogues = []
+
+    for file_num in files_to_process:
+        file_path = os.path.join(base_path, file_template.format(file_num))
+        dialogues = read_json_file(file_path)
+        formatted_dialogues = extract_dialogue_info(dialogues)
+        all_formatted_dialogues.extend(formatted_dialogues)
+
+        if dialogue_count != 'FULL' and len(all_formatted_dialogues) >= dialogue_count:
+            all_formatted_dialogues = all_formatted_dialogues[:dialogue_count]
+            break
+
+    output_file_name = f"domain_knowledge_{dialogue_count if dialogue_count != 'FULL' else 'ALL'}.json"
+    output_file_path = os.path.join(output_base_path, output_file_name)
+
+    write_to_file(output_file_path, all_formatted_dialogues)
+
+def main():
+    parser = argparse.ArgumentParser(description="Process MultiWOZ dialogues.")
+    parser.add_argument('-q', '--quantity', type=str, default='FULL', 
+                        help="Number of dialogues to process. Use 'FULL' to process all. Max: 8192.")
+
+    args = parser.parse_args()
+
+    dialogue_count = args.quantity
+    if dialogue_count != 'FULL':
+        try:
+            dialogue_count = int(dialogue_count)
+        except ValueError:
+            raise ValueError("Count must be an integer or 'FULL'.")
+        if dialogue_count > 8192:
+            raise ValueError("Count cannot exceed 8192.")
+
+    process_dialogues(dialogue_count)
+
+if __name__ == "__main__":
+    main()
-- 
GitLab