diff --git a/elise/elise/mwoz_annotation.py b/elise/elise/mwoz_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..8c10297c42ab215c588a71d14f3c85459a77d3d8 --- /dev/null +++ b/elise/elise/mwoz_annotation.py @@ -0,0 +1,70 @@ +from enum import Enum +from outlines import models, prompt +from outlines.fsm import json_schema +from pydantic import BaseModel +from typing import Union, Optional +from typing_extensions import List, Literal +import json +import requests +import re + +from .schema import MultiWOZ + + +### Prompt +@prompt +def llama_prompt_no_schema(dialog): + """ +<s>[INST] <<SYS>> +You are a helpful annotator. You read the text carefully and annotate all valid feels in the schema. + +Make sure to only annotate attractions like museums, clubs or other tourist attractions as such. + +If you are not sure with an annotation you should annotate None instead. +<</SYS>> + +{{dialog}} [/INST] +""" + +### Requests +def main(): + # Read dialogue data + with open("elise/output_dialogues_50.json", "r") as file: + data = json.load(file) + dialogues = [d["dialogue"][0] for d in data] + + # Request annotations + for dia in dialogues: + prompt = llama_prompt_no_schema(dia) + + # Send request to vLLM server + response = requests.post( + "http://localhost:8000/generate", + json = { + "prompt": prompt, + "schema": MultiWOZ.model_json_schema(), + "max_tokens": 1024, # Find reasonable limit + "n": 1 + } + ) + + for reply in response.json()["text"]: + annotation = reply.split("[/INST]")[1] + # Cleanup the whitespace by some erroneous generations + annotation = annotation.replace("\n", "") + annotation = re.sub(r"\s+", " ", annotation) + with open("output_annotations.txt", "a") as file: + # Catch annotation errors like invalid json. + # Most of the time only a closing bracket is missing. + try: + annotation_json = json.loads(annotation) + file.write(json.dumps(annotation_json)) + file.write("\n") + except: + annotation = annotation + "}" + try: + annotation_json = json.loads(annotation) + file.write(json.dumps(annotation_json)) + file.write("\n") + except: + file.write(f"PARSING ERROR: {annotation}\n") \ No newline at end of file