diff --git a/NeoX/__init__.py b/NeoX/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/NeoX/convert.py b/NeoX/convert.py new file mode 100644 index 0000000..6d38a3a --- /dev/null +++ b/NeoX/convert.py @@ -0,0 +1,27 @@ +from torch import nn +import loralib as lora + +def convert_model_lora(model): + for child_name, child in model.named_children(): + if isinstance(child, nn.Linear) and child_name == "query_key_value": + weight = child.weight + bias = child.bias + new = lora.MergedLinear(child.in_features, child.out_features, r = 4) + new.weight = weight + new.bias = bias + setattr(model, child_name, new) + # elif isinstance(child, nn.Conv2d): + # weight = child.weight + # bias = child.bias + # new = lora.Conv2d(child.in_channels, child.out_channels, child.kernel_size[0], r = 4)#kernel size would + # new.weight = weight + # new.bias = bias + # setattr(model, child_name, new) + # elif isinstance(child, nn.Embedding): + # weight = child.weight + # new = lora.Embedding(child.num_embeddings, child.embedding_dim, r = 4) + # new.weight = weight + # setattr(model, child_name, new) + else: + convert_model_lora(child) + return model \ No newline at end of file diff --git a/NeoX/load_model.py b/NeoX/load_model.py new file mode 100644 index 0000000..a4167b9 --- /dev/null +++ b/NeoX/load_model.py @@ -0,0 +1,27 @@ + +#dummy config +from transformers import GPTNeoXForCausalLM, AutoTokenizer +from convert import convert_model_lora +import torch +initial = True +from safetensors.torch import save_file, load_file + +if initial:#Load up a GPT Neo-x model specified by the config, convert to the lora model desired. + + model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped") + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped") + model = convert_model_lora(model) + + # torch.save(model.state_dict(), "./model.pt") + model.save_pretrained("./", safe_serialization = "True") + +else: + #We want to load a model + + model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")#Is it possible to just load from config without this issue... + model = convert_model_lora(model) + #We could skip the above step if we coded something that has the new architecture - this seems bad though because we'd need to do per adapter method + + loaded = load_file("./model.safetensors") + model.load_state_dict(loaded) + \ No newline at end of file