forked from YiQiu1984/lightsocks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsecuretcp.go
124 lines (112 loc) · 2.73 KB
/
securetcp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package lightsocks
import (
"io"
"log"
"net"
)
const (
bufSize = 1024
)
// 加密传输的 TCP Socket
type SecureTCPConn struct {
io.ReadWriteCloser
Cipher *cipher
}
// 从输入流里读取加密过的数据,解密后把原数据放到bs里
func (secureSocket *SecureTCPConn) DecodeRead(bs []byte) (n int, err error) {
n, err = secureSocket.Read(bs)
if err != nil {
return
}
secureSocket.Cipher.decode(bs[:n])
return
}
// 把放在bs里的数据加密后立即全部写入输出流
func (secureSocket *SecureTCPConn) EncodeWrite(bs []byte) (int, error) {
secureSocket.Cipher.encode(bs)
return secureSocket.Write(bs)
}
// 从src中源源不断的读取原数据加密后写入到dst,直到src中没有数据可以再读取
func (secureSocket *SecureTCPConn) EncodeCopy(dst io.ReadWriteCloser) error {
buf := make([]byte, bufSize)
for {
readCount, errRead := secureSocket.Read(buf)
if errRead != nil {
if errRead != io.EOF {
return errRead
} else {
return nil
}
}
if readCount > 0 {
writeCount, errWrite := (&SecureTCPConn{
ReadWriteCloser: dst,
Cipher: secureSocket.Cipher,
}).EncodeWrite(buf[0:readCount])
if errWrite != nil {
return errWrite
}
if readCount != writeCount {
return io.ErrShortWrite
}
}
}
}
// 从src中源源不断的读取加密后的数据解密后写入到dst,直到src中没有数据可以再读取
func (secureSocket *SecureTCPConn) DecodeCopy(dst io.Writer) error {
buf := make([]byte, bufSize)
for {
readCount, errRead := secureSocket.DecodeRead(buf)
if errRead != nil {
if errRead != io.EOF {
return errRead
} else {
return nil
}
}
if readCount > 0 {
writeCount, errWrite := dst.Write(buf[0:readCount])
if errWrite != nil {
return errWrite
}
if readCount != writeCount {
return io.ErrShortWrite
}
}
}
}
// see net.DialTCP
func DialTCPSecure(raddr *net.TCPAddr, cipher *cipher) (*SecureTCPConn, error) {
remoteConn, err := net.DialTCP("tcp", nil, raddr)
if err != nil {
return nil, err
}
return &SecureTCPConn{
ReadWriteCloser: remoteConn,
Cipher: cipher,
}, nil
}
// see net.ListenTCP
func ListenSecureTCP(laddr *net.TCPAddr, cipher *cipher, handleConn func(localConn *SecureTCPConn), didListen func(listenAddr net.Addr)) error {
listener, err := net.ListenTCP("tcp", laddr)
if err != nil {
return err
}
defer listener.Close()
if didListen != nil {
didListen(listener.Addr())
}
for {
localConn, err := listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
// localConn被关闭时直接清除所有数据 不管没有发送的数据
localConn.SetLinger(0)
go handleConn(&SecureTCPConn{
ReadWriteCloser: localConn,
Cipher: cipher,
})
}
}