Skip to content

Commit

Permalink
refactor(routes): split routes registration (mudler#2077)
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Apr 20, 2024
1 parent afa1bca commit 284ad02
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 167 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -714,4 +714,4 @@ docker-image-intel-xpu:

.PHONY: swagger
swagger:
swag init -g core/http/api.go --output swagger
swag init -g core/http/app.go --output swagger
135 changes: 6 additions & 129 deletions core/http/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,21 @@ import (

"github.com/go-skynet/LocalAI/pkg/utils"

"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/http/endpoints/openai"
"github.com/go-skynet/LocalAI/core/http/routes"

"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"

"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/swagger" // swagger handler

// swagger handler
"github.com/rs/zerolog/log"
)

Expand Down Expand Up @@ -175,16 +174,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
app.Use(c)
}

// LocalAI API endpoints
galleryService := services.NewGalleryService(appConfig.ModelPath)
galleryService.Start(appConfig.Context, cl)

app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})

// Make sure directories exists
os.MkdirAll(appConfig.ImageDir, 0755)
os.MkdirAll(appConfig.AudioDir, 0755)
Expand All @@ -197,122 +186,10 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants)
utils.LoadConfig(appConfig.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles)

app.Get("/swagger/*", swagger.HandlerDefault) // default

welcomeRoute(
app,
cl,
ml,
appConfig,
auth,
)

modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())

app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))

// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))

// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))

// openAI compatible API endpoint

// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig))

// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig))
app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig))

// assistant
app.Get("/v1/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Get("/assistants", auth, openai.ListAssistantsEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants", auth, openai.CreateAssistantEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id", auth, openai.DeleteAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id", auth, openai.GetAssistantEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id", auth, openai.ModifyAssistantEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files", auth, openai.ListAssistantFilesEndpoint(cl, ml, appConfig))
app.Post("/v1/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Post("/assistants/:assistant_id/files", auth, openai.CreateAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/v1/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Delete("/assistants/:assistant_id/files/:file_id", auth, openai.DeleteAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/v1/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))
app.Get("/assistants/:assistant_id/files/:file_id", auth, openai.GetAssistantFileEndpoint(cl, ml, appConfig))

// files
app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig))
app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig))
app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig))
app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))
app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig))

// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig))

// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig))

// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig))
app.Post("/v1/audio/speech", auth, localai.TTSEndpoint(cl, ml, appConfig))

// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig))

if appConfig.ImageDir != "" {
app.Static("/generated-images", appConfig.ImageDir)
}

if appConfig.AudioDir != "" {
app.Static("/generated-audio", appConfig.AudioDir)
}

ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}

// Kubernetes health checks
app.Get("/healthz", ok)
app.Get("/readyz", ok)

// Experimental Backend Statistics Module
backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitor))

// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml))
app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml))

app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig, auth)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, auth)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth)
routes.RegisterPagesRoutes(app, cl, ml, appConfig, auth)

// Define a custom 404 handler
// Note: keep this at the bottom!
Expand Down
28 changes: 28 additions & 0 deletions core/http/endpoints/localai/welcome.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package localai

import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/internal"
"github.com/gofiber/fiber/v2"
)

func WelcomeEndpoint(appConfig *config.ApplicationConfig,
models []string, backendConfigs []config.BackendConfig) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"Models": models,
"ModelsConfig": backendConfigs,
"ApplicationConfig": appConfig,
}

if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
} else {
// Render index
return c.Render("views/index", summary)
}
}
}
37 changes: 0 additions & 37 deletions core/http/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import (
"net/http"

"github.com/Masterminds/sprig/v3"
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/schema"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
fiberhtml "github.com/gofiber/template/html/v2"
"github.com/russross/blackfriday"
Expand All @@ -33,40 +30,6 @@ func notFoundHandler(c *fiber.Ctx) error {
return nil
}

func welcomeRoute(
app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error,
) {
if appConfig.DisableWelcomePage {
return
}

models, _ := ml.ListModels()
backendConfigs := cl.GetAllBackendConfigs()

app.Get("/", auth, func(c *fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"Models": models,
"ModelsConfig": backendConfigs,
"ApplicationConfig": appConfig,
}

if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
// The client expects a JSON response
return c.Status(fiber.StatusOK).JSON(summary)
} else {
// Render index
return c.Render("views/index", summary)
}
})

}

func renderEngine() *fiberhtml.Engine {
engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html")
engine.AddFuncMap(sprig.FuncMap())
Expand Down
19 changes: 19 additions & 0 deletions core/http/routes/elevenlabs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package routes

import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/elevenlabs"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
)

func RegisterElevenLabsRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {

// Elevenlabs
app.Post("/v1/text-to-speech/:voice-id", auth, elevenlabs.TTSEndpoint(cl, ml, appConfig))

}
64 changes: 64 additions & 0 deletions core/http/routes/localai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package routes

import (
"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/core/http/endpoints/localai"
"github.com/go-skynet/LocalAI/core/services"
"github.com/go-skynet/LocalAI/internal"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/swagger"
)

func RegisterLocalAIRoutes(app *fiber.App,
cl *config.BackendConfigLoader,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
auth func(*fiber.Ctx) error) {

app.Get("/swagger/*", swagger.HandlerDefault) // default

// LocalAI API endpoints
galleryService := services.NewGalleryService(appConfig.ModelPath)
galleryService.Start(appConfig.Context, cl)

modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
app.Delete("/models/galleries", auth, modelGalleryEndpointService.RemoveModelGalleryEndpoint())
app.Get("/models/jobs/:uuid", auth, modelGalleryEndpointService.GetOpStatusEndpoint())
app.Get("/models/jobs", auth, modelGalleryEndpointService.GetAllStatusEndpoint())

app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig))

// Stores
sl := model.NewModelLoader("")
app.Post("/stores/set", auth, localai.StoresSetEndpoint(sl, appConfig))
app.Post("/stores/delete", auth, localai.StoresDeleteEndpoint(sl, appConfig))
app.Post("/stores/get", auth, localai.StoresGetEndpoint(sl, appConfig))
app.Post("/stores/find", auth, localai.StoresFindEndpoint(sl, appConfig))

// Kubernetes health checks
ok := func(c *fiber.Ctx) error {
return c.SendStatus(200)
}

app.Get("/healthz", ok)
app.Get("/readyz", ok)

app.Get("/metrics", auth, localai.LocalAIMetricsEndpoint())

// Experimental Backend Statistics Module
backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now
app.Get("/backend/monitor", auth, localai.BackendMonitorEndpoint(backendMonitor))
app.Post("/backend/shutdown", auth, localai.BackendShutdownEndpoint(backendMonitor))

app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})

}
Loading

0 comments on commit 284ad02

Please sign in to comment.