diff --git a/attachments.go b/attachments.go index ffb631a73..3a8da80a2 100644 --- a/attachments.go +++ b/attachments.go @@ -113,23 +113,10 @@ func FetchAndStoreAttachment(ctx context.Context, b Backend, channel Channel, at extension = extension[1:] } - // first try getting our mime type from the first 300 bytes of our body - fileType, _ := filetype.Match(trace.ResponseBody[:300]) - if fileType != filetype.Unknown { - mimeType = fileType.MIME.Value - extension = fileType.Extension - } else { - // if that didn't work, try from our extension - fileType = filetype.GetType(extension) - if fileType != filetype.Unknown { - mimeType = fileType.MIME.Value - extension = fileType.Extension - } - } - - // we still don't know our mime type, use our content header instead - if mimeType == "" { - mimeType, _, _ = mime.ParseMediaType(trace.Response.Header.Get("Content-Type")) + // prioritize to use the response content type header if provided + contentTypeHeader := trace.Response.Header.Get("Content-Type") + if contentTypeHeader != "" { + mimeType, _, _ = mime.ParseMediaType(contentTypeHeader) if extension == "" { extensions, err := mime.ExtensionsByType(mimeType) if extensions == nil || err != nil { @@ -138,6 +125,21 @@ func FetchAndStoreAttachment(ctx context.Context, b Backend, channel Channel, at extension = extensions[0][1:] } } + } else { + + // first try getting our mime type from the first 300 bytes of our body + fileType, _ := filetype.Match(trace.ResponseBody[:300]) + if fileType != filetype.Unknown { + mimeType = fileType.MIME.Value + extension = fileType.Extension + } else { + // if that didn't work, try from our extension + fileType = filetype.GetType(extension) + if fileType != filetype.Unknown { + mimeType = fileType.MIME.Value + extension = fileType.Extension + } + } } storageURL, err := b.SaveAttachment(ctx, channel, mimeType, trace.ResponseBody, extension) diff --git a/attachments_test.go b/attachments_test.go index 6f2f9f5fa..e4f3f4e8d 100644 --- a/attachments_test.go +++ b/attachments_test.go @@ -20,6 +20,9 @@ func TestFetchAndStoreAttachment(t *testing.T) { "http://mock.com/media/hello.jpg": { httpx.NewMockResponse(200, nil, testJPG), }, + "http://mock.com/media/hello2": { + httpx.NewMockResponse(200, map[string]string{"Content-Type": "image/jpeg"}, testJPG), + }, "http://mock.com/media/hello.mp3": { httpx.NewMockResponse(502, nil, []byte(`My gateways!`)), }, @@ -53,15 +56,26 @@ func TestFetchAndStoreAttachment(t *testing.T) { assert.Len(t, clog.HTTPLogs(), 1) assert.Equal(t, "http://mock.com/media/hello.jpg", clog.HTTPLogs()[0].URL) + att, err = courier.FetchAndStoreAttachment(ctx, mb, mockChannel, "http://mock.com/media/hello2", clog) + assert.NoError(t, err) + assert.Equal(t, "image/jpeg", att.ContentType) + assert.Equal(t, "https://backend.com/attachments/547deaf7-7620-4434-95b3-58675999c4b7.jpe", att.URL) + assert.Equal(t, 17301, att.Size) + + assert.Len(t, mb.SavedAttachments(), 2) + assert.Equal(t, &test.SavedAttachment{Channel: mockChannel, ContentType: "image/jpeg", Data: testJPG, Extension: "jpg"}, mb.SavedAttachments()[0]) + assert.Len(t, clog.HTTPLogs(), 2) + assert.Equal(t, "http://mock.com/media/hello2", clog.HTTPLogs()[1].URL) + // a non-200 response should return an unavailable attachment att, err = courier.FetchAndStoreAttachment(ctx, mb, mockChannel, "http://mock.com/media/hello.mp3", clog) assert.NoError(t, err) assert.Equal(t, &courier.Attachment{ContentType: "unavailable", URL: "http://mock.com/media/hello.mp3"}, att) // should have a logged HTTP request but no attachments will have been saved to storage - assert.Len(t, clog.HTTPLogs(), 2) - assert.Equal(t, "http://mock.com/media/hello.mp3", clog.HTTPLogs()[1].URL) - assert.Len(t, mb.SavedAttachments(), 1) + assert.Len(t, clog.HTTPLogs(), 3) + assert.Equal(t, "http://mock.com/media/hello.mp3", clog.HTTPLogs()[2].URL) + assert.Len(t, mb.SavedAttachments(), 2) // same for a connection error att, err = courier.FetchAndStoreAttachment(ctx, mb, mockChannel, "http://mock.com/media/hello.pdf", clog)