Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layer extensions #11

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
1 change: 1 addition & 0 deletions LayerExtender/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
1 change: 1 addition & 0 deletions LayerExtender/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import *
13 changes: 13 additions & 0 deletions LayerExtender/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"extend_layers" : [
{
"new_layer": "LoraExtender",
"params": {"test": 1},
"use_default_match": true
},
{
"match_type": "*GPTNeoXAttention",
"new_layer": "BaseLayerExtender"
}
]
}
78 changes: 78 additions & 0 deletions LayerExtender/layer_extender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from torch import nn, Tensor
from typing import Optional, Tuple, Union
import traceback
import loralib as lora
import fnmatch


class BaseModuleExtender(nn.Module):
"""
Base code to wrap an arbitary nn.Module into a new class.

The original module is saved as self.wrapped_module and called automatically via the foerward pass.

You can modify the values passed into wrapped layer via the arg and kwarg params
"""
def __init__(self, config, wrapped_module: nn.Module, **kwargs):
super().__init__()

self.wrapped_module = wrapped_module

def forward(self, input: Tensor, * args, **kwargs):
wrapped_outputs = self.wrapped_module(input, * args, **kwargs)
return wrapped_outputs

@staticmethod
def is_match(name_list: str = "", type_list: str = ""):
return False


class BaseLayerExtender(BaseModuleExtender):
"""
Case code for wrapping LLM Layers.

Layers are the fundamental unit for pipe parallel.

This class is designed to interface our custom layers with the huggingface Trainer class and DeepSpeed.
"""
def __init__(self, config, wrapped_layer: nn.Module, **kwargs):
super().__init__(config, wrapped_layer)

def forward(self, input: Tensor, * args, **kwargs):
wrapped_outputs = self.wrapped_module(input, * args, **kwargs)
return wrapped_outputs

class LoraExtender(BaseModuleExtender):
"""
Wrapper for Lora adapaters
"""
def __init__(self, config, wrapped_module: nn.Module, **kwargs):
weight = wrapped_module.weight
bias = wrapped_module.bias
wrapped_module = lora.MergedLinear(wrapped_module.in_features, wrapped_module.out_features, r = 4, enable_lora=[True])
wrapped_module.weight = weight
wrapped_module.bias = bias

if "test" in kwargs:
print(kwargs["test"])

super().__init__(config, wrapped_module)

def forward(self, input: Tensor, * args, **kwargs):
wrapped_outputs = self.wrapped_module(input, * args, **kwargs)
return wrapped_outputs

@staticmethod
def is_match(name_list: str = "", type_list: str = ""):
model_name = type_list
first_token = model_name.find('.')
if first_token >= 0:
model_name = model_name[0:first_token]

name_match = False
type_match = False
if model_name == "GPTNeoXModel":
name_match = fnmatch.fnmatchcase(name_list, "*query_key_value")
type_match = fnmatch.fnmatchcase(type_list, "*Linear")

return name_match and type_match
61 changes: 61 additions & 0 deletions LayerExtender/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

#dummy config
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from utils import print_model, convert_model
import torch
from safetensors.torch import save_file, load_file
import json

with open('videorl/LayerExtender/config.json', 'r') as file:
config = json.load(file)

initial = True

if initial:#Load up a GPT Neo-x model specified by the config, convert to the lora model desired.

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
model = convert_model(model, config)

print(config)

print_model(model)

model.save_pretrained("./", safe_serialization = "True")

prompt = "Test"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=0.000001,
max_length=100,
)

print(tokenizer.batch_decode(gen_tokens)[0])


else:
#We want to load a model

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")#Is it possible to just load from config without this issue...
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
model = convert_model(model, {})
#We could skip the above step if we coded something that has the new architecture - this seems bad though because we'd need to do per adapter method

loaded = load_file("./model.safetensors")
model.load_state_dict(loaded)

prompt = "Test"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=0.000001,
max_length=100,
)

print(tokenizer.batch_decode(gen_tokens)[0])

60 changes: 60 additions & 0 deletions LayerExtender/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from torch import nn
from layer_extender import BaseLayerExtender, LoraExtender
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention
import loralib as lora
from dataclasses import dataclass, field
import importlib
import fnmatch

@dataclass
class ExtensionDefinition:
new_layer: str
match_name: str = "*"
match_type: str = "*"
params: dict = field(default_factory=dict)
use_default_match: bool = False


def convert_model(model, config):
extension_defs = config["extend_layers"]

extensions = [ExtensionDefinition(**extension) for extension in extension_defs]


return convert_model_internal(model, config, extensions)

def convert_model_internal(model, config, extensions, name_list: str = "", type_list: str = ""):
for child_name, child in model.named_children():
new_layer = None

name_list += f'{child_name}' if name_list == "" else f'.{child_name}'
type_list += f'{type(child).__name__}' if type_list == "" else f'.{type(child).__name__}'

for extension in extensions:
class_type = getattr(importlib.import_module("layer_extender"), extension.new_layer)
if extension.use_default_match:
if class_type.is_match(name_list, type_list):
new_layer = class_type(config, child, **extension.params)
else:
name_match = True
if not extension.match_name == "*":
name_match = fnmatch.fnmatchcase(name_list, extension.match_name)

type_match = True
if not extension.match_type == "*":
type_match = fnmatch.fnmatchcase(type_list, extension.match_type)

if type_match and name_match:
new_layer = class_type(config, child, **extension.params)

if not new_layer is None:
setattr(model, child_name, new_layer)

convert_model_internal(child, config, extensions, name_list, type_list)
return model


def print_model(model, indent = ""):
for child_name, child in model.named_children():
print(f'{indent}{child_name} ({type(child).__name__}):')
print_model(child, f'{indent} ')
Empty file added __init__.py
Empty file.