Skip to content

Commit

Permalink
Enable inference on CPU and Mac GPU using pytorch support for MPS (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
maralski authored Mar 17, 2023
1 parent 9bff21c commit db4af6a
Showing 1 changed file with 46 additions and 11 deletions.
57 changes: 46 additions & 11 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,51 @@

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16
)

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

try:
if torch.backends.mps.is_available():
device = "mps"
except:
pass

if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model, "tloen/alpaca-lora-7b",
torch_dtype=torch.float16
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
"tloen/alpaca-lora-7b",
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
device_map={"": device},
low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
"tloen/alpaca-lora-7b",
device_map={"": device},
)

def generate_prompt(instruction, input=None):
if input:
Expand Down Expand Up @@ -55,7 +90,7 @@ def evaluate(
):
prompt = generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].cuda()
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
Expand Down

0 comments on commit db4af6a

Please sign in to comment.