diff --git a/src/one-shot/generate_annotations.py b/src/one-shot/generate_annotations.py index 3746412b5b741045a509a53e5d608d78daeb5913..62e50376c3cbece2ba5829bbf89de17e1e3d8c1a 100644 --- a/src/one-shot/generate_annotations.py +++ b/src/one-shot/generate_annotations.py @@ -6,6 +6,7 @@ import csv import re from tqdm import tqdm from schema import MultiWOZ +from outlines import models, prompt def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = False) -> requests.Response: """ @@ -20,33 +21,32 @@ def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = Fals Returns: requests.Response: The response from the API. """ - headers = {"User-Agent": "Test Client"} + #headers = {"User-Agent": "Test Client"} payload = { "prompt": prompt, "n": n, - "use_beam_search": False, - "temperature": 0.0, "schema": MultiWOZ.model_json_schema(), - "max_tokens": 4000, + "max_tokens": 1200, "stream": stream, } try: - response = requests.post(api_url, headers=headers, json=payload, stream=True) + response = requests.post(api_url, json=payload) response.raise_for_status() return response except requests.RequestException as e: print(f"HTTP Request failed: {e}") return None -def get_response(response: requests.Response) -> list: +def get_response(response: requests.Response, prompt: str) -> list: """ - Extract the response text from the API response. + Extract the response text from the API response and remove the prompt if present. Args: response (requests.Response): The response object from the API. + prompt (str): The prompt that was sent in the request. Returns: - list: A list of response texts. + list: A list of response texts without the prompt. """ if not response.content: print("Error: Received empty response from the API.") @@ -54,35 +54,60 @@ def get_response(response: requests.Response) -> list: try: data = json.loads(response.content) - return data["text"] + texts = data["text"] + cleaned_texts = [] + for text in texts: + # Remove prompt from beginning of response if present + if text.startswith(prompt): + text = text[len(prompt):] + cleaned_texts.append(text) + return cleaned_texts except json.JSONDecodeError as e: print(f"JSON Decode Error: {e}") print("Received response content:", response.content) return [] + def process_annotation(annotation: str) -> dict: """ Processes and cleans up the annotation string to convert it into a valid JSON object. + In case of a JSON parsing error, tries to fix by removing the last element. Args: annotation (str): The annotation string to process. Returns: - dict: A JSON object representing the annotation. + dict: A JSON object representing the annotation, or an empty dict if parsing fails. """ - annotation = annotation.replace("\n", "") + + annotation = annotation[0] + # Basic cleanup + annotation = annotation.strip().replace("\n", "") annotation = re.sub(r"\s+", " ", annotation) + + # Handling potentially unclosed JSON and removing trailing commas + if not annotation.endswith('}'): + annotation += "}" + annotation = re.sub(r",\s*}", "}", annotation) + annotation = re.sub(r",\s*\]", "]", annotation) + + # Attempt JSON parsing try: annotation_json = json.loads(annotation) - except json.JSONDecodeError: - annotation += "}" + return annotation_json + except json.JSONDecodeError as e: + # Try to remove the last element and re-parse try: - annotation_json = json.loads(annotation) + last_comma_index = annotation.rfind(',') + if last_comma_index != -1: + fixed_annotation = annotation[:last_comma_index] + "}" + annotation_json = json.loads(fixed_annotation) + return annotation_json except json.JSONDecodeError: - print(f"PARSING ERROR: {annotation}") + # Log the error for debugging + print(f"JSON PARSING ERROR: {e.msg} in {e.doc} at position {e.pos}") return {} - return annotation_json def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame: """ @@ -96,42 +121,44 @@ def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame: Returns: pd.DataFrame: The DataFrame with added annotations. """ + if 'original_annotation' not in df.columns: + df['original_annotation'] = pd.Series(dtype='object') + if 'generated_annotation' not in df.columns: + df['generated_annotation'] = pd.Series(dtype='object') + for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Annotating Dialogues"): prompt_original = llama_prompt_no_schema(row['Original Dialogue']) response_original = post_http_request(prompt_original, api_url, n=1, stream=False) - prompt_generated = llama_prompt_no_schema(row['utterances_joined']) - response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False) if response_original: - output_original = get_response(response_original) + output_original = get_response(response_original, prompt_original) # Pass prompt annotation_original = process_annotation(output_original) + print(json.dumps(annotation_original)) df.at[index, 'original_annotation'] = json.dumps(annotation_original) + print("original done") + prompt_generated = llama_prompt_no_schema(row['utterances_joined']) + response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False) if response_generated: - output_generated = get_response(response_generated) + output_generated = get_response(response_generated, prompt_generated) annotation_generated = process_annotation(output_generated) + print(json.dumps(annotation_generated)) df.at[index, 'generated_annotation'] = json.dumps(annotation_generated) + print("generated done") return df - -def llama_prompt_no_schema(dialog: str) -> str: - """ - Creates a prompt for the llama model without using a schema. - - Args: - dialog (str): The dialogue to include in the prompt. - - Returns: - str: The formatted prompt string. - """ - return f""" - <s>[INST] <<SYS>> - You are a helpful annotator. You read the text carefully and annotate all valid feels in the schema. - - Make sure to only annotate attractions like museums, clubs or other tourist attractions as such. - - If you are not sure with an annotation you should annotate None instead. - <</SYS>> - - {dialog} [/INST] + +### Prompt +@prompt +def llama_prompt_no_schema(dialog): + """ + [INST] <<SYS>> + Your task is to fill out the schema with the information from the dialogue. + Choose the correct value based on the Dialogue content for each slot. + Make sure to only annotate attractions like museums, clubs or other tourist attractions as "attractions". + + If you are not sure about an annotation you should annotate "None". + \n<</SYS>>\n\n + + {{dialog}} [/INST] """ def main(args): @@ -141,11 +168,12 @@ def main(args): Args: args: Command line arguments. """ + schema = MultiWOZ.model_json_schema() api_url = f"http://{args.host}:{args.port}/generate" input_file = f"../../data/own_data/one-shot/dialogues/{args.input_name}" df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/') df = annotate_dialogues(df, api_url) - output_file = f"../../data/own_data/one-shot/dialogues/{args.output_name}" + output_file = f"../../data/own_data/one-shot/annotations/{args.output_name}" df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/') if __name__ == "__main__": @@ -155,4 +183,5 @@ if __name__ == "__main__": parser.add_argument('--input_name', type=str, required=True, help='input dataframe name.') parser.add_argument('--output_name', type=str, required=True, help='Output DataFrame name.') args = parser.parse_args() + main(args)