Skip to content

Commit

Permalink
refactor oidc flow/auth to DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
dovholuknf committed Jul 26, 2024
1 parent c39a1aa commit f0abfa2
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 86 deletions.
17 changes: 3 additions & 14 deletions zssh/zscp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package main

import (
"context"
"fmt"
"github.com/openziti/ziti/common/enrollment"
"github.com/openziti/ziti/ziti/cmd/common"
Expand Down Expand Up @@ -46,15 +45,6 @@ var rootCmd = &cobra.Command{
var localFilePaths []string
var isCopyToRemote bool

oidcToken := ""
var err error
if flags.OIDC.Mode {
oidcToken, err = zsshlib.OIDCFlow(context.Background(), &flags.SshFlags)
if err != nil {
logrus.Fatalf("error performing OIDC flow: %v", err)
}
}

if strings.ContainsAny(args[0], ":") {
remoteFilePath = args[0]
localFilePaths = args[1:]
Expand All @@ -70,10 +60,9 @@ var rootCmd = &cobra.Command{
} else {
logrus.Fatal(`cannot determine remote file PATH use ":" for remote path`)
}

var err error
for i, path := range localFilePaths {
localFilePaths[i], err = filepath.Abs(path)
if err != nil {
if localFilePaths[i], err = filepath.Abs(path); err != nil {
logrus.Fatalf("cannot determine absolute local file path, unrecognized file name: %s", path)
}
if _, err := os.Stat(localFilePaths[i]); err != nil {
Expand All @@ -88,7 +77,7 @@ var rootCmd = &cobra.Command{

remoteFilePath = zsshlib.ParseFilePath(remoteFilePath)

sshConn := zsshlib.EstablishClient(flags.SshFlags, remoteFilePath, targetIdentity, oidcToken)
sshConn := zsshlib.EstablishClient(&flags.SshFlags, remoteFilePath, targetIdentity)
defer func() { _ = sshConn.Close() }()

client, err := sftp.NewClient(sshConn)
Expand Down
13 changes: 2 additions & 11 deletions zssh/zssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,9 @@ var rootCmd = &cobra.Command{
zsshlib.Combine(cmd, &flags, cfg)

cmdArgs := args[1:]
token := ""
var err error
if flags.OIDC.Mode {
token, err = zsshlib.OIDCFlow(context.Background(), &flags)
if err != nil {
logrus.Fatalf("error performing OIDC flow: %v", err)
}
}
sshConn := zsshlib.EstablishClient(flags, args[0], targetIdentity, token)
sshConn := zsshlib.EstablishClient(&flags, args[0], targetIdentity)
defer func() { _ = sshConn.Close() }()
err = zsshlib.RemoteShell(sshConn, cmdArgs)
if err != nil {
if err := zsshlib.RemoteShell(sshConn, cmdArgs); err != nil {
logrus.Fatalf("error opening remote shell: %v", err)
}
},
Expand Down
64 changes: 64 additions & 0 deletions zsshlib/authenticate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package zsshlib

import (
"bufio"
"context"
"fmt"
"github.com/openziti/edge-api/rest_model"
"github.com/openziti/sdk-golang/ziti"
"github.com/sirupsen/logrus"
"os"
"strings"
)

func Auth(flags *SshFlags) ziti.Context {
oidcToken := ""
var oidcErr error
if flags.OIDC.Mode {
oidcToken, oidcErr = OIDCFlow(context.Background(), flags)
if oidcErr != nil {
logrus.Fatalf("error performing OIDC flow: %v", oidcErr)
}
}

conf := getConfig(flags.ZConfig)
ctx, err := ziti.NewContext(conf)
conf.Credentials.AddJWT(oidcToken)
if err != nil {
logrus.Fatalf("error creating ziti context: %v", err)
}

ctx.Events().AddMfaTotpCodeListener(func(c ziti.Context, detail *rest_model.AuthQueryDetail, response ziti.MfaCodeResponse) {
reader := bufio.NewReader(os.Stdin)
codeok := false
for !codeok {
fmt.Print("Enter MFA: ")
code, _ := reader.ReadString('\n')
code = strings.TrimSpace(code)
fmt.Println("You entered:" + code + " - verifying")
if err := response(code); err != nil {
fmt.Println("error verifying MFA TOTP: ", err)
} else {
codeok = true
}
}
})

if err = ctx.Authenticate(); err != nil {
logrus.Errorf("error creating ziti context: %v", err)
logrus.Fatalf("could not authenticate. verify your identity is correct and matches all necessary authentication conditions.")
}

return ctx
}

func ReadCode() string {
code := ""
reader := bufio.NewReader(os.Stdin)
for code == "" {
fmt.Print("Enter MFA: ")
code, _ = reader.ReadString('\n')
code = strings.TrimSpace(code)
}
return code
}
66 changes: 31 additions & 35 deletions zsshlib/mfa-enrollment.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
package zsshlib

import (
"bufio"
"context"
"fmt"
"github.com/openziti/sdk-golang/ziti"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"net/url"
"os"
"strings"
"zssh/config"
)

Expand All @@ -19,7 +14,7 @@ func NewMfaCmd(flags *SshFlags) *cobra.Command {
Short: "Manage MFA for the provided identity",
}

mfaCmd.AddCommand(NewEnableCmd(flags))
mfaCmd.AddCommand(NewEnableCmd(flags), NewRemoveMfaCmd(flags))
return mfaCmd
}

Expand All @@ -30,36 +25,32 @@ func NewEnableCmd(flags *SshFlags) *cobra.Command {
Run: func(cmd *cobra.Command, args []string) {
cfg := config.DefaultConfig()
Combine(cmd, flags, cfg)

oidcToken := ""
var err error
if flags.OIDC.Mode {
oidcToken, err = OIDCFlow(context.Background(), flags)
if err != nil {
logrus.Fatalf("error performing OIDC flow: %v", err)
}
}
EnableMFA(flags, oidcToken)
EnableMFA(flags)
},
}

flags.OIDCFlags(cmd)
return cmd
}

func EnableMFA(flags *SshFlags, oidcToken string) {
conf := getConfig(flags.ZConfig)
ctx, err := ziti.NewContext(conf)
conf.Credentials.AddJWT(oidcToken)
if err != nil {
logrus.Fatalf("error creating ziti context: %v", err)
func NewRemoveMfaCmd(flags *SshFlags) *cobra.Command {
cmd := &cobra.Command{
Use: "remove",
Short: "Remove MFA. Removes the MFA TOTP enablement for the provided identity",
Run: func(cmd *cobra.Command, args []string) {
cfg := config.DefaultConfig()
Combine(cmd, flags, cfg)
RemoveMfa(flags)
},
}

if err = ctx.Authenticate(); err != nil {
logrus.Errorf("error creating ziti context: %v", err)
logrus.Fatalf("could not authenticate. verify your identity is correct and matches all necessary authentication conditions.")
}

flags.OIDCFlags(cmd)
return cmd
}

func EnableMFA(flags *SshFlags) {
ctx := Auth(flags)

if deet, err := ctx.EnrollZitiMfa(); err != nil {
logrus.Fatalf("error enrolling ziti context: %v", err)
} else {
Expand All @@ -76,14 +67,10 @@ func EnableMFA(flags *SshFlags, oidcToken string) {
fmt.Println()
fmt.Println(" MFA TOTP Secret: ", secret)
fmt.Println()
reader := bufio.NewReader(os.Stdin)
code := ""
for code == "" {
fmt.Print("Enter MFA: ")
code, _ = reader.ReadString('\n')
code = strings.TrimSpace(code)
fmt.Println("You entered: " + code + " - verifying")
}

code := ReadCode()
fmt.Println("You entered: " + code + " - attempting to verify MFA TOTP")

if err := ctx.VerifyZitiMfa(code); err != nil {
logrus.Fatalf("error verifying ziti context: %v", err)
}
Expand All @@ -109,3 +96,12 @@ func EnableMFA(flags *SshFlags, oidcToken string) {
fmt.Println("└────────┴────────┴────────┴────────┴────────┘")
}
}

func RemoveMfa(flags *SshFlags) {
ctx := Auth(flags)
code := ReadCode()
fmt.Println("You entered: " + code + " - attempting to remove MFA TOTP")
if err := ctx.RemoveZitiMfa(code); err != nil {
logrus.Fatalf("error removing MFA TOTP: %v", err)
}
}
29 changes: 3 additions & 26 deletions zsshlib/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
package zsshlib

import (
"bufio"
"context"
"fmt"
"github.com/google/uuid"
"github.com/gorilla/securecookie"
"github.com/openziti/edge-api/rest_model"
"github.com/zitadel/oidc/v2/pkg/client/rp/cli"
"github.com/zitadel/oidc/v2/pkg/oidc"
"io"
Expand Down Expand Up @@ -363,30 +361,9 @@ func RetrieveRemoteFiles(client *sftp.Client, localPath string, remotePath strin
return nil
}

func EstablishClient(f SshFlags, target, targetIdentity, oidcToken string) *ssh.Client {
conf := getConfig(f.ZConfig)
ctx, err := ziti.NewContext(conf)
conf.Credentials.AddJWT(oidcToken)
if err != nil {
logrus.Fatalf("error creating ziti context: %v", err)
}

ctx.Events().AddMfaTotpCodeListener(func(c ziti.Context, detail *rest_model.AuthQueryDetail, response ziti.MfaCodeResponse) {
reader := bufio.NewReader(os.Stdin)
codeok := false
for !codeok {
fmt.Print("Enter MFA: ")
code, _ := reader.ReadString('\n')
code = strings.TrimSpace(code)
fmt.Println("You entered:" + code + " - verifying")
if err := response(code); err != nil {
fmt.Println("error verifying MFA TOTP: ", err)
} else {
codeok = true
}
}
})
if err = ctx.Authenticate(); err != nil {
func EstablishClient(f *SshFlags, target string, targetIdentity string) *ssh.Client {
ctx := Auth(f)
if err := ctx.Authenticate(); err != nil {
logrus.Errorf("error creating ziti context: %v", err)
logrus.Fatalf("could not authenticate. verify your identity is correct and matches all necessary authentication conditions.")
}
Expand Down

0 comments on commit f0abfa2

Please sign in to comment.