Skip to content

Commit

Permalink
chore: in-progress add sqlite source impl
Browse files Browse the repository at this point in the history
  • Loading branch information
jirevwe committed Nov 18, 2024
1 parent 0ec67a9 commit ad3563f
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 243 deletions.
46 changes: 46 additions & 0 deletions database/sqlite3/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,49 @@ type EndpointSecret struct {
Endpoint datastore.Endpoint `json:"endpoint"`
Secret datastore.Secret `db:"secret"`
}

type dbEndpoint struct {
UID string `db:"id"`
Name string `db:"name"`
Status datastore.EndpointStatus `db:"status"`
OwnerID string `db:"owner_id"`
Url string `db:"url"`
Description string `db:"description"`
HttpTimeout uint64 `db:"http_timeout"`
RateLimit int `db:"rate_limit"`
RateLimitDuration uint64 `db:"rate_limit_duration"`
AdvancedSignatures bool `db:"advanced_signatures"`
SlackWebhookURL string `db:"slack_webhook_url"`
SupportEmail string `db:"support_email"`
AppID string `db:"app_id"`
ProjectID string `db:"project_id"`
Secrets datastore.Secrets `db:"secrets"`
Authentication *datastore.EndpointAuthentication `db:"authentication"`
CreatedAt string `db:"created_at"`
UpdatedAt string `db:"updated_at"`
DeletedAt *string `db:"deleted_at"`
}

func (e *dbEndpoint) toDatastoreEndpoint() *datastore.Endpoint {
return &datastore.Endpoint{
UID: e.UID,
Name: e.Name,
Status: e.Status,
OwnerID: e.OwnerID,
Url: e.Url,
Description: e.Description,
HttpTimeout: e.HttpTimeout,
RateLimit: e.RateLimit,
RateLimitDuration: e.RateLimitDuration,
AdvancedSignatures: e.AdvancedSignatures,
SlackWebhookURL: e.SlackWebhookURL,
SupportEmail: e.SupportEmail,
AppID: e.AppID,
ProjectID: e.ProjectID,
Secrets: e.Secrets,
Authentication: e.Authentication,
CreatedAt: asTime(e.CreatedAt),
UpdatedAt: asTime(e.UpdatedAt),
DeletedAt: asNullTime(e.DeletedAt),
}
}
154 changes: 107 additions & 47 deletions database/sqlite3/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"github.com/lib/pq"
"time"

"github.com/oklog/ulid/v2"

Expand Down Expand Up @@ -33,6 +34,7 @@ const (

updateSourceById = `
UPDATE sources SET
updated_at = $1,
name= $2,
type=$3,
mask_id=$4,
Expand All @@ -45,13 +47,13 @@ const (
custom_response_content_type = $11,
idempotency_keys = $12,
body_function = $13,
header_function = $14,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL ;
header_function = $14
WHERE id = $15 AND deleted_at IS NULL ;
`

updateSourceVerifierById = `
UPDATE source_verifiers SET
updated_at = $1,
type=$2,
basic_username=$3,
basic_password=$4,
Expand All @@ -60,9 +62,8 @@ const (
hmac_hash=$7,
hmac_header=$8,
hmac_secret=$9,
hmac_encoding=$10,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL;
hmac_encoding=$10
WHERE id = $11 AND deleted_at IS NULL;
`

baseFetchSource = `
Expand Down Expand Up @@ -123,26 +124,26 @@ const (

deleteSource = `
UPDATE sources SET
deleted_at = NOW()
WHERE id = $1 AND project_id = $2 AND deleted_at IS NULL;
deleted_at = $1
WHERE id = $2 AND project_id = $3 AND deleted_at IS NULL;
`

deleteSourceVerifier = `
UPDATE source_verifiers SET
deleted_at = NOW()
WHERE id = $1 AND deleted_at IS NULL;
deleted_at = $1
WHERE id = $2 AND deleted_at IS NULL;
`

deleteSourceSubscription = `
UPDATE subscriptions SET
deleted_at = NOW()
WHERE source_id = $1 AND project_id = $2 AND deleted_at IS NULL;
deleted_at = $1
WHERE source_id = $2 AND project_id = $3 AND deleted_at IS NULL;
`

fetchSourcesPagedFilter = `
AND (s.type = :type OR :type = '')
AND (s.provider = :provider OR :provider = '')
AND s.name ILIKE :query
AND s.name LIKE :query
AND s.project_id = :project_id
`

Expand Down Expand Up @@ -197,7 +198,6 @@ func NewSourceRepo(db database.Database) datastore.SourceRepository {
}

func (s *sourceRepo) CreateSource(ctx context.Context, source *datastore.Source) error {
var sourceVerifierID *string
tx, err := s.db.BeginTxx(ctx, &sql.TxOptions{})
if err != nil {
return err
Expand All @@ -221,10 +221,9 @@ func (s *sourceRepo) CreateSource(ctx context.Context, source *datastore.Source)

if !util.IsStringEmpty(string(source.Verifier.Type)) {
id := ulid.Make().String()
sourceVerifierID = &id

result2, err := tx.ExecContext(
ctx, createSourceVerifier, sourceVerifierID, source.Verifier.Type, basic.UserName, basic.Password,
ctx, createSourceVerifier, id, source.Verifier.Type, basic.UserName, basic.Password,
apiKey.HeaderName, apiKey.HeaderValue, hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding,
)
if err != nil {
Expand All @@ -239,14 +238,12 @@ func (s *sourceRepo) CreateSource(ctx context.Context, source *datastore.Source)
if rowsAffected < 1 {
return ErrSourceVerifierNotCreated
}
}

if !util.IsStringEmpty(string(source.Verifier.Type)) {
source.VerifierID = *sourceVerifierID
source.VerifierID = id
}

result1, err := tx.ExecContext(
ctx, createSource, source.UID, sourceVerifierID, source.Name, source.Type, source.MaskID,
ctx, createSource, source.UID, source.VerifierID, source.Name, source.Type, source.MaskID,
source.Provider, source.IsDisabled, pq.Array(source.ForwardHeaders), source.ProjectID,
source.PubSub, source.CustomResponse.Body, source.CustomResponse.ContentType,
source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction,
Expand Down Expand Up @@ -280,10 +277,10 @@ func (s *sourceRepo) UpdateSource(ctx context.Context, projectID string, source
defer rollbackTx(tx)

result, err := tx.ExecContext(
ctx, updateSourceById, source.UID, source.Name, source.Type, source.MaskID,
ctx, updateSourceById, time.Now(), source.Name, source.Type, source.MaskID,
source.Provider, source.IsDisabled, source.ForwardHeaders, projectID,
source.PubSub, source.CustomResponse.Body, source.CustomResponse.ContentType,
source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction,
source.IdempotencyKeys, source.BodyFunction, source.HeaderFunction, source.UID,
)
if err != nil {
return err
Expand Down Expand Up @@ -314,8 +311,9 @@ func (s *sourceRepo) UpdateSource(ctx context.Context, projectID string, source

if !util.IsStringEmpty(string(source.Verifier.Type)) {
result2, err := tx.ExecContext(
ctx, updateSourceVerifierById, source.VerifierID, source.Verifier.Type, basic.UserName, basic.Password,
apiKey.HeaderName, apiKey.HeaderValue, hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding,
ctx, updateSourceVerifierById, time.Now(), source.Verifier.Type,
basic.UserName, basic.Password, apiKey.HeaderName, apiKey.HeaderValue,
hmac.Hash, hmac.Header, hmac.Secret, hmac.Encoding, source.VerifierID,
)
if err != nil {
return err
Expand All @@ -340,7 +338,7 @@ func (s *sourceRepo) UpdateSource(ctx context.Context, projectID string, source
}

func (s *sourceRepo) FindSourceByID(ctx context.Context, projectId string, id string) (*datastore.Source, error) {
source := &datastore.Source{}
source := &dbSource{}
err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSource, "s.id"), id).StructScan(source)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -349,11 +347,11 @@ func (s *sourceRepo) FindSourceByID(ctx context.Context, projectId string, id st
return nil, err
}

return source, nil
return source.toDatastoreSource(), nil
}

func (s *sourceRepo) FindSourceByName(ctx context.Context, projectID string, name string) (*datastore.Source, error) {
source := &datastore.Source{}
source := &dbSource{}
err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSourceByName, "s.project_id", "s.name"), projectID, name).StructScan(source)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -362,11 +360,11 @@ func (s *sourceRepo) FindSourceByName(ctx context.Context, projectID string, nam
return nil, err
}

return source, nil
return source.toDatastoreSource(), nil
}

func (s *sourceRepo) FindSourceByMaskID(ctx context.Context, maskID string) (*datastore.Source, error) {
source := &datastore.Source{}
source := &dbSource{}
err := s.db.QueryRowxContext(ctx, fmt.Sprintf(fetchSource, "s.mask_id"), maskID).StructScan(source)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -375,7 +373,7 @@ func (s *sourceRepo) FindSourceByMaskID(ctx context.Context, maskID string) (*da
return nil, err
}

return source, nil
return source.toDatastoreSource(), nil
}

func (s *sourceRepo) DeleteSourceByID(ctx context.Context, projectId string, id, sourceVerifierId string) error {
Expand All @@ -385,17 +383,17 @@ func (s *sourceRepo) DeleteSourceByID(ctx context.Context, projectId string, id,
}
defer rollbackTx(tx)

_, err = tx.ExecContext(ctx, deleteSourceVerifier, sourceVerifierId)
_, err = tx.ExecContext(ctx, deleteSourceVerifier, time.Now(), sourceVerifierId)
if err != nil {
return err
}

_, err = tx.ExecContext(ctx, deleteSource, id, projectId)
_, err = tx.ExecContext(ctx, deleteSource, time.Now(), id, projectId)
if err != nil {
return err
}

_, err = tx.ExecContext(ctx, deleteSourceSubscription, id, projectId)
_, err = tx.ExecContext(ctx, deleteSourceSubscription, time.Now(), id, projectId)
if err != nil {
return err
}
Expand Down Expand Up @@ -447,16 +445,16 @@ func (s *sourceRepo) LoadSourcesPaged(ctx context.Context, projectID string, fil

sources := make([]datastore.Source, 0)
for rows.Next() {
var source datastore.Source
source := dbSource{}
err = rows.StructScan(&source)
if err != nil {
return nil, datastore.PaginationData{}, err
}

sources = append(sources, source)
sources = append(sources, *source.toDatastoreSource())
}

var count datastore.PrevRowCount
var rowCount datastore.PrevRowCount
if len(sources) > 0 {
var countQuery string
var qargs []interface{}
Expand All @@ -473,16 +471,16 @@ func (s *sourceRepo) LoadSourcesPaged(ctx context.Context, projectID string, fil
countQuery = s.db.Rebind(countQuery)

// count the row number before the first row
rows, err := s.db.QueryxContext(ctx, countQuery, qargs...)
if err != nil {
return nil, datastore.PaginationData{}, err
resRows, innerErr := s.db.QueryxContext(ctx, countQuery, qargs...)
if innerErr != nil {
return nil, datastore.PaginationData{}, innerErr
}
defer closeWithError(rows)
defer closeWithError(resRows)

if rows.Next() {
err = rows.StructScan(&count)
if err != nil {
return nil, datastore.PaginationData{}, err
if resRows.Next() {
innerErr = resRows.StructScan(&rowCount)
if innerErr != nil {
return nil, datastore.PaginationData{}, innerErr
}
}
}
Expand All @@ -496,7 +494,7 @@ func (s *sourceRepo) LoadSourcesPaged(ctx context.Context, projectID string, fil
sources = sources[:len(sources)-1]
}

pagination := &datastore.PaginationData{PrevRowCount: count}
pagination := &datastore.PaginationData{PrevRowCount: rowCount}
pagination = pagination.Build(pageable, ids)

return sources, *pagination, nil
Expand Down Expand Up @@ -531,13 +529,13 @@ func (s *sourceRepo) LoadPubSubSourcesByProjectIDs(ctx context.Context, projectI

sources := make([]datastore.Source, 0)
for rows.Next() {
var source datastore.Source
source := dbSource{}
err = rows.StructScan(&source)
if err != nil {
return nil, datastore.PaginationData{}, err
}

sources = append(sources, source)
sources = append(sources, *source.toDatastoreSource())
}

// Bypass pagination.Build here since we're only dealing with forward paging here
Expand All @@ -557,3 +555,65 @@ func (s *sourceRepo) LoadPubSubSourcesByProjectIDs(ctx context.Context, projectI

return sources, *pagination, nil
}

type dbSource struct {
UID string `db:"id"`
Name string `db:"name"`
Type string `db:"type"`
Provider string `db:"provider"`
MaskID string `db:"mask_id"`
ProjectID string `db:"project_id"`
IsDisabled bool `db:"is_disabled"`
ForwardHeaders *string `db:"forward_headers"`
PubSub *datastore.PubSubConfig `db:"pub_sub"`
VerifierID string `db:"source_verifier_id"`
Verifier *datastore.VerifierConfig `db:"verifier"`
CustomResponse datastore.CustomResponse `db:"custom_response"`
IdempotencyKeys *string `db:"idempotency_keys"`
BodyFunction *string `db:"body_function"`
HeaderFunction *string `db:"header_function"`
CreatedAt string `db:"created_at"`
UpdatedAt string `db:"updated_at"`
DeletedAt *string `db:"deleted_at"`
}

func (s *dbSource) toDatastoreSource() *datastore.Source {
return &datastore.Source{
UID: s.UID,
Name: s.Name,
Type: datastore.SourceType(s.Type),
Provider: datastore.SourceProvider(s.Provider),
MaskID: s.MaskID,
ProjectID: s.ProjectID,
IsDisabled: s.IsDisabled,
ForwardHeaders: asStringArray(s.ForwardHeaders),
PubSub: s.PubSub,
VerifierID: s.VerifierID,
Verifier: s.Verifier,
CustomResponse: s.CustomResponse,
IdempotencyKeys: asStringArray(s.IdempotencyKeys),
BodyFunction: s.BodyFunction,
HeaderFunction: s.HeaderFunction,
CreatedAt: asTime(s.CreatedAt),
UpdatedAt: asTime(s.UpdatedAt),
DeletedAt: asNullTime(s.DeletedAt),
}
}

func scanSources(rows *sqlx.Rows) ([]datastore.Source, error) {
sources := make([]datastore.Source, 0)
var err error
defer closeWithError(rows)

for rows.Next() {
source := dbSource{}
err = rows.StructScan(&source)
if err != nil {
return nil, err
}

sources = append(sources, *source.toDatastoreSource())
}

return sources, nil
}
Loading

0 comments on commit ad3563f

Please sign in to comment.