diff --git a/src/seq2seq/main.py b/src/seq2seq/main.py index 1fa24d5c3dd12e8fbc720df7693e0171d1d9da4d..8202af9941175cc5b33cde4dc6e18b97db0bf7b7 100644 --- a/src/seq2seq/main.py +++ b/src/seq2seq/main.py @@ -49,6 +49,7 @@ model_args = { "num_beams": 3, "early_stopping": False, "learning_rate": 1e-4, + "early_stopping_metric": True, } model = Seq2SeqModel("bert", "bert-base-cased", "bert-base-cased", args=model_args)