diff --git a/src/seq2seq/main.py b/src/seq2seq/main.py index f5c24df137f4835b5bfeb91b124e030b580b1490..25f270330e918a4f7808a7455cca3fbb4e8babd9 100644 --- a/src/seq2seq/main.py +++ b/src/seq2seq/main.py @@ -9,7 +9,7 @@ transformers_logger = logging.getLogger("transformers") transformers_logger.setLevel(logging.WARNING) file_path = os.path.dirname(os.path.abspath(__file__)) -train_df = pd.read_csv(os.path.join(file_path, "../../dataset/input-target.csv")) +train_df = pd.read_csv(os.path.join(file_path, "../../dataset/input-target-256-dev.csv")) train_df.drop(train_df.columns[[0]], axis=1, inplace=True) train_df.dropna(subset=["input_text", "target_text"], inplace=True)