From 1c72cc6d5c36ca10db124d7b63e7bed6b6cba9ba Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Wed, 24 Mar 2021 10:07:23 +0100 Subject: [PATCH] Add endpoint for retrieving evaluation results --- joeynmt_server/models/evaluation_results.py | 6 +++ joeynmt_server/views/__init__.py | 1 + joeynmt_server/views/train.py | 28 +----------- joeynmt_server/views/validate.py | 48 +++++++++++++++++++++ 4 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 joeynmt_server/views/validate.py diff --git a/joeynmt_server/models/evaluation_results.py b/joeynmt_server/models/evaluation_results.py index c72fc09..a2f6391 100644 --- a/joeynmt_server/models/evaluation_results.py +++ b/joeynmt_server/models/evaluation_results.py @@ -13,3 +13,9 @@ class EvaluationResult(BaseModel): @property def accuracy(self): return self.correct / self.total + + + def json_ready_dict(self) -> dict: + d = super().json_ready_dict() + d['accuracy'] = self.accuracy + return d diff --git a/joeynmt_server/views/__init__.py b/joeynmt_server/views/__init__.py index 0d709ad..13a3149 100644 --- a/joeynmt_server/views/__init__.py +++ b/joeynmt_server/views/__init__.py @@ -3,3 +3,4 @@ from .feedback import (get_feedback, edit_feedback, query_feedback, save_feedback) from .translate import translate from .train import train +from .validate import validate, validations diff --git a/joeynmt_server/views/train.py b/joeynmt_server/views/train.py index b1b72d3..075a943 100644 --- a/joeynmt_server/views/train.py +++ b/joeynmt_server/views/train.py @@ -7,7 +7,7 @@ from flask import current_app, jsonify, request from joeynmt_server.app import create_app from joeynmt_server.models import Lock -from joeynmt_server.trainer import train_n_rounds, validate as validate_on_data +from joeynmt_server.trainer import train_n_rounds from joeynmt_server.utils.helper import get_utc_now @@ -49,29 +49,3 @@ def check_train_status(): still_training = False return jsonify({'still_training': still_training}) - - -@current_app.route('/validate', methods=['POST']) -def validate(): - data = request.json - config_basename = data.get('model') - - dataset = data.get('dataset', 'dev') - - def validate_in_thread(): - app = create_app() - with app.app_context(): - try: - validate_on_data(config_basename, dataset) - except: - logging.error('Training failed.') - logging.error(traceback.format_exc()) - - thread = threading.Thread(target=validate_in_thread) - thread.start() - - time.sleep(0.1) - response = {'validating': thread.is_alive()} - status = 200 if response['validating'] else 500 - return jsonify(response), status - diff --git a/joeynmt_server/views/validate.py b/joeynmt_server/views/validate.py new file mode 100644 index 0000000..e0c3a8f --- /dev/null +++ b/joeynmt_server/views/validate.py @@ -0,0 +1,48 @@ +import logging +import threading +import time +import traceback + +from flask import current_app, jsonify, request + +from joeynmt_server.app import create_app +from joeynmt_server.models import EvaluationResult +from joeynmt_server.trainer import validate as validate_on_data + + +@current_app.route('/validate', methods=['POST']) +def validate(): + data = request.json + config_basename = data.get('model') + + dataset = data.get('dataset', 'dev') + + def validate_in_thread(): + app = create_app() + with app.app_context(): + try: + validate_on_data(config_basename, dataset) + except: + logging.error('Training failed.') + logging.error(traceback.format_exc()) + + thread = threading.Thread(target=validate_in_thread) + thread.start() + + time.sleep(0.1) + response = {'validating': thread.is_alive()} + status = 200 if response['validating'] else 500 + return jsonify(response), status + + +@current_app.route('/validations', methods='GET') +def validations(): + label = request.args.get('label') + + if label: + results = EvaluationResult.query.filter_by(label=label).all() + else: + results = EvaluationResult.query.all() + + results = [result.json_ready_dict() for result in results] + return jsonify(results) -- GitLab