diff --git a/src/one-shot/generate_dialogues.py b/src/one-shot/generate_dialogues.py new file mode 100644 index 0000000000000000000000000000000000000000..4b31dbe6e58488902c8e55ae31e02e8ea939d47f --- /dev/null +++ b/src/one-shot/generate_dialogues.py @@ -0,0 +1,235 @@ +""" +This script is used to generate dialogues based on domain knowledge and save them in a structured format. + +The script takes the number of dialogues to generate as input. It checks for the existence of the corresponding domain +knowledge file. If the file doesn't exist, it generates it using a separate domain knowledge extraction script. +The dialogues are then processed and saved in a CSV format. +""" + +import pandas as pd +import argparse +import json +import re +import csv +import requests +import os +from tqdm import tqdm + + +def clear_line(n: int = 1) -> None: + """ + Clears the specified number of lines from the console. + + Args: + n (int): Number of lines to clear. + """ + LINE_UP = '\033[1A' + LINE_CLEAR = '\x1b[2K' + for _ in range(n): + print(LINE_UP, end=LINE_CLEAR, flush=True) + + +def clean_dialogue(dialogue_str: str) -> list[str]: + """ + Cleans a dialogue string by parsing JSON and removing extra quotes. + + Args: + dialogue_str (str): A string representing the dialogue in JSON format. + + Returns: + List[str]: A list of cleaned dialogue utterances. + """ + dialogue_list = json.loads(dialogue_str) + cleaned_dialogue = [utterance.replace('""', '"').strip('"') for sublist in dialogue_list for utterance in (sublist if isinstance(sublist, list) else [sublist])] + return cleaned_dialogue + + +def process_dialogues(dialogues: list[str]) -> list[dict]: + """ + Processes a list of dialogues by separating and joining utterances. + + Args: + dialogues (List[str]): A list of dialogue strings. + + Returns: + List[dict]: A list of dictionaries containing processed dialogues. + """ + processed_dialogues = [] + for dialogue in dialogues: + lines = dialogue.split('\n') + flat_dialogue = [line.replace('""', '"') for line in lines if re.match(r'(?i)^(User|System):', line)] + + dialogue_dict = { + "utterances_separated": flat_dialogue, + "utterances_joined": ' '.join(flat_dialogue) + } + processed_dialogues.append(dialogue_dict) + + return processed_dialogues + + +def post_http_request(prompt: str, api_url: str, n: int = 1, stream: bool = False) -> requests.Response: + """ + Sends a POST HTTP request to the specified URL with given parameters. + + Args: + prompt (str): The prompt to send in the request. + api_url (str): The URL to which the request is sent. + n (int): Number of responses to request. + stream (bool): Flag to indicate streaming of response. + + Returns: + requests.Response: The HTTP response received. + """ + headers = {"User-Agent": "Test Client"} + payload = { + "prompt": prompt, + "n": n, + "use_beam_search": False, + "temperature": 0.0, + "max_tokens": 4000, + "stream": stream, + } + try: + response = requests.post(api_url, headers=headers, json=payload, stream=True) + response.raise_for_status() + return response + except requests.RequestException as e: + print(f"HTTP Request failed: {e}") + return None + + +def get_streaming_response(response: requests.Response): + """ + Yields the streaming response from an HTTP request as a list of strings. + + Args: + response (requests.Response): The HTTP response to process. + + Returns: + Iterable[List[str]]: An iterable of lists of strings from the response. + """ + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + + +def get_response(response: requests.Response) -> list[str]: + """ + Retrieves the response content from an HTTP response as a list of strings. + + Args: + response (requests.Response): The HTTP response to process. + + Returns: + List[str]: A list of strings from the response content. + """ + if not response.content: + print("Error: Received empty response from the API.") + return [] + + try: + data = json.loads(response.content) + output = data["text"] + return output + except json.JSONDecodeError as e: + print(f"JSON Decode Error: {e}") + print("Received response content:", response.content) + return [] + + +def load_domain_knowledge(file_path: str) -> dict: + """ + Loads domain knowledge from a JSON file. + + Args: + file_path (str): Path to the JSON file containing domain knowledge. + + Returns: + dict: The domain knowledge loaded from the file. + """ + with open(file_path, 'r') as file: + return json.load(file) + + +def save_dialogues_to_dataframe(dialogues: list[dict], file_path: str) -> None: + """ + Saves dialogues to a DataFrame and then to a CSV file. + + Args: + dialogues (List[dict]): A list of dialogues to save. + file_path (str): Path to the CSV file where dialogues will be saved. + """ + data = [] + for dialogue in dialogues: + dialogue_id = dialogue["Dialogue ID"] + original_dialogue = " ".join(dialogue["Complete Dialogue"]) + prompt = dialogue["prompt"].replace('"', '') + domain_knowledge = json.dumps(dialogue["domain_knowledge"]) + processed_dialogue = dialogue["processed_dialogue"] + + data.append({ + "Dialogue ID": dialogue_id, + "Original Dialogue": original_dialogue, + "Prompt": prompt, + "Domain Knowledge": domain_knowledge, + "utterances_separated": processed_dialogue["utterances_separated"], + "utterances_joined": processed_dialogue["utterances_joined"] + }) + + df = pd.DataFrame(data) + df.to_csv(file_path, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar="/") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num_dialogues", type=int, required=True, help="Number of dialogues to generate (Domain Knowledge should be available for this number of dialogues)") + parser.add_argument("--stream", action="store_true") + args = parser.parse_args() + + domain_file = f'../../data/own_data/one-shot/domain_knowledge/domain_knowledge_{args.num_dialogues}.json' + output_file = f'../../data/own_data/one-shot/dialogues/dialogues/output_dialogues_{args.num_dialogues}.csv' + + # Check if domain knowledge file exists + if not os.path.exists(domain_file): + print("File does not exist. Please run the domain knowledge extraction script first.") + + domain_knowledge = load_domain_knowledge(domain_file) + + # Set up the API endpoint + api_url = f"http://{args.host}:{args.port}/generate-dialogue" + + dialogues = [] + for domain in tqdm(domain_knowledge, desc="Generating dialogues"): + prompt = f""" [INST] <<SYS>> + Generate a complete dialogue bewteen a user and a system. The dialogue should consist solely of exchanges between the user and the system, with no action descriptions or external narrative. + The user will request specific services typical of an online service system. The system must ask the user for necessary details and provide the requested service, ensuring that all responses are based on the provided Domain Knowledge. However, do not explicitly mention 'Domain Knowledge' in the dialogue. + Each user message must be followed by a system response. The system's tone should be polite and courteous, ending the conversation with a goodbye. Both participants should remain professional and not overly enthusiastic. + It's crucial that the system's statements are precise and complete, avoiding placeholders like '[insert information here]'. Instead, provide direct and specific information based on the given Domain Knowledge. + Create a complete dialogue between the System and the User, focusing exclusively on their interaction, nothing else should be generated. + Domain Knowledge: {json.dumps(domain['Service Information'])} + \n<</SYS>>\n\n [/INST]""" + response = post_http_request(prompt, api_url, 1, args.stream) + output = get_response(response) if not args.stream else list(get_streaming_response(response)) + + # Remove prompt from response if it's included + output_without_prompt = [line.replace(prompt, '') for line in output] + processed_output = process_dialogues(output_without_prompt) + + dialogues.append({ + "prompt": prompt, + "domain_knowledge": domain['Service Information'], + "processed_dialogue": processed_output[0], + "Dialogue ID": domain['Dialogue ID'], + "Complete Dialogue": domain['Complete Dialogue'] + }) + + # Call the new save function + save_dialogues_to_dataframe(dialogues, output_file) + +if __name__ == "__main__": + main()