Skip to content
Snippets Groups Projects
Commit a8666ba1 authored by finn's avatar finn
Browse files

Add checkpoint saving

parent 43cb8fd2
No related branches found
No related tags found
No related merge requests found
......@@ -109,7 +109,7 @@ def process_annotation(annotation: str) -> dict:
return {}
def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame:
def annotate_dialogues(df: pd.DataFrame, api_url: str, output_file: str) -> pd.DataFrame:
"""
Annotates the dialogues in the DataFrame using the specified API and schema.
......@@ -125,6 +125,8 @@ def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame:
df['original_annotation'] = pd.Series(dtype='object')
if 'generated_annotation' not in df.columns:
df['generated_annotation'] = pd.Series(dtype='object')
iteration_counter = 0 # Counter for iterations
for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Annotating Dialogues"):
prompt_original = llama_prompt_no_schema(row['Original Dialogue'])
......@@ -132,18 +134,31 @@ def annotate_dialogues(df: pd.DataFrame, api_url: str) -> pd.DataFrame:
if response_original:
output_original = get_response(response_original, prompt_original) # Pass prompt
annotation_original = process_annotation(output_original)
print(json.dumps(annotation_original))
#print(json.dumps(annotation_original))
df.at[index, 'original_annotation'] = json.dumps(annotation_original)
print("original done")
#print("original done")
prompt_generated = llama_prompt_no_schema(row['utterances_joined'])
response_generated = post_http_request(prompt_generated, api_url, n=1, stream=False)
if response_generated:
output_generated = get_response(response_generated, prompt_generated)
annotation_generated = process_annotation(output_generated)
print(json.dumps(annotation_generated))
#print(json.dumps(annotation_generated))
df.at[index, 'generated_annotation'] = json.dumps(annotation_generated)
print("generated done")
#print("generated done")
# Increment the counter
iteration_counter += 1
# Check if the counter has reached 50 iterations
if iteration_counter % 50 == 0:
# Save the DataFrame
df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/')
print(f"Checkpoint: Saved after {iteration_counter} iterations.")
# Save the DataFrame at the end of the process
df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/')
return df
### Prompt
......@@ -172,9 +187,9 @@ def main(args):
api_url = f"http://{args.host}:{args.port}/generate"
input_file = f"../../data/own_data/one-shot/dialogues/{args.input_name}"
df = pd.read_csv(input_file, sep=',', quoting=csv.QUOTE_NONE, escapechar='/')
df = annotate_dialogues(df, api_url)
output_file = f"../../data/own_data/one-shot/annotations/{args.output_name}"
df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/')
df = annotate_dialogues(df, api_url, output_file)
#df.to_csv(output_file, sep=',', index=False, quoting=csv.QUOTE_NONE, escapechar='/')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script to annotate dialogues.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment