Skip to content

Commit

Permalink
[nspcc-dev#228] add support of multiple sockets
Browse files Browse the repository at this point in the history
Signed-off-by: Pavel Pogodaev <[email protected]>
  • Loading branch information
Pavel Pogodaev committed Nov 24, 2022
1 parent 408d914 commit 79cb6f5
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 103 deletions.
132 changes: 52 additions & 80 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ package main
import (
"context"
"crypto/ecdsa"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strconv"
Expand Down Expand Up @@ -45,12 +43,12 @@ type (
metrics *gateMetrics
services []*metrics.Service
settings *appSettings
servers []Server
}

appSettings struct {
Uploader *uploader.Settings
Downloader *downloader.Settings
TLSProvider *certProvider
Uploader *uploader.Settings
Downloader *downloader.Settings
}

// App is an interface for the main gateway function.
Expand Down Expand Up @@ -179,9 +177,8 @@ func newApp(ctx context.Context, opt ...Option) App {

func (a *app) initAppSettings() {
a.settings = &appSettings{
Uploader: &uploader.Settings{},
Downloader: &downloader.Settings{},
TLSProvider: &certProvider{Enabled: a.cfg.IsSet(cfgTLSCertificate) || a.cfg.IsSet(cfgTLSKey)},
Uploader: &uploader.Settings{},
Downloader: &downloader.Settings{},
}

a.updateSettings()
Expand Down Expand Up @@ -341,43 +338,6 @@ func (a *app) setHealthStatus() {
a.metrics.SetHealth(1)
}

type certProvider struct {
Enabled bool

mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
}

func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}

p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}

func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}

cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}

p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}

func (a *app) Serve(ctx context.Context) {
uploadRoutes := uploader.New(ctx, a.AppParams(), a.settings.Uploader)
downloadRoutes := downloader.New(ctx, a.AppParams(), a.settings.Downloader)
Expand All @@ -386,38 +346,16 @@ func (a *app) Serve(ctx context.Context) {
a.configureRouter(uploadRoutes, downloadRoutes)

a.startServices()
a.initServers(ctx)

go func() {
var err error
defer func() {
if err != nil {
a.log.Fatal("could not start server", zap.Error(err))
}
}()

bind := a.cfg.GetString(cfgListenAddress)

if a.settings.TLSProvider.Enabled {
if err = a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil {
return
for i := range a.servers {
go func(i int) {
a.log.Info("starting server", zap.String("address", a.servers[i].Address()))
if err := a.webServer.Serve(a.servers[i].Listener()); err != nil && err != http.ErrServerClosed {
a.log.Fatal("listen and serve", zap.Error(err))
}

var lnConf net.ListenConfig
var ln net.Listener
if ln, err = lnConf.Listen(ctx, "tcp4", bind); err != nil {
return
}
lnTLS := tls.NewListener(ln, &tls.Config{
GetCertificate: a.settings.TLSProvider.GetCertificate,
})

a.log.Info("running web server (TLS-enabled)", zap.String("address", bind))
err = a.webServer.Serve(lnTLS)
} else {
a.log.Info("running web server", zap.String("address", bind))
err = a.webServer.ListenAndServe(bind)
}
}()
}(i)
}

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP)
Expand Down Expand Up @@ -460,6 +398,10 @@ func (a *app) configReload() {
a.log.Warn("failed to update resolvers", zap.Error(err))
}

if err := a.updateServers(); err != nil {
a.log.Warn("failed to reload server parameters", zap.Error(err))
}

a.stopServices()
a.startServices()

Expand All @@ -474,10 +416,6 @@ func (a *app) configReload() {
func (a *app) updateSettings() {
a.settings.Uploader.SetDefaultTimestamp(a.cfg.GetBool(cfgUploaderHeaderEnableDefaultTimestamp))
a.settings.Downloader.SetZipCompression(a.cfg.GetBool(cfgZipCompression))

if err := a.settings.TLSProvider.UpdateCert(a.cfg.GetString(cfgTLSCertificate), a.cfg.GetString(cfgTLSKey)); err != nil {
a.log.Warn("failed to reload TLS certs", zap.Error(err))
}
}

func (a *app) startServices() {
Expand Down Expand Up @@ -543,3 +481,37 @@ func (a *app) AppParams() *utils.AppParams {
Resolver: a.resolver,
}
}

func (a *app) initServers(ctx context.Context) {
serversInfo := fetchServers(a.cfg)

a.servers = make([]Server, len(serversInfo))
for i, serverInfo := range serversInfo {
a.log.Info("added server",
zap.String("address", serverInfo.Address), zap.Bool("tls enabled", serverInfo.TLS.Enabled),
zap.String("tls cert", serverInfo.TLS.CertFile), zap.String("tls key", serverInfo.TLS.KeyFile))
a.servers[i] = newServer(ctx, serverInfo, a.log)
}
}

func (a *app) updateServers() error {
serversInfo := fetchServers(a.cfg)

if len(serversInfo) != len(a.servers) {
return fmt.Errorf("invalid servers configuration: length mismatch: old '%d', new '%d", len(a.servers), len(serversInfo))
}

for i, serverInfo := range serversInfo {
if serverInfo.Address != a.servers[i].Address() {
return fmt.Errorf("invalid servers configuration: addresses mismatch: old '%s', new '%s", a.servers[i].Address(), serverInfo.Address)
}

if serverInfo.TLS.Enabled {
if err := a.servers[i].UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
return fmt.Errorf("failed to update tls certs: %w", err)
}
}
}

return nil
}
14 changes: 8 additions & 6 deletions config/config.env
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ HTTP_GW_PROMETHEUS_ADDRESS=localhost:8084
# Log level.
HTTP_GW_LOGGER_LEVEL=debug

# Address to bind.
HTTP_GW_LISTEN_ADDRESS=0.0.0.0:443
# Provide cert to enable TLS.
HTTP_GW_TLS_CERTIFICATE=/path/to/tls/cert
# Provide key to enable TLS.
HTTP_GW_TLS_KEY=/path/to/tls/key
HTTP_GW_SERVER_0_ADDRESS=0.0.0.0:443
HTTP_GW_SERVER_0_TLS_ENABLED=false
HTTP_GW_SERVER_0_TLS_CERT_FILE=/path/to/tls/cert
HTTP_GW_SERVER_0_TLS_KEY_FILE=/path/to/tls/key
HTTP_GW_SERVER_1_ADDRESS=0.0.0.0:444
HTTP_GW_SERVER_1_TLS_ENABLED=true
HTTP_GW_SERVER_1_TLS_CERT_FILE=/path/to/tls/cert
HTTP_GW_SERVER_1_TLS_KEY_FILE=/path/to/tls/key

# Nodes configuration.
# This configuration make the gateway use the first node (grpc://s01.neofs.devenv:8080)
Expand Down
14 changes: 11 additions & 3 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@ prometheus:
logger:
level: debug # Log level.

listen_address: 0.0.0.0:443 # Address to bind.
tls_certificate: /path/to/tls/cert # Provide cert to enable TLS.
tls_key: /path/to/tls/key # Provide key to enable TLS.
server:
- address: 0.0.0.0:8080
tls:
enabled: false
cert_file: /path/to/cert
key_file: /path/to/key
- address: 0.0.0.0:8081
tls:
enabled: false
cert_file: /path/to/cert
key_file: /path/to/key

# Nodes configuration.
# This configuration make the gateway use the first node (grpc://s01.neofs.devenv:8080)
Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func getDefaultConfig() *viper.Viper {
v.SetDefault(cfgPeers+".0.priority", 1)

v.SetDefault(cfgRPCEndpoint, "http://localhost:30333")
v.SetDefault(cfgListenAddress, testListenAddress)
v.SetDefault("server.0.address", testListenAddress)

return v
}
Expand Down
124 changes: 124 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package main

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"

"go.uber.org/zap"
)

type (
ServerInfo struct {
Address string
TLS ServerTLSInfo
}

ServerTLSInfo struct {
Enabled bool
CertFile string
KeyFile string
}

Server interface {
Address() string
Listener() net.Listener
UpdateCert(certFile, keyFile string) error
}

server struct {
address string
listener net.Listener
tlsProvider *certProvider
}

certProvider struct {
Enabled bool

mu sync.RWMutex
certPath string
keyPath string
cert *tls.Certificate
}
)

func (s *server) Address() string {
return s.address
}

func (s *server) Listener() net.Listener {
return s.listener
}

func (s *server) UpdateCert(certFile, keyFile string) error {
return s.tlsProvider.UpdateCert(certFile, keyFile)
}

func newServer(ctx context.Context, serverInfo ServerInfo, logger *zap.Logger) *server {
var lic net.ListenConfig
ln, err := lic.Listen(ctx, "tcp", serverInfo.Address)
if err != nil {
logger.Fatal("could not prepare listener", zap.String("address", serverInfo.Address), zap.Error(err))
}

tlsProvider := &certProvider{
Enabled: serverInfo.TLS.Enabled,
}

if serverInfo.TLS.Enabled {
if err = tlsProvider.UpdateCert(serverInfo.TLS.CertFile, serverInfo.TLS.KeyFile); err != nil {
logger.Fatal("failed to update cert", zap.Error(err))
}

ln = tls.NewListener(ln, &tls.Config{
GetCertificate: tlsProvider.GetCertificate,
})
}

return &server{
address: serverInfo.Address,
listener: ln,
tlsProvider: tlsProvider,
}
}

func (p *certProvider) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if !p.Enabled {
return nil, errors.New("cert provider: disabled")
}

p.mu.RLock()
defer p.mu.RUnlock()
return p.cert, nil
}

func (p *certProvider) UpdateCert(certPath, keyPath string) error {
if !p.Enabled {
return fmt.Errorf("tls disabled")
}

cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile '%s' and keyFile '%s': %w", certPath, keyPath, err)
}

p.mu.Lock()
p.certPath = certPath
p.keyPath = keyPath
p.cert = &cert
p.mu.Unlock()
return nil
}

func (p *certProvider) FilePaths() (string, string) {
if !p.Enabled {
return "", ""
}

p.mu.RLock()
defer p.mu.RUnlock()
return p.certPath, p.keyPath
}
Loading

0 comments on commit 79cb6f5

Please sign in to comment.