From cf32ef25afefbf1b1afef80e1896b60e85ebcd85 Mon Sep 17 00:00:00 2001 From: finn <finn@hillengass.de> Date: Tue, 27 Feb 2024 19:33:53 +0100 Subject: [PATCH] Add script to generate annotations for dialogues --- src/one-shot/generate_annotations.py | 158 +++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 src/one-shot/generate_annotations.py diff --git a/src/one-shot/generate_annotations.py b/src/one-shot/generate_annotations.py new file mode 100644 index 0000000..be5250f --- /dev/null +++ b/src/one-shot/generate_annotations.py @@ -0,0 +1,158 @@ +import json +import argparse +import requests +import pandas as pd +import csv +import re +from tqdm import tqdm +from schema import MultiWOZ + +def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = False) -> requests.Response: + """ + Send a POST request to the specified API URL with the given prompt. + + Args: + prompt (str): The prompt to send in the request. + api_url (str): The API endpoint URL. + n (int): Number of responses to generate. + stream (bool): Whether to stream the response. + + Returns: + requests.Response: The response from the API. + """ + headers = {"User-Agent": "Test Client"} + payload = { + "prompt": prompt, + "n": n, + "use_beam_search": False, + "temperature": 0.0, + "schema": MultiWOZ.model_json_schema(), + "max_tokens": 4000, + "stream": stream, + } + try: + response = requests.post(api_url, headers=headers, json=payload, stream=True) + response.raise_for_status() + return response + except requests.RequestException as e: + print(f"HTTP Request failed: {e}") + return None + +def get_response(response: requests.Response) -> list: + """ + Extract the response text from the API response. + + Args: + response (requests.Response): The response object from the API. + + Returns: + list: A list of response texts. + """ + if not response.content: + print("Error: Received empty response from the API.") + return [] + + try: + data = json.loads(response.content) + return data["text"] + except json.JSONDecodeError as e: + print(f"JSON Decode Error: {e}") + print("Received response content:", response.content) + return [] + +def process_annotation(annotation: str) -> dict: + """ + Processes and cleans up the annotation string to convert it into a valid JSON object. + + Args: + annotation (str): The annotation string to process. + + Returns: + dict: A JSON object representing the annotation. + """ + annotation = annotation.replace("\n", "") + annotation = re.sub(r"\s+", " ", annotation) + try: + annotation_json = json.loads(annotation) + except json.JSONDecodeError: + annotation += "}" + try: + annotation_json = json.loads(annotation) + except json.JSONDecodeError: + print(f"PARSING ERROR: {annotation}") + return {} + + return annotation_json + +def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame: + """ + Annotates the dialogues in the DataFrame using the specified API and schema. + + Args: + df (pd.DataFrame): The DataFrame containing the dialogues. + api_url (str): The API endpoint URL for generating annotations. + schema (dict): The schema used for annotation. + + Returns: + pd.DataFrame: The DataFrame with added annotations. + """ + for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Annotating Dialogues"): + prompt_original = llama_prompt_no_schema(row['Complete Dialogue']) + response_original = post_http_request(prompt_original, api_url, n=1, stream=False) + prompt_generated = llama_prompt_no_schema(row['utterances_joined']) + response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False) + if response_original: + output_original = get_response(response_original) + annotation_original = process_annotation(output_original) + df.at[index, 'original_annotation'] = json.dumps(annotation_original) + if response_generated: + output_generated = get_response(response_generated) + annotation_generated = process_annotation(output_generated) + df.at[index, 'generated_annotation'] = json.dumps(annotation_generated) + + return df + +def llama_prompt_no_schema(dialog: str) -> str: + """ + Creates a prompt for the llama model without using a schema. + + Args: + dialog (str): The dialogue to include in the prompt. + + Returns: + str: The formatted prompt string. + """ + return f""" + <s>[INST] <<SYS>> + You are a helpful annotator. You read the text carefully and annotate all valid feels in the schema. + + Make sure to only annotate attractions like museums, clubs or other tourist attractions as such. + + If you are not sure with an annotation you should annotate None instead. + <</SYS>> + + {dialog} [/INST] + """ + +def main(args): + """ + Main function to execute the script. + + Args: + args: Command line arguments. + """ + api_url = f"http://{args.host}:{args.port}/generate" + input_file = f"../../data/own_data/dialogues/{args.input_name}" + df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/') + df = annotate_dialogues(df, api_url) + output_file = f"../../data/own_data/dialogues/{args.output_name}" + df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Script to annotate dialogues.") + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument('--input_name', type=str, required=True, help='input dataframe name.') + parser.add_argument('--output_name', type=str, required=True, help='Output DataFrame name.') + args = parser.parse_args() + main(args) -- GitLab