Skip to content
Snippets Groups Projects
Commit cf32ef25 authored by finn's avatar finn
Browse files

Add script to generate annotations for dialogues

parent f9391e40
No related branches found
No related tags found
No related merge requests found
import json
import argparse
import requests
import pandas as pd
import csv
import re
from tqdm import tqdm
from schema import MultiWOZ
def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = False) -> requests.Response:
"""
Send a POST request to the specified API URL with the given prompt.
Args:
prompt (str): The prompt to send in the request.
api_url (str): The API endpoint URL.
n (int): Number of responses to generate.
stream (bool): Whether to stream the response.
Returns:
requests.Response: The response from the API.
"""
headers = {"User-Agent": "Test Client"}
payload = {
"prompt": prompt,
"n": n,
"use_beam_search": False,
"temperature": 0.0,
"schema": MultiWOZ.model_json_schema(),
"max_tokens": 4000,
"stream": stream,
}
try:
response = requests.post(api_url, headers=headers, json=payload, stream=True)
response.raise_for_status()
return response
except requests.RequestException as e:
print(f"HTTP Request failed: {e}")
return None
def get_response(response: requests.Response) -> list:
"""
Extract the response text from the API response.
Args:
response (requests.Response): The response object from the API.
Returns:
list: A list of response texts.
"""
if not response.content:
print("Error: Received empty response from the API.")
return []
try:
data = json.loads(response.content)
return data["text"]
except json.JSONDecodeError as e:
print(f"JSON Decode Error: {e}")
print("Received response content:", response.content)
return []
def process_annotation(annotation: str) -> dict:
"""
Processes and cleans up the annotation string to convert it into a valid JSON object.
Args:
annotation (str): The annotation string to process.
Returns:
dict: A JSON object representing the annotation.
"""
annotation = annotation.replace("\n", "")
annotation = re.sub(r"\s+", " ", annotation)
try:
annotation_json = json.loads(annotation)
except json.JSONDecodeError:
annotation += "}"
try:
annotation_json = json.loads(annotation)
except json.JSONDecodeError:
print(f"PARSING ERROR: {annotation}")
return {}
return annotation_json
def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame:
"""
Annotates the dialogues in the DataFrame using the specified API and schema.
Args:
df (pd.DataFrame): The DataFrame containing the dialogues.
api_url (str): The API endpoint URL for generating annotations.
schema (dict): The schema used for annotation.
Returns:
pd.DataFrame: The DataFrame with added annotations.
"""
for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Annotating Dialogues"):
prompt_original = llama_prompt_no_schema(row['Complete Dialogue'])
response_original = post_http_request(prompt_original, api_url, n=1, stream=False)
prompt_generated = llama_prompt_no_schema(row['utterances_joined'])
response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False)
if response_original:
output_original = get_response(response_original)
annotation_original = process_annotation(output_original)
df.at[index, 'original_annotation'] = json.dumps(annotation_original)
if response_generated:
output_generated = get_response(response_generated)
annotation_generated = process_annotation(output_generated)
df.at[index, 'generated_annotation'] = json.dumps(annotation_generated)
return df
def llama_prompt_no_schema(dialog: str) -> str:
"""
Creates a prompt for the llama model without using a schema.
Args:
dialog (str): The dialogue to include in the prompt.
Returns:
str: The formatted prompt string.
"""
return f"""
<s>[INST] <<SYS>>
You are a helpful annotator. You read the text carefully and annotate all valid feels in the schema.
Make sure to only annotate attractions like museums, clubs or other tourist attractions as such.
If you are not sure with an annotation you should annotate None instead.
<</SYS>>
{dialog} [/INST]
"""
def main(args):
"""
Main function to execute the script.
Args:
args: Command line arguments.
"""
api_url = f"http://{args.host}:{args.port}/generate"
input_file = f"../../data/own_data/dialogues/{args.input_name}"
df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/')
df = annotate_dialogues(df, api_url)
output_file = f"../../data/own_data/dialogues/{args.output_name}"
df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script to annotate dialogues.")
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument('--input_name', type=str, required=True, help='input dataframe name.')
parser.add_argument('--output_name', type=str, required=True, help='Output DataFrame name.')
args = parser.parse_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment