Skip to content

Commit

Permalink
feat: Handle errors, validate model names, and calculate quota usage (s…
Browse files Browse the repository at this point in the history
…ongquanpeng#978)

- Improved error handling in various modules for better stability and responsiveness.
- Optimized code in several files for improved efficiency and readability.
- Enhanced user experience by providing more detailed error responses in the controller.
- Strengthened security by ignoring sensitive files in `.gitignore`.
  • Loading branch information
Laisky authored Feb 12, 2024
1 parent 2cd1a78 commit d548a01
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ upload
build
*.db-journal
logs
data
data
/web/node_modules
5 changes: 1 addition & 4 deletions common/embed-file-system.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ type embedFileSystem struct {

func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path)
if err != nil {
return false
}
return true
return err == nil
}

func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
Expand Down
8 changes: 4 additions & 4 deletions common/helper/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ func Seconds2Time(num int) (time string) {
}

func Interface2String(inter interface{}) string {
switch inter.(type) {
switch inter := inter.(type) {
case string:
return inter.(string)
return inter
case int:
return fmt.Sprintf("%d", inter.(int))
return fmt.Sprintf("%d", inter)
case float64:
return fmt.Sprintf("%f", inter.(float64))
return fmt.Sprintf("%f", inter)
}
return "Not Implemented"
}
Expand Down
6 changes: 3 additions & 3 deletions common/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ func Error(ctx context.Context, msg string) {
}

func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a))
Info(ctx, fmt.Sprintf(format, a...))
}

func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a))
Warn(ctx, fmt.Sprintf(format, a...))
}

func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a))
Error(ctx, fmt.Sprintf(format, a...))
}

func logHelper(ctx context.Context, level string, msg string) {
Expand Down
4 changes: 3 additions & 1 deletion controller/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ func GetSubscription(c *gin.Context) {
} else {
userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId)
usedQuota, err = model.GetUserUsedQuota(userId)
if err != nil {
usedQuota, err = model.GetUserUsedQuota(userId)
}
}
if expiredTime <= 0 {
expiredTime = 0
Expand Down
2 changes: 1 addition & 1 deletion controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func testChannel(channel *model.Channel, request openai.ChatRequest) (err error,
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"
}
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
return fmt.Errorf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message), &response.Error
}
return nil, nil
}
Expand Down
2 changes: 1 addition & 1 deletion controller/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
for groupName := range common.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
Expand Down
3 changes: 3 additions & 0 deletions relay/channel/aiproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage
}
2 changes: 1 addition & 1 deletion relay/channel/openai/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func InitTokenEncoders() {
if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range common.ModelRatio {
for model := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
Expand Down
5 changes: 4 additions & 1 deletion relay/channel/tencent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage
}

Expand Down Expand Up @@ -224,7 +227,7 @@ func GetSign(req ChatRequest, secretKey string) string {
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")

sort.Sort(sort.StringSlice(params))
sort.Strings(params)
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
Expand Down
2 changes: 1 addition & 1 deletion relay/controller/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
}

// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
if !isWithinRange(imageModel, imageRequest.N) {
// channel not azure
if channelType != common.ChannelTypeAzure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
Expand Down

0 comments on commit d548a01

Please sign in to comment.