From cf32ef25afefbf1b1afef80e1896b60e85ebcd85 Mon Sep 17 00:00:00 2001
From: finn <finn@hillengass.de>
Date: Tue, 27 Feb 2024 19:33:53 +0100
Subject: [PATCH] Add script to generate annotations for dialogues

---
 src/one-shot/generate_annotations.py | 158 +++++++++++++++++++++++++++
 1 file changed, 158 insertions(+)
 create mode 100644 src/one-shot/generate_annotations.py

diff --git a/src/one-shot/generate_annotations.py b/src/one-shot/generate_annotations.py
new file mode 100644
index 0000000..be5250f
--- /dev/null
+++ b/src/one-shot/generate_annotations.py
@@ -0,0 +1,158 @@
+import json
+import argparse
+import requests
+import pandas as pd
+import csv
+import re
+from tqdm import tqdm
+from schema import MultiWOZ
+
+def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = False) -> requests.Response:
+    """
+    Send a POST request to the specified API URL with the given prompt.
+
+    Args:
+        prompt (str): The prompt to send in the request.
+        api_url (str): The API endpoint URL.
+        n (int): Number of responses to generate.
+        stream (bool): Whether to stream the response.
+
+    Returns:
+        requests.Response: The response from the API.
+    """
+    headers = {"User-Agent": "Test Client"}
+    payload = {
+        "prompt": prompt,
+        "n": n,
+        "use_beam_search": False,
+        "temperature": 0.0,
+        "schema": MultiWOZ.model_json_schema(),
+        "max_tokens": 4000,
+        "stream": stream,
+    }
+    try:
+        response = requests.post(api_url, headers=headers, json=payload, stream=True)
+        response.raise_for_status()
+        return response
+    except requests.RequestException as e:
+        print(f"HTTP Request failed: {e}")
+        return None
+
+def get_response(response: requests.Response) -> list:
+    """
+    Extract the response text from the API response.
+
+    Args:
+        response (requests.Response): The response object from the API.
+
+    Returns:
+        list: A list of response texts.
+    """
+    if not response.content:
+        print("Error: Received empty response from the API.")
+        return []
+
+    try:
+        data = json.loads(response.content)
+        return data["text"]
+    except json.JSONDecodeError as e:
+        print(f"JSON Decode Error: {e}")
+        print("Received response content:", response.content)
+        return []
+
+def process_annotation(annotation: str) -> dict:
+    """
+    Processes and cleans up the annotation string to convert it into a valid JSON object.
+
+    Args:
+        annotation (str): The annotation string to process.
+
+    Returns:
+        dict: A JSON object representing the annotation.
+    """
+    annotation = annotation.replace("\n", "")
+    annotation = re.sub(r"\s+", " ", annotation)
+    try:
+        annotation_json = json.loads(annotation)
+    except json.JSONDecodeError:
+        annotation += "}"
+        try:
+            annotation_json = json.loads(annotation)
+        except json.JSONDecodeError:
+            print(f"PARSING ERROR: {annotation}")
+            return {}
+
+    return annotation_json
+
+def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame:
+    """
+    Annotates the dialogues in the DataFrame using the specified API and schema.
+
+    Args:
+        df (pd.DataFrame): The DataFrame containing the dialogues.
+        api_url (str): The API endpoint URL for generating annotations.
+        schema (dict): The schema used for annotation.
+
+    Returns:
+        pd.DataFrame: The DataFrame with added annotations.
+    """
+    for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Annotating Dialogues"):
+        prompt_original = llama_prompt_no_schema(row['Complete Dialogue'])
+        response_original = post_http_request(prompt_original, api_url, n=1, stream=False)
+        prompt_generated = llama_prompt_no_schema(row['utterances_joined'])
+        response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False)
+        if response_original:
+            output_original = get_response(response_original)
+            annotation_original = process_annotation(output_original)
+            df.at[index, 'original_annotation'] = json.dumps(annotation_original)
+        if response_generated:
+            output_generated = get_response(response_generated)
+            annotation_generated = process_annotation(output_generated)
+            df.at[index, 'generated_annotation'] = json.dumps(annotation_generated)
+
+    return df
+
+def llama_prompt_no_schema(dialog: str) -> str:
+    """
+    Creates a prompt for the llama model without using a schema.
+
+    Args:
+        dialog (str): The dialogue to include in the prompt.
+
+    Returns:
+        str: The formatted prompt string.
+    """
+    return f"""
+    <s>[INST] <<SYS>>
+    You are a helpful annotator. You read the text carefully and annotate all valid feels in the schema.
+
+    Make sure to only annotate attractions like museums, clubs or other tourist attractions as such.
+
+    If you are not sure with an annotation you should annotate None instead.
+    <</SYS>>
+
+    {dialog} [/INST]
+    """
+
+def main(args):
+    """
+    Main function to execute the script.
+
+    Args:
+        args: Command line arguments.
+    """
+    api_url = f"http://{args.host}:{args.port}/generate"
+    input_file = f"../../data/own_data/dialogues/{args.input_name}"
+    df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/')
+    df = annotate_dialogues(df, api_url)
+    output_file = f"../../data/own_data/dialogues/{args.output_name}"
+    df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/')
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Script to annotate dialogues.")
+    parser.add_argument("--host", type=str, default="localhost")
+    parser.add_argument("--port", type=int, default=8000)
+    parser.add_argument('--input_name', type=str, required=True, help='input dataframe name.')
+    parser.add_argument('--output_name', type=str, required=True, help='Output DataFrame name.')
+    args = parser.parse_args()
+    main(args)
-- 
GitLab