Skip to content

Commit

Permalink
Merge pull request pathwar#446 from pathwar/dev/moul/panic-stack
Browse files Browse the repository at this point in the history
fix: panic for sso token with multiple clients
  • Loading branch information
moul authored Mar 12, 2020
2 parents 74471e6 + 64de5de commit f27ffdd
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
45 changes: 38 additions & 7 deletions go/pkg/pwapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package pwapi
import (
"context"
"fmt"
"log"
"net"
"net/http"
"net/http/pprof"
"runtime/debug"
"strings"
"time"

Expand All @@ -26,6 +28,8 @@ import (
chilogger "github.com/treastech/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"pathwar.land/v2/go/pkg/errcode"
)

Expand Down Expand Up @@ -159,12 +163,15 @@ func grpcServer(svc Service, opts ServerOpts) (*grpc.Server, error) {
authFunc := func(context.Context) (context.Context, error) {
return nil, errcode.ErrNotImplemented
}
serverStreamOpts := []grpc.StreamServerInterceptor{
grpc_recovery.StreamServerInterceptor(),
}
serverUnaryOpts := []grpc.UnaryServerInterceptor{
grpc_recovery.UnaryServerInterceptor(),
recoveryOpts := []grpc_recovery.Option{}
if opts.Logger.Check(zap.DebugLevel, "") != nil {
recoveryOpts = append(recoveryOpts, grpc_recovery.WithRecoveryHandlerContext(func(ctx context.Context, p interface{}) error {
log.Println("stacktrace from panic: \n" + string(debug.Stack()))
return status.Errorf(codes.Internal, "recover: %s", p)
}))
}
serverStreamOpts := []grpc.StreamServerInterceptor{grpc_recovery.StreamServerInterceptor(recoveryOpts...)}
serverUnaryOpts := []grpc.UnaryServerInterceptor{grpc_recovery.UnaryServerInterceptor(recoveryOpts...)}
if opts.Tracer != nil {
tracingOpts := []grpc_opentracing.Option{grpc_opentracing.WithTracer(opts.Tracer)}
serverStreamOpts = append(serverStreamOpts, grpc_opentracing.StreamServerInterceptor(tracingOpts...))
Expand All @@ -174,15 +181,19 @@ func grpcServer(svc Service, opts ServerOpts) (*grpc.Server, error) {
grpc_auth.StreamServerInterceptor(authFunc),
//grpc_ctxtags.StreamServerInterceptor(),
grpc_zap.StreamServerInterceptor(logger),
grpc_recovery.StreamServerInterceptor(),
)
serverUnaryOpts = append(
serverUnaryOpts,
grpc_auth.UnaryServerInterceptor(authFunc),
//grpc_ctxtags.UnaryServerInterceptor(),
grpc_zap.UnaryServerInterceptor(logger),
grpc_recovery.UnaryServerInterceptor(),
)
if opts.Logger.Check(zap.DebugLevel, "") != nil {
serverStreamOpts = append(serverStreamOpts, grpcServerStreamInterceptor())
serverUnaryOpts = append(serverUnaryOpts, grpcServerUnaryInterceptor())
}
serverStreamOpts = append(serverStreamOpts, grpc_recovery.StreamServerInterceptor(recoveryOpts...))
serverUnaryOpts = append(serverUnaryOpts, grpc_recovery.UnaryServerInterceptor(recoveryOpts...))
grpcServer := grpc.NewServer(
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(serverStreamOpts...)),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(serverUnaryOpts...)),
Expand All @@ -192,6 +203,26 @@ func grpcServer(svc Service, opts ServerOpts) (*grpc.Server, error) {
return grpcServer, nil
}

func grpcServerStreamInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := handler(srv, stream)
if err != nil {
log.Printf("%+v", err)
}
return err
}
}

func grpcServerUnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ret, err := handler(ctx, req)
if err != nil {
log.Printf("%+v", err)
}
return ret, err
}
}

func httpServer(ctx context.Context, serverListenerAddr string, opts ServerOpts) (*http.Server, error) {
logger := opts.Logger.Named("http")
r := chi.NewRouter()
Expand Down
8 changes: 7 additions & 1 deletion go/pkg/pwsso/token.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pwsso

import (
"fmt"
time "time"

jwt "github.com/dgrijalva/jwt-go"
Expand Down Expand Up @@ -104,7 +105,12 @@ func ClaimsFromToken(token *jwt.Token) *Claims {
claims.ActionToken.Iss = v.(string)
}
if v := mc["aud"]; v != nil {
claims.ActionToken.Aud = v.(string)
switch typed := v.(type) {
case string:
claims.ActionToken.Aud = typed
default:
claims.ActionToken.Aud = fmt.Sprintf("%v", typed)
}
}
if v := mc["asid"]; v != nil {
claims.ActionToken.Asid = v.(string)
Expand Down

0 comments on commit f27ffdd

Please sign in to comment.