Skip to content
Snippets Groups Projects
Commit 61510ca0 authored by Ke Li's avatar Ke Li Committed by Vijayaditya Peddinti
Browse files

nnet3/report : Modified directory specification options in generate_plots.py (#1368)

parent 04404176
Branches
No related tags found
No related merge requests found
......@@ -48,13 +48,16 @@ def get_args():
parser = argparse.ArgumentParser(
description="""Parses the training logs and generates a variety of
plots.
e.g.: steps/nnet3/report/generate_plots.py \\
e.g. (deprecated): steps/nnet3/report/generate_plots.py \\
--comparison-dir exp/nnet3/tdnn1 --comparison-dir exp/nnet3/tdnn2 \\
exp/nnet3/tdnn exp/nnet3/tdnn/report""")
exp/nnet3/tdnn exp/nnet3/tdnn/report
e.g. (current): steps/nnet3/report/generate_plots.py \\
exp/nnet3/tdnn exp/nnet3/tdnn1 exp/nnet3/tdnn2 exp/nnet3/tdnn/report""")
parser.add_argument("--comparison-dir", type=str, action='append',
help="other experiment directories for comparison. "
"These will only be used for plots, not tables")
"These will only be used for plots, not tables"
"Note: this option is deprecated.")
parser.add_argument("--start-iter", type=int,
help="Iteration from which plotting will start",
default=1)
......@@ -66,16 +69,19 @@ def get_args():
help="""List of space separated
<output-node>:<objective-type> entities,
one for each output node""")
parser.add_argument("exp_dir",
help="experiment directory, e.g. exp/nnet3/tdnn")
parser.add_argument("exp_dir", nargs='+',
help="the first dir is the experiment directory, "
"e.g. exp/nnet3/tdnn, the rest dirs (if exist) "
"are other experiment directories for comparison.")
parser.add_argument("output_dir",
help="experiment directory, "
"e.g. exp/nnet3/tdnn/report")
args = parser.parse_args()
if args.comparison_dir is not None and len(args.comparison_dir) > 6:
if (args.comparison_dir is not None and len(args.comparison_dir) > 6) or \
(args.exp_dir is not None and len(args.exp_dir) > 7):
raise Exception(
"""max 6 --comparison-dir options can be specified.
"""max 6 comparison directories can be specified.
If you want to compare with more comparison_dir, you would have to
carefully tune the plot_colors variable which specified colors used
for plotting.""")
......@@ -654,9 +660,18 @@ def main():
else:
output_nodes.append(('output', 'linear'))
generate_plots(args.exp_dir, args.output_dir, output_nodes,
if args.comparison_dir is not None:
generate_plots(args.exp_dir[0], args.output_dir, output_nodes,
comparison_dir=args.comparison_dir,
start_iter=args.start_iter)
else:
if len(args.exp_dir) == 1:
generate_plots(args.exp_dir[0], args.output_dir, output_nodes,
start_iter=args.start_iter)
if len(args.exp_dir) > 1:
generate_plots(args.exp_dir[0], args.output_dir, output_nodes,
comparison_dir=args.exp_dir[1:],
start_iter=args.start_iter)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment