Skip to content

Commit

Permalink
Merge pull request #818 from gchq/feature/download-file
Browse files Browse the repository at this point in the history
  • Loading branch information
a3957273 authored Nov 4, 2023
2 parents 0ef684d + ebe81a4 commit 18df916
Show file tree
Hide file tree
Showing 15 changed files with 299 additions and 66 deletions.
23 changes: 6 additions & 17 deletions backend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"chalk": "^5.2.0",
"config": "^3.3.9",
"connect-mongo": "^5.1.0",
"content-disposition": "^0.5.4",
"cross-fetch": "^3.1.8",
"dedent-js": "^1.0.1",
"dev-null": "^0.1.1",
Expand Down
20 changes: 19 additions & 1 deletion backend/src/clients/s3.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { S3Client } from '@aws-sdk/client-s3'
import { GetObjectCommand, GetObjectRequest, S3Client } from '@aws-sdk/client-s3'
import { Upload } from '@aws-sdk/lib-storage'

import config from '../utils/v2/config.js'
Expand Down Expand Up @@ -31,3 +31,21 @@ export async function putObjectStream(bucket: string, key: string, body: Readabl
fileSize,
}
}

export async function getObjectStream(bucket: string, key: string, range?: { start: number; end: number }) {
const client = await getS3Client()

const input: GetObjectRequest = {
Bucket: bucket,
Key: key,
}

if (range) {
input.Range = `bytes=${range.start}-${range.end}`
}

const command = new GetObjectCommand(input)
const response = await client.send(command)

return response
}
21 changes: 20 additions & 1 deletion backend/src/connectors/v2/authorisation/Base.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import { AccessRequestDoc } from '../../../models/v2/AccessRequest.js'
import { FileInterfaceDoc } from '../../../models/v2/File.js'
import { ModelDoc, ModelVisibility } from '../../../models/v2/Model.js'
import { ReleaseDoc } from '../../../models/v2/Release.js'
import { SchemaDoc } from '../../../models/v2/Schema.js'
import { UserDoc } from '../../../models/v2/User.js'
import { Access } from '../../../routes/v1/registryAuth.js'
import authentication from '../authentication/index.js'

export const ModelAction = {
Expand Down Expand Up @@ -35,6 +37,17 @@ export const SchemaAction = {
}
export type SchemaActionKeys = (typeof SchemaAction)[keyof typeof SchemaAction]

export const FileAction = {
Download: 'download',
}
export type FileActionKeys = (typeof FileAction)[keyof typeof FileAction]

export const ImageAction = {
Pull: 'pull',
Push: 'push',
}
export type ImageActionKeys = (typeof ImageAction)[keyof typeof ImageAction]

export abstract class BaseAuthorisationConnector {
abstract userModelAction(user: UserDoc, model: ModelDoc, action: ModelActionKeys): Promise<boolean>
abstract userSchemaAction(user: UserDoc, Schema: SchemaDoc, action: SchemaActionKeys): Promise<boolean>
Expand All @@ -50,7 +63,13 @@ export abstract class BaseAuthorisationConnector {
accessRequest: AccessRequestDoc,
action: AccessRequestActionKeys,
): Promise<boolean>

abstract userFileAction(
user: UserDoc,
model: ModelDoc,
file: FileInterfaceDoc,
action: FileActionKeys,
): Promise<boolean>
abstract userImageAction(user: UserDoc, model: ModelDoc, access: Access, action: ImageActionKeys): Promise<boolean>
async hasModelVisibilityAccess(user: UserDoc, model: ModelDoc) {
if (model.visibility === ModelVisibility.Public) {
return true
Expand Down
75 changes: 75 additions & 0 deletions backend/src/connectors/v2/authorisation/silly.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import { AccessRequestDoc } from '../../../models/v2/AccessRequest.js'
import { FileInterfaceDoc } from '../../../models/v2/File.js'
import { ModelDoc } from '../../../models/v2/Model.js'
import { ReleaseDoc } from '../../../models/v2/Release.js'
import { SchemaDoc } from '../../../models/v2/Schema.js'
import { UserDoc } from '../../../models/v2/User.js'
import { Access } from '../../../routes/v1/registryAuth.js'
import { getAccessRequestsByModel } from '../../../services/v2/accessRequest.js'
import log from '../../../services/v2/log.js'
import { Roles } from '../authentication/Base.js'
import authentication from '../authentication/index.js'
import {
AccessRequestActionKeys,
BaseAuthorisationConnector,
FileAction,
FileActionKeys,
ImageAction,
ImageActionKeys,
ModelActionKeys,
ReleaseActionKeys,
SchemaActionKeys,
Expand Down Expand Up @@ -58,6 +66,73 @@ export class SillyAuthorisationConnector extends BaseAuthorisationConnector {
return true
}

async userFileAction(
user: UserDoc,
model: ModelDoc,
file: FileInterfaceDoc,
action: FileActionKeys,
): Promise<boolean> {
// Prohibit non-collaborators from seeing private models
if (!(await this.hasModelVisibilityAccess(user, model))) {
return false
}

const entities = await authentication.getEntities(user)
if (model.collaborators.some((collaborator) => entities.includes(collaborator.entity))) {
// Collaborators can upload or download files
return true
}

if (action !== FileAction.Download) {
log.warn({ userDn: user.dn, file: file._id }, 'Non-collaborator can only download artefacts')
return false
}

const accessRequests = await getAccessRequestsByModel(user, model.id)
const accessRequest = accessRequests.find((accessRequest) =>
accessRequest.metadata.overview.entities.some((entity) => entities.includes(entity)),
)

if (!accessRequest) {
// User does not have a valid access request
log.warn({ userDn: user.dn, file: file._id }, 'No valid access request found')
return false
}

return true
}

async userImageAction(user: UserDoc, model: ModelDoc, access: Access, action: ImageActionKeys): Promise<boolean> {
// Prohibit non-collaborators from seeing private models
if (!(await this.hasModelVisibilityAccess(user, model))) {
return false
}

const entities = await authentication.getEntities(user)
if (model.collaborators.some((collaborator) => entities.includes(collaborator.entity))) {
// Collaborators can upload or download files
return true
}

if (action !== ImageAction.Pull) {
log.warn({ userDn: user.dn, access }, 'Non-collaborator can only pull models')
return false
}

const accessRequests = await getAccessRequestsByModel(user, model.id)
const accessRequest = accessRequests.find((accessRequest) =>
accessRequest.metadata.overview.entities.some((entity) => entities.includes(entity)),
)

if (!accessRequest) {
// User does not have a valid access request
log.warn({ userDn: user.dn, access }, 'No valid access request found')
return false
}

return true
}

async userSchemaAction(user: UserDoc, _schema: SchemaDoc, _action: SchemaActionKeys) {
return authentication.hasRole(user, Roles.Admin)
}
Expand Down
2 changes: 2 additions & 0 deletions backend/src/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import { getModelAccessRequests } from './routes/v2/model/accessRequest/getModel
import { patchAccessRequest } from './routes/v2/model/accessRequest/patchAccessRequest.js'
import { postAccessRequest } from './routes/v2/model/accessRequest/postAccessRequest.js'
import { deleteFile } from './routes/v2/model/file/deleteFile.js'
import { getDownloadFile } from './routes/v2/model/file/getDownloadFile.js'
import { getFiles } from './routes/v2/model/file/getFiles.js'
import { postFinishMultipartUpload } from './routes/v2/model/file/postFinishMultipartUpload.js'
import { postSimpleUpload } from './routes/v2/model/file/postSimpleUpload.js'
Expand Down Expand Up @@ -226,6 +227,7 @@ if (config.experimental.v2) {
server.post('/api/v2/model/:modelId/access-request/:accessRequestId/review', ...postAccessRequestReviewResponse)

server.get('/api/v2/model/:modelId/files', ...getFiles)
server.get('/api/v2/model/:modelId/file/:fileId/download', ...getDownloadFile)
server.post('/api/v2/model/:modelId/files/upload/simple', ...postSimpleUpload)
server.post('/api/v2/model/:modelId/files/upload/multipart/start', ...postStartMultipartUpload)
server.post('/api/v2/model/:modelId/files/upload/multipart/finish', ...postFinishMultipartUpload)
Expand Down
32 changes: 4 additions & 28 deletions backend/src/routes/v1/registryAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import jwt from 'jsonwebtoken'
import { isEqual } from 'lodash-es'
import { stringify as uuidStringify, v4 as uuidv4 } from 'uuid'

import authentication from '../../connectors/v2/authentication/index.js'
import { ImageAction } from '../../connectors/v2/authorisation/Base.js'
import authorisation from '../../connectors/v2/authorisation/index.js'
import { ModelDoc } from '../../models/v2/Model.js'
import { UserDoc as UserDocV2 } from '../../models/v2/User.js'
import { findDeploymentByUuid } from '../../services/deployment.js'
import { getAccessRequestsByModel } from '../../services/v2/accessRequest.js'
import log from '../../services/v2/log.js'
import { getModelById } from '../../services/v2/model.js'
import { ModelId, UserDoc } from '../../types/types.js'
Expand Down Expand Up @@ -175,32 +175,8 @@ async function checkAccessV2(access: Access, user: UserDocV2) {
return false
}

const entities = await authentication.getEntities(user)
if (model.collaborators.some((collaborator) => entities.includes(collaborator.entity))) {
// They are a collaborator to the model, let them push or pull.
return true
}

if (!isEqual(access.actions, ['pull'])) {
// If users are not collaborators, they should only be able to pull
log.warn({ userDn: user.dn, access }, 'Non-collaborator can only pull models')
return false
}

// TODO: If the model is 'public access' automatically approve pulls.

const accessRequests = await getAccessRequestsByModel(user, modelId)
const accessRequest = accessRequests.find((accessRequest) =>
accessRequest.metadata.overview.entities.some((entity) => entities.includes(entity)),
)

if (!accessRequest) {
// User does not have a valid access request
log.warn({ userDn: user.dn, access }, 'No valid access request found')
return false
}

return true
const action = isEqual(access.actions, ['pull']) ? ImageAction.Pull : ImageAction.Push
return authorisation.userImageAction(user, model, access, action)
}

async function checkAccess(access: Access, user: UserDoc) {
Expand Down
57 changes: 57 additions & 0 deletions backend/src/routes/v2/model/file/getDownloadFile.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import bodyParser from 'body-parser'
import contentDisposition from 'content-disposition'
import { Request, Response } from 'express'
import stream from 'stream'
import { z } from 'zod'

import { FileInterface } from '../../../../models/v2/File.js'
import { downloadFile, getFileById } from '../../../../services/v2/file.js'
import { BadReq, InternalError } from '../../../../utils/v2/error.js'
import { parse } from '../../../../utils/validate.js'

export const getDownloadFileSchema = z.object({
params: z.object({
modelId: z.string(),
fileId: z.string(),
}),
})

interface GetDownloadFileResponse {
files: Array<FileInterface>
}

export const getDownloadFile = [
bodyParser.json(),
async (req: Request, res: Response<GetDownloadFileResponse>) => {
const {
params: { fileId },
} = parse(req, getDownloadFileSchema)

const file = await getFileById(req.user, fileId)

// required to support utf-8 file names
res.set('Content-Disposition', contentDisposition(file.name, { type: 'inline' }))
res.set('Content-Type', file.mime)
res.set('Cache-Control', 'public, max-age=604800, immutable')

if (req.headers.range) {
// TODO: support ranges
throw BadReq('Ranges are not supported', { fileId })
}

res.set('Content-Length', String(file.size))
// TODO: support ranges
// res.set('Accept-Ranges', 'bytes')

const stream = await downloadFile(req.user, fileId)

if (!stream.Body) {
throw InternalError('We were not able to retrieve the body of this file', { fileId })
}

res.writeHead(200)

// The AWS library doesn't seem to properly type 'Body' as being pipeable?
;(stream.Body as stream.Readable).pipe(res)
},
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
"description": "A description of what the model does.",
"type": "string",
"minLength": 1,
"maxLength": 5000,
"widget": "customTextInput"
"maxLength": 5000
},
"tags": {
"title": "Descriptive tags for the model.",
Expand Down
16 changes: 14 additions & 2 deletions backend/src/services/v2/file.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { putObjectStream } from '../../clients/s3.js'
import { ModelAction } from '../../connectors/v2/authorisation/Base.js'
import { getObjectStream, putObjectStream } from '../../clients/s3.js'
import { FileAction, ModelAction } from '../../connectors/v2/authorisation/Base.js'
import authorisation from '../../connectors/v2/authorisation/index.js'
import FileModel from '../../models/v2/File.js'
import { UserDoc } from '../../models/v2/User.js'
Expand Down Expand Up @@ -37,6 +37,18 @@ export async function uploadFile(user: UserDoc, modelId: string, name: string, m
return file
}

export async function downloadFile(user: UserDoc, fileId: string, range?: { start: number; end: number }) {
const file = await getFileById(user, fileId)
const model = await getModelById(user, file.modelId)

const access = await authorisation.userFileAction(user, model, file, FileAction.Download)
if (!access) {
throw Forbidden(`You do not have permission to download this model.`, { user: user.dn, fileId })
}

return getObjectStream(file.bucket, file.path, range)
}

export async function getFileById(user: UserDoc, fileId: string) {
const file = await FileModel.findOne({
_id: fileId,
Expand Down
4 changes: 4 additions & 0 deletions backend/src/utils/v2/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ export function Forbidden(message: string, context?: BailoError['context'], logg
export function NotFound(message: string, context?: BailoError['context'], logger?: Logger) {
return GenericError(404, message, context, logger)
}

export function InternalError(message: string, context?: BailoError['context'], logger?: Logger) {
return GenericError(500, message, context, logger)
}
Loading

0 comments on commit 18df916

Please sign in to comment.