diff --git a/src/one-shot/generate_annotations.py b/src/one-shot/generate_annotations.py index 62e50376c3cbece2ba5829bbf89de17e1e3d8c1a..dfe673a162eee7f6f47010839c0fc9194e1dd07f 100644 --- a/src/one-shot/generate_annotations.py +++ b/src/one-shot/generate_annotations.py @@ -109,7 +109,7 @@ def process_annotation(annotation: str) -> dict: return {} -def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame: +def annotate_dialogues(df: pd.DataFrame, api_url: str, output_file: str) -> pd.DataFrame: """ Annotates the dialogues in the DataFrame using the specified API and schema. @@ -125,6 +125,8 @@ def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame: df['original_annotation'] = pd.Series(dtype='object') if 'generated_annotation' not in df.columns: df['generated_annotation'] = pd.Series(dtype='object') + + iteration_counter = 0 # Counter for iterations for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Annotating Dialogues"): prompt_original = llama_prompt_no_schema(row['Original Dialogue']) @@ -132,18 +134,31 @@ def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame: if response_original: output_original = get_response(response_original, prompt_original) # Pass prompt annotation_original = process_annotation(output_original) - print(json.dumps(annotation_original)) + #print(json.dumps(annotation_original)) df.at[index, 'original_annotation'] = json.dumps(annotation_original) - print("original done") + #print("original done") prompt_generated = llama_prompt_no_schema(row['utterances_joined']) response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False) if response_generated: output_generated = get_response(response_generated, prompt_generated) annotation_generated = process_annotation(output_generated) - print(json.dumps(annotation_generated)) + #print(json.dumps(annotation_generated)) df.at[index, 'generated_annotation'] = json.dumps(annotation_generated) - print("generated done") + #print("generated done") + + + # Increment the counter + iteration_counter += 1 + + # Check if the counter has reached 50 iterations + if iteration_counter % 50 == 0: + # Save the DataFrame + df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/') + print(f"Checkpoint: Saved after {iteration_counter} iterations.") + + # Save the DataFrame at the end of the process + df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/') return df ### Prompt @@ -172,9 +187,9 @@ def main(args): api_url = f"http://{args.host}:{args.port}/generate" input_file = f"../../data/own_data/one-shot/dialogues/{args.input_name}" df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/') - df = annotate_dialogues(df, api_url) output_file = f"../../data/own_data/one-shot/annotations/{args.output_name}" - df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/') + df = annotate_dialogues(df, api_url, output_file) + #df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/') if __name__ == "__main__": parser = argparse.ArgumentParser(description="Script to annotate dialogues.")