Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

35 feature add key auth #38

Merged
merged 13 commits into from
Oct 13, 2024
6 changes: 6 additions & 0 deletions src/api/.air.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@ tmp_dir = "tmp"

[log]
main_only = false
silent = false
time = false

[misc]
clean_on_exit = false

[proxy]
app_port = 0
enabled = false
proxy_port = 0

[screen]
clear_on_rebuild = false
keep_scroll = true
17 changes: 4 additions & 13 deletions src/api/handler/count_holidays.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package handler

import (
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/database"
"github.com/kynmh69/go-ja-holidays/logging"
"github.com/kynmh69/go-ja-holidays/model"
"net/http"
"time"

"github.com/doug-martin/goqu/v9"
"github.com/kynmh69/go-ja-holidays/database"
)

type CountStruct struct {
Expand All @@ -18,12 +15,6 @@ type CountStruct struct {
func CountHolidays(c *gin.Context) {
logger := logging.GetLogger()
var request HolidaysRequest
if location, err := time.LoadLocation(LOCATION); err != nil {
BadRequestJson(c, err.Error())
return
} else {
goqu.SetTimeLocation(location)
}

if err := c.ShouldBindQuery(&request); err != nil {
logger.Error(err)
Expand All @@ -35,11 +26,11 @@ func CountHolidays(c *gin.Context) {

dataSet := db.Model(&model.HolidayData{})
if !request.StartDay.IsZero() && !request.EndDay.IsZero() {
dataSet = dataSet.Where("created_at BETWEEN ? AND ?", request.StartDay, request.EndDay)
dataSet = dataSet.Where("date BETWEEN ? AND ?", request.StartDay, request.EndDay)
} else if !request.StartDay.IsZero() {
dataSet = dataSet.Where("created_at >= ?", request.StartDay)
dataSet = dataSet.Where("date >= ?", request.StartDay)
} else if !request.EndDay.IsZero() {
dataSet = dataSet.Where("created_at <= ?", request.EndDay)
dataSet = dataSet.Where("date <= ?", request.EndDay)
}
var count int64
err := dataSet.Count(&count).Error
Expand Down
17 changes: 12 additions & 5 deletions src/api/handler/get_holidays.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package handler
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/logging"
"net/http"
"time"

Expand All @@ -12,17 +13,16 @@ import (
)

const (
ColumnDate = "holiday_date"
LOCATION = "Asia/Tokyo"
LOCATION = "Asia/Tokyo"
)

type HolidaysRequest struct {
StartDay time.Time `form:"start_day" time_format:"2006-01-02"`
EndDay time.Time `form:"end_day" time_format:"2006-01-02"`
}

func (receiver HolidaysRequest) String() string {
return fmt.Sprintf("StartDay: \"%s\", EndDay: \"%s\"", receiver.StartDay, receiver.EndDay)
func (h *HolidaysRequest) String() string {
return fmt.Sprintf("StartDay: \"%s\", EndDay: \"%s\"", h.StartDay, h.EndDay)
}

func GetHolidays(c *gin.Context) {
Expand All @@ -32,6 +32,8 @@ func GetHolidays(c *gin.Context) {
)
// DB接続
db := database.GetDbConnection()
// get logger
logger := logging.GetLogger()
// タイムゾーン設定
if location, err := time.LoadLocation(LOCATION); err != nil {
BadRequestJson(c, err.Error())
Expand All @@ -46,13 +48,18 @@ func GetHolidays(c *gin.Context) {
return
}
// リクエストパラメータから開始日と終了日を取得
dataSet := db.Model(&model.HolidayData{}).Order(goqu.C(ColumnDate).Asc())
dataSet := db.Model(&model.HolidayData{})
if !reqParams.StartDay.IsZero() && !reqParams.EndDay.IsZero() {
logger.Debugf("StartDay: %s, EndDay: %s", reqParams.StartDay, reqParams.EndDay)
dataSet = dataSet.Where("date between ? and ?", reqParams.StartDay, reqParams.EndDay)
} else if !reqParams.StartDay.IsZero() {
logger.Debugf("StartDay: %s", reqParams.StartDay)
dataSet = dataSet.Where("date >= ?", reqParams.StartDay)
} else if !reqParams.EndDay.IsZero() {
logger.Debugf("EndDay: %s", reqParams.EndDay)
dataSet = dataSet.Where("date <= ?", reqParams.EndDay)
} else {
logger.Debug("No request parameters")
}
// データ取得
if err := dataSet.Find(&holidays).Error; err != nil {
Expand Down
9 changes: 2 additions & 7 deletions src/api/handler/is_holiday.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"time"

"github.com/doug-martin/goqu/v9"
"github.com/gin-gonic/gin"

"github.com/kynmh69/go-ja-holidays/database"
Expand All @@ -32,15 +31,11 @@ func IsHoliday(c *gin.Context) {
BadRequestJson(c, err.Error())
return
}
logger.Debug(request.Date)

// Set the time zone to JST.
loc := request.Date.Location()
goqu.SetTimeLocation(loc)
logger.Debug("requested date ", request.Date)

// Get the holiday data for the specified day.
holiday.Date = request.Date
result := db.First(&holiday)
result := db.Where(&holiday).Take(&holiday)
err := result.Error

var isHoliday model.IsHoliday
Expand Down
2 changes: 2 additions & 0 deletions src/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package router
import (
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/api/handler"
"github.com/kynmh69/go-ja-holidays/middleware"
)

func MakeRoute(r *gin.Engine) {
r.Use(middleware.Auth)
r.GET("/holidays", handler.GetHolidays)
r.GET("/holidays/:date", handler.IsHoliday)
r.GET("/holidays/count", handler.CountHolidays)
Expand Down
26 changes: 26 additions & 0 deletions src/middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package middleware

import (
"github.com/gin-gonic/gin"
"github.com/kynmh69/go-ja-holidays/logging"
"github.com/kynmh69/go-ja-holidays/model"
"net/http"
)

func Auth(c *gin.Context) {
logger := logging.GetLogger()
// get api key from request header
key := c.GetHeader("X-API-KEY")
if key == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "API key is required."})
return
}
apiKey, err := model.GetApiKey(key)
if err != nil {
logger.Warnln("API key is invalid. ", key)
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "API key is invalid."})
return
}
logger.Debugln("API key is valid. ", apiKey.Id)
c.Next()
}
79 changes: 79 additions & 0 deletions src/middleware/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package middleware

import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/kynmh69/go-ja-holidays/database"
"github.com/kynmh69/go-ja-holidays/logging"
"github.com/kynmh69/go-ja-holidays/model"
"gorm.io/gorm"
"net/http"
"net/http/httptest"
"os"
"testing"
)

func TestMain(m *testing.M) {
_ = os.Setenv("PSQL_HOSTNAME", "localhost")
_ = os.Setenv("DATABASE", "unittest")
logging.LoggerInitialize()
database.ConnectDatabase()

defer tearDown()

db := database.GetDbConnection()
if err := db.AutoMigrate(
&model.ApiKey{},
&model.HolidayData{},
); err != nil {
logging.GetLogger().Panicln(err)
}
if code := m.Run(); code > 0 {
logging.GetLogger().Panicln("Test failed with code ", code)
}
logging.GetLogger().Infoln("Test passed")
}

func TestAuth(t *testing.T) {
r := gin.Default()
u, _ := uuid.NewUUID()
var apiKey model.ApiKey
apiKeyError := model.ApiKey{Key: u}
db := database.GetDbConnection()
db.Create(&apiKey)
tests := []struct {
name string
apiKey model.ApiKey
statusCode int
}{
{
name: "Test Auth",
apiKey: apiKey,
statusCode: http.StatusOK,
},
{
name: "Test Auth Error",
apiKey: apiKeyError,
statusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-KEY", tt.apiKey.Key.String())
w := httptest.NewRecorder()
ctx := gin.CreateTestContextOnly(w, r)
ctx.Request = req
Auth(ctx)
if tt.statusCode != w.Code {
t.Errorf("expected %d but got %d", tt.statusCode, w.Code)
}
})
}
}

func tearDown() {
db := database.GetDbConnection()
db.Session(&gorm.Session{AllowGlobalUpdate: true}).
Delete(&model.ApiKey{})
}
11 changes: 11 additions & 0 deletions src/model/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
return apiKeys, err
}

func GetApiKey(apiKeyStr string) (ApiKey, error) {
logger := logging.GetLogger()
var apiKey ApiKey
db := database.GetDbConnection()
err := db.Where("key = ?", apiKeyStr).First(&apiKey).Error
if err != nil {
logger.Debug("API key is invalid. ", apiKeyStr)
}
return apiKey, err
}

func CreateApiKey() error {
logger := logging.GetLogger()
db := database.GetDbConnection()
Expand Down
Loading