diff --git a/src/seq2seq/main.py b/src/seq2seq/main.py index 5b85fd63bb21be3304ea18cdeed495ed11302ec8..7e5d7672ed3b6e1490cf5f72a0d04ee940522ae0 100644 --- a/src/seq2seq/main.py +++ b/src/seq2seq/main.py @@ -22,22 +22,23 @@ logging.basicConfig(level=logging.INFO) transformers_logger = logging.getLogger("transformers") transformers_logger.setLevel(logging.WARNING) -train_df = get_dataset("../../dataset/twitter-dev-small/twitter_en.train.txt") +train_df = get_dataset("../../dataset/twitter-dev/twitter_en.train.txt") model_args = { "fp16": False, "overwrite_output_dir": True, "max_seq_length": 128, - "train_batch_size": 8, + "train_batch_size": 4, "eval_batch_size": 1, - "num_train_epochs": 64, + "num_train_epochs": 16, "max_length": 128, "num_beams": 3, "early_stopping": False, "learning_rate": 1e-4, "save_eval_checkpoints": False, "save_model_every_epoch": False, - "save_best_model": False, + "save_best_model": True, + "gradient_accumulation_steps": 4, } model = Seq2SeqModel(encoder_type="bert", encoder_name="bert-base-cased", decoder_name="bert-base-cased", args=model_args)