Skip to content

Commit

Permalink
add example: autort.examples.03_llama_tiny
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant committed Dec 26, 2023
1 parent 2608465 commit 10e2784
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ AutoRT is a compiler solution that helps runtime users to invent, benchmark and

| Platform | OS Requirement | Python Requirement | Download Link |
| --- | --- | --- | --- |
| DirectX 12 | Windows >= 10 / Microsoft XBox | [Python3.12](https://www.python.org/ftp/python/3.12.0/python-3.12.0-amd64.exe) (Windows) | python3.12 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.3/autort_for_dxwin.py312 |
| DirectX 12 | Windows >= 10 / Microsoft XBox | [Python3.12](https://www.python.org/ftp/python/3.12.0/python-3.12.0-amd64.exe) (Windows) | python3.12 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.4/autort_for_dxwin.py312 |
| Vulkan 1.3 | Ubuntu >= 18.04 (or images) | [Python3.12](https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb) (Linux) | python3.12 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.3/autort_for_vklinux.py312 |
| CUDA >= 11 | Ubuntu >= 18.04 (or images) | Python 3.8/3.9/3.10/3.11/3.12 | python3 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.3/autort_for_cuda_linux.py3x |
| .. | .. | .. | .. (More coming soon) .. |
Expand Down
45 changes: 31 additions & 14 deletions docker/Dockerfile.c-mcpu_avx512
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
FROM intelaipg/intel-optimized-tensorflow:2.3.0-avx512-mkl
FROM ubuntu:18.04

ENV DEBIAN_FRONTEND noninteractive
ENV PYTHONDONTWRITEBYTECODE 1
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH

RUN env > /etc/environment

RUN rm -f /etc/apt/sources.list.d/*

RUN /bin/echo "deb http://archive.ubuntu.com/ubuntu bionic main restricted universe multiverse" > /etc/apt/sources.list
RUN /bin/echo "deb http://archive.ubuntu.com/ubuntu bionic-updates main restricted universe multiverse" >> /etc/apt/sources.list

RUN apt-get update && apt install -y --no-install-recommends git ca-certificates \
python3-pip python3-wheel python3-setuptools python3-dev python3-pytest \
vim-tiny less netcat-openbsd inetutils-ping curl patch iproute2 \
g++ libpci3 libnuma-dev make file openssh-server kmod gdb libopenmpi-dev openmpi-bin psmisc \
autoconf automake autotools-dev libtool \
zlib1g-dev rename zip unzip librdmacm-dev gnupg \
clang-10 \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
make file openssh-server gdb psmisc zlib1g-dev rename zip unzip gnupg rsync p7zip-full clang-10 libomp-dev

RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny
RUN /bin/echo "deb http://ppa.launchpad.net/ubuntu-toolchain-r/test/ubuntu bionic main" >> /etc/apt/sources.list
RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 1E9377A2BA9EF27F
RUN apt-get update && apt install -y --no-install-recommends g++-9 xz-utils

RUN pip3 install --upgrade antares && mkdir -p /root/.local/antares && ln -s $(antares pwd)/../3rdparty /root/.local/antares/3rdparty
RUN ln -sf /usr/bin/g++-9 /usr/bin/g++
RUN ln -sf /usr/bin/gcc-9 /usr/bin/gcc
RUN ln -sf /usr/bin/g++ /usr/bin/c++
RUN ln -sf /usr/bin/gcc /usr/bin/cc
RUN ln -sf /usr/bin/gcc /usr/bin/x86_64-linux-gnu-gcc
RUN ln -sf /usr/bin/g++ /usr/bin/x86_64-linux-gnu-g++

RUN curl -LO https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb && dpkg -i python-3.12-linux-x86_64.deb && rm -f python-3.12-linux-x86_64.deb
RUN ln -sf /usr/local/bin/python3.12 /usr/local/bin/python3
RUN ln -sf python3 /usr/local/bin/python
RUN ln -sf python /usr/local/bin/python.exe
RUN /bin/echo -e 'exec python3 -m pip "$@"' > /usr/local/bin/pip3 && chmod a+x /usr/local/bin/pip3
RUN ln -sf pip3 /usr/local/bin/pip
# RUN curl -LO https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && rm -f get-pip.py

RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny

RUN ln -s clang++-10 /usr/bin/clang++ || true
RUN python3 -m pip install mpi4py
RUN mv /usr/bin/mpiexec /usr/bin/mpiexec.real && \
echo 'exec mpiexec.real --allow-run-as-root "$@"' > /usr/bin/mpiexec && \
chmod a+x /usr/bin/mpiexec
RUN python3 -m pip install cython setuptools
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.3.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.2.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.18.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl
18 changes: 13 additions & 5 deletions docker/Dockerfile.c-sycl_intel
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@ RUN apt-get update && apt install -y --no-install-recommends git ca-certificates
RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny
RUN cp -r /opt/intel/oneapi/compiler/latest/linux/include/sycl/CL /opt/intel/oneapi/compiler/latest/linux/include/

RUN apt update && apt-get install -y python3-pip
RUN curl -LO https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb && dpkg -i python-3.12-linux-x86_64.deb && rm -f python-3.12-linux-x86_64.deb
RUN ln -sf /usr/local/bin/python3.12 /usr/local/bin/python3
RUN ln -sf python3 /usr/local/bin/python
RUN ln -sf python /usr/local/bin/python.exe
RUN /bin/echo -e 'exec python3 -m pip "$@"' > /usr/local/bin/pip3 && chmod a+x /usr/local/bin/pip3
RUN ln -sf pip3 /usr/local/bin/pip

RUN pip3 install --upgrade antares && mkdir -p /root/.local/antares && mv $(antares pwd)/../3rdparty /root/.local/antares/3rdparty && pip3 uninstall antares -y && echo 'exec /antares/main.py "$@"' > /usr/local/bin/antares && chmod a+x /usr/local/bin/antares

RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.2.0.dev20231118%2Bcpu-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install cython
RUN ln -s python3 /usr/bin/python
RUN ln -s python /usr/bin/python.exe
RUN python3 -m pip install cython setuptools
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.3.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.2.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.18.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl

RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg
8 changes: 5 additions & 3 deletions docker/Dockerfile.c-vulkan
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RUN ln -sf /usr/bin/gcc /usr/bin/cc
RUN ln -sf /usr/bin/gcc /usr/bin/x86_64-linux-gnu-gcc
RUN ln -sf /usr/bin/g++ /usr/bin/x86_64-linux-gnu-g++

RUN curl -L https://github.com/ghostplant/collections/releases/download/utilities/python3.12-linux-x86_64.tar.xz | xz -d | tar xvf - -C /usr/local >/dev/null
RUN curl -LO https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb && dpkg -i python-3.12-linux-x86_64.deb && rm -f python-3.12-linux-x86_64.deb
RUN ln -sf /usr/local/bin/python3.12 /usr/local/bin/python3
RUN ln -sf python3 /usr/local/bin/python
RUN ln -sf python /usr/local/bin/python.exe
Expand All @@ -38,6 +38,8 @@ RUN ln -sf pip3 /usr/local/bin/pip
RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny

RUN python3 -m pip install cython setuptools
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.2.0.dev20231118%2Bcpu-cp312-cp312-linux_x86_64.whl
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.3.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.2.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl
RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.18.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl

RUN apt-get install -y --no-install-recommends mesa-vulkan-drivers
RUN apt-get install -y --no-install-recommends mesa-vulkan-drivers ffmpeg
155 changes: 155 additions & 0 deletions samples/03_llama_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import os, sys, math, random
import autort


def download_pt(data_path, url):
if not os.path.exists(data_path):
print(f'Downloading dataset to {data_path} ..')
import urllib.request, zipfile, io
with urllib.request.urlopen(url) as fp:
r = fp.read()
with open(data_path, 'wb') as fp:
fp.write(r)
return torch.load(data_path)

pt = download_pt('llama_story_110m.pt', 'https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt?download=true')
vocab = download_pt('vocab_32K.pt', 'https://huggingface.co/datasets/ghostplant/data-collections/resolve/main/vocab_32K.pt?download=true')

device = autort.device()

args, param = pt['model_args'], pt['model']
n_heads, seq_len = args['n_heads'], args['max_seq_len']
head_size = args['dim'] // n_heads
token_embedding_table = param['tok_embeddings.weight']

data_type = token_embedding_table.dtype

rms_att_w, rms_ffn_w = [], []
weight_q, weight_k, weight_v, weight_o, weight_f1, weight_f2, weight_f3 = [], [], [], [], [], [], []
for i in range(1024):
try:
rms_att_w += [param[f'layers.{i}.attention_norm.weight'].unsqueeze(0)]
rms_ffn_w += [param[f'layers.{i}.ffn_norm.weight'].unsqueeze(0)]
weight_q += [param[f'layers.{i}.attention.wq.weight'].unsqueeze(0)]
weight_k += [param[f'layers.{i}.attention.wk.weight'].unsqueeze(0)]
weight_v += [param[f'layers.{i}.attention.wv.weight'].unsqueeze(0)]
weight_o += [param[f'layers.{i}.attention.wo.weight'].unsqueeze(0)]
weight_f1 += [param[f'layers.{i}.feed_forward.w1.weight'].unsqueeze(0)]
weight_f2 += [param[f'layers.{i}.feed_forward.w2.weight'].unsqueeze(0)]
weight_f3 += [param[f'layers.{i}.feed_forward.w3.weight'].unsqueeze(0)]
except KeyError:
break

rms_att_w = torch.cat(rms_att_w, dim=0).to(data_type).to(device)
rms_ffn_w = torch.cat(rms_ffn_w, dim=0).to(data_type).to(device)
rms_end_w = param['norm.weight'].to(data_type).to(device)
weight_classify = param['output.weight'].to(data_type).to(device)
weight_q = torch.cat(weight_q, dim=0).to(data_type).to(device)
weight_k = torch.cat(weight_k, dim=0).to(data_type).to(device)
weight_v = torch.cat(weight_v, dim=0).to(data_type).to(device)
weight_o = torch.cat(weight_o, dim=0).to(data_type).to(device)
weight_f1 = torch.cat(weight_f1, dim=0).to(data_type).to(device)
weight_f2 = torch.cat(weight_f2, dim=0).to(data_type).to(device)
weight_f3 = torch.cat(weight_f3, dim=0).to(data_type).to(device)
token_embedding_table = token_embedding_table.view([token_embedding_table.size(0), n_heads, head_size]).to(data_type).to(device)

n_layers = weight_q.size(0)
vocab_size, n_heads, head_size, = token_embedding_table.size(0), token_embedding_table.size(1), token_embedding_table.size(2)
n_layers, hidden, = rms_att_w.size(0), weight_f1.size(1)
kv_heads, dim = n_heads, n_heads * head_size

assert n_heads // kv_heads == 1 and head_size % 2 == 0

key_cache = torch.zeros([n_layers, seq_len, dim], dtype=data_type, device=weight_o.device)
val_cache = torch.zeros_like(key_cache)

ceof = 1 / torch.pow(1e4, torch.arange(0, dim, 2, dtype=torch.int64) % head_size / head_size).view(1, -1).to(data_type).to(weight_o.device)
att_f = torch.tensor([1 / math.sqrt(head_size)], dtype=data_type, device=weight_o.device)

def rmsnorm(x, weight):
x = x.float()
vsum = (x * x).sum()
return autort.ops.rmsnorm_f32(x.view(-1), vsum, weight, extra=[1.0 / int(x.numel())])

def rotate(data, ceof, pos, out):
autort.ops.rotate_f32(ceof, data.view(-1), out.view(-1), extra=[pos,])
return out

def forward(token, pos):
x = token_embedding_table.select(0, token).view(1, dim)

for l in range(n_layers):
xb = rmsnorm(x, rms_att_w.select(0, l))

sq = torch.matmul(xb, weight_q.select(0, l).t())

sk = torch.matmul(xb, weight_k.select(0, l).t())

sv = val_cache.select(0, l).narrow(0, pos, 1)
torch.matmul(xb, weight_v.select(0, l).t(), out=sv)

sq_out = torch.empty_like(sq)
sk_out = key_cache.select(0, l).narrow(0, pos, 1)
rotate(sq, ceof, pos, out=sq_out)
rotate(sk, ceof, pos, out=sk_out)
sq, sk = sq_out, sk_out

b_sq = sq.view(n_heads, head_size)
b_sk = key_cache.select(0, l).view(seq_len, n_heads, head_size).narrow(0, 0, pos + 1)

att = torch.einsum('hm,shm->hs', b_sq, b_sk) * att_f

att = torch.nn.functional.softmax(att, dim=-1)
b_sv = val_cache.select(0, l).view(seq_len, n_heads, head_size).narrow(0, 0, pos + 1)

xb = torch.einsum('hs,shm->hm', att, b_sv)
xb = xb.view(1, dim)
xb2 = torch.matmul(xb, weight_o.select(0, l).t())
x = x + xb2

xb = rmsnorm(x, rms_ffn_w.select(0, l))

data = torch.matmul(xb, weight_f1.select(0, l).t())
hb = torch.nn.functional.silu(data)

hb = hb * torch.matmul(xb, weight_f3.select(0, l).t())
xb = torch.matmul(hb, weight_f2.select(0, l).t())
x = x + xb

x = rmsnorm(x, rms_end_w)
logits = torch.matmul(x, weight_classify.t())
return logits

def sampling(logits):
index = 2 if random.random() < 0.25 else 1
return int(torch.topk(logits, k=index).indices.view(-1)[-1])

def decode(prev, next):
piece = vocab[next]
if prev == 1 and piece.startswith(' '):
piece = piece[1:]
return piece

if __name__ == '__main__':
with torch.no_grad():
prompt_tokens, pos = [1, 1724], 0
token = prompt_tokens[pos]

while pos < seq_len:
logits = forward(token, pos)
if pos < len(prompt_tokens) - 1:
next = prompt_tokens[pos + 1]
else:
next = sampling(logits)
if next == 1:
break
sys.stdout.write(decode(token, next))
sys.stdout.flush()
pos, token = pos + 1, next

print('\n')

0 comments on commit 10e2784

Please sign in to comment.