diff --git a/src/one-shot/generate_annotations.py b/src/one-shot/generate_annotations.py
index 3746412b5b741045a509a53e5d608d78daeb5913..62e50376c3cbece2ba5829bbf89de17e1e3d8c1a 100644
--- a/src/one-shot/generate_annotations.py
+++ b/src/one-shot/generate_annotations.py
@@ -6,6 +6,7 @@ import csv
import re
from tqdm import tqdm
from schema import MultiWOZ
+from outlines import models, prompt
def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = False) -> requests.Response:
"""
@@ -20,33 +21,32 @@ def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = Fals
Returns:
requests.Response: The response from the API.
"""
- headers = {"User-Agent": "Test Client"}
+ #headers = {"User-Agent": "Test Client"}
payload = {
"prompt": prompt,
"n": n,
- "use_beam_search": False,
- "temperature": 0.0,
"schema": MultiWOZ.model_json_schema(),
- "max_tokens": 4000,
+ "max_tokens": 1200,
"stream": stream,
}
try:
- response = requests.post(api_url, headers=headers, json=payload, stream=True)
+ response = requests.post(api_url, json=payload)
response.raise_for_status()
return response
except requests.RequestException as e:
print(f"HTTP Request failed: {e}")
return None
-def get_response(response: requests.Response) -> list:
+def get_response(response: requests.Response, prompt: str) -> list:
"""
- Extract the response text from the API response.
+ Extract the response text from the API response and remove the prompt if present.
Args:
response (requests.Response): The response object from the API.
+ prompt (str): The prompt that was sent in the request.
Returns:
- list: A list of response texts.
+ list: A list of response texts without the prompt.
"""
if not response.content:
print("Error: Received empty response from the API.")
@@ -54,35 +54,60 @@ def get_response(response: requests.Response) -> list:
try:
data = json.loads(response.content)
- return data["text"]
+ texts = data["text"]
+ cleaned_texts = []
+ for text in texts:
+ # Remove prompt from beginning of response if present
+ if text.startswith(prompt):
+ text = text[len(prompt):]
+ cleaned_texts.append(text)
+ return cleaned_texts
except json.JSONDecodeError as e:
print(f"JSON Decode Error: {e}")
print("Received response content:", response.content)
return []
+
def process_annotation(annotation: str) -> dict:
"""
Processes and cleans up the annotation string to convert it into a valid JSON object.
+ In case of a JSON parsing error, tries to fix by removing the last element.
Args:
annotation (str): The annotation string to process.
Returns:
- dict: A JSON object representing the annotation.
+ dict: A JSON object representing the annotation, or an empty dict if parsing fails.
"""
- annotation = annotation.replace("\n", "")
+
+ annotation = annotation[0]
+ # Basic cleanup
+ annotation = annotation.strip().replace("\n", "")
annotation = re.sub(r"\s+", " ", annotation)
+
+ # Handling potentially unclosed JSON and removing trailing commas
+ if not annotation.endswith('}'):
+ annotation += "}"
+ annotation = re.sub(r",\s*}", "}", annotation)
+ annotation = re.sub(r",\s*\]", "]", annotation)
+
+ # Attempt JSON parsing
try:
annotation_json = json.loads(annotation)
- except json.JSONDecodeError:
- annotation += "}"
+ return annotation_json
+ except json.JSONDecodeError as e:
+ # Try to remove the last element and re-parse
try:
- annotation_json = json.loads(annotation)
+ last_comma_index = annotation.rfind(',')
+ if last_comma_index != -1:
+ fixed_annotation = annotation[:last_comma_index] + "}"
+ annotation_json = json.loads(fixed_annotation)
+ return annotation_json
except json.JSONDecodeError:
- print(f"PARSING ERROR: {annotation}")
+ # Log the error for debugging
+ print(f"JSON PARSING ERROR: {e.msg} in {e.doc} at position {e.pos}")
return {}
- return annotation_json
def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame:
"""
@@ -96,42 +121,44 @@ def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame:
Returns:
pd.DataFrame: The DataFrame with added annotations.
"""
+ if 'original_annotation' not in df.columns:
+ df['original_annotation'] = pd.Series(dtype='object')
+ if 'generated_annotation' not in df.columns:
+ df['generated_annotation'] = pd.Series(dtype='object')
+
for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Annotating Dialogues"):
prompt_original = llama_prompt_no_schema(row['Original Dialogue'])
response_original = post_http_request(prompt_original, api_url, n=1, stream=False)
- prompt_generated = llama_prompt_no_schema(row['utterances_joined'])
- response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False)
if response_original:
- output_original = get_response(response_original)
+ output_original = get_response(response_original, prompt_original) # Pass prompt
annotation_original = process_annotation(output_original)
+ print(json.dumps(annotation_original))
df.at[index, 'original_annotation'] = json.dumps(annotation_original)
+ print("original done")
+ prompt_generated = llama_prompt_no_schema(row['utterances_joined'])
+ response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False)
if response_generated:
- output_generated = get_response(response_generated)
+ output_generated = get_response(response_generated, prompt_generated)
annotation_generated = process_annotation(output_generated)
+ print(json.dumps(annotation_generated))
df.at[index, 'generated_annotation'] = json.dumps(annotation_generated)
+ print("generated done")
return df
-
-def llama_prompt_no_schema(dialog: str) -> str:
- """
- Creates a prompt for the llama model without using a schema.
-
- Args:
- dialog (str): The dialogue to include in the prompt.
-
- Returns:
- str: The formatted prompt string.
- """
- return f"""
- <s>[INST] <<SYS>>
- You are a helpful annotator. You read the text carefully and annotate all valid feels in the schema.
-
- Make sure to only annotate attractions like museums, clubs or other tourist attractions as such.
-
- If you are not sure with an annotation you should annotate None instead.
- <</SYS>>
-
- {dialog} [/INST]
+
+### Prompt
+@prompt
+def llama_prompt_no_schema(dialog):
+ """
+ [INST] <<SYS>>
+ Your task is to fill out the schema with the information from the dialogue.
+ Choose the correct value based on the Dialogue content for each slot.
+ Make sure to only annotate attractions like museums, clubs or other tourist attractions as "attractions".
+
+ If you are not sure about an annotation you should annotate "None".
+ \n<</SYS>>\n\n
+
+ {{dialog}} [/INST]
"""
def main(args):
@@ -141,11 +168,12 @@ def main(args):
Args:
args: Command line arguments.
"""
+ schema = MultiWOZ.model_json_schema()
api_url = f"http://{args.host}:{args.port}/generate"
input_file = f"../../data/own_data/one-shot/dialogues/{args.input_name}"
df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/')
df = annotate_dialogues(df, api_url)
- output_file = f"../../data/own_data/one-shot/dialogues/{args.output_name}"
+ output_file = f"../../data/own_data/one-shot/annotations/{args.output_name}"
df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/')
if __name__ == "__main__":
@@ -155,4 +183,5 @@ if __name__ == "__main__":
parser.add_argument('--input_name', type=str, required=True, help='input dataframe name.')
parser.add_argument('--output_name', type=str, required=True, help='Output DataFrame name.')
args = parser.parse_args()
+
main(args)