Skip to content

Commit

Permalink
feat: support llama 3.1. (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Jul 25, 2024
1 parent 8c57298 commit 4b8a0ca
Show file tree
Hide file tree
Showing 15 changed files with 249 additions and 55 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*.0
*.dSYM
*.data
*.temp
__pycache__

*-test
Expand Down
16 changes: 16 additions & 0 deletions converter/convert-hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def parseHiddenAct(act: str):
raise Exception(f'Unsupported hidden act: {act}')
return hiddenAct

def parseRopeType(rt: str):
ropeType = {
'llama3': 2, # LLAMA3_1
}.get(rt)
if (ropeType is None):
raise Exception(f'Unsupported rope type: {ropeType}')
return ropeType

def loadConfig(folderPath: str, weightsFloatType: int):
allFiles = os.listdir(folderPath)
allFiles.sort()
Expand Down Expand Up @@ -178,6 +186,14 @@ def loadConfig(folderPath: str, weightsFloatType: int):
ropeTheta = config.get('rope_theta')
if (ropeTheta is not None):
result['rope_theta'] = int(ropeTheta)

ropeScaling = config.get('rope_scaling')
if (ropeScaling is not None):
result['rope_scaling_factor'] = int(ropeScaling['factor'])
result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor'])
result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor'])
result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings'])
result['rope_type'] = parseRopeType(ropeScaling['rope_type'])
return result

def printUsage():
Expand Down
7 changes: 6 additions & 1 deletion converter/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,12 @@ def writeHeader(file, params):
'max_seq_len': 10,
'hidden_act': 11,
'rope_theta': 12,
'weights_float_type': 13
'weights_float_type': 13,
'rope_scaling_factor': 14,
'rope_scaling_low_freq_factor': 15,
'rope_scaling_high_freq_factory': 16,
'rope_scaling_orig_max_seq_len': 17,
'rope_type': 18,
}
header = struct.pack('i', 0xA00ABCD)

Expand Down
7 changes: 6 additions & 1 deletion launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_lama3_instruct_q40.m?download=true',
'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3.t?download=true',
'q40', 'q80', 'chat'
]
],
'llama3_1_8b_instruct_q40': [
'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.1_instruct_q40.m?download=true',
'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true',
'q40', 'q80', 'chat'
],
}

def downloadFile(url: str, path: str):
Expand Down
62 changes: 35 additions & 27 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,27 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.steps = 0;
args.seed = (unsigned long long)time(NULL);
args.chatTemplateType = TEMPLATE_UNKNOWN;
args.useDiscForKvCache = false;

int i = 1;
if (hasMode && argc > 1) {
args.mode = argv[1];
i++;
}
for (; i + 1 < argc; i += 2) {
if (strcmp(argv[i], "--model") == 0) {
args.modelPath = argv[i + 1];
} else if (strcmp(argv[i], "--tokenizer") == 0) {
args.tokenizerPath = argv[i + 1];
} else if (strcmp(argv[i], "--prompt") == 0) {
args.prompt = argv[i + 1];
} else if (strcmp(argv[i], "--weights-float-type") == 0) {
args.weightsFloatType = parseFloatType(argv[i + 1]);
} else if (strcmp(argv[i], "--buffer-float-type") == 0) {
args.bufferFloatType = parseFloatType(argv[i + 1]);
} else if (strcmp(argv[i], "--workers") == 0) {
char* name = argv[i];
char* value = argv[i + 1];
if (strcmp(name, "--model") == 0) {
args.modelPath = value;
} else if (strcmp(name, "--tokenizer") == 0) {
args.tokenizerPath = value;
} else if (strcmp(name, "--prompt") == 0) {
args.prompt = value;
} else if (strcmp(name, "--weights-float-type") == 0) {
args.weightsFloatType = parseFloatType(value);
} else if (strcmp(name, "--buffer-float-type") == 0) {
args.bufferFloatType = parseFloatType(value);
} else if (strcmp(name, "--workers") == 0) {
int j = i + 1;
for (; j < argc && argv[j][0] != '-'; j++);
int count = j - i - 1;
Expand All @@ -82,22 +85,24 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
}

i += count - 1;
} else if (strcmp(argv[i], "--port") == 0) {
args.port = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "--nthreads") == 0) {
args.nThreads = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "--steps") == 0) {
args.steps = atoi(argv[i + 1]);
} else if (strcmp(argv[i], "--temperature") == 0) {
args.temperature = atof(argv[i + 1]);
} else if (strcmp(argv[i], "--topp") == 0) {
args.topp = atof(argv[i + 1]);
} else if (strcmp(argv[i], "--seed") == 0) {
args.seed = atoll(argv[i + 1]);
} else if (strcmp(argv[i], "--chat-template") == 0) {
args.chatTemplateType = parseChatTemplateType(argv[i + 1]);
} else if (strcmp(name, "--port") == 0) {
args.port = atoi(value);
} else if (strcmp(name, "--nthreads") == 0) {
args.nThreads = atoi(value);
} else if (strcmp(name, "--steps") == 0) {
args.steps = atoi(value);
} else if (strcmp(name, "--temperature") == 0) {
args.temperature = atof(value);
} else if (strcmp(name, "--topp") == 0) {
args.topp = atof(value);
} else if (strcmp(name, "--seed") == 0) {
args.seed = atoll(value);
} else if (strcmp(name, "--chat-template") == 0) {
args.chatTemplateType = parseChatTemplateType(value);
} else if (strcmp(name, "--kv-cache-storage") == 0) {
args.useDiscForKvCache = strcmp(value, "disc") == 0;
} else {
printf("Unknown option %s\n", argv[i]);
printf("Unknown option %s\n", name);
exit(EXIT_FAILURE);
}
}
Expand Down Expand Up @@ -131,7 +136,10 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
args->steps = spec.seqLen;
}

Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool);
TransformerConfig config;
config.useDiscForKvCache = args->useDiscForKvCache;

Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, &config, socketPool);
socketPool->setTurbo(true);

Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool);
Expand Down
3 changes: 2 additions & 1 deletion src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
class AppArgs {
public:
char* mode;
int nThreads;
int nThreads;
bool useDiscForKvCache;

// inference
char* modelPath;
Expand Down
7 changes: 5 additions & 2 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class Chat {
int nInputTokens;
tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false);

pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, pos + nInputTokens - 1);
pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, (int)pos + nInputTokens - 1);
for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) {
inference->infer(inputTokens[i], pos);
token = inputTokens[i + 1];
Expand Down Expand Up @@ -207,10 +207,13 @@ void worker(AppArgs* args) {
throw std::runtime_error("Invalid port number");
}

TransformerConfig config;
config.useDiscForKvCache = args->useDiscForKvCache;

SocketServer server(args->port);
Socket socket = server.accept();
TransformerSpec spec;
Transformer transformer = Transformer::loadSlice(&spec, &socket);
Transformer transformer = Transformer::loadSlice(&spec, &config, &socket);
TransformerArch arch = TransformerArchFactory::create(&spec);

Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);
Expand Down
51 changes: 51 additions & 0 deletions src/commands.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include <cassert>
#include <cstring>
#ifdef _WIN32
#define _USE_MATH_DEFINES
#endif
#include <cmath>
#include "utils.hpp"
#include "funcs.hpp"
Expand Down Expand Up @@ -167,6 +170,54 @@ void LlamaRopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nT
}
}

Llama3_1RopeCommand::Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen) {
this->slice = slice;
this->ropeScalingFactor = ropeScalingFactor;
this->ropeScalingLowFreqFactor = ropeScalingLowFreqFactor;
this->ropeScalingHighFreqFactory = ropeScalingHighFreqFactory;
this->ropeScalingOrigMaxSeqLen = ropeScalingOrigMaxSeqLen;
printf("🕒 ropeScalingFactor: %f\n", ropeScalingFactor);
printf("🕒 ropeScalingLowFreqFactor: %f\n", ropeScalingLowFreqFactor);
printf("🕒 ropeScalingHighFreqFactory: %f\n", ropeScalingHighFreqFactory);
printf("🕒 ropeScalingOrigMaxSeqLen: %d\n", ropeScalingOrigMaxSeqLen);
}

float Llama3_1RopeCommand::scale(float freq) {
float waveLen = 2.0f * M_PI * freq;
float lowFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingLowFreqFactor;
float highFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingHighFreqFactory;
if (waveLen < highFreqWavelen) {
return freq;
} else if (waveLen > lowFreqWavelen) {
return freq / ropeScalingFactor;
} else {
float smooth = (ropeScalingOrigMaxSeqLen / waveLen - ropeScalingLowFreqFactor) / (ropeScalingHighFreqFactory - ropeScalingLowFreqFactor);
return (1 - smooth) * freq / ropeScalingFactor + smooth * freq;
}
}

void Llama3_1RopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
const unsigned int dim0Half = (isQ ? slice->qDim0 : slice->kvDim0) / 2;
const unsigned int shift = isQ ? slice->qShift : 0;
SPLIT_RANGE_TO_THREADS(s, e, 0, dim0Half, nThreads, threadIndex);
const unsigned int iStart = s * 2;
const unsigned int iEnd = e * 2;

for (unsigned int i = iStart; i < iEnd; i += 2) {
const unsigned int headDim = i % slice->headSize;
const float freq = 1.0f / powf(slice->ropeTheta, headDim / (float)slice->headSize);
const float val = pos * freq;
const float fcr = cosf(val);
const float fci = sinf(val);

float v0 = qOrK[i];
float v1 = qOrK[i + 1];

qOrK[i] = scale(v0 * fcr - v1 * fci);
qOrK[i + 1] = scale(v0 * fci + v1 * fcr);
}
}

FalconRopeCommand::FalconRopeCommand(RopeSlice *slice) {
this->slice = slice;
}
Expand Down
15 changes: 14 additions & 1 deletion src/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// *Slice - calculates sizes, offsets, slice sizes etc. It is not responsible for memory allocation. It may help in the loading of data.
// *Command - allocates memory for weights, performs calculations.

typedef unsigned short pos_t;
typedef unsigned int pos_t;
typedef uint8_t slice_index_t;

class MatmulSlice {
Expand Down Expand Up @@ -106,6 +106,19 @@ class LlamaRopeCommand : public RopeCommand {
void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
};

class Llama3_1RopeCommand : public RopeCommand {
private:
RopeSlice* slice;
float ropeScalingFactor;
float ropeScalingLowFreqFactor;
float ropeScalingHighFreqFactory;
int ropeScalingOrigMaxSeqLen;
public:
Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen);
void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
float scale(float freq);
};

class FalconRopeCommand : public RopeCommand {
private:
RopeSlice* slice;
Expand Down
6 changes: 5 additions & 1 deletion src/grok1-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ int main() {
TransformerSpec spec;
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
spec.archType = GROK1;
spec.ropeType = ROPE_FALCON;
spec.dim = 6144;
spec.nLayers = 1;
spec.nHeads = 48;
Expand All @@ -47,6 +48,9 @@ int main() {
spec.hiddenAct = GELU;
spec.ropeTheta = 10000.0f;

TransformerConfig config;
config.useDiscForKvCache = false;

size_t beforeBlockBytes = spec.dim * spec.vocabSize * sizeof(float);
size_t blockBytes = 956596224;
size_t afterBlockBytes = (spec.dim + spec.dim * spec.vocabSize) * sizeof(float);
Expand All @@ -60,7 +64,7 @@ int main() {
for (int f = 0; f < nFloats; f++) block[f] = randomF32(&state) / 100.0;

SocketPool socketPool(0, NULL);
Transformer transformer = Transformer::loadRoot(weights, &spec, &socketPool);
Transformer transformer = Transformer::loadRoot(weights, &spec, &config, &socketPool);
transformer.pos = 0;

float* x = transformer.x;
Expand Down
6 changes: 5 additions & 1 deletion src/llama2-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ int main() {
TransformerSpec spec;
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
spec.archType = LLAMA;
spec.ropeType = ROPE_LLAMA;
spec.dim = 4096;
spec.nLayers = 1;
spec.headSize = 128;
Expand All @@ -545,6 +546,9 @@ int main() {
spec.hiddenAct = SILU;
spec.ropeTheta = 10000.0f;

TransformerConfig config;
config.useDiscForKvCache = false;

size_t beforeBlockBytes = /* embedding */ 524288000;
size_t blockBytes = 809533440;
size_t afterBlockBytes = /* norm */ 16384 + /* embedding */ 524288000;
Expand All @@ -562,7 +566,7 @@ int main() {
for (int i = 0; i < mm; i++) mmData[i] = randomF32(&state) / 120.0;

SocketPool socketPool(0, NULL);
Transformer transformer = Transformer::loadRoot((char*)data, &spec, &socketPool);
Transformer transformer = Transformer::loadRoot((char*)data, &spec, &config, &socketPool);
transformer.pos = 0;

float* x = transformer.x;
Expand Down
Loading

0 comments on commit 4b8a0ca

Please sign in to comment.