Skip to content

Commit

Permalink
Fix Context
Browse files Browse the repository at this point in the history
  • Loading branch information
spiegel-im-spiegel committed Jan 13, 2021
1 parent 2dc5f97 commit cc74ee5
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 26 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ import "github.com/spiegel-im-spiegel/fetch"
package main

import (
"context"
"fmt"
"io"
"net/http"
"os"

"github.com/spiegel-im-spiegel/fetch"
Expand All @@ -30,13 +32,18 @@ func main() {
fmt.Fprintln(os.Stderr, err)
return
}
resp, err := fetch.New().Get(u)
resp, err := fetch.New(
fetch.WithHTTPClient(&http.Client{}),
).Get(
u,
fetch.WithContext(context.Background()),
)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
defer resp.Close()
if _, err := io.Copy(os.Stdout, resp.Body); err != nil {
if _, err := io.Copy(os.Stdout, resp.Body()); err != nil {
fmt.Fprintln(os.Stderr, err)
}
}
Expand Down
39 changes: 20 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,20 @@ import (

// client is client class for fetching (internal).
type client struct {
ctx context.Context
client *http.Client
}

type ClientOpts func(*client)

// New function returns Client instance.
func New(opts ...ClientOpts) Client {
cli := &client{ctx: context.Background(), client: http.DefaultClient}
cli := &client{client: http.DefaultClient}
for _, opt := range opts {
opt(cli)
}
return cli
}

// WithProtocol returns function for setting context.Context.
func WithContext(ctx context.Context) ClientOpts {
return func(c *client) {
c.ctx = ctx
}
}

// WithProtocol returns function for setting http.Client.
func WithHTTPClient(cli *http.Client) ClientOpts {
return func(c *client) {
Expand All @@ -42,7 +34,7 @@ func WithHTTPClient(cli *http.Client) ClientOpts {

// Get method returns respons data from URL by GET method.
func (c *client) Get(u *url.URL, opts ...RequestOpts) (Response, error) {
req, err := c.request(http.MethodGet, u, nil, opts...)
req, err := request(http.MethodGet, u, nil, opts...)
if err != nil {
return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", u.String()))
}
Expand All @@ -55,7 +47,7 @@ func (c *client) Get(u *url.URL, opts ...RequestOpts) (Response, error) {

// Post method returns respons data from URL by POST method.
func (c *client) Post(u *url.URL, payload io.Reader, opts ...RequestOpts) (Response, error) {
req, err := c.request(http.MethodPost, u, payload, opts...)
req, err := request(http.MethodPost, u, payload, opts...)
if err != nil {
return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", u.String()))
}
Expand All @@ -66,30 +58,39 @@ func (c *client) Post(u *url.URL, payload io.Reader, opts ...RequestOpts) (Respo
return resp, nil
}

// WithProtocol returns function for setting context.Context.
func WithContext(ctx context.Context) RequestOpts {
return func(req *http.Request) *http.Request {
if ctx != nil {
req = req.WithContext(ctx)
}
return req
}
}

// WithRequestHeaderAdd returns function for adding request header in http.Request.
func WithRequestHeaderAdd(name, value string) RequestOpts {
return func(req *http.Request) {
return func(req *http.Request) *http.Request {
req.Header.Add(name, value)
return req
}
}

// WithRequestHeaderSet returns function for setting request header in http.Request.
func WithRequestHeaderSet(name, value string) RequestOpts {
return func(req *http.Request) {
return func(req *http.Request) *http.Request {
req.Header.Set(name, value)
return req
}
}

func (c *client) request(method string, u *url.URL, payload io.Reader, opts ...RequestOpts) (*http.Request, error) {
if c == nil {
c = New().(*client)
}
req, err := http.NewRequestWithContext(c.ctx, method, u.String(), payload)
func request(method string, u *url.URL, payload io.Reader, opts ...RequestOpts) (*http.Request, error) {
req, err := http.NewRequest(method, u.String(), payload)
if err != nil {
return nil, errs.Wrap(err)
}
for _, opt := range opts {
opt(req)
req = opt(req)
}
return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net/url"
)

type RequestOpts func(*http.Request)
type RequestOpts func(*http.Request) *http.Request

// Client is inteface class for HTTP client.
type Client interface {
Expand Down
6 changes: 4 additions & 2 deletions fetch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ func TestGet(t *testing.T) {
fmt.Printf("Info: %+v\n", err)
} else {
resp, err := fetch.New(
fetch.WithContext(context.Background()),
fetch.WithHTTPClient(&http.Client{}),
).Get(u)
).Get(
u,
fetch.WithContext(context.Background()),
)
if err != nil {
if !errors.Is(err, tc.err2) {
t.Errorf("fetch.Client.Get() is \"%v\", want \"%+v\"", err, tc.err2)
Expand Down
11 changes: 9 additions & 2 deletions sample/sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package main

import (
"context"
"fmt"
"io"
"net/http"
"os"

"github.com/spiegel-im-spiegel/fetch"
Expand All @@ -16,13 +18,18 @@ func main() {
fmt.Fprintln(os.Stderr, err)
return
}
resp, err := fetch.New().Get(u)
resp, err := fetch.New(
fetch.WithHTTPClient(&http.Client{}),
).Get(
u,
fetch.WithContext(context.Background()),
)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
defer resp.Close()
if _, err := io.Copy(os.Stdout, resp.Body); err != nil {
if _, err := io.Copy(os.Stdout, resp.Body()); err != nil {
fmt.Fprintln(os.Stderr, err)
}
}

0 comments on commit cc74ee5

Please sign in to comment.