Skip to content

Commit

Permalink
fix(api): allow device listing without tenant
Browse files Browse the repository at this point in the history
  • Loading branch information
heiytor committed Jun 26, 2024
1 parent 6c93607 commit 4a2ee07
Show file tree
Hide file tree
Showing 6 changed files with 456 additions and 1,130 deletions.
38 changes: 11 additions & 27 deletions api/routes/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"strconv"

"github.com/shellhub-io/shellhub/api/pkg/gateway"
"github.com/shellhub-io/shellhub/pkg/api/query"
"github.com/shellhub-io/shellhub/pkg/api/requests"
"github.com/shellhub-io/shellhub/pkg/models"
)
Expand All @@ -32,46 +31,31 @@ const (
)

func (h *Handler) GetDeviceList(c gateway.Context) error {
type Query struct {
Status models.DeviceStatus `query:"status"`
query.Paginator
query.Sorter
query.Filters
}

query := Query{}
req := new(requests.DeviceList)

if err := c.Bind(&query); err != nil {
if err := c.Bind(req); err != nil {
return err
}

query.Paginator.Normalize()
query.Sorter.Normalize()
req.Paginator.Normalize()
req.Sorter.Normalize()

if err := query.Filters.Unmarshal(); err != nil {
if err := req.Filters.Unmarshal(); err != nil {
return err
}

var tenant string
if c.Tenant() != nil {
tenant = c.Tenant().ID
if err := c.Validate(req); err != nil {
return err
}

devices, count, err := h.service.ListDevices(
c.Ctx(),
tenant,
query.Status,
query.Paginator,
query.Filters,
query.Sorter,
)
res, count, err := h.service.ListDevices(c.Ctx(), req)
c.Response().Header().Set("X-Total-Count", strconv.Itoa(count))

if err != nil {
return err
}

c.Response().Header().Set("X-Total-Count", strconv.Itoa(count))

return c.JSON(http.StatusOK, devices)
return c.JSON(http.StatusOK, res)
}

func (h *Handler) GetDevice(c gateway.Context) error {
Expand Down
139 changes: 49 additions & 90 deletions api/routes/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"

Expand All @@ -17,6 +19,7 @@ import (
"github.com/shellhub-io/shellhub/pkg/models"
"github.com/stretchr/testify/assert"
gomock "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

func TestGetDevice(t *testing.T) {
Expand Down Expand Up @@ -281,130 +284,86 @@ func TestGetDeviceList(t *testing.T) {
mock := new(mocks.Service)

type Expected struct {
session []models.Device
devices []models.Device
status int
}

cases := []struct {
description string
paginator query.Paginator
sorter query.Sorter
filters query.Filters
status models.DeviceStatus
tenant string
requiredMocks func(status models.DeviceStatus, paginator query.Paginator, filters query.Filters, sorter query.Sorter)
req *requests.DeviceList
requiredMocks func()
expected Expected
}{
{
description: "fails when try to get a device list existing",
tenant: "tenant-id",
status: models.DeviceStatus("online"),
paginator: query.Paginator{Page: 1, PerPage: 10},
sorter: query.Sorter{By: "name", Order: query.OrderAsc},
filters: query.Filters{
Raw: "Wwp7CiAgInR5cGUiOiAicHJvcGVydHkiLAogICJwYXJhbXMiOiB7CiAgICAibmFtZSI6ICJuYW1lIiwKICAgICJvcGVyYXRvciI6ICJjb250YWlucyIsCiAgICAidmFsdWUiOiAiZXhhbXBsZXNwYWNlIgogIH0KfQpd",
Data: []query.Filter{
{
Type: "property",
Params: &query.FilterProperty{
Name: "name",
Operator: "contains",
Value: "examplespace",
},
},
},
},
requiredMocks: func(status models.DeviceStatus, paginator query.Paginator, filters query.Filters, sorter query.Sorter) {
mock.On("ListDevices",
gomock.Anything,
"tenant-id",
status,
paginator,
filters,
sorter,
).Return(nil, 0, svc.ErrDeviceNotFound).Once()
req: &requests.DeviceList{
TenantID: "00000000-0000-4000-0000-000000000000",
DeviceStatus: models.DeviceStatus("online"),
Paginator: query.Paginator{Page: 1, PerPage: 10},
Sorter: query.Sorter{By: "name", Order: "asc"},
Filters: query.Filters{},
},
requiredMocks: func() {
mock.
On("ListDevices", gomock.Anything, gomock.AnythingOfType("*requests.DeviceList")).
Return(nil, 0, svc.ErrDeviceNotFound).
Once()
},
expected: Expected{
session: nil,
devices: []models.Device{},
status: http.StatusNotFound,
},
},
{
description: "fails when try to get a device list existing",
tenant: "tenant-id",
status: models.DeviceStatus("online"),
paginator: query.Paginator{Page: 1, PerPage: 10},
sorter: query.Sorter{By: "name", Order: query.OrderAsc},
filters: query.Filters{
Raw: "Wwp7CiAgInR5cGUiOiAicHJvcGVydHkiLAogICJwYXJhbXMiOiB7CiAgICAibmFtZSI6ICJuYW1lIiwKICAgICJvcGVyYXRvciI6ICJjb250YWlucyIsCiAgICAidmFsdWUiOiAiZXhhbXBsZXNwYWNlIgogIH0KfQpd",
Data: []query.Filter{
{
Type: "property",
Params: &query.FilterProperty{
Name: "name",
Operator: "contains",
Value: "examplespace",
},
},
},
},
requiredMocks: func(status models.DeviceStatus, paginator query.Paginator, filters query.Filters, sorter query.Sorter) {
mock.On("ListDevices",
gomock.Anything,
"tenant-id",
status,
paginator,
filters,
sorter,
).Return([]models.Device{}, 1, nil).Once()
req: &requests.DeviceList{
TenantID: "00000000-0000-4000-0000-000000000000",
DeviceStatus: models.DeviceStatus("online"),
Paginator: query.Paginator{Page: 1, PerPage: 10},
Sorter: query.Sorter{By: "name", Order: "asc"},
Filters: query.Filters{},
},
requiredMocks: func() {
mock.
On("ListDevices", gomock.Anything, gomock.AnythingOfType("*requests.DeviceList")).
Return([]models.Device{}, 0, nil).
Once()
},
expected: Expected{
session: []models.Device{},
devices: []models.Device{},
status: http.StatusOK,
},
},
}

for _, tc := range cases {
t.Run(tc.description, func(t *testing.T) {
tc.requiredMocks(tc.status, tc.paginator, tc.filters, tc.sorter)

type Query struct {
Status models.DeviceStatus `query:"status"`
query.Paginator
query.Sorter
query.Filters
}

b := Query{
Status: tc.status,
Paginator: tc.paginator,
Sorter: tc.sorter,
Filters: tc.filters,
}
tc.requiredMocks()

jsonData, err := json.Marshal(b)
if err != nil {
assert.NoError(t, err)
}
urlVal := &url.Values{}
urlVal.Set("page", strconv.Itoa(tc.req.Page))
urlVal.Set("per_page", strconv.Itoa(tc.req.PerPage))
urlVal.Set("sort_by", tc.req.By)
urlVal.Set("order_by", tc.req.Order)
urlVal.Set("status", string(tc.req.DeviceStatus))

req := httptest.NewRequest(http.MethodGet, "/api/devices", strings.NewReader(string(jsonData)))
req.Header.Set("Content-Type", "application/json")
req := httptest.NewRequest(http.MethodGet, "/api/devices?"+urlVal.Encode(), nil)
req.Header.Set("X-Role", authorizer.RoleOwner.String())
req.Header.Set("X-Tenant-ID", tc.tenant)
rec := httptest.NewRecorder()
req.Header.Set("X-Tenant-ID", tc.req.TenantID)

rec := httptest.NewRecorder()
e := NewRouter(mock)
e.ServeHTTP(rec, req)

assert.Equal(t, tc.expected.status, rec.Result().StatusCode)

var session []models.Device
if err := json.NewDecoder(rec.Result().Body).Decode(&session); err != nil {
assert.ErrorIs(t, io.EOF, err)
devices := make([]models.Device, 0)
if len(tc.expected.devices) != 0 {
if err := json.NewDecoder(rec.Result().Body).Decode(&devices); err != nil {
require.ErrorIs(t, io.EOF, err)
}
}

assert.Equal(t, tc.expected.session, session)
require.Equal(t, tc.expected.status, rec.Result().StatusCode)
require.Equal(t, tc.expected.devices, devices)
})
}
}
Expand Down
53 changes: 29 additions & 24 deletions api/services/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

"github.com/shellhub-io/shellhub/api/store"
req "github.com/shellhub-io/shellhub/pkg/api/internalclient"
"github.com/shellhub-io/shellhub/pkg/api/query"
"github.com/shellhub-io/shellhub/pkg/api/requests"
"github.com/shellhub-io/shellhub/pkg/envs"
"github.com/shellhub-io/shellhub/pkg/models"
"github.com/shellhub-io/shellhub/pkg/validator"
Expand All @@ -18,7 +18,7 @@ import (
const StatusAccepted = "accepted"

type DeviceService interface {
ListDevices(ctx context.Context, tenant string, status models.DeviceStatus, paginator query.Paginator, filter query.Filters, sorter query.Sorter) ([]models.Device, int, error)
ListDevices(ctx context.Context, req *requests.DeviceList) ([]models.Device, int, error)
GetDevice(ctx context.Context, uid models.UID) (*models.Device, error)
GetDeviceByPublicURLAddress(ctx context.Context, address string) (*models.Device, error)
DeleteDevice(ctx context.Context, uid models.UID, tenant string) error
Expand All @@ -29,14 +29,10 @@ type DeviceService interface {
UpdateDevice(ctx context.Context, tenant string, uid models.UID, name *string, publicURL *bool) error
}

func (s *service) ListDevices(ctx context.Context, tenant string, status models.DeviceStatus, paginator query.Paginator, filter query.Filters, sorter query.Sorter) ([]models.Device, int, error) {
ns, err := s.store.NamespaceGet(ctx, tenant, true)
if err != nil {
return nil, 0, NewErrNamespaceNotFound(tenant, err)
}

if status == models.DeviceStatusRemoved {
removed, count, err := s.store.DeviceRemovedList(ctx, tenant, paginator, filter, sorter)
func (s *service) ListDevices(ctx context.Context, req *requests.DeviceList) ([]models.Device, int, error) {
if req.DeviceStatus == models.DeviceStatusRemoved {
// TODO: unique DeviceList
removed, count, err := s.store.DeviceRemovedList(ctx, req.TenantID, req.Paginator, req.Filters, req.Sorter)
if err != nil {
return nil, 0, err
}
Expand All @@ -49,25 +45,34 @@ func (s *service) ListDevices(ctx context.Context, tenant string, status models.
return devices, count, nil
}

if ns.HasMaxDevices() {
switch {
case envs.IsCloud():
removed, err := s.store.DeviceRemovedCount(ctx, ns.TenantID)
if err != nil {
return nil, 0, NewErrDeviceRemovedCount(err)
}
if req.TenantID != "" {
ns, err := s.store.NamespaceGet(ctx, req.TenantID, true)
if err != nil {
return nil, 0, NewErrNamespaceNotFound(req.TenantID, err)
}

if ns.HasLimitDevicesReached(removed) {
return s.store.DeviceList(ctx, status, paginator, filter, sorter, store.DeviceAcceptableFromRemoved)
}
case envs.IsCommunity(), envs.IsEnterprise():
if ns.HasMaxDevicesReached() {
return s.store.DeviceList(ctx, status, paginator, filter, sorter, store.DeviceAcceptableAsFalse)
if ns.HasMaxDevices() {
switch {
case envs.IsCloud():
removed, err := s.store.DeviceRemovedCount(ctx, ns.TenantID)
if err != nil {
return nil, 0, NewErrDeviceRemovedCount(err)
}

if ns.HasLimitDevicesReached(removed) {
return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, store.DeviceAcceptableFromRemoved)
}
case envs.IsEnterprise():
fallthrough
case envs.IsCommunity():
if ns.HasMaxDevicesReached() {
return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, store.DeviceAcceptableAsFalse)
}
}
}
}

return s.store.DeviceList(ctx, status, paginator, filter, sorter, store.DeviceAcceptableIfNotAccepted)
return s.store.DeviceList(ctx, req.DeviceStatus, req.Paginator, req.Filters, req.Sorter, store.DeviceAcceptableIfNotAccepted)
}

func (s *service) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) {
Expand Down
Loading

0 comments on commit 4a2ee07

Please sign in to comment.