See the official CookBERT paper.
CookBERT is a domain-specific BERT model that was created by domain adaptive pretraining on the instructions of the RecipeNLG corpus and by enhancing BERT's default vocabulary by a total of 1229 cooking specific words. As a result, CookBERT is geared more towards the cooking domain than the default model:
Input | Model | Top 5 predictions for [MASK] token |
---|---|---|
"Do I have to [MASK] the apple?" | BERTbase | eat, take, have, touch, get |
CookBERT | peel, slice, use, dice, chop | |
“[MASK] the water.” | BERTbase | in, drink, under, into, on |
CookBERT | boil, heat, add, scald, chill | |
“Cut the [MASK] into small pieces.” | BERTbase | wood, paper, leaves, meat, bark |
CookBERT | chicken, cheese, fruit, cabbage, sausage |
The domain-specifity of CookBERT has proven to be superior in text classification and named entity recognition when dealing with data related to the cooking domain.
To obtain CookBERT, BERTbase (uncased version)
was used as the starting point which was then further pretrained for three additional epochs
on the MLM
task on the RecipeNLG instructions
, with 5% serving as validation data
. Training was performed with a learning rate of 2e-5
, an effective batch size of 32
, and a maximum sequence length of 256
. The training took appoximately five complete days
on a single NVIDIA Tesla P100 GPU
provided by Google Colab Pro.
CookBERT was finetuned and evaluated on three different tasks, including information need classification, food entity tagging and question answering. In addition, BERTbase (uncased version) and FoodBERT were applied for the same tasks in order to be able to compare and rank CookBERT's performance.
Results of the classification of user information needs that arise during cooking; Based on the Cookversational dataset.
Model | Condition | Precision | Recall | F-Measure | 95%-CI |
---|---|---|---|---|---|
BERTbase | no context | 47.94% | 48.68% | 46.15% | [41.15%;51.16%] |
1 prev turn | 46.29% | 49.84% | 45.38% | [40.06%;50.70%] | |
CookBERT | no context | 48.58% | 55.65% | 50.72% | [45.54%;55.90%] |
1 prev turn | 52.26% | 59.30% | 54.05% | [48.93%;59.16%] | |
FoodBERT | no context | 42.41% | 49.81% | 44.32% | [38.92%;49.73%] |
1 prev turn | 36.89% | 44.49% | 38.09% | [32.64%;43.55%] |
Results of the food entity tagging task using the curated version of the FoodBase corpus, as well as the labels provided by Stojanov et al. (2021) for five different tagging schemes.
Model | Tagging-Task | Precision | Recall | F-Measure | 95%-CI |
---|---|---|---|---|---|
BERTbase | Food-classification | 90.68% | 96.06% | 93.29% | [92,87%;93.71%] |
FoodOn | 65.24% | 73.10% | 68.94% | [67.04%;70.83%] | |
Hansard-parent | 80.35% | 88.68% | 84.31% | [83.54%;85.08%] | |
Hansard-closest | 70.79% | 79.98% | 75.10% | [73.87%;76.34%] | |
SNOMED CT | 63.04% | 70.65% | 66.62% | [64.49%;68.75%] | |
CookBERT | Food-classification | 92.25% | 96.52% | 94.47% | [94.17%;94.76%] |
FoodOn | 69.75% | 77.51% | 73.42% | [71.91%;74.93%] | |
Hansard-parent | 82.72% | 89.18% | 85.83% | [84.69%;86.97%] | |
Hansard-closest | 72.21% | 80.41% | 76.08% | [74.60%;77.56%] | |
SNOMED CT | 68.58% | 75.51% | 71.87% | [69.99%;73.75%] | |
FoodBERT | Food-classification | 85.28% | 94.24% | 89.53% | [88.90%;90.17%] |
FoodOn | 58.73% | 61.03% | 59.85% | [56.56%;63.13%] | |
Hansard-parent | 68.41% | 80.62% | 74.01% | [72.13%;75.90%] | |
Hansard-closest | 59.55% | 67.52% | 63.28% | [60.43%;66.13%] | |
SNOMED CT | 53.63% | 51.84% | 52.67% | [49.17%;56.17%] |
Results of the question answering task in the sense of answer span extraction; Based on the cooking subset of the DoQA dataset.
Model | Exact match | F-measure | 95%-CI |
---|---|---|---|
BERTbase | 14.06% | 32.39% | [31.25%;33.54%] |
CookBERT | 12.51% | 30.64% | [29.50%;31.78%] |
FoodBERT | 10.81% | 27.51% | [26.51%;28.50%] |
Best performances printed in bold
The CookBERT pytorch model checkpoint can be downloaded from this Google Drive folder. Huggingface Transformer Library enables the model to be set up easily:
from transformers import (
BertTokenizerFast,
BertForMaskedLM,
pipeline
)
CookBERT_tokenizer = BertTokenizerFast.from_pretrained("CookBERT-checkpoint")
CookBERT = BertForMaskedLM.from_pretrained("CookBERT-checkpoint")
CookBERT_pipeline = pipeline("fill-mask", model=CookBERT, tokenizer=CookBERT_tokenizer)
masked_text = "Cut the [MASK] into small pieces."
print("Predictions: ", CookBERT_pipeline(masked_text, top_k=5))
Note that the Google Drive folder only contains the CookBERT checkpoint that was trained on the MLM task. In order to apply CookBERT for different tasks (NER, QA, ...), the finetuning scripts from Huggingface can be used.