From 4a8d44734ca5961466d34d05d5c71e112a880706 Mon Sep 17 00:00:00 2001 From: Claudio Scheer Date: Fri, 22 May 2020 22:51:48 -0300 Subject: [PATCH] Add notebook with implementation from tutorial --- src/seq2seq.ipynb | 858 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 858 insertions(+) create mode 100644 src/seq2seq.ipynb diff --git a/src/seq2seq.ipynb b/src/seq2seq.ipynb new file mode 100644 index 0000000..fb26aec --- /dev/null +++ b/src/seq2seq.ipynb @@ -0,0 +1,858 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## About" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A sequence to sequence model, also known as encode-decoder." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import requeriments:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:31:48.965223Z", + "start_time": "2020-05-23T00:31:48.954325Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from __future__ import unicode_literals, print_function, division\n", + "from io import open\n", + "import unicodedata\n", + "import string\n", + "import re\n", + "import random\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch import optim\n", + "import torch.nn.functional as F\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "device" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:33:41.443416Z", + "start_time": "2020-05-23T00:33:41.435219Z" + } + }, + "outputs": [], + "source": [ + "SOS_token = 0\n", + "EOS_token = 1\n", + "\n", + "\n", + "class Lang:\n", + " def __init__(self, name):\n", + " self.name = name\n", + " self.word2index = {}\n", + " self.word2count = {}\n", + " self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", + " self.n_words = 2 # Count SOS and EOS\n", + "\n", + " def addSentence(self, sentence):\n", + " for word in sentence.split(' '):\n", + " self.addWord(word)\n", + "\n", + " def addWord(self, word):\n", + " if word not in self.word2index:\n", + " self.word2index[word] = self.n_words\n", + " self.word2count[word] = 1\n", + " self.index2word[self.n_words] = word\n", + " self.n_words += 1\n", + " else:\n", + " self.word2count[word] += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:35:56.907752Z", + "start_time": "2020-05-23T00:35:56.897354Z" + } + }, + "outputs": [], + "source": [ + "# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427.\n", + "def unicodeToAscii(s):\n", + " return \"\".join(\n", + " c for c in unicodedata.normalize(\"NFD\", s)\n", + " if unicodedata.category(c) != \"Mn\"\n", + " )\n", + "\n", + "\n", + "# Lowercase, trim, and remove non-letter characters.\n", + "def normalizeString(s):\n", + " s = unicodeToAscii(s.lower().strip())\n", + " s = re.sub(r\"([.!?])\", r\" \\1\", s)\n", + " s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n", + " return s" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:39:35.205497Z", + "start_time": "2020-05-23T00:39:35.198294Z" + } + }, + "outputs": [], + "source": [ + "def readLangs(lang1, lang2, reverse=False):\n", + " print(\"Reading lines...\")\n", + "\n", + " # Read the file and split into lines.\n", + " lines = open(\"../data/%s-%s.txt\" % (lang1, lang2), encoding=\"utf-8\").\\\n", + " read().strip().split(\"\\n\")\n", + "\n", + " # Split every line into pairs and normalize.\n", + " pairs = [[normalizeString(s) for s in l.split(\"\\t\")] for l in lines]\n", + "\n", + " # Reverse pairs, make Lang instances.\n", + " if reverse:\n", + " pairs = [list(reversed(p)) for p in pairs]\n", + " input_lang = Lang(lang2)\n", + " output_lang = Lang(lang1)\n", + " else:\n", + " input_lang = Lang(lang1)\n", + " output_lang = Lang(lang2)\n", + "\n", + " return input_lang, output_lang, pairs" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:39:35.401126Z", + "start_time": "2020-05-23T00:39:35.381227Z" + } + }, + "outputs": [], + "source": [ + "MAX_LENGTH = 10\n", + "\n", + "eng_prefixes = (\n", + " \"i am \", \"i m \",\n", + " \"he is\", \"he s \",\n", + " \"she is\", \"she s \",\n", + " \"you are\", \"you re \",\n", + " \"we are\", \"we re \",\n", + " \"they are\", \"they re \"\n", + ")\n", + "\n", + "\n", + "def filterPair(p):\n", + " return len(p[0].split(\" \")) < MAX_LENGTH and \\\n", + " len(p[1].split(\" \")) < MAX_LENGTH and \\\n", + " p[1].startswith(eng_prefixes)\n", + "\n", + "\n", + "def filterPairs(pairs):\n", + " return [pair for pair in pairs if filterPair(pair)]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:41:46.640103Z", + "start_time": "2020-05-23T00:41:43.093695Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reading lines...\n", + "Read 135842 sentence pairs\n", + "Trimmed to 10599 sentence pairs\n", + "Counting words...\n", + "Counted words:\n", + "fra 4345\n", + "eng 2803\n", + "['ce ne sont pas tous des criminels .', 'they re not all criminals .']\n" + ] + } + ], + "source": [ + "def prepareData(lang1, lang2, reverse=False):\n", + " input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n", + " print(\"Read %s sentence pairs\" % len(pairs))\n", + " pairs = filterPairs(pairs)\n", + " print(\"Trimmed to %s sentence pairs\" % len(pairs))\n", + " print(\"Counting words...\")\n", + " for pair in pairs:\n", + " input_lang.addSentence(pair[0])\n", + " output_lang.addSentence(pair[1])\n", + " print(\"Counted words:\")\n", + " print(input_lang.name, input_lang.n_words)\n", + " print(output_lang.name, output_lang.n_words)\n", + " return input_lang, output_lang, pairs\n", + "\n", + "\n", + "input_lang, output_lang, pairs = prepareData(\"eng\", \"fra\", True)\n", + "print(random.choice(pairs))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:42:10.169849Z", + "start_time": "2020-05-23T00:42:10.162724Z" + } + }, + "outputs": [], + "source": [ + "class EncoderRNN(nn.Module):\n", + " def __init__(self, input_size, hidden_size):\n", + " super(EncoderRNN, self).__init__()\n", + " self.hidden_size = hidden_size\n", + "\n", + " self.embedding = nn.Embedding(input_size, hidden_size)\n", + " self.gru = nn.GRU(hidden_size, hidden_size)\n", + "\n", + " def forward(self, input, hidden):\n", + " embedded = self.embedding(input).view(1, 1, -1)\n", + " output = embedded\n", + " output, hidden = self.gru(output, hidden)\n", + " return output, hidden\n", + "\n", + " def initHidden(self):\n", + " return torch.zeros(1, 1, self.hidden_size, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:42:23.098766Z", + "start_time": "2020-05-23T00:42:23.094127Z" + } + }, + "outputs": [], + "source": [ + "class DecoderRNN(nn.Module):\n", + " def __init__(self, hidden_size, output_size):\n", + " super(DecoderRNN, self).__init__()\n", + " self.hidden_size = hidden_size\n", + "\n", + " self.embedding = nn.Embedding(output_size, hidden_size)\n", + " self.gru = nn.GRU(hidden_size, hidden_size)\n", + " self.out = nn.Linear(hidden_size, output_size)\n", + " self.softmax = nn.LogSoftmax(dim=1)\n", + "\n", + " def forward(self, input, hidden):\n", + " output = self.embedding(input).view(1, 1, -1)\n", + " output = F.relu(output)\n", + " output, hidden = self.gru(output, hidden)\n", + " output = self.softmax(self.out(output[0]))\n", + " return output, hidden\n", + "\n", + " def initHidden(self):\n", + " return torch.zeros(1, 1, self.hidden_size, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:44:50.268979Z", + "start_time": "2020-05-23T00:44:50.254209Z" + } + }, + "outputs": [], + "source": [ + "class AttnDecoderRNN(nn.Module):\n", + " def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n", + " super(AttnDecoderRNN, self).__init__()\n", + " self.hidden_size = hidden_size\n", + " self.output_size = output_size\n", + " self.dropout_p = dropout_p\n", + " self.max_length = max_length\n", + "\n", + " self.embedding = nn.Embedding(self.output_size, self.hidden_size)\n", + " self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n", + " self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\n", + " self.dropout = nn.Dropout(self.dropout_p)\n", + " self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n", + " self.out = nn.Linear(self.hidden_size, self.output_size)\n", + "\n", + " def forward(self, input, hidden, encoder_outputs):\n", + " embedded = self.embedding(input).view(1, 1, -1)\n", + " embedded = self.dropout(embedded)\n", + "\n", + " attn_weights = F.softmax(\n", + " self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n", + " attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n", + " encoder_outputs.unsqueeze(0))\n", + "\n", + " output = torch.cat((embedded[0], attn_applied[0]), 1)\n", + " output = self.attn_combine(output).unsqueeze(0)\n", + "\n", + " output = F.relu(output)\n", + " output, hidden = self.gru(output, hidden)\n", + "\n", + " output = F.log_softmax(self.out(output[0]), dim=1)\n", + " return output, hidden, attn_weights\n", + "\n", + " def initHidden(self):\n", + " return torch.zeros(1, 1, self.hidden_size, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:50:26.223001Z", + "start_time": "2020-05-23T00:50:26.211228Z" + } + }, + "outputs": [], + "source": [ + "def indexesFromSentence(lang, sentence):\n", + " return [lang.word2index[word] for word in sentence.split(' ')]\n", + "\n", + "\n", + "def tensorFromSentence(lang, sentence):\n", + " indexes = indexesFromSentence(lang, sentence)\n", + " indexes.append(EOS_token)\n", + " return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n", + "\n", + "\n", + "def tensorsFromPair(pair):\n", + " input_tensor = tensorFromSentence(input_lang, pair[0])\n", + " target_tensor = tensorFromSentence(output_lang, pair[1])\n", + " return (input_tensor, target_tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:50:27.526710Z", + "start_time": "2020-05-23T00:50:27.503564Z" + } + }, + "outputs": [], + "source": [ + "teacher_forcing_ratio = 0.5\n", + "\n", + "\n", + "def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):\n", + " encoder_hidden = encoder.initHidden()\n", + "\n", + " encoder_optimizer.zero_grad()\n", + " decoder_optimizer.zero_grad()\n", + "\n", + " input_length = input_tensor.size(0)\n", + " target_length = target_tensor.size(0)\n", + "\n", + " encoder_outputs = torch.zeros(\n", + " max_length, encoder.hidden_size, device=device)\n", + "\n", + " loss = 0\n", + "\n", + " for ei in range(input_length):\n", + " encoder_output, encoder_hidden = encoder(\n", + " input_tensor[ei], encoder_hidden)\n", + " encoder_outputs[ei] = encoder_output[0, 0]\n", + "\n", + " decoder_input = torch.tensor([[SOS_token]], device=device)\n", + "\n", + " decoder_hidden = encoder_hidden\n", + "\n", + " use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n", + "\n", + " if use_teacher_forcing:\n", + " # Teacher forcing: Feed the target as the next input\n", + " for di in range(target_length):\n", + " decoder_output, decoder_hidden, decoder_attention = decoder(\n", + " decoder_input, decoder_hidden, encoder_outputs)\n", + " loss += criterion(decoder_output, target_tensor[di])\n", + " decoder_input = target_tensor[di] # Teacher forcing\n", + "\n", + " else:\n", + " # Without teacher forcing: use its own predictions as the next input\n", + " for di in range(target_length):\n", + " decoder_output, decoder_hidden, decoder_attention = decoder(\n", + " decoder_input, decoder_hidden, encoder_outputs)\n", + " topv, topi = decoder_output.topk(1)\n", + " decoder_input = topi.squeeze().detach() # detach from history as input\n", + "\n", + " loss += criterion(decoder_output, target_tensor[di])\n", + " if decoder_input.item() == EOS_token:\n", + " break\n", + "\n", + " loss.backward()\n", + "\n", + " encoder_optimizer.step()\n", + " decoder_optimizer.step()\n", + "\n", + " return loss.item() / target_length" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:50:28.230809Z", + "start_time": "2020-05-23T00:50:28.222823Z" + } + }, + "outputs": [], + "source": [ + "import time\n", + "import math\n", + "\n", + "\n", + "def asMinutes(s):\n", + " m = math.floor(s / 60)\n", + " s -= m * 60\n", + " return '%dm %ds' % (m, s)\n", + "\n", + "\n", + "def timeSince(since, percent):\n", + " now = time.time()\n", + " s = now - since\n", + " es = s / (percent)\n", + " rs = es - s\n", + " return '%s (- %s)' % (asMinutes(s), asMinutes(rs))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:50:28.784096Z", + "start_time": "2020-05-23T00:50:28.764511Z" + } + }, + "outputs": [], + "source": [ + "def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):\n", + " start = time.time()\n", + " plot_losses = []\n", + " print_loss_total = 0 # Reset every print_every\n", + " plot_loss_total = 0 # Reset every plot_every\n", + "\n", + " encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n", + " decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n", + " training_pairs = [tensorsFromPair(random.choice(pairs))\n", + " for i in range(n_iters)]\n", + " criterion = nn.NLLLoss()\n", + "\n", + " for iter in range(1, n_iters + 1):\n", + " training_pair = training_pairs[iter - 1]\n", + " input_tensor = training_pair[0]\n", + " target_tensor = training_pair[1]\n", + "\n", + " loss = train(input_tensor, target_tensor, encoder,\n", + " decoder, encoder_optimizer, decoder_optimizer, criterion)\n", + " print_loss_total += loss\n", + " plot_loss_total += loss\n", + "\n", + " if iter % print_every == 0:\n", + " print_loss_avg = print_loss_total / print_every\n", + " print_loss_total = 0\n", + " print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\n", + " iter, iter / n_iters * 100, print_loss_avg))\n", + "\n", + " if iter % plot_every == 0:\n", + " plot_loss_avg = plot_loss_total / plot_every\n", + " plot_losses.append(plot_loss_avg)\n", + " plot_loss_total = 0\n", + "\n", + " showPlot(plot_losses)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:50:30.060880Z", + "start_time": "2020-05-23T00:50:30.057279Z" + } + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.switch_backend('agg')\n", + "import matplotlib.ticker as ticker\n", + "import numpy as np\n", + "\n", + "\n", + "def showPlot(points):\n", + " plt.figure()\n", + " fig, ax = plt.subplots()\n", + " # this locator puts ticks at regular intervals\n", + " loc = ticker.MultipleLocator(base=0.2)\n", + " ax.yaxis.set_major_locator(loc)\n", + " plt.plot(points)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:50:31.015502Z", + "start_time": "2020-05-23T00:50:31.009820Z" + } + }, + "outputs": [], + "source": [ + "def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\n", + " with torch.no_grad():\n", + " input_tensor = tensorFromSentence(input_lang, sentence)\n", + " input_length = input_tensor.size()[0]\n", + " encoder_hidden = encoder.initHidden()\n", + "\n", + " encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n", + "\n", + " for ei in range(input_length):\n", + " encoder_output, encoder_hidden = encoder(input_tensor[ei],\n", + " encoder_hidden)\n", + " encoder_outputs[ei] += encoder_output[0, 0]\n", + "\n", + " decoder_input = torch.tensor([[SOS_token]], device=device) # SOS\n", + "\n", + " decoder_hidden = encoder_hidden\n", + "\n", + " decoded_words = []\n", + " decoder_attentions = torch.zeros(max_length, max_length)\n", + "\n", + " for di in range(max_length):\n", + " decoder_output, decoder_hidden, decoder_attention = decoder(\n", + " decoder_input, decoder_hidden, encoder_outputs)\n", + " decoder_attentions[di] = decoder_attention.data\n", + " topv, topi = decoder_output.data.topk(1)\n", + " if topi.item() == EOS_token:\n", + " decoded_words.append('')\n", + " break\n", + " else:\n", + " decoded_words.append(output_lang.index2word[topi.item()])\n", + "\n", + " decoder_input = topi.squeeze().detach()\n", + "\n", + " return decoded_words, decoder_attentions[:di + 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T00:50:31.456534Z", + "start_time": "2020-05-23T00:50:31.439308Z" + } + }, + "outputs": [], + "source": [ + "def evaluateRandomly(encoder, decoder, n=10):\n", + " for i in range(n):\n", + " pair = random.choice(pairs)\n", + " print('>', pair[0])\n", + " print('=', pair[1])\n", + " output_words, attentions = evaluate(encoder, decoder, pair[0])\n", + " output_sentence = ' '.join(output_words)\n", + " print('<', output_sentence)\n", + " print('')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T01:08:47.976352Z", + "start_time": "2020-05-23T00:50:32.224369Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1m 9s (- 16m 12s) (5000 6%) 2.8631\n", + "2m 14s (- 14m 31s) (10000 13%) 2.3104\n", + "3m 18s (- 13m 13s) (15000 20%) 1.9808\n", + "4m 27s (- 12m 16s) (20000 26%) 1.7190\n", + "5m 39s (- 11m 18s) (25000 33%) 1.5530\n", + "6m 53s (- 10m 20s) (30000 40%) 1.3702\n", + "8m 9s (- 9m 19s) (35000 46%) 1.2285\n", + "9m 24s (- 8m 13s) (40000 53%) 1.0931\n", + "10m 41s (- 7m 7s) (45000 60%) 0.9838\n", + "11m 57s (- 5m 58s) (50000 66%) 0.8980\n", + "13m 13s (- 4m 48s) (55000 73%) 0.7922\n", + "14m 28s (- 3m 37s) (60000 80%) 0.7351\n", + "15m 43s (- 2m 25s) (65000 86%) 0.6758\n", + "17m 0s (- 1m 12s) (70000 93%) 0.6178\n", + "18m 15s (- 0m 0s) (75000 100%) 0.5520\n" + ] + } + ], + "source": [ + "hidden_size = 256\n", + "encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n", + "attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)\n", + "\n", + "trainIters(encoder1, attn_decoder1, 75000, print_every=5000)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T01:09:13.269075Z", + "start_time": "2020-05-23T01:09:13.192794Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> ce n est pas une beaute .\n", + "= she is no beauty .\n", + "< she is no beauty . \n", + "\n", + "> tu es tres sympa .\n", + "= you are very nice .\n", + "< you re very nice . \n", + "\n", + "> vous etes givree !\n", + "= you re nuts !\n", + "< you re nuts ! \n", + "\n", + "> c est toi le plus vieux .\n", + "= you re the oldest .\n", + "< you re the oldest . \n", + "\n", + "> je vais bientot etre parti .\n", + "= i m going to be gone soon .\n", + "< i m going to be gone soon . \n", + "\n", + "> tu es libre .\n", + "= you re free .\n", + "< you re free . \n", + "\n", + "> son travail lui plait .\n", + "= he is pleased with his work .\n", + "< he is pleased with his work . \n", + "\n", + "> vous devez venir avec moi .\n", + "= you are to come with me .\n", + "< you are to come with me . \n", + "\n", + "> tu n es pas totalement irreprochable .\n", + "= you re not totally blameless .\n", + "< you re not a fast . \n", + "\n", + "> vous etes impressionne n est ce pas ?\n", + "= you re impressed aren t you ?\n", + "< you re impressed aren t you ? \n", + "\n" + ] + } + ], + "source": [ + "evaluateRandomly(encoder1, attn_decoder1)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T01:10:55.866398Z", + "start_time": "2020-05-23T01:10:55.800136Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_words, attentions = evaluate(\n", + " encoder1, attn_decoder1, \"je suis trop froid .\")\n", + "plt.matshow(attentions.numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "ExecuteTime": { + "end_time": "2020-05-23T01:11:05.729333Z", + "start_time": "2020-05-23T01:11:05.459741Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input = elle a cinq ans de moins que moi .\n", + "output = she s three years younger than me . \n", + "input = elle est trop petit .\n", + "output = she is too short . \n", + "input = je ne crains pas de mourir .\n", + "output = i m not afraid to die . \n", + "input = c est un jeune directeur plein de talent .\n", + "output = he s a very young young \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":17: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.\n", + " plt.show()\n" + ] + } + ], + "source": [ + "def showAttention(input_sentence, output_words, attentions):\n", + " # Set up figure with colorbar\n", + " fig = plt.figure()\n", + " ax = fig.add_subplot(111)\n", + " cax = ax.matshow(attentions.numpy(), cmap='bone')\n", + " fig.colorbar(cax)\n", + "\n", + " # Set up axes\n", + " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", + " [''], rotation=90)\n", + " ax.set_yticklabels([''] + output_words)\n", + "\n", + " # Show label at every tick\n", + " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", + " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", + "\n", + " plt.show()\n", + "\n", + "\n", + "def evaluateAndShowAttention(input_sentence):\n", + " output_words, attentions = evaluate(\n", + " encoder1, attn_decoder1, input_sentence)\n", + " print('input =', input_sentence)\n", + " print('output =', ' '.join(output_words))\n", + " showAttention(input_sentence, output_words, attentions)\n", + "\n", + "\n", + "evaluateAndShowAttention(\"elle a cinq ans de moins que moi .\")\n", + "\n", + "evaluateAndShowAttention(\"elle est trop petit .\")\n", + "\n", + "evaluateAndShowAttention(\"je ne crains pas de mourir .\")\n", + "\n", + "evaluateAndShowAttention(\"c est un jeune directeur plein de talent .\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} -- GitLab