diff --git a/.changelog/1090.txt b/.changelog/1090.txt new file mode 100644 index 00000000000..10d1d303fc3 --- /dev/null +++ b/.changelog/1090.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +devices_policy: Add support for additional device settings policies +``` diff --git a/devices_policy.go b/devices_policy.go index ae2d3cfd73e..96d542125e8 100644 --- a/devices_policy.go +++ b/devices_policy.go @@ -17,6 +17,59 @@ type DeviceClientCertificatesZone struct { Result Enabled } +type ServiceModeV2 struct { + Mode string `json:"mode,omitempty"` + Port int `json:"port,omitempty"` +} + +type DeviceSettingsPolicy struct { + ServiceModeV2 *ServiceModeV2 `json:"service_mode_v2"` + DisableAutoFallback *bool `json:"disable_auto_fallback"` + FallbackDomains *[]FallbackDomain `json:"fallback_domains"` + Include *[]SplitTunnel `json:"include"` + Exclude *[]SplitTunnel `json:"exclude"` + GatewayUniqueID *string `json:"gateway_unique_id"` + SupportURL *string `json:"support_url"` + CaptivePortal *int `json:"captive_portal"` + AllowModeSwitch *bool `json:"allow_mode_switch"` + SwitchLocked *bool `json:"switch_locked"` + AllowUpdates *bool `json:"allow_updates"` + AutoConnect *int `json:"auto_connect"` + AllowedToLeave *bool `json:"allowed_to_leave"` + PolicyID *string `json:"policy_id"` + Enabled *bool `json:"enabled"` + Name *string `json:"name"` + Match *string `json:"match"` + Precedence *int `json:"precedence"` + Default bool `json:"default"` +} + +type DeviceSettingsPolicyResponse struct { + Response + Result DeviceSettingsPolicy +} + +type DeleteDeviceSettingsPolicyResponse struct { + Response + Result []DeviceSettingsPolicy +} + +type DeviceSettingsPolicyRequest struct { + DisableAutoFallback *bool `json:"disable_auto_fallback,omitempty"` + CaptivePortal *int `json:"captive_portal,omitempty"` + AllowModeSwitch *bool `json:"allow_mode_switch,omitempty"` + SwitchLocked *bool `json:"switch_locked,omitempty"` + AllowUpdates *bool `json:"allow_updates,omitempty"` + AutoConnect *int `json:"auto_connect,omitempty"` + AllowedToLeave *bool `json:"allowed_to_leave,omitempty"` + SupportURL *string `json:"support_url,omitempty"` + ServiceModeV2 *ServiceModeV2 `json:"service_mode_v2,omitempty"` + Precedence *int `json:"precedence,omitempty"` + Name *string `json:"name,omitempty"` + Match *string `json:"match,omitempty"` + Enabled *bool `json:"enabled,omitempty"` +} + // UpdateDeviceClientCertificates controls the zero trust zone used to provision client certificates. // // API reference: https://api.cloudflare.com/#device-client-certificates @@ -54,3 +107,117 @@ func (api *API) GetDeviceClientCertificatesZone(ctx context.Context, zoneID stri return result, err } + +// CreateDeviceSettingsPolicy creates a settings policy against devices that match the policy +// +// API reference: https://api.cloudflare.com/#devices-create-device-settings-policy +func (api *API) CreateDeviceSettingsPolicy(ctx context.Context, accountID string, req DeviceSettingsPolicyRequest) (DeviceSettingsPolicyResponse, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy", AccountRouteRoot, accountID) + + result := DeviceSettingsPolicyResponse{} + res, err := api.makeRequestContext(ctx, http.MethodPost, uri, req) + if err != nil { + return result, err + } + + if err := json.Unmarshal(res, &result); err != nil { + return result, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return result, err +} + +// UpdateDefaultDeviceSettingsPolicy updates the default settings policy for an account +// +// API reference: https://api.cloudflare.com/#devices-update-default-device-settings-policy +func (api *API) UpdateDefaultDeviceSettingsPolicy(ctx context.Context, accountID string, req DeviceSettingsPolicyRequest) (DeviceSettingsPolicyResponse, error) { + result := DeviceSettingsPolicyResponse{} + uri := fmt.Sprintf("/%s/%s/devices/policy", AccountRouteRoot, accountID) + res, err := api.makeRequestContext(ctx, http.MethodPatch, uri, req) + if err != nil { + return result, err + } + + if err := json.Unmarshal(res, &result); err != nil { + return result, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return result, err +} + +// UpdateDeviceSettingsPolicy updates a settings policy +// +// API reference: https://api.cloudflare.com/#devices-update-device-settings-policy +func (api *API) UpdateDeviceSettingsPolicy(ctx context.Context, accountID, policyID string, req DeviceSettingsPolicyRequest) (DeviceSettingsPolicyResponse, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s", AccountRouteRoot, accountID, policyID) + + result := DeviceSettingsPolicyResponse{} + res, err := api.makeRequestContext(ctx, http.MethodPatch, uri, req) + if err != nil { + return result, err + } + + if err := json.Unmarshal(res, &result); err != nil { + return result, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return result, err +} + +// DeleteDeviceSettingsPolicy deletes a settings policy and returns a list +// of all of the other policies in the account +// +// API reference: https://api.cloudflare.com/#devices-delete-device-settings-policy +func (api *API) DeleteDeviceSettingsPolicy(ctx context.Context, accountID, policyID string) (DeleteDeviceSettingsPolicyResponse, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s", AccountRouteRoot, accountID, policyID) + + result := DeleteDeviceSettingsPolicyResponse{} + res, err := api.makeRequestContext(ctx, http.MethodDelete, uri, nil) + if err != nil { + return result, err + } + + if err := json.Unmarshal(res, &result); err != nil { + return result, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return result, err +} + +// GetDefaultDeviceSettings gets the default device settings policy +// +// API reference: https://api.cloudflare.com/#devices-get-default-device-settings-policy +func (api *API) GetDefaultDeviceSettingsPolicy(ctx context.Context, accountID string) (DeviceSettingsPolicyResponse, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy", AccountRouteRoot, accountID) + + result := DeviceSettingsPolicyResponse{} + res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return result, err + } + + if err := json.Unmarshal(res, &result); err != nil { + return result, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return result, err +} + +// GetDefaultDeviceSettings gets the device settings policy by its policyID +// +// API reference: https://api.cloudflare.com/#devices-get-device-settings-policy-by-id +func (api *API) GetDeviceSettingsPolicy(ctx context.Context, accountID, policyID string) (DeviceSettingsPolicyResponse, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s", AccountRouteRoot, accountID, policyID) + + result := DeviceSettingsPolicyResponse{} + res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return result, err + } + + if err := json.Unmarshal(res, &result); err != nil { + return result, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return result, err +} diff --git a/devices_policy_test.go b/devices_policy_test.go index 22912db8ed1..12bd073549b 100644 --- a/devices_policy_test.go +++ b/devices_policy_test.go @@ -9,6 +9,140 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + deviceSettingsPolicyID = "a842fa8a-a583-482e-9cd9-eb43362949fd" + deviceSettingsPolicyMatch = "identity.email == \"test@example.com\"" + deviceSettingsPolicyPrecedence = 10 + + defaultDeviceSettingsPolicy = DeviceSettingsPolicy{ + ServiceModeV2: &ServiceModeV2{ + Mode: "warp", + }, + DisableAutoFallback: BoolPtr(false), + FallbackDomains: &[]FallbackDomain{ + {Suffix: "invalid"}, + {Suffix: "test"}, + }, + Exclude: &[]SplitTunnel{ + {Address: "10.0.0.0/8"}, + {Address: "100.64.0.0/10"}, + }, + GatewayUniqueID: StringPtr("t1235"), + SupportURL: StringPtr(""), + CaptivePortal: IntPtr(180), + AllowModeSwitch: BoolPtr(false), + SwitchLocked: BoolPtr(false), + AllowUpdates: BoolPtr(false), + AutoConnect: IntPtr(0), + AllowedToLeave: BoolPtr(true), + Enabled: BoolPtr(true), + PolicyID: nil, + Name: nil, + Match: nil, + Precedence: nil, + Default: true, + } + + nonDefaultDeviceSettingsPolicy = DeviceSettingsPolicy{ + ServiceModeV2: &ServiceModeV2{ + Mode: "warp", + }, + DisableAutoFallback: BoolPtr(false), + FallbackDomains: &[]FallbackDomain{ + {Suffix: "invalid"}, + {Suffix: "test"}, + }, + Exclude: &[]SplitTunnel{ + {Address: "10.0.0.0/8"}, + {Address: "100.64.0.0/10"}, + }, + GatewayUniqueID: StringPtr("t1235"), + SupportURL: StringPtr(""), + CaptivePortal: IntPtr(180), + AllowModeSwitch: BoolPtr(false), + SwitchLocked: BoolPtr(false), + AllowUpdates: BoolPtr(false), + AutoConnect: IntPtr(0), + AllowedToLeave: BoolPtr(true), + PolicyID: &deviceSettingsPolicyID, + Enabled: BoolPtr(true), + Name: StringPtr("test"), + Match: &deviceSettingsPolicyMatch, + Precedence: &deviceSettingsPolicyPrecedence, + Default: false, + } + + defaultDeviceSettingsPolicyJson = `{ + "service_mode_v2": { + "mode": "warp" + }, + "disable_auto_fallback": false, + "fallback_domains": [ + { + "suffix": "invalid" + }, + { + "suffix": "test" + } + ], + "exclude": [ + { + "address": "10.0.0.0/8" + }, + { + "address": "100.64.0.0/10" + } + ], + "gateway_unique_id": "t1235", + "support_url": "", + "captive_portal": 180, + "allow_mode_switch": false, + "switch_locked": false, + "allow_updates": false, + "auto_connect": 0, + "allowed_to_leave": true, + "enabled": true, + "default": true + }` + + nonDefaultDeviceSettingsPolicyJson = fmt.Sprintf(`{ + "service_mode_v2": { + "mode": "warp" + }, + "disable_auto_fallback": false, + "fallback_domains": [ + { + "suffix": "invalid" + }, + { + "suffix": "test" + } + ], + "exclude": [ + { + "address": "10.0.0.0/8" + }, + { + "address": "100.64.0.0/10" + } + ], + "gateway_unique_id": "t1235", + "support_url": "", + "captive_portal": 180, + "allow_mode_switch": false, + "switch_locked": false, + "allow_updates": false, + "auto_connect": 0, + "allowed_to_leave": true, + "policy_id": "%s", + "enabled": true, + "name": "test", + "match": %#v, + "precedence": 10, + "default": false + }`, deviceSettingsPolicyID, deviceSettingsPolicyMatch) +) + func TestUpdateDeviceClientCertificatesZone(t *testing.T) { setup() defer teardown() @@ -74,3 +208,208 @@ func TestGetDeviceClientCertificatesZone(t *testing.T) { assert.Equal(t, want, actual) } } + +func TestCreateDeviceSettingsPolicy(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method, "Expected method 'POST', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, `{ + "success": true, + "errors": null, + "messages": null, + "result": %s + }`, nonDefaultDeviceSettingsPolicyJson) + } + + want := DeviceSettingsPolicyResponse{ + Response: Response{ + Success: true, + Errors: nil, + Messages: nil, + }, + Result: nonDefaultDeviceSettingsPolicy, + } + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy", handler) + + actual, err := client.CreateDeviceSettingsPolicy(context.Background(), testAccountID, DeviceSettingsPolicyRequest{ + Precedence: IntPtr(10), + Match: &deviceSettingsPolicyMatch, + Name: StringPtr("test"), + }) + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +} + +func TestUpdateDefaultDeviceSettingsPolicy(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPatch, r.Method, "Expected method 'PATCH', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, `{ + "success": true, + "errors": null, + "messages": null, + "result": %s + }`, defaultDeviceSettingsPolicyJson) + } + + want := DeviceSettingsPolicyResponse{ + Response: Response{ + Success: true, + Errors: nil, + Messages: nil, + }, + Result: defaultDeviceSettingsPolicy, + } + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy", handler) + + actual, err := client.UpdateDefaultDeviceSettingsPolicy(context.Background(), testAccountID, DeviceSettingsPolicyRequest{}) + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +} + +func TestUpdateDeviceSettingsPolicy(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPatch, r.Method, "Expected method 'PATCH', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, `{ + "success": true, + "errors": null, + "messages": null, + "result": %s + }`, nonDefaultDeviceSettingsPolicyJson) + } + + precedence := 10 + want := DeviceSettingsPolicyResponse{ + Response: Response{ + Success: true, + Errors: nil, + Messages: nil, + }, + Result: nonDefaultDeviceSettingsPolicy, + } + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy/"+deviceSettingsPolicyID, handler) + + actual, err := client.UpdateDeviceSettingsPolicy(context.Background(), testAccountID, deviceSettingsPolicyID, DeviceSettingsPolicyRequest{ + Precedence: &precedence, + }) + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +} + +func TestDeleteDeviceSettingsPolicy(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodDelete, r.Method, "Expected method 'DELETE', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, `{ + "success": true, + "errors": null, + "messages": null, + "result": [ %s ] + }`, defaultDeviceSettingsPolicyJson) + } + + want := DeleteDeviceSettingsPolicyResponse{ + Response: Response{ + Success: true, + Errors: nil, + Messages: nil, + }, + Result: []DeviceSettingsPolicy{defaultDeviceSettingsPolicy}, + } + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy/"+deviceSettingsPolicyID, handler) + + actual, err := client.DeleteDeviceSettingsPolicy(context.Background(), testAccountID, deviceSettingsPolicyID) + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +} + +func TestGetDefaultDeviceSettings(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method, "Expected method 'GET', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, `{ + "success": true, + "errors": null, + "messages": null, + "result": %s + }`, defaultDeviceSettingsPolicyJson) + } + + want := DeviceSettingsPolicyResponse{ + Response: Response{ + Success: true, + Errors: nil, + Messages: nil, + }, + Result: defaultDeviceSettingsPolicy, + } + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy", handler) + + actual, err := client.GetDefaultDeviceSettingsPolicy(context.Background(), testAccountID) + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +} + +func TestGetDeviceSettings(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method, "Expected method 'GET', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, `{ + "success": true, + "errors": null, + "messages": null, + "result": %s + }`, nonDefaultDeviceSettingsPolicyJson) + } + + want := DeviceSettingsPolicyResponse{ + Response: Response{ + Success: true, + Errors: nil, + Messages: nil, + }, + Result: nonDefaultDeviceSettingsPolicy, + } + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy/"+deviceSettingsPolicyID, handler) + + actual, err := client.GetDeviceSettingsPolicy(context.Background(), testAccountID, deviceSettingsPolicyID) + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +} diff --git a/fallback_domain.go b/fallback_domain.go index b0aa4a78b60..e0b9132eabc 100644 --- a/fallback_domain.go +++ b/fallback_domain.go @@ -41,6 +41,26 @@ func (api *API) ListFallbackDomains(ctx context.Context, accountID string) ([]Fa return fallbackDomainResponse.Result, nil } +// ListFallbackDomainsDeviceSettingsPolicy returns all fallback domains within an account for a specific device settings policy. +// +// API reference: https://api.cloudflare.com/#devices-get-local-domain-fallback-list +func (api *API) ListFallbackDomainsDeviceSettingsPolicy(ctx context.Context, accountID, policyID string) ([]FallbackDomain, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s/fallback_domains", AccountRouteRoot, accountID, policyID) + + res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return []FallbackDomain{}, err + } + + var fallbackDomainResponse FallbackDomainResponse + err = json.Unmarshal(res, &fallbackDomainResponse) + if err != nil { + return []FallbackDomain{}, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return fallbackDomainResponse.Result, nil +} + // UpdateFallbackDomain updates the existing fallback domain policy. // // API reference: https://api.cloudflare.com/#devices-set-local-domain-fallback-list @@ -61,8 +81,28 @@ func (api *API) UpdateFallbackDomain(ctx context.Context, accountID string, doma return fallbackDomainResponse.Result, nil } -// RestoreFallbackDomainDefaults resets the domain fallback values to the default -// list. +// UpdateFallbackDomainDeviceSettingsPolicy updates the existing fallback domain policy for a specific device settings policy. +// +// API reference: https://api.cloudflare.com/#devices-set-local-domain-fallback-list +func (api *API) UpdateFallbackDomainDeviceSettingsPolicy(ctx context.Context, accountID, policyID string, domains []FallbackDomain) ([]FallbackDomain, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s/fallback_domains", AccountRouteRoot, accountID, policyID) + + res, err := api.makeRequestContext(ctx, http.MethodPut, uri, domains) + if err != nil { + return []FallbackDomain{}, err + } + + var fallbackDomainResponse FallbackDomainResponse + err = json.Unmarshal(res, &fallbackDomainResponse) + if err != nil { + return []FallbackDomain{}, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return fallbackDomainResponse.Result, nil +} + +// RestoreFallbackDomainDefaultsDeviceSettingsPolicy resets the domain fallback values to the default +// list for a specific device settings policy. // // API reference: TBA. func (api *API) RestoreFallbackDomainDefaults(ctx context.Context, accountID string) error { @@ -75,3 +115,18 @@ func (api *API) RestoreFallbackDomainDefaults(ctx context.Context, accountID str return nil } + +// RestoreFallbackDomainDefaults resets the domain fallback values to the default +// list. +// +// API reference: TBA. +func (api *API) RestoreFallbackDomainDefaultsDeviceSettingsPolicy(ctx context.Context, accountID, policyID string) error { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s/fallback_domains?reset_defaults=true", AccountRouteRoot, accountID, policyID) + + _, err := api.makeRequestContext(ctx, http.MethodDelete, uri, []string{}) + if err != nil { + return err + } + + return nil +} diff --git a/fallback_domain_test.go b/fallback_domain_test.go index fb27b2d8526..159e909e64c 100644 --- a/fallback_domain_test.go +++ b/fallback_domain_test.go @@ -45,6 +45,44 @@ func TestListFallbackDomain(t *testing.T) { } } +func TestListFallbackDomainsDeviceSettingsPolicy(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method, "Expected method 'GET', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, ` + { + "success": true, + "errors": [], + "messages": [], + "result": [ + { + "suffix": "example.com", + "description": "Domain bypass for local development" + } + ] + } + `) + } + + want := []FallbackDomain{{ + Suffix: "example.com", + Description: "Domain bypass for local development", + }} + + policyID := "a842fa8a-a583-482e-9cd9-eb43362949fd" + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy/"+policyID+"/fallback_domains", handler) + + actual, err := client.ListFallbackDomainsDeviceSettingsPolicy(context.Background(), testAccountID, policyID) + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +} + func TestFallbackDomainDNSServer(t *testing.T) { setup() defer teardown() @@ -138,3 +176,61 @@ func TestUpdateFallbackDomain(t *testing.T) { assert.Equal(t, domains, actual) } } + +func TestUpdateFallbackDomainDeviceSettingsPolicy(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPut, r.Method, "Expected method 'PUT', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, ` + { + "success": true, + "errors": [], + "messages": [], + "result": [ + { + "suffix": "example_one.com", + "description": "example one", + "dns_server": ["192.168.0.1", "10.1.1.1"] + }, + { + "suffix": "example_two.com", + "description": "example two" + }, + { + "suffix": "example_three.com", + "description": "example three" + } + ] + } + `) + } + + domains := []FallbackDomain{ + { + Suffix: "example_one.com", + Description: "example one", + DNSServer: []string{"192.168.0.1", "10.1.1.1"}, + }, + { + Suffix: "example_two.com", + Description: "example two", + }, + { + Suffix: "example_three.com", + Description: "example three", + }, + } + + policyID := "a842fa8a-a583-482e-9cd9-eb43362949fd" + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy/"+policyID+"/fallback_domains", handler) + + actual, err := client.UpdateFallbackDomainDeviceSettingsPolicy(context.Background(), testAccountID, policyID, domains) + + if assert.NoError(t, err) { + assert.Equal(t, domains, actual) + } +} diff --git a/split_tunnel.go b/split_tunnel.go index 21d9297c6df..3855cae0fab 100644 --- a/split_tunnel.go +++ b/split_tunnel.go @@ -62,3 +62,45 @@ func (api *API) UpdateSplitTunnel(ctx context.Context, accountID string, mode st return splitTunnelResponse.Result, nil } + +// ListSplitTunnelDeviceSettingsPolicy returns all include or exclude split tunnel within a device settings policy +// +// API reference for include: https://api.cloudflare.com/#device-policy-get-split-tunnel-include-list +// API reference for exclude: https://api.cloudflare.com/#device-policy-get-split-tunnel-exclude-list +func (api *API) ListSplitTunnelsDeviceSettingsPolicy(ctx context.Context, accountID, policyID string, mode string) ([]SplitTunnel, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s/%s", AccountRouteRoot, accountID, policyID, mode) + + res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return []SplitTunnel{}, err + } + + var splitTunnelResponse SplitTunnelResponse + err = json.Unmarshal(res, &splitTunnelResponse) + if err != nil { + return []SplitTunnel{}, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return splitTunnelResponse.Result, nil +} + +// UpdateSplitTunnelDeviceSettingsPolicy updates the existing split tunnel policy within a device settings policy +// +// API reference for include: https://api.cloudflare.com/#device-policy-set-split-tunnel-include-list +// API reference for exclude: https://api.cloudflare.com/#device-policy-set-split-tunnel-exclude-list +func (api *API) UpdateSplitTunnelDeviceSettingsPolicy(ctx context.Context, accountID, policyID string, mode string, tunnels []SplitTunnel) ([]SplitTunnel, error) { + uri := fmt.Sprintf("/%s/%s/devices/policy/%s/%s", AccountRouteRoot, accountID, policyID, mode) + + res, err := api.makeRequestContext(ctx, http.MethodPut, uri, tunnels) + if err != nil { + return []SplitTunnel{}, err + } + + var splitTunnelResponse SplitTunnelResponse + err = json.Unmarshal(res, &splitTunnelResponse) + if err != nil { + return []SplitTunnel{}, fmt.Errorf("%s: %w", errUnmarshalError, err) + } + + return splitTunnelResponse.Result, nil +} diff --git a/split_tunnel_test.go b/split_tunnel_test.go index 0b75a7a605c..63c6448fb31 100644 --- a/split_tunnel_test.go +++ b/split_tunnel_test.go @@ -260,3 +260,41 @@ func TestUpdateSplitTunnelExclude(t *testing.T) { assert.Equal(t, tunnels, actual) } } + +func TestSplitTunnelsDeviceSettingsPolicy(t *testing.T) { + setup() + defer teardown() + + handler := func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method, "Expected method 'GET', got %s", r.Method) + w.Header().Set("content-type", "application/json") + fmt.Fprintf(w, ` + { + "success": true, + "errors": [], + "messages": [], + "result": [ + { + "host": "*.example.com", + "description": "default" + } + ] + } + `) + } + + want := []SplitTunnel{{ + Host: "*.example.com", + Description: "default", + }} + + policyID := "a842fa8a-a583-482e-9cd9-eb43362949fd" + + mux.HandleFunc("/accounts/"+testAccountID+"/devices/policy/"+policyID+"/include", handler) + + actual, err := client.ListSplitTunnelsDeviceSettingsPolicy(context.Background(), testAccountID, policyID, "include") + + if assert.NoError(t, err) { + assert.Equal(t, want, actual) + } +}