diff --git a/.gitignore b/.gitignore index d5ab9eea74db40c466e2ef38e8f01d25c559556b..d0399047e5dfbbe0badbe83f76f544563a75dbc1 100644 --- a/.gitignore +++ b/.gitignore @@ -149,4 +149,8 @@ dmypy.json # End of https://www.gitignore.io/api/linux,python,jupyternotebooks,visualstudiocode -!src/lib \ No newline at end of file +!src/lib + +cache_dir +runs +outputs \ No newline at end of file diff --git a/dataset/twitter-dev/twitter_en.txt b/dataset/twitter-dev/twitter_en.txt new file mode 100644 index 0000000000000000000000000000000000000000..ddc784d4de3514eb3fef384b660bd8f555c07df3 --- /dev/null +++ b/dataset/twitter-dev/twitter_en.txt @@ -0,0 +1,30 @@ +yeah i'm preparing myself to drop a lot on this man, but definitely need something reliable +yeah dude i would definitely consider a daniel defence super reliable and they are just bad ass +i'm about to meet my mans ex friend with benefit, tune in next week to see if i have to put hands on +i'm dead not looking forward to this +shouldn't the supporter's natural answer to 's hashtag be ? +or just insert itl to make . +you want to turn twitter followers into blog readers. +how do you do this? +besides if trump say his condolences it won't sound genuine, ex: (dwayne wade cousin) it will sound all political and petty +yea you right. but we do live in a world where republicans will harass obama about a birth certificate but won't say +jill stein & her fan club can now officially go to hell -just vote trump & be done with it +i love green but 3d parties might elect trump like nader elected bush in 2000 with gore there would not have been iraq war no wmd +well, i finally finished watching all the episodes of breaking the magician's code: magic's biggest secrets finally revealed on netflix. +now you are a walking spoiler... +ask about this. it's the whole reason he built . i've been hoping to see you get into it and start teaching it 😊 +i’m waiting for to wake up :p +then again, some sf hipsters would rather get crushed by a rack of fixie bikes in an earthquake than have to move out. +in seriousness, if the next 8.0 happens in my lifetime, i really hope i'm not walking around downtown sf when it hits. +does the dog also have polio +i think he is an old pupper ian he just wants some pizza tbh +accelerate operation in . best response for any action pakistan and army. +while accelerating corruption schemes before next topi wala gets the charge +bre takin shots 🔫🔫🔫 +just standing up for my friends yo ❤️ i figured i'd take it for them +núñez is a tough customer. steals second, takes third on e2. +with a bad back andy! miss duffy but love eduardo too!!! +hillary had nothing to do with sid blumenthal, other than her slush fund paying him $10k/month while using him as a shadow intel source. +such dishonesty is unimaginable! +it’s a genuine evil and you will be first against the wall in the nuremberg trials that result from this movie. +what was vibe at p&i? at *packed* premiere w/ kids it was like muppet theatre, place went nuts applauding. \ No newline at end of file diff --git a/dataset/twitter/twitter_en.txt b/dataset/twitter/twitter_en.txt new file mode 100644 index 0000000000000000000000000000000000000000..c9cb607619ebc14d33ce50e19907b86820e580f4 Binary files /dev/null and b/dataset/twitter/twitter_en.txt differ diff --git a/notebooks/examples-bert-pre-trained.ipynb b/notebooks/examples-bert-pre-trained.ipynb deleted file mode 100644 index ee52ab8c28fc5a672a5e00eacbf98a56cc7ec87c..0000000000000000000000000000000000000000 --- a/notebooks/examples-bert-pre-trained.ipynb +++ /dev/null @@ -1,1280 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## BERT" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "I am using the pre-trained model provided by [https://github.com/huggingface/transformers](https://github.com/huggingface/transformers)." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T12:54:58.533369Z", - "start_time": "2020-06-01T12:54:55.586790Z" - } - }, - "outputs": [], - "source": [ - "from transformers import BertTokenizer, BertForQuestionAnswering\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T12:54:58.638538Z", - "start_time": "2020-06-01T12:54:58.536031Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='cuda')" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "I think we can use this dataset: [https://www.kaggle.com/stanfordu/stanford-question-answering-dataset](https://www.kaggle.com/stanfordu/stanford-question-answering-dataset)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The question and context below is a sample extracted of the Stanford Question Answering Dataset (SQuAD). In Section 1.2, I show an example without pre-training the model in the SQuAD 1.1 dataset. Section 1.3 shows an example with the model pre-trained in the SQuAD 1.1 dataset." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T12:55:52.111755Z", - "start_time": "2020-06-01T12:55:52.105342Z" - } - }, - "outputs": [], - "source": [ - "question, context = (\n", - " \"Where did Super Bowl 50 take place?\",\n", - " 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.',\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Example 1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Load word tokenizer." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T12:55:54.141311Z", - "start_time": "2020-06-01T12:55:53.427832Z" - } - }, - "outputs": [], - "source": [ - "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Encode the input." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T13:09:03.661043Z", - "start_time": "2020-06-01T13:09:03.649446Z" - } - }, - "outputs": [], - "source": [ - "encoding = tokenizer.encode_plus(question, context)\n", - "input_ids, token_type_ids, attention_mask = encoding[\"input_ids\"], encoding[\"token_type_ids\"], encoding[\"attention_mask\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Load the pre-trained model." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T01:44:28.488195Z", - "start_time": "2020-06-01T01:44:16.895311Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "BertForQuestionAnswering(\n", - " (bert): BertModel(\n", - " (embeddings): BertEmbeddings(\n", - " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", - " (position_embeddings): Embedding(512, 768)\n", - " (token_type_embeddings): Embedding(2, 768)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (encoder): BertEncoder(\n", - " (layer): ModuleList(\n", - " (0): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (1): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (2): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (3): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (4): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (5): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (6): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (7): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (8): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (9): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (10): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (11): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (pooler): BertPooler(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (activation): Tanh()\n", - " )\n", - " )\n", - " (qa_outputs): Linear(in_features=768, out_features=2, bias=True)\n", - ")" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = BertForQuestionAnswering.from_pretrained(\"bert-base-uncased\").to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T01:44:31.391516Z", - "start_time": "2020-06-01T01:44:30.598653Z" - } - }, - "outputs": [], - "source": [ - "start_scores, end_scores = model(\n", - " torch.tensor([input_ids]).to(device),\n", - " token_type_ids=torch.tensor([token_type_ids]).to(device),\n", - " attention_mask=torch.tensor([attention_mask]).to(device),\n", - ")\n", - "all_tokens = tokenizer.convert_ids_to_tokens(input_ids)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Get the output." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T01:44:34.171591Z", - "start_time": "2020-06-01T01:44:34.159855Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "''" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "answer = \" \".join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores) + 1])\n", - "answer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Example 2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Load word tokenizer." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T01:44:43.675118Z", - "start_time": "2020-06-01T01:44:42.936909Z" - } - }, - "outputs": [], - "source": [ - "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Encode the input." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T13:07:35.342111Z", - "start_time": "2020-06-01T13:07:35.325161Z" - } - }, - "outputs": [], - "source": [ - "encoding = tokenizer.encode_plus(question, context)\n", - "input_ids, token_type_ids, attention_mask = encoding[\"input_ids\"], encoding[\"token_type_ids\"], encoding[\"attention_mask\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Load the pre-trained model." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T13:08:05.418557Z", - "start_time": "2020-06-01T13:07:38.020314Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "BertForQuestionAnswering(\n", - " (bert): BertModel(\n", - " (embeddings): BertEmbeddings(\n", - " (word_embeddings): Embedding(30522, 1024, padding_idx=0)\n", - " (position_embeddings): Embedding(512, 1024)\n", - " (token_type_embeddings): Embedding(2, 1024)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (encoder): BertEncoder(\n", - " (layer): ModuleList(\n", - " (0): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (1): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (2): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (3): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (4): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (5): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (6): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (7): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (8): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (9): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (10): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (11): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (12): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (13): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (14): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (15): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (16): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (17): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (18): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (19): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (20): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (21): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (22): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (23): BertLayer(\n", - " (attention): BertAttention(\n", - " (self): BertSelfAttention(\n", - " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (output): BertSelfOutput(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): BertIntermediate(\n", - " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", - " )\n", - " (output): BertOutput(\n", - " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (pooler): BertPooler(\n", - " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (activation): Tanh()\n", - " )\n", - " )\n", - " (qa_outputs): Linear(in_features=1024, out_features=2, bias=True)\n", - ")" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = BertForQuestionAnswering.from_pretrained(\"bert-large-uncased-whole-word-masking-finetuned-squad\")\n", - "model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T13:08:41.447756Z", - "start_time": "2020-06-01T13:08:41.361791Z" - } - }, - "outputs": [], - "source": [ - "start_scores, end_scores = model(\n", - " torch.tensor([input_ids]).to(device),\n", - " token_type_ids=torch.tensor([token_type_ids]).to(device),\n", - " attention_mask=torch.tensor([attention_mask]).to(device),\n", - ")\n", - "all_tokens = tokenizer.convert_ids_to_tokens(input_ids)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Get the output." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "ExecuteTime": { - "end_time": "2020-06-01T13:08:43.336326Z", - "start_time": "2020-06-01T13:08:43.326257Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\"levi ' s stadium in the san francisco bay area at santa clara , california\"" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "answer = \" \".join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores) + 1])\n", - "answer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Conclusions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "I think we must use the model from Section 1.2 and fine-tune it in Stanford Question Answering Dataset (SQuAD)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.3" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/examples-bert-qa-pre-trained.ipynb b/notebooks/examples-bert-qa-pre-trained.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..36e712c07dff83c4e32488a56163c4a5dc784284 --- /dev/null +++ b/notebooks/examples-bert-qa-pre-trained.ipynb @@ -0,0 +1,697 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## BERT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I am using the pre-trained model provided by [https://github.com/huggingface/transformers](https://github.com/huggingface/transformers)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:12:10.854131Z", + "start_time": "2020-06-02T13:12:07.896720Z" + } + }, + "outputs": [], + "source": [ + "from transformers import BertTokenizer, BertForQuestionAnswering\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:12:11.001557Z", + "start_time": "2020-06-02T13:12:10.856167Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "device" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I think we can use this dataset: [https://www.kaggle.com/stanfordu/stanford-question-answering-dataset](https://www.kaggle.com/stanfordu/stanford-question-answering-dataset)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The question and context below is a sample extracted of the Stanford Question Answering Dataset (SQuAD). In Section 1.2, I show an example without pre-training the model in the SQuAD 1.1 dataset. Section 1.3 shows an example with the model pre-trained in the SQuAD 1.1 dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:12:13.576423Z", + "start_time": "2020-06-02T13:12:13.570731Z" + } + }, + "outputs": [], + "source": [ + "question, context = (\n", + " \"Where did Super Bowl 50 take place?\",\n", + " '',\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load word tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:12:17.354161Z", + "start_time": "2020-06-02T13:12:16.433649Z" + } + }, + "outputs": [], + "source": [ + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Encode the input." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-01T13:09:03.661043Z", + "start_time": "2020-06-01T13:09:03.649446Z" + } + }, + "outputs": [], + "source": [ + "encoding = tokenizer.encode_plus(question, context)\n", + "input_ids, token_type_ids, attention_mask = encoding[\"input_ids\"], encoding[\"token_type_ids\"], encoding[\"attention_mask\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the pre-trained model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-01T01:44:28.488195Z", + "start_time": "2020-06-01T01:44:16.895311Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertForQuestionAnswering(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (qa_outputs): Linear(in_features=768, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = BertForQuestionAnswering.from_pretrained(\"bert-base-uncased\").to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-01T01:44:31.391516Z", + "start_time": "2020-06-01T01:44:30.598653Z" + } + }, + "outputs": [], + "source": [ + "start_scores, end_scores = model(\n", + " torch.tensor([input_ids]).to(device),\n", + " token_type_ids=torch.tensor([token_type_ids]).to(device),\n", + " attention_mask=torch.tensor([attention_mask]).to(device),\n", + ")\n", + "all_tokens = tokenizer.convert_ids_to_tokens(input_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get the output." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-01T01:44:34.171591Z", + "start_time": "2020-06-01T01:44:34.159855Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "''" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "answer = \" \".join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores) + 1])\n", + "answer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load word tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:12:23.194797Z", + "start_time": "2020-06-02T13:12:22.318515Z" + } + }, + "outputs": [], + "source": [ + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Encode the input." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:13:19.402933Z", + "start_time": "2020-06-02T13:13:19.395224Z" + } + }, + "outputs": [], + "source": [ + "encoding = tokenizer.encode_plus(question)\n", + "input_ids, token_type_ids, attention_mask = encoding[\"input_ids\"], encoding[\"token_type_ids\"], encoding[\"attention_mask\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the pre-trained model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:13:33.053603Z", + "start_time": "2020-06-02T13:13:20.897917Z" + } + }, + "outputs": [], + "source": [ + "model = BertForQuestionAnswering.from_pretrained(\"bert-large-uncased-whole-word-masking-finetuned-squad\").to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:13:33.104373Z", + "start_time": "2020-06-02T13:13:33.055736Z" + } + }, + "outputs": [], + "source": [ + "start_scores, end_scores = model(\n", + " torch.tensor([input_ids]).to(device),\n", + " token_type_ids=torch.tensor([token_type_ids]).to(device),\n", + " attention_mask=torch.tensor([attention_mask]).to(device),\n", + ")\n", + "all_tokens = tokenizer.convert_ids_to_tokens(input_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get the output." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:13:33.114156Z", + "start_time": "2020-06-02T13:13:33.106611Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "''" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "answer = \" \".join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores) + 1])\n", + "answer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I think we must use the model from Section 1.2 and fine-tune it in Stanford Question Answering Dataset (SQuAD)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/examples-bert2bert-pre-trained.ipynb b/notebooks/examples-bert2bert-pre-trained.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..46607f0da07f8bf82ae118aa32bbc52d2abb5f73 --- /dev/null +++ b/notebooks/examples-bert2bert-pre-trained.ipynb @@ -0,0 +1,948 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Seq2seq with pre-trained BERT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook shows some example of the BERT model used for seq2seq tasks." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:23:46.254478Z", + "start_time": "2020-06-02T13:23:45.068527Z" + } + }, + "outputs": [], + "source": [ + "from transformers import EncoderDecoderModel, BertTokenizer\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:23:59.789187Z", + "start_time": "2020-06-02T13:23:58.948040Z" + } + }, + "outputs": [], + "source": [ + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the encoder-decoder model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:26:23.639591Z", + "start_time": "2020-06-02T13:26:12.052485Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "EncoderDecoderModel(\n", + " (encoder): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (decoder): BertForMaskedLM(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (cls): BertOnlyMLMHead(\n", + " (predictions): BertLMPredictionHead(\n", + " (transform): BertPredictionHeadTransform(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " )\n", + " (decoder): Linear(in_features=768, out_features=30522, bias=True)\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"bert-base-uncased\", \"bert-base-uncased\")\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T13:24:56.549156Z", + "start_time": "2020-06-02T13:24:56.540737Z" + } + }, + "outputs": [], + "source": [ + "input_ids = torch.tensor(tokenizer.encode(\"Hello, my dog is cute\", add_special_tokens=True)).unsqueeze(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T14:16:04.070633Z", + "start_time": "2020-06-02T14:16:03.990291Z" + } + }, + "outputs": [], + "source": [ + "outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-02T14:16:07.693806Z", + "start_time": "2020-06-02T14:16:07.653070Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[[ -6.3390, -6.3665, -6.4600, ..., -5.5355, -4.1787, -5.8384],\n", + " [ -6.0605, -6.0980, -6.1492, ..., -5.0190, -3.6619, -5.6481],\n", + " [ -6.2835, -6.1857, -6.2198, ..., -5.8243, -3.9650, -4.2239],\n", + " ...,\n", + " [ -8.6994, -8.6061, -8.6930, ..., -8.4026, -7.0615, -6.1120],\n", + " [ -7.7221, -7.7373, -7.7094, ..., -7.6440, -6.1568, -5.5106],\n", + " [-13.5756, -13.0523, -12.9125, ..., -10.4893, -11.9085, -9.3556]]],\n", + " grad_fn=),\n", + " tensor([[[-0.1144, 0.1937, 0.1250, ..., -0.3827, 0.2107, 0.5407],\n", + " [ 0.5308, 0.3207, 0.3665, ..., -0.0036, 0.7579, 0.0388],\n", + " [-0.4877, 0.8849, 0.4256, ..., -0.6976, 0.4458, 0.1231],\n", + " ...,\n", + " [-0.7003, -0.1815, 0.3297, ..., -0.4838, 0.0680, 0.8901],\n", + " [-1.0355, -0.2567, -0.0317, ..., 0.3197, 0.3999, 0.1795],\n", + " [ 0.6080, 0.2610, -0.3131, ..., 0.0311, -0.6283, -0.1994]]],\n", + " grad_fn=),\n", + " tensor([[-7.1946e-01, -2.1445e-01, -2.9576e-01, 3.6603e-01, 2.7968e-01,\n", + " 2.2184e-02, 5.7299e-01, 6.2331e-02, 5.9587e-02, -9.9965e-01,\n", + " 5.0147e-02, 4.4756e-01, 9.7612e-01, 3.3988e-02, 8.4494e-01,\n", + " -3.6905e-01, 9.8650e-02, -3.7169e-01, 1.7371e-01, 1.1515e-01,\n", + " 4.4133e-01, 9.9525e-01, 3.7221e-01, 8.2881e-02, 2.1402e-01,\n", + " 6.8965e-01, -6.1042e-01, 8.7136e-01, 9.4158e-01, 5.7372e-01,\n", + " -3.2187e-01, 8.6671e-03, -9.8611e-01, -2.0542e-02, -4.3755e-01,\n", + " -9.8012e-01, 1.1142e-01, -6.7587e-01, 1.3499e-01, 3.1130e-01,\n", + " -8.2997e-01, 1.9006e-01, 9.9896e-01, -3.1798e-01, 2.1517e-02,\n", + " -1.6531e-01, -9.9943e-01, 1.0173e-01, -8.1811e-01, 3.3118e-02,\n", + " 3.6740e-01, -7.3231e-02, -1.4261e-01, 1.8907e-01, 2.6119e-01,\n", + " 4.1582e-01, -2.4427e-01, -5.9846e-02, -7.3492e-02, -3.4202e-01,\n", + " -5.8001e-01, 2.8331e-01, -5.0513e-01, -8.1967e-01, 1.9813e-01,\n", + " 1.9108e-01, 3.7011e-02, -1.1327e-01, 1.3472e-01, -2.1614e-01,\n", + " 6.3494e-01, 2.4869e-02, 3.8287e-01, -8.1779e-01, -2.4874e-01,\n", + " 8.4982e-02, -5.2998e-01, 1.0000e+00, -5.2154e-02, -9.7052e-01,\n", + " 3.9848e-01, 2.1359e-02, 3.9035e-01, 3.5588e-01, -1.7881e-01,\n", + " -9.9997e-01, 2.6939e-01, -3.8057e-02, -9.8657e-01, 6.9322e-02,\n", + " 3.9138e-01, -2.1884e-02, -9.6332e-02, 3.8545e-01, -3.4136e-01,\n", + " -8.0362e-02, -3.2022e-02, -3.6328e-01, -7.8129e-02, 1.9192e-02,\n", + " -1.3429e-01, -1.6013e-02, -5.2640e-02, -2.8006e-01, 9.3611e-02,\n", + " -2.2885e-01, -1.2305e-01, -1.1002e-01, -3.2808e-01, 4.0356e-01,\n", + " 2.8048e-01, -2.0102e-01, 2.7685e-01, -9.4023e-01, 4.1756e-01,\n", + " -1.5473e-01, -9.7553e-01, -4.3003e-01, -9.8546e-01, 5.9158e-01,\n", + " 3.7344e-02, -1.9320e-01, 9.1691e-01, 3.6012e-01, 1.4505e-01,\n", + " 1.5398e-01, -1.0656e-02, -1.0000e+00, -3.1573e-01, -3.1037e-01,\n", + " 1.6523e-01, -8.0330e-02, -9.6650e-01, -9.4546e-01, 3.6145e-01,\n", + " 9.0138e-01, -7.2696e-02, 9.9774e-01, 3.7289e-02, 9.3599e-01,\n", + " 2.5317e-01, -2.0184e-01, 2.9532e-02, -2.3162e-01, 3.4632e-01,\n", + " -1.0763e-01, -2.6565e-01, 1.0874e-01, 1.2985e-01, 2.1134e-02,\n", + " -9.6283e-02, -7.6357e-02, -6.5149e-02, -8.9277e-01, -2.3465e-01,\n", + " 9.1176e-01, 7.0430e-02, -2.1429e-01, 3.8197e-01, 3.5892e-02,\n", + " -1.6971e-01, 7.0654e-01, 2.4045e-01, 1.5014e-01, -1.9477e-02,\n", + " 2.1369e-01, -1.7977e-01, 3.5112e-01, -6.0260e-01, 4.1683e-01,\n", + " 1.8090e-01, -3.2496e-02, -3.0137e-01, -9.7103e-01, -1.3917e-01,\n", + " 3.5130e-01, 9.8326e-01, 5.2702e-01, 4.8811e-02, 1.3990e-02,\n", + " -6.7964e-02, 2.9717e-01, -9.4136e-01, 9.7219e-01, -2.4773e-02,\n", + " 1.5224e-01, -1.8241e-01, 5.5584e-02, -7.7306e-01, -9.9000e-02,\n", + " 4.7058e-01, -1.7022e-01, -7.7803e-01, 5.2834e-02, -3.7679e-01,\n", + " -4.1296e-02, -4.9612e-01, 1.4171e-01, -1.1803e-01, -1.8995e-01,\n", + " 5.0384e-02, 9.0623e-01, 7.8828e-01, 5.2288e-01, -3.5274e-01,\n", + " 2.8563e-01, -8.1494e-01, -1.9622e-01, -9.2975e-02, 5.9311e-02,\n", + " 3.1902e-02, 9.8860e-01, -3.9452e-01, 1.1867e-01, -8.6977e-01,\n", + " -9.7789e-01, -1.4859e-01, -7.7064e-01, -4.0616e-03, -4.1152e-01,\n", + " 3.2578e-01, 1.8777e-01, -2.4501e-01, 2.6668e-01, -7.9329e-01,\n", + " -4.8133e-01, 9.3245e-02, -1.7010e-01, 2.7043e-01, -3.5880e-02,\n", + " 7.7973e-01, 4.6696e-01, -3.4636e-01, 5.5237e-02, 9.0312e-01,\n", + " -2.4115e-01, -6.4200e-01, 4.1441e-01, -9.7797e-02, 6.2983e-01,\n", + " -4.1787e-01, 9.4069e-01, 4.9285e-01, 3.6058e-01, -8.7901e-01,\n", + " -2.6726e-01, -5.4679e-01, 9.4008e-04, -1.0502e-02, -4.6837e-01,\n", + " 3.1116e-01, 3.6999e-01, 1.3306e-01, 6.4092e-01, -3.5630e-01,\n", + " 8.8549e-01, -8.9036e-01, -9.3865e-01, -8.1215e-01, 2.7362e-01,\n", + " -9.8566e-01, 4.0362e-01, 2.1223e-01, -1.4316e-01, -2.4553e-01,\n", + " -2.1144e-01, -9.4728e-01, 5.0806e-01, -9.6622e-02, 8.5571e-01,\n", + " -1.0133e-01, -6.7768e-01, -2.8500e-01, -8.9905e-01, -3.3577e-01,\n", + " 8.9155e-02, 3.2600e-01, -2.6467e-01, -9.2032e-01, 3.4629e-01,\n", + " 3.3430e-01, 2.1397e-01, 3.0630e-02, 9.3878e-01, 9.9986e-01,\n", + " 9.6385e-01, 8.3159e-01, 6.2250e-01, -9.8055e-01, -7.3623e-01,\n", + " 9.9986e-01, -7.8395e-01, -9.9998e-01, -8.7800e-01, -5.0893e-01,\n", + " 2.3399e-02, -1.0000e+00, -6.1938e-02, 1.9563e-01, -9.0552e-01,\n", + " -1.4008e-01, 9.5264e-01, 7.9837e-01, -1.0000e+00, 7.6343e-01,\n", + " 8.3670e-01, -4.5859e-01, 5.4410e-01, -2.4073e-01, 9.6085e-01,\n", + " 1.9164e-01, 3.2135e-01, -1.3064e-02, 2.4534e-01, -5.3001e-01,\n", + " -5.9538e-01, 3.7464e-01, -2.1189e-01, 8.8024e-01, 1.9647e-02,\n", + " -3.8349e-01, -8.4779e-01, 1.4676e-02, -2.8375e-02, -4.4313e-01,\n", + " -9.4966e-01, -6.5704e-02, -7.2328e-02, 6.5967e-01, -1.1504e-01,\n", + " 2.1876e-01, -5.5254e-01, 9.2218e-02, -5.0583e-01, -5.2826e-02,\n", + " 5.1425e-01, -8.9533e-01, -1.2744e-01, 9.7846e-02, -6.0145e-01,\n", + " -3.1651e-02, -9.5186e-01, 9.4685e-01, -2.2341e-01, 1.8390e-01,\n", + " 1.0000e+00, 1.1755e-01, -7.0390e-01, 3.2502e-01, -1.0898e-02,\n", + " -1.8308e-01, 9.9999e-01, 5.8376e-01, -9.7387e-01, -3.3783e-01,\n", + " 2.9640e-01, -2.7002e-01, -2.2243e-01, 9.9711e-01, 1.4422e-02,\n", + " 7.8269e-02, 3.8660e-01, 9.7787e-01, -9.8501e-01, 8.7459e-01,\n", + " -7.2276e-01, -9.5249e-01, 9.4567e-01, 9.1005e-01, -5.0722e-01,\n", + " -4.9026e-01, -1.2517e-01, -3.9076e-02, 8.8128e-02, -8.2481e-01,\n", + " 3.8301e-01, 1.8045e-01, 5.4796e-02, 8.0041e-01, -3.3501e-01,\n", + " -3.9115e-01, 1.4233e-01, -9.0140e-02, 3.4585e-01, 4.4044e-01,\n", + " 3.1044e-01, -1.3280e-01, -1.3614e-01, -3.0303e-01, -4.8794e-01,\n", + " -9.4950e-01, 1.0887e-01, 1.0000e+00, 6.0752e-02, 8.3374e-02,\n", + " -3.1297e-03, 8.5578e-02, -3.1288e-01, 2.6283e-01, 2.6870e-01,\n", + " -1.4267e-01, -7.4000e-01, 2.2856e-01, -7.9441e-01, -9.8812e-01,\n", + " 4.3592e-01, 7.7229e-02, -3.8084e-02, 9.9490e-01, 3.2615e-01,\n", + " 6.7989e-02, 8.2887e-02, 4.7390e-01, -2.1855e-01, 3.9278e-01,\n", + " 3.7664e-02, 9.6440e-01, -1.8374e-01, 3.9259e-01, 4.3319e-01,\n", + " -1.8618e-01, -2.1584e-01, -4.9610e-01, -9.7025e-02, -8.8006e-01,\n", + " 2.4995e-01, -9.3940e-01, 9.3827e-01, 3.2001e-01, 1.1919e-01,\n", + " 7.3959e-02, 3.1272e-02, 1.0000e+00, -7.5631e-01, 3.5396e-01,\n", + " 5.3290e-01, 3.2035e-01, -9.7538e-01, -4.7482e-01, -2.3322e-01,\n", + " 3.5377e-02, -4.6059e-02, -1.2863e-01, 8.3798e-02, -9.5139e-01,\n", + " 3.4661e-02, 4.5213e-03, -8.8296e-01, -9.8300e-01, 1.6468e-01,\n", + " 3.3595e-01, -1.0217e-01, -7.0275e-01, -4.3307e-01, -5.4169e-01,\n", + " 1.8883e-01, -5.5797e-02, -9.2162e-01, 4.4790e-01, -3.5256e-02,\n", + " 2.1131e-01, -4.6267e-02, 4.1688e-01, 1.9311e-01, 8.2643e-01,\n", + " 3.1897e-02, 1.8036e-02, 2.2502e-02, -5.6261e-01, 5.2690e-01,\n", + " -4.1523e-01, -2.0335e-01, 5.0975e-03, 1.0000e+00, -1.3769e-01,\n", + " 4.0090e-01, 4.8580e-01, 3.0547e-01, 1.0161e-01, 1.1372e-01,\n", + " 5.4687e-01, 1.7282e-01, -1.1611e-01, 1.1692e-01, 3.3706e-01,\n", + " -9.4995e-02, 3.3125e-01, -1.1600e-01, 5.5663e-02, 6.9017e-01,\n", + " 5.2775e-01, -7.8248e-02, 7.7874e-02, -2.5570e-01, 9.5441e-01,\n", + " 4.4725e-02, 7.5062e-02, -1.6521e-01, 9.8572e-02, -1.2673e-01,\n", + " 4.2396e-01, 9.9999e-01, 1.4011e-01, -6.5118e-02, -9.8683e-01,\n", + " -3.4659e-01, -6.9549e-01, 9.9968e-01, 7.8693e-01, -6.2560e-01,\n", + " 4.0561e-01, 5.1398e-01, -7.1926e-03, 3.7469e-01, -4.9920e-02,\n", + " -1.8379e-01, 1.0699e-01, 6.4271e-02, 9.4363e-01, -4.5982e-01,\n", + " -9.6684e-01, -4.8714e-01, 1.6233e-01, -9.2982e-01, 9.8976e-01,\n", + " -2.8241e-01, -3.9526e-02, -2.8969e-01, 2.2178e-01, -7.3322e-01,\n", + " -1.9752e-01, -9.7385e-01, 1.4625e-01, 1.7384e-02, 9.4459e-01,\n", + " 8.0070e-02, -4.1026e-01, -7.2363e-01, 6.5493e-02, 2.9531e-01,\n", + " -2.0402e-01, -9.4453e-01, 9.4867e-01, -9.6224e-01, 4.1987e-01,\n", + " 9.9992e-01, 2.0182e-01, -5.9719e-01, 6.7061e-02, -1.3560e-01,\n", + " 1.1140e-01, -7.1070e-02, 3.3843e-01, -9.1928e-01, -1.1785e-01,\n", + " 7.1904e-03, 9.3812e-02, 1.2718e-01, -4.2175e-01, 6.2383e-01,\n", + " -3.0948e-02, -3.9573e-01, -4.9911e-01, 1.9713e-01, 1.9574e-01,\n", + " 5.2774e-01, -6.4998e-02, 3.8218e-02, -1.3764e-01, 1.3114e-01,\n", + " -8.2896e-01, -6.2801e-02, -1.3077e-01, -9.9745e-01, 3.8189e-01,\n", + " -1.0000e+00, -4.9528e-02, -3.3011e-01, -9.7047e-03, 7.4031e-01,\n", + " 4.5588e-01, -4.3039e-02, -5.9485e-01, 3.5139e-02, 8.4290e-01,\n", + " 7.0024e-01, 4.9505e-03, 1.5221e-01, -4.8182e-01, 3.4912e-02,\n", + " 6.8681e-02, 5.9797e-02, 9.4147e-02, 5.7532e-01, 3.5063e-02,\n", + " 1.0000e+00, -4.4786e-03, -3.4757e-01, -7.9309e-01, 5.7241e-02,\n", + " -4.8241e-02, 9.9991e-01, -3.6963e-01, -9.2729e-01, 2.2610e-01,\n", + " -3.2602e-01, -6.5948e-01, 2.3506e-01, -6.6027e-02, -6.2875e-01,\n", + " -4.7124e-01, 8.3105e-01, 4.3462e-01, -5.2237e-01, 2.1811e-01,\n", + " -1.1176e-01, -2.7027e-01, -6.8502e-02, 5.0502e-02, 9.8319e-01,\n", + " 3.3888e-01, 5.6442e-01, 1.0517e-01, 6.1442e-02, 9.3666e-01,\n", + " 7.3988e-02, -2.4528e-01, -8.5207e-02, 9.9998e-01, 1.4210e-01,\n", + " -8.2488e-01, 2.2405e-01, -9.2098e-01, -1.0235e-01, -8.4105e-01,\n", + " 2.1140e-01, -3.4107e-02, 8.0942e-01, 4.9842e-03, 8.9624e-01,\n", + " 6.7186e-02, -1.7137e-01, -2.7561e-01, 2.6385e-01, 1.9073e-01,\n", + " -8.6307e-01, -9.8238e-01, -9.8035e-01, 2.2370e-01, -3.5154e-01,\n", + " 1.9181e-01, 8.9503e-02, -9.8139e-02, 8.3593e-02, 3.0373e-01,\n", + " -9.9998e-01, 9.0944e-01, 2.9007e-01, 4.4585e-01, 9.4631e-01,\n", + " 4.1260e-01, 1.9621e-01, 2.4693e-01, -9.7562e-01, -7.6957e-01,\n", + " -1.7996e-01, -5.8601e-02, 4.2949e-01, 3.3341e-01, 8.0547e-01,\n", + " 2.5306e-01, -4.0736e-01, -3.4586e-02, 4.1000e-01, -8.3874e-01,\n", + " -9.9092e-01, 3.0937e-01, 3.3917e-01, -6.2679e-01, 9.4565e-01,\n", + " -5.9613e-01, -1.9438e-03, 3.7971e-01, -2.2250e-01, 5.2158e-01,\n", + " 5.9324e-01, -1.8357e-02, -6.8000e-03, 2.1554e-01, 8.2484e-01,\n", + " 8.0068e-01, 9.7795e-01, -1.0868e-01, 4.3963e-01, 2.2388e-01,\n", + " 2.7078e-01, 8.5065e-01, -9.2567e-01, 4.3628e-03, -3.2063e-02,\n", + " -1.9565e-01, 1.1169e-01, -9.4711e-02, -7.2644e-01, 6.3986e-01,\n", + " -1.7955e-01, 4.2939e-01, -2.0787e-01, 2.2294e-01, -2.3857e-01,\n", + " 6.7195e-02, -5.1772e-01, -3.6389e-01, 5.3170e-01, 5.3484e-02,\n", + " 8.5309e-01, 6.4611e-01, 1.2341e-02, -2.4756e-01, 1.4719e-02,\n", + " -5.3293e-02, -9.2566e-01, 5.0771e-01, 1.2492e-01, 2.1458e-01,\n", + " -6.7960e-02, -2.7113e-01, 9.0946e-01, -1.9032e-01, -2.1274e-01,\n", + " -6.4846e-02, -4.3871e-01, 6.3752e-01, -2.1017e-01, -2.9291e-01,\n", + " -3.1616e-01, 5.4117e-01, 1.6768e-01, 9.9424e-01, -9.4508e-02,\n", + " -2.9022e-01, -2.1879e-03, -1.5720e-01, 2.8317e-01, -2.9364e-01,\n", + " -9.9998e-01, 1.4066e-01, 9.1606e-02, 1.1457e-01, -2.1965e-01,\n", + " 3.0746e-01, -5.7719e-02, -8.7692e-01, -9.3891e-02, 2.2809e-01,\n", + " 3.8766e-02, -3.2828e-01, -3.1138e-01, 4.1117e-01, 4.6004e-01,\n", + " 5.5266e-01, 7.2535e-01, 2.5635e-01, 5.2958e-01, 4.7964e-01,\n", + " -1.0402e-01, -5.4204e-01, 8.4934e-01]], grad_fn=))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outputs" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/fine-tune-bert.ipynb b/notebooks/fine-tune-bert.ipynb deleted file mode 100644 index 9517733822194f1f77a235bbbc546e375a75416f..0000000000000000000000000000000000000000 --- a/notebooks/fine-tune-bert.ipynb +++ /dev/null @@ -1,45 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.3" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}