diff --git a/README.md b/README.md index 1ddca43e..a55b211c 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ AutoRT is a compiler solution that helps runtime users to invent, benchmark and | Platform | OS Requirement | Python Requirement | Download Link | | --- | --- | --- | --- | -| DirectX 12 | Windows >= 10 / Microsoft XBox | [Python3.12](https://www.python.org/ftp/python/3.12.0/python-3.12.0-amd64.exe) (Windows) | python3.12 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.3/autort_for_dxwin.py312 | +| DirectX 12 | Windows >= 10 / Microsoft XBox | [Python3.12](https://www.python.org/ftp/python/3.12.0/python-3.12.0-amd64.exe) (Windows) | python3.12 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.4/autort_for_dxwin.py312 | | Vulkan 1.3 | Ubuntu >= 18.04 (or images) | [Python3.12](https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb) (Linux) | python3.12 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.3/autort_for_vklinux.py312 | | CUDA >= 11 | Ubuntu >= 18.04 (or images) | Python 3.8/3.9/3.10/3.11/3.12 | python3 -m pip install -r https://github.com/microsoft/antares/releases/download/v0.9.3/autort_for_cuda_linux.py3x | | .. | .. | .. | .. (More coming soon) .. | diff --git a/docker/Dockerfile.c-mcpu_avx512 b/docker/Dockerfile.c-mcpu_avx512 index 83b67232..e3f8e670 100644 --- a/docker/Dockerfile.c-mcpu_avx512 +++ b/docker/Dockerfile.c-mcpu_avx512 @@ -1,25 +1,42 @@ -FROM intelaipg/intel-optimized-tensorflow:2.3.0-avx512-mkl +FROM ubuntu:18.04 ENV DEBIAN_FRONTEND noninteractive ENV PYTHONDONTWRITEBYTECODE 1 +ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH RUN env > /etc/environment +RUN rm -f /etc/apt/sources.list.d/* + +RUN /bin/echo "deb http://archive.ubuntu.com/ubuntu bionic main restricted universe multiverse" > /etc/apt/sources.list +RUN /bin/echo "deb http://archive.ubuntu.com/ubuntu bionic-updates main restricted universe multiverse" >> /etc/apt/sources.list + RUN apt-get update && apt install -y --no-install-recommends git ca-certificates \ - python3-pip python3-wheel python3-setuptools python3-dev python3-pytest \ vim-tiny less netcat-openbsd inetutils-ping curl patch iproute2 \ - g++ libpci3 libnuma-dev make file openssh-server kmod gdb libopenmpi-dev openmpi-bin psmisc \ - autoconf automake autotools-dev libtool \ - zlib1g-dev rename zip unzip librdmacm-dev gnupg \ - clang-10 \ - && apt-get clean && rm -rf /var/lib/apt/lists/* + make file openssh-server gdb psmisc zlib1g-dev rename zip unzip gnupg rsync p7zip-full clang-10 libomp-dev -RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny +RUN /bin/echo "deb http://ppa.launchpad.net/ubuntu-toolchain-r/test/ubuntu bionic main" >> /etc/apt/sources.list +RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 1E9377A2BA9EF27F +RUN apt-get update && apt install -y --no-install-recommends g++-9 xz-utils -RUN pip3 install --upgrade antares && mkdir -p /root/.local/antares && ln -s $(antares pwd)/../3rdparty /root/.local/antares/3rdparty +RUN ln -sf /usr/bin/g++-9 /usr/bin/g++ +RUN ln -sf /usr/bin/gcc-9 /usr/bin/gcc +RUN ln -sf /usr/bin/g++ /usr/bin/c++ +RUN ln -sf /usr/bin/gcc /usr/bin/cc +RUN ln -sf /usr/bin/gcc /usr/bin/x86_64-linux-gnu-gcc +RUN ln -sf /usr/bin/g++ /usr/bin/x86_64-linux-gnu-g++ + +RUN curl -LO https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb && dpkg -i python-3.12-linux-x86_64.deb && rm -f python-3.12-linux-x86_64.deb +RUN ln -sf /usr/local/bin/python3.12 /usr/local/bin/python3 +RUN ln -sf python3 /usr/local/bin/python +RUN ln -sf python /usr/local/bin/python.exe +RUN /bin/echo -e 'exec python3 -m pip "$@"' > /usr/local/bin/pip3 && chmod a+x /usr/local/bin/pip3 +RUN ln -sf pip3 /usr/local/bin/pip +# RUN curl -LO https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && rm -f get-pip.py + +RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny -RUN ln -s clang++-10 /usr/bin/clang++ || true -RUN python3 -m pip install mpi4py -RUN mv /usr/bin/mpiexec /usr/bin/mpiexec.real && \ - echo 'exec mpiexec.real --allow-run-as-root "$@"' > /usr/bin/mpiexec && \ - chmod a+x /usr/bin/mpiexec +RUN python3 -m pip install cython setuptools +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.3.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.2.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.18.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl diff --git a/docker/Dockerfile.c-sycl_intel b/docker/Dockerfile.c-sycl_intel index dc1a585e..9f439e8b 100644 --- a/docker/Dockerfile.c-sycl_intel +++ b/docker/Dockerfile.c-sycl_intel @@ -18,10 +18,18 @@ RUN apt-get update && apt install -y --no-install-recommends git ca-certificates RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny RUN cp -r /opt/intel/oneapi/compiler/latest/linux/include/sycl/CL /opt/intel/oneapi/compiler/latest/linux/include/ -RUN apt update && apt-get install -y python3-pip +RUN curl -LO https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb && dpkg -i python-3.12-linux-x86_64.deb && rm -f python-3.12-linux-x86_64.deb +RUN ln -sf /usr/local/bin/python3.12 /usr/local/bin/python3 +RUN ln -sf python3 /usr/local/bin/python +RUN ln -sf python /usr/local/bin/python.exe +RUN /bin/echo -e 'exec python3 -m pip "$@"' > /usr/local/bin/pip3 && chmod a+x /usr/local/bin/pip3 +RUN ln -sf pip3 /usr/local/bin/pip + RUN pip3 install --upgrade antares && mkdir -p /root/.local/antares && mv $(antares pwd)/../3rdparty /root/.local/antares/3rdparty && pip3 uninstall antares -y && echo 'exec /antares/main.py "$@"' > /usr/local/bin/antares && chmod a+x /usr/local/bin/antares -RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.2.0.dev20231118%2Bcpu-cp38-cp38-linux_x86_64.whl -RUN python3 -m pip install cython -RUN ln -s python3 /usr/bin/python -RUN ln -s python /usr/bin/python.exe +RUN python3 -m pip install cython setuptools +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.3.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.2.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.18.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl + +RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg diff --git a/docker/Dockerfile.c-vulkan b/docker/Dockerfile.c-vulkan index 50ad3849..d4d14360 100644 --- a/docker/Dockerfile.c-vulkan +++ b/docker/Dockerfile.c-vulkan @@ -27,7 +27,7 @@ RUN ln -sf /usr/bin/gcc /usr/bin/cc RUN ln -sf /usr/bin/gcc /usr/bin/x86_64-linux-gnu-gcc RUN ln -sf /usr/bin/g++ /usr/bin/x86_64-linux-gnu-g++ -RUN curl -L https://github.com/ghostplant/collections/releases/download/utilities/python3.12-linux-x86_64.tar.xz | xz -d | tar xvf - -C /usr/local >/dev/null +RUN curl -LO https://github.com/ghostplant/collections/releases/download/utilities/python-3.12-linux-x86_64.deb && dpkg -i python-3.12-linux-x86_64.deb && rm -f python-3.12-linux-x86_64.deb RUN ln -sf /usr/local/bin/python3.12 /usr/local/bin/python3 RUN ln -sf python3 /usr/local/bin/python RUN ln -sf python /usr/local/bin/python.exe @@ -38,6 +38,8 @@ RUN ln -sf pip3 /usr/local/bin/pip RUN /bin/echo -e "set backspace=indent,eol,start\nset nocompatible\nset ts=4" > /etc/vim/vimrc.tiny RUN python3 -m pip install cython setuptools -RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.2.0.dev20231118%2Bcpu-cp312-cp312-linux_x86_64.whl +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.3.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.2.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl +RUN python3 -m pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.18.0.dev20231220%2Bcpu-cp312-cp312-linux_x86_64.whl -RUN apt-get install -y --no-install-recommends mesa-vulkan-drivers +RUN apt-get install -y --no-install-recommends mesa-vulkan-drivers ffmpeg diff --git a/samples/03_llama_tiny.py b/samples/03_llama_tiny.py new file mode 100644 index 00000000..9ec53605 --- /dev/null +++ b/samples/03_llama_tiny.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import os, sys, math, random +import autort + + +def download_pt(data_path, url): + if not os.path.exists(data_path): + print(f'Downloading dataset to {data_path} ..') + import urllib.request, zipfile, io + with urllib.request.urlopen(url) as fp: + r = fp.read() + with open(data_path, 'wb') as fp: + fp.write(r) + return torch.load(data_path) + +pt = download_pt('llama_story_110m.pt', 'https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt?download=true') +vocab = download_pt('vocab_32K.pt', 'https://huggingface.co/datasets/ghostplant/data-collections/resolve/main/vocab_32K.pt?download=true') + +device = autort.device() + +args, param = pt['model_args'], pt['model'] +n_heads, seq_len = args['n_heads'], args['max_seq_len'] +head_size = args['dim'] // n_heads +token_embedding_table = param['tok_embeddings.weight'] + +data_type = token_embedding_table.dtype + +rms_att_w, rms_ffn_w = [], [] +weight_q, weight_k, weight_v, weight_o, weight_f1, weight_f2, weight_f3 = [], [], [], [], [], [], [] +for i in range(1024): + try: + rms_att_w += [param[f'layers.{i}.attention_norm.weight'].unsqueeze(0)] + rms_ffn_w += [param[f'layers.{i}.ffn_norm.weight'].unsqueeze(0)] + weight_q += [param[f'layers.{i}.attention.wq.weight'].unsqueeze(0)] + weight_k += [param[f'layers.{i}.attention.wk.weight'].unsqueeze(0)] + weight_v += [param[f'layers.{i}.attention.wv.weight'].unsqueeze(0)] + weight_o += [param[f'layers.{i}.attention.wo.weight'].unsqueeze(0)] + weight_f1 += [param[f'layers.{i}.feed_forward.w1.weight'].unsqueeze(0)] + weight_f2 += [param[f'layers.{i}.feed_forward.w2.weight'].unsqueeze(0)] + weight_f3 += [param[f'layers.{i}.feed_forward.w3.weight'].unsqueeze(0)] + except KeyError: + break + +rms_att_w = torch.cat(rms_att_w, dim=0).to(data_type).to(device) +rms_ffn_w = torch.cat(rms_ffn_w, dim=0).to(data_type).to(device) +rms_end_w = param['norm.weight'].to(data_type).to(device) +weight_classify = param['output.weight'].to(data_type).to(device) +weight_q = torch.cat(weight_q, dim=0).to(data_type).to(device) +weight_k = torch.cat(weight_k, dim=0).to(data_type).to(device) +weight_v = torch.cat(weight_v, dim=0).to(data_type).to(device) +weight_o = torch.cat(weight_o, dim=0).to(data_type).to(device) +weight_f1 = torch.cat(weight_f1, dim=0).to(data_type).to(device) +weight_f2 = torch.cat(weight_f2, dim=0).to(data_type).to(device) +weight_f3 = torch.cat(weight_f3, dim=0).to(data_type).to(device) +token_embedding_table = token_embedding_table.view([token_embedding_table.size(0), n_heads, head_size]).to(data_type).to(device) + +n_layers = weight_q.size(0) +vocab_size, n_heads, head_size, = token_embedding_table.size(0), token_embedding_table.size(1), token_embedding_table.size(2) +n_layers, hidden, = rms_att_w.size(0), weight_f1.size(1) +kv_heads, dim = n_heads, n_heads * head_size + +assert n_heads // kv_heads == 1 and head_size % 2 == 0 + +key_cache = torch.zeros([n_layers, seq_len, dim], dtype=data_type, device=weight_o.device) +val_cache = torch.zeros_like(key_cache) + +ceof = 1 / torch.pow(1e4, torch.arange(0, dim, 2, dtype=torch.int64) % head_size / head_size).view(1, -1).to(data_type).to(weight_o.device) +att_f = torch.tensor([1 / math.sqrt(head_size)], dtype=data_type, device=weight_o.device) + +def rmsnorm(x, weight): + x = x.float() + vsum = (x * x).sum() + return autort.ops.rmsnorm_f32(x.view(-1), vsum, weight, extra=[1.0 / int(x.numel())]) + +def rotate(data, ceof, pos, out): + autort.ops.rotate_f32(ceof, data.view(-1), out.view(-1), extra=[pos,]) + return out + +def forward(token, pos): + x = token_embedding_table.select(0, token).view(1, dim) + + for l in range(n_layers): + xb = rmsnorm(x, rms_att_w.select(0, l)) + + sq = torch.matmul(xb, weight_q.select(0, l).t()) + + sk = torch.matmul(xb, weight_k.select(0, l).t()) + + sv = val_cache.select(0, l).narrow(0, pos, 1) + torch.matmul(xb, weight_v.select(0, l).t(), out=sv) + + sq_out = torch.empty_like(sq) + sk_out = key_cache.select(0, l).narrow(0, pos, 1) + rotate(sq, ceof, pos, out=sq_out) + rotate(sk, ceof, pos, out=sk_out) + sq, sk = sq_out, sk_out + + b_sq = sq.view(n_heads, head_size) + b_sk = key_cache.select(0, l).view(seq_len, n_heads, head_size).narrow(0, 0, pos + 1) + + att = torch.einsum('hm,shm->hs', b_sq, b_sk) * att_f + + att = torch.nn.functional.softmax(att, dim=-1) + b_sv = val_cache.select(0, l).view(seq_len, n_heads, head_size).narrow(0, 0, pos + 1) + + xb = torch.einsum('hs,shm->hm', att, b_sv) + xb = xb.view(1, dim) + xb2 = torch.matmul(xb, weight_o.select(0, l).t()) + x = x + xb2 + + xb = rmsnorm(x, rms_ffn_w.select(0, l)) + + data = torch.matmul(xb, weight_f1.select(0, l).t()) + hb = torch.nn.functional.silu(data) + + hb = hb * torch.matmul(xb, weight_f3.select(0, l).t()) + xb = torch.matmul(hb, weight_f2.select(0, l).t()) + x = x + xb + + x = rmsnorm(x, rms_end_w) + logits = torch.matmul(x, weight_classify.t()) + return logits + +def sampling(logits): + index = 2 if random.random() < 0.25 else 1 + return int(torch.topk(logits, k=index).indices.view(-1)[-1]) + +def decode(prev, next): + piece = vocab[next] + if prev == 1 and piece.startswith(' '): + piece = piece[1:] + return piece + +if __name__ == '__main__': + with torch.no_grad(): + prompt_tokens, pos = [1, 1724], 0 + token = prompt_tokens[pos] + + while pos < seq_len: + logits = forward(token, pos) + if pos < len(prompt_tokens) - 1: + next = prompt_tokens[pos + 1] + else: + next = sampling(logits) + if next == 1: + break + sys.stdout.write(decode(token, next)) + sys.stdout.flush() + pos, token = pos + 1, next + + print('\n')