Skip to content

Commit

Permalink
Use proxy in country check
Browse files Browse the repository at this point in the history
  • Loading branch information
arriven committed Mar 30, 2022
1 parent 989f057 commit 70f741c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/Arriven/db1000n/src/utils"
"github.com/Arriven/db1000n/src/utils/metrics"
"github.com/Arriven/db1000n/src/utils/ota"
"github.com/Arriven/db1000n/src/utils/templates"
)

func main() {
Expand Down Expand Up @@ -81,7 +82,7 @@ func main() {
setUpPprof(*pprof, *debug)
rand.Seed(time.Now().UnixNano())

country := utils.CheckCountryOrFail(countryCheckerConfig)
country := utils.CheckCountryOrFail(countryCheckerConfig, templates.ParseAndExecute(zap.NewNop(), jobsGlobalConfig.ProxyURLs, nil))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
42 changes: 34 additions & 8 deletions src/utils/countrychecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,37 @@ import (
"encoding/json"
"flag"
"log"
"math/rand"
"net/http"
"net/url"
"strings"
"time"
)

type CountryCheckerConfig struct {
countryBlackListCSV string
strictCountryCheck bool
maxRetries int
}

// NewGlobalConfigWithFlags returns a GlobalConfig initialized with command line flags.
func NewCountryCheckerConfigWithFlags() *CountryCheckerConfig {
const maxFetchRetries = 3

var res CountryCheckerConfig

flag.StringVar(&res.countryBlackListCSV, "country-list", GetEnvStringDefault("COUNTRY_LIST", "Ukraine"), "comma-separated list of countries")
flag.BoolVar(&res.strictCountryCheck, "strict-country-check", GetEnvBoolDefault("STRICT_COUNTRY_CHECK", false),
"enable strict country check; will also exit if IP can't be determined")
flag.IntVar(&res.maxRetries, "country-check-retries", GetEnvIntDefault("COUNTRY_CHECK_RETRIES", maxFetchRetries),
"how much retries should be made when checking the country")

return &res
}

// CheckCountryOrFail checks the country of client origin by IP and exits the program if it is in the blacklist.
func CheckCountryOrFail(cfg *CountryCheckerConfig) string {
isCountryAllowed, country := CheckCountry(strings.Split(cfg.countryBlackListCSV, ","), cfg.strictCountryCheck)
func CheckCountryOrFail(cfg *CountryCheckerConfig, proxyURLs string) string {
isCountryAllowed, country := CheckCountry(strings.Split(cfg.countryBlackListCSV, ","), cfg.strictCountryCheck, proxyURLs, cfg.maxRetries)
if !isCountryAllowed {
log.Fatalf("%q is not an allowed country, exiting", country)
}
Expand All @@ -37,9 +44,7 @@ func CheckCountryOrFail(cfg *CountryCheckerConfig) string {
}

// CheckCountry checks which country the app is running from and whether it is in the blacklist.
func CheckCountry(countriesToAvoid []string, strictCountryCheck bool) (bool, string) {
const maxFetchRetries = 3

func CheckCountry(countriesToAvoid []string, strictCountryCheck bool, proxyURLs string, maxFetchRetries int) (bool, string) {
var (
country, ip string
err error
Expand All @@ -51,8 +56,10 @@ func CheckCountry(countriesToAvoid []string, strictCountryCheck bool) (bool, str
for counter.Next() {
log.Printf("Checking IP address, attempt #%d", counter.iter)

if country, ip, err = fetchLocationInfo(); err != nil {
if country, ip, err = fetchLocationInfo(proxyURLs); err != nil {
Sleep(context.Background(), backoffController.Increment().GetTimeout())
} else {
break
}
}

Expand Down Expand Up @@ -80,7 +87,7 @@ func CheckCountry(countriesToAvoid []string, strictCountryCheck bool) (bool, str
return true, country
}

func fetchLocationInfo() (country, ip string, err error) {
func fetchLocationInfo(proxyURLs string) (country, ip string, err error) {
const (
ipCheckerURI = "https://api.myip.com/"
requestTimeout = 3 * time.Second
Expand All @@ -94,7 +101,26 @@ func fetchLocationInfo() (country, ip string, err error) {
return "", "", err
}

resp, err := http.DefaultClient.Do(req)
client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyFromEnvironment}}

if proxyURLs != "" {
log.Println("proxy config detected, using it to check country")

proxies := strings.Split(proxyURLs, ",")

proxy := proxies[rand.Intn(len(proxies))] //nolint:gosec // Cryptographically secure random not required

log.Println("using proxy", proxy)

u, err := url.Parse(proxy)
if err != nil {
return "", "", err
}

client = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(u)}}
}

resp, err := client.Do(req)
if err != nil {
log.Println("Can't check users country. Please manually check that VPN is enabled or that you have non Ukrainian IP address.")

Expand Down

0 comments on commit 70f741c

Please sign in to comment.