forked from antoniomika/sish
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
267 lines (228 loc) · 8.79 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
package main
import (
"flag"
"log"
"net"
"os"
"os/signal"
"runtime"
"strings"
"sync"
"time"
"github.com/jpillora/ipfilter"
"golang.org/x/crypto/ssh"
)
// SSHConnection handles state for a SSHConnection
type SSHConnection struct {
SSHConn *ssh.ServerConn
Listeners *sync.Map
Close chan bool
Messages chan string
ProxyProto byte
Session chan bool
CleanupHandler bool
}
// State handles overall state
type State struct {
SSHConnections *sync.Map
Listeners *sync.Map
HTTPListeners *sync.Map
TCPListeners *sync.Map
IPFilter *ipfilter.IPFilter
}
var (
version = "dev"
commit = "none"
date = "unknown"
serverAddr = flag.String("sish.addr", "localhost:2222", "The address to listen for SSH connections")
httpAddr = flag.String("sish.http", "localhost:80", "The address to listen for HTTP connections")
httpPort = flag.Int("sish.httpport", 80, "The port for HTTP connections. This is only for output messages")
httpsAddr = flag.String("sish.https", "localhost:443", "The address to listen for HTTPS connections")
httpsPort = flag.Int("sish.httpsport", 443, "The port for HTTPS connections. This is only for output messages")
verifyOrigin = flag.Bool("sish.verifyorigin", true, "Whether or not to verify origin on websocket connection")
verifySSL = flag.Bool("sish.verifyssl", true, "Whether or not to verify SSL on proxy connection")
httpsEnabled = flag.Bool("sish.httpsenabled", false, "Whether or not to listen for HTTPS connections")
redirectRoot = flag.Bool("sish.redirectroot", true, "Whether or not to redirect the root domain")
redirectRootLocation = flag.String("sish.redirectrootlocation", "https://github.com/antoniomika/sish", "Where to redirect the root domain to")
httpsPems = flag.String("sish.httpspems", "ssl/", "The location of pem files for HTTPS (fullchain.pem and privkey.pem)")
rootDomain = flag.String("sish.domain", "ssi.sh", "The domain for HTTP(S) multiplexing")
domainLen = flag.Int("sish.subdomainlen", 3, "The length of the random subdomain to generate")
forceRandomSubdomain = flag.Bool("sish.forcerandomsubdomain", true, "Whether or not to force a random subdomain")
bannedSubdomains = flag.String("sish.bannedsubdomains", "localhost", "A comma separated list of banned subdomains")
bannedIPs = flag.String("sish.bannedips", "", "A comma separated list of banned ips")
bannedCountries = flag.String("sish.bannedcountries", "", "A comma separated list of banned countries")
whitelistedIPs = flag.String("sish.whitelistedips", "", "A comma separated list of whitelisted ips")
whitelistedCountries = flag.String("sish.whitelistedcountries", "", "A comma separated list of whitelisted countries")
useGeoDB = flag.Bool("sish.usegeodb", false, "Whether or not to use the maxmind geodb")
pkPass = flag.String("sish.pkpass", "S3Cr3tP4$$phrAsE", "Passphrase to use for the server private key")
pkLoc = flag.String("sish.pkloc", "keys/ssh_key", "SSH server private key")
authEnabled = flag.Bool("sish.auth", false, "Whether or not to require auth on the SSH service")
authPassword = flag.String("sish.password", "S3Cr3tP4$$W0rD", "Password to use for password auth")
authKeysDir = flag.String("sish.keysdir", "pubkeys/", "Directory for public keys for pubkey auth")
bindRange = flag.String("sish.bindrange", "0,1024-65535", "Ports that are allowed to be bound")
cleanupUnbound = flag.Bool("sish.cleanupunbound", true, "Whether or not to cleanup unbound (forwarded) SSH connections")
bindRandom = flag.Bool("sish.bindrandom", true, "Bind ports randomly (OS chooses)")
proxyProtoEnabled = flag.Bool("sish.proxyprotoenabled", false, "Whether or not to enable the use of the proxy protocol")
proxyProtoVersion = flag.String("sish.proxyprotoversion", "1", "What version of the proxy protocol to use. Can either be 1, 2, or userdefined. If userdefined, the user needs to add a command to SSH called proxyproto:version (ie proxyproto:1)")
debug = flag.Bool("sish.debug", false, "Whether or not to print debug information")
versionCheck = flag.Bool("sish.version", false, "Print version and exit")
tcpAlias = flag.Bool("sish.tcpalias", false, "Whether or not to allow the use of TCP aliasing")
bannedSubdomainList = []string{""}
filter *ipfilter.IPFilter
)
func main() {
flag.Parse()
if *versionCheck {
log.Printf("Version: %v\nCommit: %v\nDate: %v\n", version, commit, date)
os.Exit(0)
}
commaSplitFields := func(c rune) bool {
return c == ','
}
bannedSubdomainList = append(bannedSubdomainList, strings.FieldsFunc(*bannedSubdomains, commaSplitFields)...)
for k, v := range bannedSubdomainList {
bannedSubdomainList[k] = strings.ToLower(strings.TrimSpace(v) + "." + *rootDomain)
}
upperList := func(stringList string) []string {
list := strings.FieldsFunc(stringList, commaSplitFields)
for k, v := range list {
list[k] = strings.ToUpper(v)
}
return list
}
whitelistedCountriesList := upperList(*whitelistedCountries)
whitelistedIPList := strings.FieldsFunc(*whitelistedIPs, commaSplitFields)
ipfilterOpts := ipfilter.Options{
BlockedCountries: upperList(*bannedCountries),
AllowedCountries: whitelistedCountriesList,
BlockedIPs: strings.FieldsFunc(*bannedIPs, commaSplitFields),
AllowedIPs: whitelistedIPList,
BlockByDefault: len(whitelistedIPList) > 0 || len(whitelistedCountriesList) > 0,
}
if *useGeoDB {
filter = ipfilter.NewLazy(ipfilterOpts)
} else {
filter = ipfilter.NewNoDB(ipfilterOpts)
}
watchCerts()
state := &State{
SSHConnections: &sync.Map{},
Listeners: &sync.Map{},
HTTPListeners: &sync.Map{},
TCPListeners: &sync.Map{},
IPFilter: filter,
}
go startHTTPHandler(state)
if *debug {
go func() {
for {
log.Println("=======Start=========")
log.Println("===Goroutines=====")
log.Println(runtime.NumGoroutine())
log.Println("===Listeners======")
state.Listeners.Range(func(key, value interface{}) bool {
log.Println(key, value)
return true
})
log.Println("===Clients========")
state.SSHConnections.Range(func(key, value interface{}) bool {
log.Println(key, value)
return true
})
log.Println("===HTTP Clients===")
state.HTTPListeners.Range(func(key, value interface{}) bool {
log.Println(key, value)
return true
})
log.Print("========End==========\n\n")
time.Sleep(2 * time.Second)
}
}()
}
log.Println("Starting SSH service on address:", *serverAddr)
sshConfig := getSSHConfig()
listener, err := net.Listen("tcp", *serverAddr)
if err != nil {
log.Fatal(err)
}
state.Listeners.Store(listener.Addr(), listener)
defer func() {
listener.Close()
state.Listeners.Delete(listener.Addr())
}()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for range c {
os.Exit(0)
}
}()
for {
conn, err := listener.Accept()
if err != nil {
log.Println(err)
continue
}
clientRemote, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil || filter.Blocked(clientRemote) {
conn.Close()
continue
}
clientLoggedIn := false
if *cleanupUnbound {
go func() {
<-time.After(5 * time.Second)
if !clientLoggedIn {
conn.Close()
}
}()
}
log.Println("Accepted SSH connection for:", conn.RemoteAddr())
go func() {
sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig)
clientLoggedIn = true
if err != nil {
conn.Close()
log.Println(err)
return
}
holderConn := &SSHConnection{
SSHConn: sshConn,
Listeners: &sync.Map{},
Close: make(chan bool),
Messages: make(chan string),
Session: make(chan bool),
}
state.SSHConnections.Store(sshConn.RemoteAddr(), holderConn)
go handleRequests(reqs, holderConn, state)
go handleChannels(chans, holderConn, state)
if *cleanupUnbound {
go func() {
select {
case <-time.After(1 * time.Second):
count := 0
holderConn.Listeners.Range(func(key, value interface{}) bool {
count++
return true
})
if count == 0 {
sendMessage(holderConn, "No forwarding requests sent. Closing connection.")
time.Sleep(1 * time.Millisecond)
holderConn.CleanUp(state)
}
case <-holderConn.Close:
return
}
}()
}
}()
}
}
// CleanUp closes all allocated resources and cleans them up
func (s *SSHConnection) CleanUp(state *State) {
close(s.Close)
close(s.Messages)
s.SSHConn.Close()
state.SSHConnections.Delete(s.SSHConn.RemoteAddr())
log.Println("Closed SSH connection for:", s.SSHConn.RemoteAddr(), "user:", s.SSHConn.User())
}