Skip to content

Commit

Permalink
fixing internal (#15315)
Browse files Browse the repository at this point in the history
  • Loading branch information
seankane-msft authored Aug 18, 2021
1 parent ee892d5 commit ecb0e72
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 340 deletions.
306 changes: 19 additions & 287 deletions sdk/internal/recording/recording.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
package recording

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
Expand All @@ -20,7 +16,6 @@ import (
"path/filepath"
"strconv"
"strings"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
Expand All @@ -40,6 +35,7 @@ type Recording struct {
src rand.Source
now *time.Time
Sanitizer *Sanitizer
Matcher *RequestMatcher
c TestContext
}

Expand Down Expand Up @@ -69,8 +65,11 @@ const (
type VariableType string

const (
Default VariableType = "default"
Secret_String VariableType = "secret_string"
// NoSanitization indicates that the recorded value should not be sanitized.
NoSanitization VariableType = "default"
// Secret_String indicates that the recorded value should be replaced with a sanitized value.
Secret_String VariableType = "secret_string"
// Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value.
Secret_Base64String VariableType = "secret_base64String"
)

Expand Down Expand Up @@ -107,17 +106,18 @@ func NewRecording(c TestContext, mode RecordMode) (*Recording, error) {
}

// set the recorder Matcher
recording.Matcher = defaultMatcher(c)
rec.SetMatcher(recording.matchRequest)

// wire up the sanitizer
recording.Sanitizer = DefaultSanitizer(rec)
recording.Sanitizer = defaultSanitizer(rec)

return recording, err
}

// GetRecordedVariable returns a recorded variable. If the variable is not found we return an error
// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation.
func (r *Recording) GetRecordedVariable(name string, variableType VariableType) (string, error) {
// GetEnvVar returns a recorded environment variable. If the variable is not found we return an error.
// variableType determines how the recorded variable will be saved.
func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) {
var err error
result, ok := r.previousSessionVariables[name]
if !ok || r.Mode == Live {
Expand All @@ -132,9 +132,10 @@ func (r *Recording) GetRecordedVariable(name string, variableType VariableType)
return *result, err
}

// GetOptionalRecordedVariable returns a recorded variable with a fallback default value
// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation.
func (r *Recording) GetOptionalRecordedVariable(name string, defaultValue string, variableType VariableType) string {
// GetOptionalEnvVar returns a recorded environment variable with a fallback default value.
// default Value configures the fallback value to be returned if the environment variable is not set.
// variableType determines how the recorded variable will be saved.
func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string {
result, ok := r.previousSessionVariables[name]
if !ok || r.Mode == Live {
result = getOptionalEnv(name, defaultValue)
Expand Down Expand Up @@ -280,10 +281,10 @@ func getOptionalEnv(name string, defaultValue string) *string {
}

func (r *Recording) matchRequest(req *http.Request, rec cassette.Request) bool {
isMatch := compareMethods(req, rec, r.c) &&
compareURLs(req, rec, r.c) &&
compareHeaders(req, rec, r.c) &&
compareBodies(req, rec, r.c)
isMatch := r.Matcher.compareMethods(req, rec.Method) &&
r.Matcher.compareURLs(req, rec.URL) &&
r.Matcher.compareHeaders(req, rec) &&
r.Matcher.compareBodies(req, rec.Body)

return isMatch
}
Expand Down Expand Up @@ -432,272 +433,3 @@ var modeMap = map[RecordMode]recorder.Mode{
Live: recorder.ModeDisabled,
Playback: recorder.ModeReplaying,
}

var recordMode, _ = os.LookupEnv("AZURE_RECORD_MODE")
var ModeRecording = "record"
var ModePlayback = "playback"

var baseProxyURLSecure = "localhost:5001"
var baseProxyURL = "localhost:5000"
var startURL = baseProxyURLSecure + "/record/start"
var stopURL = baseProxyURLSecure + "/record/stop"

var recordingId string
var IdHeader = "x-recording-id"
var ModeHeader = "x-recording-mode"
var UpstreamUriHeader = "x-recording-upstream-base-uri"

var tr = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
var client = http.Client{
Transport: tr,
}

type RecordingOptions struct {
MaxRetries int32
UseHTTPS bool
Host string
Scheme string
}

func defaultOptions() *RecordingOptions {
return &RecordingOptions{
MaxRetries: 0,
UseHTTPS: true,
Host: "localhost:5001",
Scheme: "https",
}
}

func (r RecordingOptions) HostScheme() string {
if r.UseHTTPS {
return "https://localhost:5001"
}
return "http://localhost:5000"
}

func getTestId(t *testing.T) string {
cwd, err := os.Getwd()
if err != nil {
t.Errorf("Could not find current working directory")
}
cwd = "./recordings/" + t.Name() + ".json"
return cwd
}

func StartRecording(t *testing.T, options *RecordingOptions) error {
if options == nil {
options = defaultOptions()
}
if recordMode == "" {
t.Log("AZURE_RECORD_MODE was not set, options are \"record\" or \"playback\". \nDefaulting to playback")
recordMode = "playback"
} else {
t.Log("AZURE_RECORD_MODE: ", recordMode)
}
testId := getTestId(t)

url := fmt.Sprintf("%v/%v/start", options.HostScheme(), recordMode)

req, err := http.NewRequest("POST", url, nil)
if err != nil {
return err
}

req.Header.Set("x-recording-file", testId)

resp, err := client.Do(req)
if err != nil {
return err
}
recordingId = resp.Header.Get(IdHeader)
return nil
}

func StopRecording(t *testing.T, options *RecordingOptions) error {
if options == nil {
options = defaultOptions()
}

url := fmt.Sprintf("%v/%v/stop", options.HostScheme(), recordMode)
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return err
}
if recordingId == "" {
return errors.New("Recording ID was never set. Did you call StartRecording?")
}
req.Header.Set("x-recording-id", recordingId)
_, err = client.Do(req)
if err != nil {
t.Errorf(err.Error())
}
return nil
}

func AddUriSanitizer(replacement, regex string, options *RecordingOptions) error {
if options == nil {
options = defaultOptions()
}
url := fmt.Sprintf("%v/Admin/AddSanitizer", options.HostScheme())
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return err
}
req.Header.Set("x-abstraction-identifier", "UriRegexSanitizer")
bodyContent := map[string]string{
"value": replacement,
"regex": regex,
}
marshalled, err := json.Marshal(bodyContent)
if err != nil {
return err
}
req.Body = ioutil.NopCloser(bytes.NewReader(marshalled))
req.ContentLength = int64(len(marshalled))
_, err = client.Do(req)
return err
}

func (o *RecordingOptions) Init() {
if o.MaxRetries != 0 {
o.MaxRetries = 0
}
if o.UseHTTPS {
o.Host = baseProxyURLSecure
o.Scheme = "https"
} else {
o.Host = baseProxyURL
o.Scheme = "http"
}
}

// type recordingPolicy struct {
// options RecordingOptions
// }

// func NewRecordingPolicy(o *RecordingOptions) azcore.Policy {
// if o == nil {
// o = &RecordingOptions{}
// }
// p := &recordingPolicy{options: *o}
// p.options.init()
// return p
// }

// func (p *recordingPolicy) Do(req *azcore.Request) (resp *azcore.Response, err error) {
// originalURLHost := req.URL.Host
// req.URL.Scheme = "https"
// req.URL.Host = p.options.host
// req.Host = p.options.host

// req.Header.Set(UpstreamUriHeader, fmt.Sprintf("%v://%v", p.options.scheme, originalURLHost))
// req.Header.Set(ModeHeader, recordMode)
// req.Header.Set(recordingIdHeader, recordingId)

// return req.Next()
// }

// This looks up an environment variable and if it is not found, returns the recordedValue
func GetEnvVariable(t *testing.T, varName string, recordedValue string) string {
val, ok := os.LookupEnv(varName)
if !ok {
t.Logf("Could not find environment variable: %v", varName)
return recordedValue
}
return val
}

func LiveOnly(t *testing.T) {
if GetRecordMode() != ModeRecording {
t.Skip("Live Test Only")
}
}

// Function for sleeping during a test for `duration` seconds. This method will only execute when
// AZURE_RECORD_MODE = "record", if a test is running in playback this will be a noop.
func Sleep(duration int) {
if GetRecordMode() == ModeRecording {
time.Sleep(time.Duration(duration) * time.Second)
}
}

func GetRecordingId() string {
return recordingId
}

func GetRecordMode() string {
return recordMode
}

func InPlayback() bool {
return GetRecordMode() == ModePlayback
}

func InRecord() bool {
return GetRecordMode() == ModeRecording
}

// type FakeCredential struct {
// accountName string
// accountKey string
// }

// func NewFakeCredential(accountName, accountKey string) *FakeCredential {
// return &FakeCredential{
// accountName: accountName,
// accountKey: accountKey,
// }
// }

// func (f *FakeCredential) AuthenticationPolicy(azcore.AuthenticationPolicyOptions) azcore.Policy {
// return azcore.PolicyFunc(func(req *azcore.Request) (*azcore.Response, error) {
// authHeader := strings.Join([]string{"Authorization ", f.accountName, ":", f.accountKey}, "")
// req.Request.Header.Set(azcore.HeaderAuthorization, authHeader)
// return req.Next()
// })
// }

func getRootCas() (*x509.CertPool, error) {
localFile, ok := os.LookupEnv("PROXY_CERT")

rootCAs, err := x509.SystemCertPool()
if err != nil {
rootCAs = x509.NewCertPool()
}

if !ok {
fmt.Println("Could not find path to proxy certificate, set the environment variable 'PROXY_CERT' to the location of your certificate")
return rootCAs, nil
}

cert, err := ioutil.ReadFile(*&localFile)
if err != nil {
fmt.Println("error opening cert file")
return nil, err
}

if ok := rootCAs.AppendCertsFromPEM(cert); !ok {
fmt.Println("No certs appended, using system certs only")
}

return rootCAs, nil
}

func GetHTTPClient() (*http.Client, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()

rootCAs, err := getRootCas()
if err != nil {
return nil, err
}

transport.TLSClientConfig.RootCAs = rootCAs
transport.TLSClientConfig.MinVersion = tls.VersionTLS12

defaultHttpClient := &http.Client{
Transport: transport,
}
return defaultHttpClient, nil
}
Loading

0 comments on commit ecb0e72

Please sign in to comment.