Skip to content

Commit

Permalink
server: support unix socket
Browse files Browse the repository at this point in the history
Fixes #1415
  • Loading branch information
wdvxdr1123 committed Mar 23, 2022
1 parent d42d8dd commit 40a765b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 15 deletions.
33 changes: 28 additions & 5 deletions server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package server

import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -65,6 +67,7 @@ type HTTPClient struct {
filter string
apiPort int
timeout int32
client *http.Client
MaxRetries uint64
RetriesInterval uint64
}
Expand All @@ -77,8 +80,7 @@ type httpCtx struct {

const httpDefault = `
- http: # HTTP 通信设置
host: 127.0.0.1 # 服务端监听地址
port: 5700 # 服务端监听端口
address: 0.0.0.0:5700 # HTTP监听地址
timeout: 5 # 反向 HTTP 超时时间, 单位秒,<5 时将被忽略
long-polling: # 长轮询拓展
enabled: false # 是否开启
Expand Down Expand Up @@ -300,8 +302,30 @@ func (c HTTPClient) Run() {
if c.timeout < 5 {
c.timeout = 5
}
rawAddress := c.addr
network, address := resolveURI(c.addr)
client := &http.Client{
Timeout: time.Second * time.Duration(c.timeout),
Transport: &http.Transport{
DialContext: func(_ context.Context, _, addr string) (net.Conn, error) {
if network == "unix" {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
filepath, err := base64.RawURLEncoding.DecodeString(host)
if err == nil {
addr = string(filepath)
}
}
return net.Dial(network, addr)
},
},
}
c.addr = address // clean path
c.client = client
log.Infof("HTTP POST上报器已启动: %v", rawAddress)
c.bot.OnEventPush(c.onBotPushEvent)
log.Infof("HTTP POST上报器已启动: %v", c.addr)
}

func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
Expand All @@ -313,7 +337,6 @@ func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
}
}

client := http.Client{Timeout: time.Second * time.Duration(c.timeout)}
header := make(http.Header)
header.Set("X-Self-ID", strconv.FormatInt(c.bot.Client.Uin, 10))
header.Set("User-Agent", "CQHttp/4.15.0")
Expand All @@ -338,7 +361,7 @@ func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
}
req.Header = header

res, err = client.Do(req)
res, err = c.client.Do(req)
if res != nil {
//goland:noinspection GoDeferInLoop
defer res.Body.Close()
Expand Down
54 changes: 44 additions & 10 deletions server/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"net"
Expand Down Expand Up @@ -77,9 +78,7 @@ var upgrader = websocket.Upgrader{
const wsDefault = ` # 正向WS设置
- ws:
# 正向WS服务器监听地址
host: 127.0.0.1
# 正向WS服务器监听端口
port: 6700
address: 0.0.0.0:8080
middlewares:
<<: *default # 引用默认中间件
`
Expand Down Expand Up @@ -213,8 +212,25 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) {
}
}

func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, url)
func resolveURI(addr string) (network, address string) {
network, address = "tcp", addr
uri, err := url.Parse(addr)
if err == nil && uri.Scheme != "" {
scheme, ext, _ := strings.Cut(uri.Scheme, "+")
if ext != "" {
network = ext
uri.Scheme = scheme // remove `+unix`/`+tcp4`
if ext == "unix" {
uri.Host = base64.StdEncoding.EncodeToString([]byte(uri.Host + uri.Path))
}
address = uri.String()
}
}
return
}

func (c *websocketClient) connect(typ, addr string, conptr **wsConn) {
log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, addr)
header := http.Header{
"X-Client-Role": []string{typ},
"X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)},
Expand All @@ -223,12 +239,30 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
if c.token != "" {
header["Authorization"] = []string{"Token " + c.token}
}
conn, _, err := websocket.DefaultDialer.Dial(url, header) // nolint

network, address := resolveURI(addr)
dialer := websocket.Dialer{
NetDial: func(_, addr string) (net.Conn, error) {
if network == "unix" {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
filepath, err := base64.RawURLEncoding.DecodeString(host)
if err == nil {
addr = string(filepath)
}
}
return net.Dial(network, addr) // support unix socket transport
},
}

conn, _, err := dialer.Dial(address, header) // nolint
if err != nil {
log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err)
log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, addr, err)
if c.reconnectInterval != 0 {
time.Sleep(c.reconnectInterval)
c.connect(typ, url, conptr)
c.connect(typ, addr, conptr)
}
return
}
Expand All @@ -242,7 +276,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
}
}

log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url)
log.Infof("已连接到反向WebSocket %s服务器 %v", typ, addr)

var wrappedConn *wsConn
if conptr != nil && *conptr != nil {
Expand All @@ -261,7 +295,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
}

if typ != "Event" {
go c.listenAPI(typ, url, wrappedConn)
go c.listenAPI(typ, addr, wrappedConn)
}
}

Expand Down

0 comments on commit 40a765b

Please sign in to comment.