Skip to content

Commit

Permalink
Merge pull request #38 from kynmh69/35-feature-add-key-auth
Browse files Browse the repository at this point in the history
35 feature add key auth
  • Loading branch information
kynmh69 authored Oct 13, 2024
2 parents d9b1271 + dc94634 commit 59934b6
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 25 deletions.
6 changes: 6 additions & 0 deletions src/api/.air.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@ tmp_dir = "tmp"

[log]
main_only = false
silent = false
time = false

[misc]
clean_on_exit = false

[proxy]
app_port = 0
enabled = false
proxy_port = 0

[screen]
clear_on_rebuild = false
keep_scroll = true
17 changes: 4 additions & 13 deletions src/api/handler/count_holidays.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package handler

import (
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/database"
"github.com/kynmh69/go-ja-holidays/logging"
"github.com/kynmh69/go-ja-holidays/model"
"net/http"
"time"

"github.com/doug-martin/goqu/v9"
"github.com/kynmh69/go-ja-holidays/database"
)

type CountStruct struct {
Expand All @@ -18,12 +15,6 @@ type CountStruct struct {
func CountHolidays(c *gin.Context) {
logger := logging.GetLogger()
var request HolidaysRequest
if location, err := time.LoadLocation(LOCATION); err != nil {
BadRequestJson(c, err.Error())
return
} else {
goqu.SetTimeLocation(location)
}

if err := c.ShouldBindQuery(&request); err != nil {
logger.Error(err)
Expand All @@ -35,11 +26,11 @@ func CountHolidays(c *gin.Context) {

dataSet := db.Model(&model.HolidayData{})
if !request.StartDay.IsZero() && !request.EndDay.IsZero() {
dataSet = dataSet.Where("created_at BETWEEN ? AND ?", request.StartDay, request.EndDay)
dataSet = dataSet.Where("date BETWEEN ? AND ?", request.StartDay, request.EndDay)
} else if !request.StartDay.IsZero() {
dataSet = dataSet.Where("created_at >= ?", request.StartDay)
dataSet = dataSet.Where("date >= ?", request.StartDay)
} else if !request.EndDay.IsZero() {
dataSet = dataSet.Where("created_at <= ?", request.EndDay)
dataSet = dataSet.Where("date <= ?", request.EndDay)
}
var count int64
err := dataSet.Count(&count).Error
Expand Down
17 changes: 12 additions & 5 deletions src/api/handler/get_holidays.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package handler
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/logging"
"net/http"
"time"

Expand All @@ -12,17 +13,16 @@ import (
)

const (
ColumnDate = "holiday_date"
LOCATION = "Asia/Tokyo"
LOCATION = "Asia/Tokyo"
)

type HolidaysRequest struct {
StartDay time.Time `form:"start_day" time_format:"2006-01-02"`
EndDay time.Time `form:"end_day" time_format:"2006-01-02"`
}

func (receiver HolidaysRequest) String() string {
return fmt.Sprintf("StartDay: \"%s\", EndDay: \"%s\"", receiver.StartDay, receiver.EndDay)
func (h *HolidaysRequest) String() string {
return fmt.Sprintf("StartDay: \"%s\", EndDay: \"%s\"", h.StartDay, h.EndDay)
}

func GetHolidays(c *gin.Context) {
Expand All @@ -32,6 +32,8 @@ func GetHolidays(c *gin.Context) {
)
// DB接続
db := database.GetDbConnection()
// get logger
logger := logging.GetLogger()
// タイムゾーン設定
if location, err := time.LoadLocation(LOCATION); err != nil {
BadRequestJson(c, err.Error())
Expand All @@ -46,13 +48,18 @@ func GetHolidays(c *gin.Context) {
return
}
// リクエストパラメータから開始日と終了日を取得
dataSet := db.Model(&model.HolidayData{}).Order(goqu.C(ColumnDate).Asc())
dataSet := db.Model(&model.HolidayData{})
if !reqParams.StartDay.IsZero() && !reqParams.EndDay.IsZero() {
logger.Debugf("StartDay: %s, EndDay: %s", reqParams.StartDay, reqParams.EndDay)
dataSet = dataSet.Where("date between ? and ?", reqParams.StartDay, reqParams.EndDay)
} else if !reqParams.StartDay.IsZero() {
logger.Debugf("StartDay: %s", reqParams.StartDay)
dataSet = dataSet.Where("date >= ?", reqParams.StartDay)
} else if !reqParams.EndDay.IsZero() {
logger.Debugf("EndDay: %s", reqParams.EndDay)
dataSet = dataSet.Where("date <= ?", reqParams.EndDay)
} else {
logger.Debug("No request parameters")
}
// データ取得
if err := dataSet.Find(&holidays).Error; err != nil {
Expand Down
9 changes: 2 additions & 7 deletions src/api/handler/is_holiday.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"time"

"github.com/doug-martin/goqu/v9"
"github.com/gin-gonic/gin"

"github.com/kynmh69/go-ja-holidays/database"
Expand All @@ -32,15 +31,11 @@ func IsHoliday(c *gin.Context) {
BadRequestJson(c, err.Error())
return
}
logger.Debug(request.Date)

// Set the time zone to JST.
loc := request.Date.Location()
goqu.SetTimeLocation(loc)
logger.Debug("requested date ", request.Date)

// Get the holiday data for the specified day.
holiday.Date = request.Date
result := db.First(&holiday)
result := db.Where(&holiday).Take(&holiday)
err := result.Error

var isHoliday model.IsHoliday
Expand Down
2 changes: 2 additions & 0 deletions src/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package router
import (
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/api/handler"
"github.com/kynmh69/go-ja-holidays/middleware"
)

func MakeRoute(r *gin.Engine) {
r.Use(middleware.Auth)
r.GET("/holidays", handler.GetHolidays)
r.GET("/holidays/:date", handler.IsHoliday)
r.GET("/holidays/count", handler.CountHolidays)
Expand Down
26 changes: 26 additions & 0 deletions src/middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package middleware

import (
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/logging"
"github.com/kynmh69/go-ja-holidays/model"
"net/http"
)

func Auth(c *gin.Context) {
logger := logging.GetLogger()
// get api key from request header
key := c.GetHeader("X-API-KEY")
if key == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "API key is required."})
return
}
apiKey, err := model.GetApiKey(key)
if err != nil {
logger.Warnln("API key is invalid. ", key)
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "API key is invalid."})
return
}
logger.Debugln("API key is valid. ", apiKey.Id)
c.Next()
}
79 changes: 79 additions & 0 deletions src/middleware/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package middleware

import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/kynmh69/go-ja-holidays/database"
"github.com/kynmh69/go-ja-holidays/logging"
"github.com/kynmh69/go-ja-holidays/model"
"gorm.io/gorm"
"net/http"
"net/http/httptest"
"os"
"testing"
)

func TestMain(m *testing.M) {
_ = os.Setenv("PSQL_HOSTNAME", "localhost")
_ = os.Setenv("DATABASE", "unittest")
logging.LoggerInitialize()
database.ConnectDatabase()

defer tearDown()

db := database.GetDbConnection()
if err := db.AutoMigrate(
&model.ApiKey{},
&model.HolidayData{},
); err != nil {
logging.GetLogger().Panicln(err)
}
if code := m.Run(); code > 0 {
logging.GetLogger().Panicln("Test failed with code ", code)
}
logging.GetLogger().Infoln("Test passed")
}

func TestAuth(t *testing.T) {
r := gin.Default()
u, _ := uuid.NewUUID()
var apiKey model.ApiKey
apiKeyError := model.ApiKey{Key: u}
db := database.GetDbConnection()
db.Create(&apiKey)
tests := []struct {
name string
apiKey model.ApiKey
statusCode int
}{
{
name: "Test Auth",
apiKey: apiKey,
statusCode: http.StatusOK,
},
{
name: "Test Auth Error",
apiKey: apiKeyError,
statusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-KEY", tt.apiKey.Key.String())
w := httptest.NewRecorder()
ctx := gin.CreateTestContextOnly(w, r)
ctx.Request = req
Auth(ctx)
if tt.statusCode != w.Code {
t.Errorf("expected %d but got %d", tt.statusCode, w.Code)
}
})
}
}

func tearDown() {
db := database.GetDbConnection()
db.Session(&gorm.Session{AllowGlobalUpdate: true}).
Delete(&model.ApiKey{})
}
11 changes: 11 additions & 0 deletions src/model/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ func GetApiKeys() ([]ApiKey, error) {
return apiKeys, err
}

func GetApiKey(apiKeyStr string) (ApiKey, error) {
logger := logging.GetLogger()
var apiKey ApiKey
db := database.GetDbConnection()
err := db.Where("key = ?", apiKeyStr).First(&apiKey).Error
if err != nil {
logger.Debug("API key is invalid. ")
}
return apiKey, err
}

func CreateApiKey() error {
logger := logging.GetLogger()
db := database.GetDbConnection()
Expand Down

0 comments on commit 59934b6

Please sign in to comment.