Skip to content

Commit

Permalink
remove caption mask from rainbow_dalle example, as it is no longer ne…
Browse files Browse the repository at this point in the history
…eded in most recent version
  • Loading branch information
lucidrains committed Jul 6, 2022
1 parent d7bd745 commit daf30d0
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions examples/rainbow_dalle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,7 @@
"for i in range(len(caption_tokens)):\n",
" captions_array[i, :len(caption_tokens[i])] = caption_tokens[i]\n",
" \n",
"captions_array = torch.from_numpy(captions_array).to(device)\n",
"captions_mask = captions_array != 0"
"captions_array = torch.from_numpy(captions_array).to(device)\n"
]
},
{
Expand Down Expand Up @@ -912,8 +911,8 @@
"outputs": [],
"source": [
"def train_dalle_batch(vae, train_data, _, idx, __):\n",
" text, image_codes, mask = train_data\n",
" loss = dalle(text[idx, ...], image_codes[idx, ...], mask=mask[idx, ...], return_loss=True)\n",
" text, image_codes = train_data\n",
" loss = dalle(text[idx, ...], image_codes[idx, ...], return_loss=True)\n",
" return loss"
]
},
Expand Down Expand Up @@ -990,7 +989,7 @@
"dalle_model_file = \"data/rainbow_dalle.model\"\n",
"if not os.path.exists(dalle_model_file):\n",
" dalle, loss_history = fit(dalle, opt, None, scheduler, \n",
" (captions_array[train_idx, ...], all_image_codes[train_idx, ...], captions_mask[train_idx, ...]), None, 200, 256, \n",
" (captions_array[train_idx, ...], all_image_codes[train_idx, ...]), None, 200, 256, \n",
" dalle_model_file, train_dalle_batch, \n",
" n_train_samples=len(train_idx))\n",
"\n",
Expand Down Expand Up @@ -1043,7 +1042,7 @@
"generated_images = []\n",
"with torch.no_grad():\n",
" for i in trange(0, len(captions), 128):\n",
" generated = dalle.generate_images(captions_array[i:i + 128, ...], mask=captions_mask[i:i + 128, ...], temperature=0.00001)\n",
" generated = dalle.generate_images(captions_array[i:i + 128, ...], temperature=0.00001)\n",
" generated_images.append(generated)"
]
},
Expand Down Expand Up @@ -1212,7 +1211,7 @@
"source": [
"from torch.nn import functional as F\n",
"\n",
"def generate_image_code(dalle, text, mask):\n",
"def generate_image_code(dalle, text):\n",
" vae, text_seq_len, image_seq_len, num_text_tokens = dalle.vae, dalle.text_seq_len, dalle.image_seq_len, dalle.num_text_tokens\n",
" total_len = text_seq_len + image_seq_len\n",
" out = text\n",
Expand All @@ -1222,7 +1221,7 @@
"\n",
" text, image = out[:, :text_seq_len], out[:, text_seq_len:]\n",
"\n",
" logits = dalle(text, image, mask = mask)[:, -1, :]\n",
" logits = dalle(text, image)[:, -1, :]\n",
" chosen = torch.argmax(logits, dim=1, keepdim=True)\n",
" chosen -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens\n",
" out = torch.cat((out, chosen), dim=-1)\n",
Expand Down Expand Up @@ -1278,7 +1277,7 @@
"generated_image_codes = []\n",
"with torch.no_grad():\n",
" for i in trange(0, len(captions), 128):\n",
" generated = generate_image_code(dalle, captions_array[i:i + 128, ...], mask=captions_mask[i:i + 128, ...])\n",
" generated = generate_image_code(dalle, captions_array[i:i + 128, ...])\n",
" generated_image_codes.append(generated)\n",
" \n",
"generated_image_codes = torch.cat(generated_image_codes, axis=0).cpu()"
Expand Down

0 comments on commit daf30d0

Please sign in to comment.