Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bounty] PyTorch & HuggingFace Interface #139

Open
wants to merge 680 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
680 commits
Select commit Hold shift + click to select a range
89f1be0
Merge branch 'main' into HEAD
AlexCheema Oct 10, 2024
b6f6afc
Merge remote-tracking branch 'origin/main' into HEAD
AlexCheema Oct 10, 2024
5eb6c34
fixed torch device selection
risingsunomi Oct 11, 2024
ed64437
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Oct 11, 2024
18d41eb
fixing imports
risingsunomi Oct 11, 2024
c73ed76
Merge pull request #21 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 11, 2024
9ecbf0c
fixing chatgpt_api mistake
risingsunomi Oct 11, 2024
79c9e70
Merge branch 'exo-explore:main' into main
risingsunomi Oct 11, 2024
ebfd44a
Merge pull request #22 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 11, 2024
dae2cbe
removing old pytorch folder
risingsunomi Oct 11, 2024
1c1dd06
Merge pull request #23 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 11, 2024
55ae027
Update README.md
risingsunomi Oct 11, 2024
4b6a86d
set all torch models in models.py
AlexCheema Oct 11, 2024
830d33d
in torch, explicitly set the device when initilaizing the model
AlexCheema Oct 11, 2024
074dfe3
spacing
AlexCheema Oct 11, 2024
d9cfcc4
add model mlx-community/Qwen2-0.5B-Instruct-4bit
AlexCheema Oct 11, 2024
c3e1934
Merge branch 'exo-explore:main' into main
risingsunomi Oct 12, 2024
2c056b4
code changes from PR feedback, working on splitting of weights
risingsunomi Oct 12, 2024
da5c28d
Merge branch 'exo-explore:main' into pr139-dev-oct24
risingsunomi Oct 12, 2024
83a723b
doing more work toward individual safetensor loading, adding back dev…
risingsunomi Oct 13, 2024
47be250
working on split model, moving to server for more vram
risingsunomi Oct 13, 2024
ea0d4b1
change to hf downloader as was not getting all safetensor files
risingsunomi Oct 13, 2024
30b7991
splitting model still work in progress as transformers still seems to…
risingsunomi Oct 13, 2024
3a2c431
updating readme
risingsunomi Oct 13, 2024
4def538
Merge branch 'main' into pr139-dev-oct24
risingsunomi Oct 13, 2024
b35224c
Merge pull request #24 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 13, 2024
6c6e7b2
successful splitting model test with only loading needed weights, imp…
risingsunomi Oct 14, 2024
55ffdc7
Merge branch 'pr139-dev-oct24' of github.com:risingsunomi/exo-nvidia …
risingsunomi Oct 14, 2024
aacdeb5
adding model sharding to inference engine, doing testing with inferen…
risingsunomi Oct 14, 2024
ce702d1
fixing layer range issue
risingsunomi Oct 14, 2024
e387a79
fixing layer range issue
risingsunomi Oct 14, 2024
e0ba2bb
fixing layer range issue
risingsunomi Oct 14, 2024
5b9638f
checking if ram over usaage even if reducing layers on large models
risingsunomi Oct 14, 2024
664f29f
half layer inference engine testing
risingsunomi Oct 14, 2024
2591fab
fixing layer amount with sharded modeling
risingsunomi Oct 14, 2024
99dac57
adding qwen2.5 3B for testing
risingsunomi Oct 14, 2024
c12526f
Merge pull request #25 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 14, 2024
493cd3e
updating inference engine test
risingsunomi Oct 14, 2024
de23294
cleaning up utils and split model
risingsunomi Oct 14, 2024
d5a02be
Merge pull request #26 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 14, 2024
e7470b1
bugfix in llm setup
dtnewman Oct 15, 2024
fa24f46
Merge pull request #27 from dtnewman/main
risingsunomi Oct 15, 2024
5c69f3f
Merge remote-tracking branch 'origin/main' into HEAD
AlexCheema Oct 16, 2024
f5a1cef
handle range not satisfiable edge case
AlexCheema Oct 16, 2024
751bd1c
updating to use automodelforcausallm instead of autoconfig
risingsunomi Oct 16, 2024
7d866d8
removing meta model
risingsunomi Oct 16, 2024
253237b
updating split model test
risingsunomi Oct 16, 2024
e46ffa4
updating split model test
risingsunomi Oct 16, 2024
476b6ba
automodel fix
risingsunomi Oct 16, 2024
f7e02e9
fixing split model test
risingsunomi Oct 16, 2024
bd6322f
pytorch offload buffers error
risingsunomi Oct 17, 2024
c51bd91
device_map any issue with split model
risingsunomi Oct 17, 2024
4a2aef4
updating split model test
risingsunomi Oct 17, 2024
79f0763
fixing split model issue
risingsunomi Oct 17, 2024
cbbc9cf
fixing node issues
risingsunomi Oct 17, 2024
58cebab
fixing node issues
risingsunomi Oct 17, 2024
7f9b1bb
fixing node issues
risingsunomi Oct 17, 2024
c3adec5
fixing node issues
risingsunomi Oct 17, 2024
c8e6acc
fixing node issues
risingsunomi Oct 17, 2024
df028e2
fixing node issues, range issue
risingsunomi Oct 17, 2024
e5a1939
fixing node issues, range issue
risingsunomi Oct 17, 2024
d03a85c
Merge branch 'main' into pr139-dev-oct24
risingsunomi Oct 17, 2024
69a8955
Merge pull request #28 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 17, 2024
d07b825
adding num hidden layers manipulation for all models
risingsunomi Oct 18, 2024
a840e7f
updating to use shard_num_hidden_layers
risingsunomi Oct 18, 2024
bf5f22d
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Oct 18, 2024
52fa3f8
adding in better layer manipulation
risingsunomi Oct 18, 2024
ec49e31
adding in safe tensor sharding, generate model.safetensors.index.json…
risingsunomi Oct 19, 2024
f45b514
implementing sharding tests, fixing bugs with safetensor recompile
risingsunomi Oct 19, 2024
f90c24a
adding safetensor sharding, implementing it into model inference engine
risingsunomi Oct 20, 2024
696c264
updating backup and backup restore
risingsunomi Oct 20, 2024
9514e92
added removing backup when restoring
risingsunomi Oct 20, 2024
d65505e
added generating weight map if none, did updates to backup and restor…
risingsunomi Oct 20, 2024
d5b6113
cleaning up logging
risingsunomi Oct 20, 2024
d2302cc
updating docstring in newest class file
risingsunomi Oct 20, 2024
35c32eb
Merge pull request #29 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 20, 2024
72fcf9b
starting write of llama3 model outside of transformers and using pytorch
risingsunomi Oct 21, 2024
9cac5ab
moving llama3 modeling source code, updating readme file
risingsunomi Oct 21, 2024
8012008
adding pytorch based llama model, added testing and working through bugs
risingsunomi Oct 23, 2024
291aa10
Merge branch 'exo-explore:main' into main
risingsunomi Oct 23, 2024
76323d7
Update llama3.py
risingsunomi Oct 23, 2024
0d66acd
updating pytorch llama model still, currently broken but backing up a…
risingsunomi Oct 23, 2024
fcb298b
Merge branch 'main' into main
AlexCheema Oct 23, 2024
1512d13
updated llamablock and llamamodel, created a MLP helper class to use …
risingsunomi Oct 25, 2024
0eb8044
fixing causual mask loading error, updated testing, working on logit …
risingsunomi Oct 25, 2024
a6768b4
adding a chat temple from tokenizer to test, looking at padding ids t…
risingsunomi Oct 25, 2024
8ba24e2
fixing parameter defintion on 4d mask method, commiting before trying…
risingsunomi Oct 26, 2024
6e32be6
merge fixing
risingsunomi Oct 27, 2024
df13fbc
Merge branch 'main' into pr/risingsunomi/30
risingsunomi Oct 27, 2024
6b3af3f
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Oct 27, 2024
cfb10ba
added in more base llm functions like multiheadattention and rotate e…
risingsunomi Oct 28, 2024
ea868c6
updating attentions, changed model struct, fixing kv cache
risingsunomi Oct 30, 2024
f1822e2
fixing kvcache for multiheadattention, fixing layers names for loadin…
risingsunomi Oct 30, 2024
38028c0
doing work with position_id and causal mask
risingsunomi Oct 31, 2024
0fd1797
updating torch readme with current model in development
risingsunomi Nov 1, 2024
5aaffe6
implemented using torchtune multiheadattention, added dot product att…
risingsunomi Nov 2, 2024
b2b63c3
FINALLY A WORKING PYTORCH ONLY MODEL, working on logit gen, shard tes…
risingsunomi Nov 3, 2024
f53ebd1
cleaning up custom dot product attention but might be removed, buildi…
risingsunomi Nov 3, 2024
e8db8ee
first layer run fixes, variable layer length weight loading fixes, wo…
risingsunomi Nov 10, 2024
22bc6a7
made it so weight for last output layer is only loaded when shard is …
risingsunomi Nov 11, 2024
7f2abc3
working on sharding issue where hidden state is not working when bein…
risingsunomi Nov 12, 2024
bdf3240
fixing last hidden value handling
risingsunomi Nov 15, 2024
227199f
update test
risingsunomi Nov 15, 2024
5af6302
update test
risingsunomi Nov 15, 2024
d7e5aca
update test
risingsunomi Nov 15, 2024
1874d23
update test, turn on caching
risingsunomi Nov 15, 2024
3a0ad62
test safetensor load
risingsunomi Nov 15, 2024
6098ae5
test hidden alignment
risingsunomi Nov 15, 2024
fa1e70f
updates to torchtune model, fixing non-generation errors, created spl…
risingsunomi Nov 17, 2024
d958bf9
split model working, updates to safetensor loading letting shard control
risingsunomi Nov 17, 2024
c8bdb09
reduced model loading ram by loading only some layers in layer list, …
risingsunomi Nov 17, 2024
75817eb
updating readme
risingsunomi Nov 17, 2024
73630d1
building out torch inference engine
risingsunomi Nov 18, 2024
ad99332
creating torch inference engine, separated torch and hf torch engines…
risingsunomi Nov 23, 2024
9ff2cc8
Merge github.com:exo-explore/exo into fork-merge
risingsunomi Nov 23, 2024
6ab6f1c
merge
risingsunomi Nov 23, 2024
811befc
Merge pull request #32 from risingsunomi/fork-merge
risingsunomi Nov 23, 2024
0e9f42a
fixing torchtune module issues
risingsunomi Nov 24, 2024
a170cc6
adding torchtune install
risingsunomi Nov 24, 2024
05f3e52
Merge branch 'pr139-dev-oct24' of github.com:risingsunomi/exo-nvidia …
risingsunomi Nov 24, 2024
ff78688
adding torchao install
risingsunomi Nov 24, 2024
3ce2df0
Merge branch 'pr139-dev-oct24' of github.com:risingsunomi/exo-nvidia …
risingsunomi Nov 24, 2024
4da7377
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Nov 24, 2024
9f57e45
building out test inference engine for pytorch, adding torch engine t…
risingsunomi Nov 24, 2024
9f52f24
Merge branch 'exo-explore:main' into main
risingsunomi Nov 24, 2024
96c3eb5
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Nov 24, 2024
4455224
Merge branch 'pr139-dev-oct24' of github.com:risingsunomi/exo-nvidia …
risingsunomi Nov 24, 2024
fbf106e
removing last shard check for return of hidden state from infer_prompt
risingsunomi Nov 24, 2024
2ecc629
Merge branch 'pr139-dev-oct24' of github.com:risingsunomi/exo-nvidia …
risingsunomi Nov 24, 2024
405b5ae
fixing vram/ram issue, switched to using float16 for dtype
risingsunomi Nov 24, 2024
0320c50
trying to offload as I can
risingsunomi Nov 24, 2024
84f4131
adding detach
risingsunomi Nov 24, 2024
596c715
debug process issue
risingsunomi Nov 24, 2024
a5eb1be
putting back hidden state pass
risingsunomi Nov 24, 2024
e550af2
Merge branch 'pr139-dev-oct24' of github.com:risingsunomi/exo-nvidia …
risingsunomi Nov 24, 2024
21e626e
fixing torch inference engine selection not working when adding more …
risingsunomi Nov 24, 2024
e8f689c
fixing typo
risingsunomi Nov 24, 2024
d0cc3b0
fixing llama3 tests, removing mask and input_ids for going through mo…
risingsunomi Nov 24, 2024
5f085dc
optional caching as cache might not work with how sharding works
risingsunomi Nov 24, 2024
f667735
fix cache assignment
risingsunomi Nov 24, 2024
cb847e4
fixing, set cache to false for inference for now
risingsunomi Nov 25, 2024
8325975
model double loading vram issue
risingsunomi Nov 25, 2024
c8308b8
model double loading vram issue
risingsunomi Nov 25, 2024
2dfce95
setting class shard
risingsunomi Nov 25, 2024
39cfbf5
working on inference engine issue of too much vram
risingsunomi Nov 26, 2024
185502a
fixing vram issue from total_response_length being set with max_seq_l…
risingsunomi Nov 28, 2024
738e931
fixing split and full model generation, finetune for nodes and genera…
risingsunomi Nov 29, 2024
907ba0b
updated full test to generate to stop or max tokens for testing, upda…
risingsunomi Nov 30, 2024
f3c868b
Merge branch 'main' of github.com:exo-explore/exo into exo-explore-main
risingsunomi Nov 30, 2024
f0c8fb1
Merge branch 'exo-explore-main'
risingsunomi Nov 30, 2024
c6806f9
Merge branch 'main' of github.com:risingsunomi/exo-nvidia
risingsunomi Nov 30, 2024
4c93855
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Nov 30, 2024
b538cd2
Merge branch 'main' of github.com:risingsunomi/exo-nvidia into pr139-…
risingsunomi Nov 30, 2024
8c29d27
fixing inference sampling
risingsunomi Nov 30, 2024
80122a0
changing back temp and top_k passing
risingsunomi Dec 1, 2024
3e0d117
moving back to unsloth llama version for 3.2-1B
risingsunomi Dec 1, 2024
34ca2ad
Merge pull request #34 from risingsunomi/pr139-dev-oct24
risingsunomi Dec 1, 2024
28d9900
cleaning up code, doing more testing as some bugs a bit still
risingsunomi Dec 1, 2024
30651ea
having a check for llama or Llama for loading tensors, will add suppo…
risingsunomi Dec 1, 2024
58122ba
Merge branch 'main' of github.com:exo-explore/exo into fork-update-de…
risingsunomi Dec 28, 2024
fd1d469
Merge pull request #36 from risingsunomi/fork-update-dec272024
risingsunomi Dec 28, 2024
d4f39fc
Merge pull request #37 from risingsunomi/main
risingsunomi Dec 28, 2024
39ffe70
Merge pull request #38 from risingsunomi/pr139-dev-oct24
risingsunomi Dec 28, 2024
9494949
adding load_checkpoint to TorchDynamicShardInferenceEngine
risingsunomi Dec 29, 2024
5085adb
fixing formatting in code, adding in logging to debug, changing full …
risingsunomi Jan 2, 2025
c7a2b6b
Merge branch 'exo-explore:main' into main
risingsunomi Jan 2, 2025
522825f
Merge branch 'exo-explore:main' into main
risingsunomi Jan 8, 2025
88aeae7
Merge branch 'main' of github.com:risingsunomi/exo-pt into dev-jan-2025
risingsunomi Jan 12, 2025
73b71d5
Updating model to use max_position_embeddings, testing mono logit pas…
risingsunomi Jan 16, 2025
1de87fb
updating llama model caching inference, updating to latest torchtune …
risingsunomi Jan 16, 2025
028f305
improved weight loading, switched to torchtune based norm and llama f…
risingsunomi Jan 16, 2025
f3bd881
adding TORCH_USE_ORG_SEQ to use the origin max positions embeds for m…
risingsunomi Jan 16, 2025
099b6c1
Merge branch 'main' of github.com:exo-explore/exo into dev-jan-2025
risingsunomi Jan 16, 2025
027de6f
updating readme before merge into main
risingsunomi Jan 16, 2025
6f6b167
Merge pull request #40 from risingsunomi/dev-jan-2025
risingsunomi Jan 16, 2025
eabcdaa
adding inference state back to inference engine
risingsunomi Jan 16, 2025
6d29ba6
fixing inference engine selection by adding torch to supported engine…
risingsunomi Jan 16, 2025
38878cb
adding in tracking of selected exo infrace engine via env var EXO_INF…
risingsunomi Jan 17, 2025
0a985a7
delete build folder, add it to .gitignore, added OOM error fix to res…
risingsunomi Jan 17, 2025
48a75c1
updating torchao version in setup.py
risingsunomi Jan 17, 2025
bceeaf5
updating for node to node hidden state passing and adding back infere…
risingsunomi Jan 17, 2025
4f9f038
syntax error fix
risingsunomi Jan 17, 2025
51184d4
mask and input_pos passing between nodes fix
risingsunomi Jan 17, 2025
bc7d699
mask and input_pos passing between nodes fix
risingsunomi Jan 18, 2025
c0d9f57
mask and input pos issue when passing to sharded node fix
risingsunomi Jan 18, 2025
c0d0c71
mask and input pos issue when passing to sharded node fix
risingsunomi Jan 18, 2025
05d4c9d
mask and input pos issue when passing to sharded node fix
risingsunomi Jan 18, 2025
464b9cf
removing mask and input_pos from inference state passing for nodes
risingsunomi Jan 18, 2025
d469e3e
removing mask and input_pos from inference state passing for nodes
risingsunomi Jan 18, 2025
eeee605
Merge pull request #41 from risingsunomi/dev-jan24-2
risingsunomi Jan 18, 2025
5302b73
testing reloading model at prompt encode to fix OOM issue
risingsunomi Jan 18, 2025
f1b05cd
Merge pull request #42 from risingsunomi/dev-jan25-2
risingsunomi Jan 18, 2025
13ea82d
fixing output weight for llama3.1
risingsunomi Jan 18, 2025
5e31e3d
changing top_k to 25, making reset of cache and model at every initia…
risingsunomi Jan 19, 2025
646c14a
putting in better clearing model functions, adding in clearing after …
risingsunomi Jan 19, 2025
25bdfb7
fixing env variable boolean
risingsunomi Jan 19, 2025
d0680b6
syntax fix
risingsunomi Jan 19, 2025
58a190a
remove log message
risingsunomi Jan 19, 2025
d7d5590
removing clearing model on non-primary nodes
risingsunomi Jan 19, 2025
ddaab5a
changing out tensor loading, separating out initial mask and input_po…
risingsunomi Jan 23, 2025
216ee1b
still moving and upgrading token generation, broken
risingsunomi Jan 24, 2025
106e56e
updated inference engine, added InferenceState class, added in tweaks…
risingsunomi Jan 24, 2025
754608a
trying oom tweaks
risingsunomi Jan 24, 2025
3ac345e
Merge pull request #43 from risingsunomi/dev-jan25-03
risingsunomi Jan 24, 2025
9b1ce15
Merge branch 'main' of github.com:exo-explore/exo into exo-explore-main
risingsunomi Jan 24, 2025
8ddcaea
Merge branch 'exo-explore-main'
risingsunomi Jan 24, 2025
81a27c5
Adding two new requirements not in fork main new updates
risingsunomi Jan 24, 2025
192f0c5
fixing numpy no bfloat16 support issue
risingsunomi Jan 24, 2025
fea4b31
fixing numpy no bfloat16 support issue
risingsunomi Jan 24, 2025
ffd2907
debugging grpc peer handle issue for serialization of inference state
risingsunomi Jan 24, 2025
726ef0d
adding in bool to check if apple or not for checking array type
risingsunomi Jan 24, 2025
f24397f
fixing other mx.array call
risingsunomi Jan 24, 2025
386ac0b
Merge pull request #45 from risingsunomi/grpc-fix-jan242025
risingsunomi Jan 24, 2025
6bbbb04
adding tokens pass to hidden state passing
risingsunomi Jan 24, 2025
0d5779e
fixing type conversion and device conversion
risingsunomi Jan 24, 2025
fe1e8ef
fixing types and adding clones
risingsunomi Jan 24, 2025
f43af1b
tensor shape error
risingsunomi Jan 24, 2025
9ec9b23
tensor shape error
risingsunomi Jan 24, 2025
2f31a7b
tensor shape error
risingsunomi Jan 24, 2025
b7cfece
turning back on mask calculation
risingsunomi Jan 24, 2025
332ed2a
tensor none called
risingsunomi Jan 24, 2025
983c341
tensor repeating in state cache
risingsunomi Jan 24, 2025
73e13b4
tensor repeating in state cache
risingsunomi Jan 24, 2025
f65766e
hidden state issues
risingsunomi Jan 24, 2025
f80597c
hidden state issues
risingsunomi Jan 24, 2025
58a7d3c
use_cache fix
risingsunomi Jan 24, 2025
efcb5b9
model load issue
risingsunomi Jan 24, 2025
4bf752b
pass token issue
risingsunomi Jan 25, 2025
7cc42f1
pass token issue
risingsunomi Jan 25, 2025
4e3e53e
pass token issue
risingsunomi Jan 25, 2025
c3bde74
Merge pull request #46 from risingsunomi/node-fixes-jan242025
risingsunomi Jan 25, 2025
a09956e
removing screenshot
risingsunomi Jan 25, 2025
f508dff
fix for sharded weights
risingsunomi Jan 26, 2025
8920a87
adding torch support for llama-3.2-3b
risingsunomi Jan 26, 2025
1d7262d
fixing tok_embeddings
risingsunomi Jan 26, 2025
ec91e09
fixing caching setup
risingsunomi Jan 27, 2025
a7757d3
adding qwen2 model, creating a general multihead attention transforma…
risingsunomi Jan 31, 2025
1431d48
removing duplicate mlp for single layer_mlp
risingsunomi Jan 31, 2025
7ad4b1c
fixes to general mha for detecting to use_tied or not
risingsunomi Jan 31, 2025
fbb6e55
Merge pull request #47 from risingsunomi/qwen-dev-jan25
risingsunomi Jan 31, 2025
f7028c7
Merge branch 'main' of github.com:exo-explore/exo into exo-explore-main
risingsunomi Jan 31, 2025
5f6b22d
Merge branch 'exo-explore-main'
risingsunomi Jan 31, 2025
76e141a
adding new shard download method, adding all llama models for torch
risingsunomi Jan 31, 2025
85d25c1
adding torch support for qwen models
risingsunomi Jan 31, 2025
57b43f7
updating torch to 2.6.0 latest stable
risingsunomi Jan 31, 2025
611bffb
test for mistral support
risingsunomi Feb 1, 2025
0523893
--help list --interence-engine=torch
divinity76 Feb 7, 2025
2bc2b3d
Merge branch 'exo-explore:main' into main
risingsunomi Feb 8, 2025
aac3e75
Merge pull request #49 from divinity76/patch-2
risingsunomi Feb 8, 2025
2d1a9a6
Merge branch 'exo-explore:main' into main
risingsunomi Feb 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# XCode
**/*.xcodeproj/*

# Aider
.aider*

exo/tinychat/images/*.png
.vscode/
build/
3 changes: 3 additions & 0 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def generate_completion(
}],
}

if DEBUG >= 3:
print(f"completion: {completion}")

if not stream:
completion["usage"] = {
"prompt_tokens": len(tokenizer.encode(prompt)),
Expand Down
1 change: 0 additions & 1 deletion exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_system_info():
return "Linux"
return "Non-Mac, non-Linux system"


def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")

Expand Down
5 changes: 5 additions & 0 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inferen
"mlx": "MLXDynamicShardInferenceEngine",
"tinygrad": "TinygradDynamicShardInferenceEngine",
"dummy": "DummyInferenceEngine",
"torch": "TorchDynamicShardInferenceEngine"
}


Expand All @@ -71,6 +72,10 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDown
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))

return TinygradDynamicShardInferenceEngine(shard_downloader)
elif inference_engine_name == "torch":
from exo.inference.torch.sharded_inference_engine import TorchDynamicShardInferenceEngine

return TorchDynamicShardInferenceEngine(shard_downloader)
elif inference_engine_name == "dummy":
from exo.inference.dummy_inference_engine import DummyInferenceEngine
return DummyInferenceEngine()
Expand Down
2 changes: 2 additions & 0 deletions exo/inference/torch/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data/
model/archive/
103 changes: 103 additions & 0 deletions exo/inference/torch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# PyTorch inference engine

## Devs
- [Vincent Castro](https://x.com/t0kenl1mit)

## ENV Vars
```bash
# Use the original max position embeddings amount, if present in the rope_scaling - default is False
TORCH_USE_ORG_SEQ = True or False

# Use cache - default is True
TORCH_USE_CACHE = True or False
```

## Notes/Issues
### 10/10/2024
- To select a pytorch device via environment variables, set the variable TORCH_DEVICE
- XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM
- With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373)
- Looking into adding mobile device support properly
- If device is not CPU the data type defaults to float32 else float16.

### 10/13/2024
Still working on split model development (see test_split_model.py). Right now, it seems to do it but still transformers is loading more in the RAM and GPU as it loads up a larger models (causing an OOM). Will research and add to next update. Right now, tests are added and are in development.

### 10/21/2024
Working on removing transformers due to inference and VRAM usage [issues](https://github.com/exo-explore/exo/pull/139#issuecomment-2424953962). Creating a pure pytorch implementation of llama3 as using transformers wont work for exo. Using some code from meta but also implementing the use of torchtune.

### 10/27/2024
Still working on llama3 model but wanted to note that a better KVCache needs to be investigated.

#### 11/17/2024
Llama sharded model now working and next step is inference engine. Still testing on small llama 3.2 1B but will try larger models.

### 01/16/2024
Torchtune has replaced huggingface transformers except for the tokenizer. Inferencing on Meta 3.2 1B seems okay but my GPU runs into a wall quickly. Will be trying on a bigger VM and LAN server to split up model.

## Tech
```bash
# Laptop/PC
Distributor ID: Ubuntu
Description: Ubuntu 24.04.1 LTS
Release: 24.04
Codename: noble
CUDA Version: 12.4
Nvidia Driver Version: 550.107.02

CPU: 11th Gen Intel® Core™ i7-11800H × 16
RAM: 16GB
GPU 1: Nvidia GeForce RTX 3060 6GB Laptop
```
```bash
# Server
Distributor ID: Pop
Description: Pop!_OS 22.04 LTS
Release: 22.04
Codename: jammy
CUDA Version: 12.4
Nvidia Driver Version: 550.90.07

GPU 1: NVIDIA T1000 8GB
GPU 2: NVIDIA Quadro M2000 4GB
GPU 3: NVIDIA Quadro M2000 4GB
GPU 4: NVIDIA Quadro P400 2GB
GPU 5: NVIDIA Quadro P400 2GB
```

## Current Model

WIP pytorch llama model

```
# Llama-3.2-1B-Instruct #

ShardedLlamaModel(
(model): ShardTransformerDecoder(
(tok_embeddings): Embedding(128256, 2048)
(layers): ModuleList(
(0-15): 16 x TransformerSelfAttentionLayer(
(attn): MultiHeadAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=512, bias=False)
(v_proj): Linear(in_features=2048, out_features=512, bias=False)
(output_proj): Linear(in_features=2048, out_features=2048, bias=False)
(pos_embeddings): Llama3ScaledRoPE()
)
(mlp): FeedForward(
(w1): Linear(in_features=2048, out_features=8192, bias=False)
(w2): Linear(in_features=8192, out_features=2048, bias=False)
(w3): Linear(in_features=2048, out_features=8192, bias=False)
(activation): SiLU()
)
(sa_norm): RMSNorm()
(mlp_norm): RMSNorm()
(sa_scale): Identity()
(mlp_scale): Identity()
)
)
(norm): RMSNorm()
)
)

```
Empty file added exo/inference/torch/__init__.py
Empty file.
Empty file.
Loading