{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the model (350ish GB)\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bigscience/bloom\")\n",
    "model = AutoModel.from_pretrained(\"bigscience/bloom\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "import torch\n",
    "import time\n",
    "\n",
    "\n",
    "s = time.time()\n",
    "pipe = pipeline(model=\"bigscience/bloom\", torch_dtype=torch.bfloat16)\n",
    "print(f\"Time to load model: {time.time()-s}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# src: https://stackoverflow.com/questions/16816013/is-it-possible-to-print-using-different-colors-in-ipythons-notebook\n",
    "from IPython.display import HTML as html_print\n",
    "\n",
    "def cstr(s, color='black'):\n",
    "    #return \"<text style=color:{}>{}</text>\".format(color, s)\n",
    "    return \"<text style=color:{}>{}</text>\".format(color, s.replace('\\n', '<br>'))\n",
    "\n",
    "def cstr_with_newlines(s, color='black'):\n",
    "    return \"<text style=color:{}>{}</text>\".format(color, s.replace('\\n', '<br>'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def local_inf(prompt, temperature=0.7, top_p=None, max_new_tokens=32, repetition_penalty=None, do_sample=False, num_return_sequences=1):  \n",
    "    response = pipe(f\"{prompt}\", \n",
    "                    temperature = temperature, # 0 to 1\n",
    "                    top_p = top_p, # None, 0-1\n",
    "                    max_new_tokens = max_new_tokens, # up to 2047 theoretically\n",
    "                    return_full_text = False, # include prompt or not.\n",
    "                    repetition_penalty = repetition_penalty, # None, 0-100 (penalty for repeat tokens.\n",
    "                    do_sample = do_sample, # True: use sampling, False: Greedy decoding.\n",
    "                    num_return_sequences = num_return_sequences\n",
    "                    )\n",
    "    return html_print(cstr(prompt, color='#f1f1c7') + cstr(response[0]['generated_text'], color='#a1d8eb')), response[0]['generated_text']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp = \"\"\"# Use OpenCV in Python\"\"\"\n",
    "color_resp, resp = local_inf(inp, max_new_tokens=64)\n",
    "color_resp"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}