This is an implementation of Falcon model in JAX using functional approach for improved perfomance. The project is inspired by https://github.com/ayaka14732/llama-2-jax. A very large amount of that project's code has been reused. Newely implemented features are:
- Model architecture
- Training
- Generation
- Early stopping