From 4b8a0ca1acf84d5991726d559a4e905952d674aa Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Thu, 25 Jul 2024 13:53:52 +0200 Subject: [PATCH] feat: support llama 3.1. (#106) --- .gitignore | 1 + converter/convert-hf.py | 16 ++++++++++ converter/writer.py | 7 ++++- launch.py | 7 ++++- src/app.cpp | 62 +++++++++++++++++++++----------------- src/app.hpp | 3 +- src/apps/dllama/dllama.cpp | 7 +++-- src/commands.cpp | 51 +++++++++++++++++++++++++++++++ src/commands.hpp | 15 ++++++++- src/grok1-tasks-test.cpp | 6 +++- src/llama2-tasks-test.cpp | 6 +++- src/transformer.cpp | 61 ++++++++++++++++++++++++++++--------- src/transformer.hpp | 33 +++++++++++++++++--- src/utils.cpp | 26 ++++++++++++++++ src/utils.hpp | 3 ++ 15 files changed, 249 insertions(+), 55 deletions(-) diff --git a/.gitignore b/.gitignore index da9e5a2..a253f44 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.0 *.dSYM *.data +*.temp __pycache__ *-test diff --git a/converter/convert-hf.py b/converter/convert-hf.py index bbbda77..f637ddb 100644 --- a/converter/convert-hf.py +++ b/converter/convert-hf.py @@ -143,6 +143,14 @@ def parseHiddenAct(act: str): raise Exception(f'Unsupported hidden act: {act}') return hiddenAct +def parseRopeType(rt: str): + ropeType = { + 'llama3': 2, # LLAMA3_1 + }.get(rt) + if (ropeType is None): + raise Exception(f'Unsupported rope type: {ropeType}') + return ropeType + def loadConfig(folderPath: str, weightsFloatType: int): allFiles = os.listdir(folderPath) allFiles.sort() @@ -178,6 +186,14 @@ def loadConfig(folderPath: str, weightsFloatType: int): ropeTheta = config.get('rope_theta') if (ropeTheta is not None): result['rope_theta'] = int(ropeTheta) + + ropeScaling = config.get('rope_scaling') + if (ropeScaling is not None): + result['rope_scaling_factor'] = int(ropeScaling['factor']) + result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor']) + result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor']) + result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings']) + result['rope_type'] = parseRopeType(ropeScaling['rope_type']) return result def printUsage(): diff --git a/converter/writer.py b/converter/writer.py index 56e6dd4..4bd319c 100644 --- a/converter/writer.py +++ b/converter/writer.py @@ -125,7 +125,12 @@ def writeHeader(file, params): 'max_seq_len': 10, 'hidden_act': 11, 'rope_theta': 12, - 'weights_float_type': 13 + 'weights_float_type': 13, + 'rope_scaling_factor': 14, + 'rope_scaling_low_freq_factor': 15, + 'rope_scaling_high_freq_factory': 16, + 'rope_scaling_orig_max_seq_len': 17, + 'rope_type': 18, } header = struct.pack('i', 0xA00ABCD) diff --git a/launch.py b/launch.py index bfd3010..2f93e59 100644 --- a/launch.py +++ b/launch.py @@ -18,7 +18,12 @@ 'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_lama3_instruct_q40.m?download=true', 'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3.t?download=true', 'q40', 'q80', 'chat' - ] + ], + 'llama3_1_8b_instruct_q40': [ + 'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.1_instruct_q40.m?download=true', + 'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true', + 'q40', 'q80', 'chat' + ], } def downloadFile(url: str, path: str): diff --git a/src/app.cpp b/src/app.cpp index 018ca2b..f4a3b4b 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -41,6 +41,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { args.steps = 0; args.seed = (unsigned long long)time(NULL); args.chatTemplateType = TEMPLATE_UNKNOWN; + args.useDiscForKvCache = false; int i = 1; if (hasMode && argc > 1) { @@ -48,17 +49,19 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { i++; } for (; i + 1 < argc; i += 2) { - if (strcmp(argv[i], "--model") == 0) { - args.modelPath = argv[i + 1]; - } else if (strcmp(argv[i], "--tokenizer") == 0) { - args.tokenizerPath = argv[i + 1]; - } else if (strcmp(argv[i], "--prompt") == 0) { - args.prompt = argv[i + 1]; - } else if (strcmp(argv[i], "--weights-float-type") == 0) { - args.weightsFloatType = parseFloatType(argv[i + 1]); - } else if (strcmp(argv[i], "--buffer-float-type") == 0) { - args.bufferFloatType = parseFloatType(argv[i + 1]); - } else if (strcmp(argv[i], "--workers") == 0) { + char* name = argv[i]; + char* value = argv[i + 1]; + if (strcmp(name, "--model") == 0) { + args.modelPath = value; + } else if (strcmp(name, "--tokenizer") == 0) { + args.tokenizerPath = value; + } else if (strcmp(name, "--prompt") == 0) { + args.prompt = value; + } else if (strcmp(name, "--weights-float-type") == 0) { + args.weightsFloatType = parseFloatType(value); + } else if (strcmp(name, "--buffer-float-type") == 0) { + args.bufferFloatType = parseFloatType(value); + } else if (strcmp(name, "--workers") == 0) { int j = i + 1; for (; j < argc && argv[j][0] != '-'; j++); int count = j - i - 1; @@ -82,22 +85,24 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { } i += count - 1; - } else if (strcmp(argv[i], "--port") == 0) { - args.port = atoi(argv[i + 1]); - } else if (strcmp(argv[i], "--nthreads") == 0) { - args.nThreads = atoi(argv[i + 1]); - } else if (strcmp(argv[i], "--steps") == 0) { - args.steps = atoi(argv[i + 1]); - } else if (strcmp(argv[i], "--temperature") == 0) { - args.temperature = atof(argv[i + 1]); - } else if (strcmp(argv[i], "--topp") == 0) { - args.topp = atof(argv[i + 1]); - } else if (strcmp(argv[i], "--seed") == 0) { - args.seed = atoll(argv[i + 1]); - } else if (strcmp(argv[i], "--chat-template") == 0) { - args.chatTemplateType = parseChatTemplateType(argv[i + 1]); + } else if (strcmp(name, "--port") == 0) { + args.port = atoi(value); + } else if (strcmp(name, "--nthreads") == 0) { + args.nThreads = atoi(value); + } else if (strcmp(name, "--steps") == 0) { + args.steps = atoi(value); + } else if (strcmp(name, "--temperature") == 0) { + args.temperature = atof(value); + } else if (strcmp(name, "--topp") == 0) { + args.topp = atof(value); + } else if (strcmp(name, "--seed") == 0) { + args.seed = atoll(value); + } else if (strcmp(name, "--chat-template") == 0) { + args.chatTemplateType = parseChatTemplateType(value); + } else if (strcmp(name, "--kv-cache-storage") == 0) { + args.useDiscForKvCache = strcmp(value, "disc") == 0; } else { - printf("Unknown option %s\n", argv[i]); + printf("Unknown option %s\n", name); exit(EXIT_FAILURE); } } @@ -131,7 +136,10 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s args->steps = spec.seqLen; } - Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool); + TransformerConfig config; + config.useDiscForKvCache = args->useDiscForKvCache; + + Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, &config, socketPool); socketPool->setTurbo(true); Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool); diff --git a/src/app.hpp b/src/app.hpp index d53e9e7..e7f239c 100644 --- a/src/app.hpp +++ b/src/app.hpp @@ -16,7 +16,8 @@ class AppArgs { public: char* mode; - int nThreads; + int nThreads; + bool useDiscForKvCache; // inference char* modelPath; diff --git a/src/apps/dllama/dllama.cpp b/src/apps/dllama/dllama.cpp index f28c123..7d83823 100644 --- a/src/apps/dllama/dllama.cpp +++ b/src/apps/dllama/dllama.cpp @@ -160,7 +160,7 @@ class Chat { int nInputTokens; tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false); - pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, pos + nInputTokens - 1); + pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, (int)pos + nInputTokens - 1); for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) { inference->infer(inputTokens[i], pos); token = inputTokens[i + 1]; @@ -207,10 +207,13 @@ void worker(AppArgs* args) { throw std::runtime_error("Invalid port number"); } + TransformerConfig config; + config.useDiscForKvCache = args->useDiscForKvCache; + SocketServer server(args->port); Socket socket = server.accept(); TransformerSpec spec; - Transformer transformer = Transformer::loadSlice(&spec, &socket); + Transformer transformer = Transformer::loadSlice(&spec, &config, &socket); TransformerArch arch = TransformerArchFactory::create(&spec); Worker worker = Worker(&arch, args->nThreads, &transformer, &socket); diff --git a/src/commands.cpp b/src/commands.cpp index c91aec1..26cb1c6 100644 --- a/src/commands.cpp +++ b/src/commands.cpp @@ -1,5 +1,8 @@ #include #include +#ifdef _WIN32 + #define _USE_MATH_DEFINES +#endif #include #include "utils.hpp" #include "funcs.hpp" @@ -167,6 +170,54 @@ void LlamaRopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nT } } +Llama3_1RopeCommand::Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen) { + this->slice = slice; + this->ropeScalingFactor = ropeScalingFactor; + this->ropeScalingLowFreqFactor = ropeScalingLowFreqFactor; + this->ropeScalingHighFreqFactory = ropeScalingHighFreqFactory; + this->ropeScalingOrigMaxSeqLen = ropeScalingOrigMaxSeqLen; + printf("🕒 ropeScalingFactor: %f\n", ropeScalingFactor); + printf("🕒 ropeScalingLowFreqFactor: %f\n", ropeScalingLowFreqFactor); + printf("🕒 ropeScalingHighFreqFactory: %f\n", ropeScalingHighFreqFactory); + printf("🕒 ropeScalingOrigMaxSeqLen: %d\n", ropeScalingOrigMaxSeqLen); +} + +float Llama3_1RopeCommand::scale(float freq) { + float waveLen = 2.0f * M_PI * freq; + float lowFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingLowFreqFactor; + float highFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingHighFreqFactory; + if (waveLen < highFreqWavelen) { + return freq; + } else if (waveLen > lowFreqWavelen) { + return freq / ropeScalingFactor; + } else { + float smooth = (ropeScalingOrigMaxSeqLen / waveLen - ropeScalingLowFreqFactor) / (ropeScalingHighFreqFactory - ropeScalingLowFreqFactor); + return (1 - smooth) * freq / ropeScalingFactor + smooth * freq; + } +} + +void Llama3_1RopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) { + const unsigned int dim0Half = (isQ ? slice->qDim0 : slice->kvDim0) / 2; + const unsigned int shift = isQ ? slice->qShift : 0; + SPLIT_RANGE_TO_THREADS(s, e, 0, dim0Half, nThreads, threadIndex); + const unsigned int iStart = s * 2; + const unsigned int iEnd = e * 2; + + for (unsigned int i = iStart; i < iEnd; i += 2) { + const unsigned int headDim = i % slice->headSize; + const float freq = 1.0f / powf(slice->ropeTheta, headDim / (float)slice->headSize); + const float val = pos * freq; + const float fcr = cosf(val); + const float fci = sinf(val); + + float v0 = qOrK[i]; + float v1 = qOrK[i + 1]; + + qOrK[i] = scale(v0 * fcr - v1 * fci); + qOrK[i + 1] = scale(v0 * fci + v1 * fcr); + } +} + FalconRopeCommand::FalconRopeCommand(RopeSlice *slice) { this->slice = slice; } diff --git a/src/commands.hpp b/src/commands.hpp index c105d4e..7df1292 100644 --- a/src/commands.hpp +++ b/src/commands.hpp @@ -9,7 +9,7 @@ // *Slice - calculates sizes, offsets, slice sizes etc. It is not responsible for memory allocation. It may help in the loading of data. // *Command - allocates memory for weights, performs calculations. -typedef unsigned short pos_t; +typedef unsigned int pos_t; typedef uint8_t slice_index_t; class MatmulSlice { @@ -106,6 +106,19 @@ class LlamaRopeCommand : public RopeCommand { void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex); }; +class Llama3_1RopeCommand : public RopeCommand { +private: + RopeSlice* slice; + float ropeScalingFactor; + float ropeScalingLowFreqFactor; + float ropeScalingHighFreqFactory; + int ropeScalingOrigMaxSeqLen; +public: + Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen); + void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex); + float scale(float freq); +}; + class FalconRopeCommand : public RopeCommand { private: RopeSlice* slice; diff --git a/src/grok1-tasks-test.cpp b/src/grok1-tasks-test.cpp index db0fd8c..e93b626 100644 --- a/src/grok1-tasks-test.cpp +++ b/src/grok1-tasks-test.cpp @@ -30,6 +30,7 @@ int main() { TransformerSpec spec; spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int); spec.archType = GROK1; + spec.ropeType = ROPE_FALCON; spec.dim = 6144; spec.nLayers = 1; spec.nHeads = 48; @@ -47,6 +48,9 @@ int main() { spec.hiddenAct = GELU; spec.ropeTheta = 10000.0f; + TransformerConfig config; + config.useDiscForKvCache = false; + size_t beforeBlockBytes = spec.dim * spec.vocabSize * sizeof(float); size_t blockBytes = 956596224; size_t afterBlockBytes = (spec.dim + spec.dim * spec.vocabSize) * sizeof(float); @@ -60,7 +64,7 @@ int main() { for (int f = 0; f < nFloats; f++) block[f] = randomF32(&state) / 100.0; SocketPool socketPool(0, NULL); - Transformer transformer = Transformer::loadRoot(weights, &spec, &socketPool); + Transformer transformer = Transformer::loadRoot(weights, &spec, &config, &socketPool); transformer.pos = 0; float* x = transformer.x; diff --git a/src/llama2-tasks-test.cpp b/src/llama2-tasks-test.cpp index 6775900..ae7851c 100644 --- a/src/llama2-tasks-test.cpp +++ b/src/llama2-tasks-test.cpp @@ -528,6 +528,7 @@ int main() { TransformerSpec spec; spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int); spec.archType = LLAMA; + spec.ropeType = ROPE_LLAMA; spec.dim = 4096; spec.nLayers = 1; spec.headSize = 128; @@ -545,6 +546,9 @@ int main() { spec.hiddenAct = SILU; spec.ropeTheta = 10000.0f; + TransformerConfig config; + config.useDiscForKvCache = false; + size_t beforeBlockBytes = /* embedding */ 524288000; size_t blockBytes = 809533440; size_t afterBlockBytes = /* norm */ 16384 + /* embedding */ 524288000; @@ -562,7 +566,7 @@ int main() { for (int i = 0; i < mm; i++) mmData[i] = randomF32(&state) / 120.0; SocketPool socketPool(0, NULL); - Transformer transformer = Transformer::loadRoot((char*)data, &spec, &socketPool); + Transformer transformer = Transformer::loadRoot((char*)data, &spec, &config, &socketPool); transformer.pos = 0; float* x = transformer.x; diff --git a/src/transformer.cpp b/src/transformer.cpp index 9c85d44..92040a4 100644 --- a/src/transformer.cpp +++ b/src/transformer.cpp @@ -13,6 +13,7 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i TransformerSpec spec; memset(&spec, 0, sizeof(TransformerSpec)); spec.hiddenAct = SILU; + spec.ropeType = ROPE_UNKNOWN; spec.ropeTheta = 10000.0f; FILE* fd = fopen(path, "rb"); @@ -68,6 +69,11 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i else if (key == HIDDEN_ACT) spec.hiddenAct = (TransformerHiddenAct)value; else if (key == ROPE_THETA) spec.ropeTheta = (float)value; else if (key == WEIGHTS_FLOAT_TYPE) weightsFloatType = (FloatType)value; + else if (key == ROPE_SCALING_FACTOR) spec.ropeScalingFactor = (float)value; + else if (key == ROPE_SCALING_LOW_FREQ_FACTOR) spec.ropeScalingLowFreqFactor = (float)value; + else if (key == ROPE_SCALING_HIGH_FREQ_FACTORY) spec.ropeScalingHighFreqFactory = (float)value; + else if (key == ROPE_SCALING_ORIG_MAX_SEQ_LEN) spec.ropeScalingOrigMaxSeqLen = value; + else if (key == ROPE_TYPE) spec.ropeType = (TransformerRopeType)value; else { throw std::runtime_error("Unsupported header key"); } @@ -79,6 +85,16 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i if (weightsFloatType == FUNK) throw std::runtime_error("Not specified weights float type"); + if (spec.ropeType == ROPE_UNKNOWN) { + if (spec.archType == LLAMA) { + spec.ropeType = ROPE_LLAMA; + } else if (spec.archType == GROK1 || spec.archType == MIXTRAL) { + spec.ropeType = ROPE_FALCON; + } else { + throw std::runtime_error("Cannot resolve rope type from architecture"); + } + } + spec.headSize = spec.dim / spec.nHeads; spec.kvDim = (spec.dim * spec.nKvHeads) / spec.nHeads; spec.weightsFloatType = weightsFloatType; @@ -199,14 +215,14 @@ size_t TransformerBuffer::getSlicedBytes(uint8_t bufferIndex) { return bufferBytes[bufferIndex] / nSlices; } -Transformer::Transformer(TransformerSpec* spec, slice_index_t sliceIndex) { +Transformer::Transformer(TransformerSpec* spec, TransformerConfig* config, slice_index_t sliceIndex) { this->spec = spec; this->sliceIndex = sliceIndex; buffer = new TransformerBuffer(spec); blocks = new TransformerBlock*[spec->nLayers]; for (int i = 0; i < spec->nLayers; i++) { - blocks[i] = new TransformerBlock(spec, sliceIndex); + blocks[i] = new TransformerBlock(spec, config, sliceIndex); } if (IS_ROOT_SLICE(sliceIndex)) { @@ -223,10 +239,14 @@ Transformer::Transformer(TransformerSpec* spec, slice_index_t sliceIndex) { } ropeSlice = new RopeSlice(spec->dim, spec->kvDim, spec->nKvHeads, spec->nSlices, spec->seqLen, spec->headSize, spec->ropeTheta, sliceIndex); - if (spec->archType == GROK1 || spec->archType == MIXTRAL) { + if (spec->ropeType == ROPE_FALCON) { rope = new FalconRopeCommand(ropeSlice); - } else { + } else if (spec->ropeType == ROPE_LLAMA) { rope = new LlamaRopeCommand(ropeSlice); + } else if (spec->ropeType == ROPE_LLAMA3_1) { + rope = new Llama3_1RopeCommand(ropeSlice, spec->ropeScalingFactor, spec->ropeScalingLowFreqFactor, spec->ropeScalingHighFreqFactory, spec->ropeScalingOrigMaxSeqLen); + } else { + throw std::runtime_error("Unsupported rope type"); } TransformerBlock* b = blocks[0]; @@ -257,9 +277,10 @@ Transformer::~Transformer() { delete rope; } -TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceIndex) { +TransformerBlock::TransformerBlock(TransformerSpec* spec, TransformerConfig* config, slice_index_t sliceIndex) { this->sliceIndex = sliceIndex; this->spec = spec; + this->config = config; if (IS_ROOT_SLICE(sliceIndex)) { rmsAttBytes = spec->dim * sizeof(float); @@ -276,8 +297,13 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceInd } kvCacheSlice = new KvCacheSlice(spec->kvDim, spec->seqLen, spec->nSlices); - keyCache = (float*)newBuffer(kvCacheSlice->keyCacheSize); - valueCache = (float*)newBuffer(kvCacheSlice->valueCacheSize); + if (config->useDiscForKvCache) { + keyCache = (float*)newMmapFileBuffer(sliceIndex, kvCacheSlice->keyCacheSize); + valueCache = (float*)newMmapFileBuffer(sliceIndex, kvCacheSlice->valueCacheSize); + } else { + keyCache = (float*)newBuffer(kvCacheSlice->keyCacheSize); + valueCache = (float*)newBuffer(kvCacheSlice->valueCacheSize); + } multiHeadAttSlice = new MultiHeadAttSlice(spec->nHeads, spec->seqLen, spec->nSlices, sliceIndex); att = (float*)newBuffer(multiHeadAttSlice->attSize); @@ -337,8 +363,13 @@ TransformerBlock::~TransformerBlock() { } delete kvCacheSlice; - freeBuffer(keyCache); - freeBuffer(valueCache); + if (config->useDiscForKvCache) { + freeMmapFileBuffer(keyCache); + freeMmapFileBuffer(valueCache); + } else { + freeBuffer(keyCache); + freeBuffer(valueCache); + } delete multiHeadAttSlice; freeBuffer(att); @@ -411,23 +442,23 @@ static size_t readSlicedMatmulWeights(MatmulSlice* slice, char* weights0, Socket return slice->sliceBytes; } -Transformer Transformer::loadRootFromFile(const char* path, TransformerSpec* spec, SocketPool* socketPool) { +Transformer Transformer::loadRootFromFile(const char* path, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool) { MmapFile file; openMmapFile(&file, path, spec->fileSize); char* weights = ((char*)file.data) + spec->headerSize; - Transformer transformer = Transformer::loadRoot((char*)weights, spec, socketPool); + Transformer transformer = Transformer::loadRoot((char*)weights, spec, config, socketPool); closeMmapFile(&file); return transformer; } -Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, SocketPool* socketPool) { +Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool) { assert(socketPool->nSockets == spec->nSlices - 1); const slice_index_t sliceIndex = 0; // Root slice - Transformer transformer(spec, sliceIndex); + Transformer transformer(spec, config, sliceIndex); if (spec->nSlices > 1) { for (slice_index_t sliceIndex = 1; sliceIndex < spec->nSlices; sliceIndex++) { @@ -484,7 +515,7 @@ Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, SocketPool* return transformer; } -Transformer Transformer::loadSlice(TransformerSpec* spec, Socket* socket) { +Transformer Transformer::loadSlice(TransformerSpec* spec, TransformerConfig* config, Socket* socket) { slice_index_t sliceIndex; socket->read((char*)&sliceIndex, sizeof(uint8_t)); socket->read((char*)spec, sizeof(TransformerSpec)); @@ -493,7 +524,7 @@ Transformer Transformer::loadSlice(TransformerSpec* spec, Socket* socket) { printf("💡 nSlices: %d\n", spec->nSlices); assert(sliceIndex >= 1); - Transformer transformer(spec, sliceIndex); + Transformer transformer(spec, config, sliceIndex); size_t bufferSize = 0; // TODO: this is ugly diff --git a/src/transformer.hpp b/src/transformer.hpp index cb7bfa2..d7b296e 100644 --- a/src/transformer.hpp +++ b/src/transformer.hpp @@ -22,6 +22,11 @@ enum TransformerHeaderKey { HIDDEN_ACT = 11, ROPE_THETA = 12, WEIGHTS_FLOAT_TYPE = 13, + ROPE_SCALING_FACTOR = 14, + ROPE_SCALING_LOW_FREQ_FACTOR = 15, + ROPE_SCALING_HIGH_FREQ_FACTORY = 16, + ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17, + ROPE_TYPE = 18, }; struct TransformerFileOldHeader { @@ -47,6 +52,13 @@ enum TransformerHiddenAct { SILU = 1, }; +enum TransformerRopeType { + ROPE_UNKNOWN = -1, + ROPE_LLAMA = 0, + ROPE_FALCON = 1, + ROPE_LLAMA3_1 = 2, +}; + struct TransformerSpec { size_t headerSize; size_t fileSize; @@ -65,16 +77,26 @@ struct TransformerSpec { int kvDim; int vocabSize; float ropeTheta; + TransformerRopeType ropeType; + float ropeScalingFactor; + float ropeScalingLowFreqFactor; + float ropeScalingHighFreqFactory; + int ropeScalingOrigMaxSeqLen; FloatType weightsFloatType; FloatType bufferFloatType; uint8_t nSlices; }; +struct TransformerConfig { + bool useDiscForKvCache; +}; + class TransformerBlock { public: slice_index_t sliceIndex; TransformerSpec *spec; + TransformerConfig* config; size_t rmsAttBytes; float* rmsAtt; @@ -120,7 +142,7 @@ class TransformerBlock { float* att; float* qo0; - TransformerBlock(TransformerSpec* spec, slice_index_t sliceIndex); + TransformerBlock(TransformerSpec* spec, TransformerConfig* config, slice_index_t sliceIndex); ~TransformerBlock(); }; @@ -155,6 +177,7 @@ class TransformerBuffer { class Transformer { public: TransformerSpec* spec; + TransformerConfig* config; TransformerBlock** blocks; TransformerBuffer* buffer; slice_index_t sliceIndex; @@ -175,12 +198,12 @@ class Transformer { ~Transformer(); static TransformerSpec loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType); - static Transformer loadRootFromFile(const char* path, TransformerSpec* spec, SocketPool* socketPool); - static Transformer loadRoot(char* data, TransformerSpec* spec, SocketPool* socketPool); - static Transformer loadSlice(TransformerSpec* spec, Socket* socket); + static Transformer loadRootFromFile(const char* path, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool); + static Transformer loadRoot(char* data, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool); + static Transformer loadSlice(TransformerSpec* spec, TransformerConfig* config, Socket* socket); private: - Transformer(TransformerSpec* spec, slice_index_t sliceIndex); + Transformer(TransformerSpec* spec, TransformerConfig* config, slice_index_t sliceIndex); }; #endif diff --git a/src/utils.cpp b/src/utils.cpp index fb42519..1365795 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include "utils.hpp" @@ -44,6 +45,31 @@ void freeBuffer(void* buffer) { #endif } +unsigned int lastMmapFileBufferIndex = 0; + +void* newMmapFileBuffer(unsigned int appInstanceId, size_t size) { +#ifdef _WIN32 + throw new std::runtime_error("Mmap file buffer is not supported on Windows yet"); +#else + char path[256]; + snprintf(path, 256, "mmap-buffer-%d-%d.temp", appInstanceId, lastMmapFileBufferIndex++); + int fd = open(path, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); + if (fd == -1) + throw new std::runtime_error("Cannot create mmap buffer file"); + if (ftruncate(fd, size) == -1) + throw new std::runtime_error("Cannot truncate mmap buffer file. Not enough disk space?"); + void *addr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) + throw new std::runtime_error("Cannot mmap buffer file"); + close(fd); + return addr; +#endif +} + +void freeMmapFileBuffer(void* addr) { + // TODO +} + unsigned long timeMs() { struct timeval te; gettimeofday(&te, NULL); diff --git a/src/utils.hpp b/src/utils.hpp index 4323fdc..1c61994 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -21,6 +21,9 @@ void* newBuffer(size_t size); void freeBuffer(void* buffer); +void* newMmapFileBuffer(unsigned int appInstanceId, size_t size); +void freeMmapFileBuffer(void* addr); + unsigned long timeMs(); unsigned int randomU32(unsigned long long *state); float randomF32(unsigned long long *state);