diff --git a/bin/cml/runner/launch.js b/bin/cml/runner/launch.js index 20c0b9dd0..b548680fb 100755 --- a/bin/cml/runner/launch.js +++ b/bin/cml/runner/launch.js @@ -135,6 +135,7 @@ const runCloud = async (opts) => { cloudStartupScript: startupScript, cloudAwsSecurityGroup: awsSecurityGroup, cloudAwsSubnet: awsSubnet, + cloudKubernetesNodeSelector: kubernetesNodeSelector, cloudImage: image, workdir } = opts; @@ -172,6 +173,7 @@ const runCloud = async (opts) => { startupScript, awsSecurityGroup, awsSubnet, + kubernetesNodeSelector, image, dockerVolumes }); @@ -216,7 +218,8 @@ const runCloud = async (opts) => { single: attributes.single, spot: attributes.spot, spotPrice: attributes.spot_price, - timeouts: attributes.timeouts + timeouts: attributes.timeouts, + kubernetesNodeSelector: attributes.kubernetes_node_selector }; winston.info(JSON.stringify(nonSensitiveValues)); } @@ -588,6 +591,17 @@ exports.options = kebabcaseKeys({ description: 'Specifies the subnet to use within AWS', alias: 'cloud-aws-subnet-id' }, + cloudKubernetesNodeSelector: { + type: 'array', + string: true, + default: [], + coerce: (items) => { + const keyValuePairs = items.map((item) => [...item.split(/=(.+)/), null]); + return Object.fromEntries(keyValuePairs); + }, + description: + 'Key Value pairs to specify the node selector to use within Kubernetes i.e. tags/labels "key=value". If not provided a default "accelerator=infer" key pair will be used' + }, cloudImage: { type: 'string', description: 'Custom machine/container image', diff --git a/src/terraform.js b/src/terraform.js index 4b46c6212..344438f4d 100644 --- a/src/terraform.js +++ b/src/terraform.js @@ -115,7 +115,10 @@ const iterativeCmlRunnerTpl = (opts = {}) => { ...(opts.sshPrivate && { ssh_private: opts.sshPrivate }), ...(opts.startupScript && { startup_script: opts.startupScript }), ...(opts.token && { token: opts.token }), - ...(opts.type && { instance_type: opts.type }) + ...(opts.type && { instance_type: opts.type }), + ...(opts.kubernetesNodeSelector && { + kubernetes_node_selector: opts.kubernetesNodeSelector + }) } } } diff --git a/src/terraform.test.js b/src/terraform.test.js index e61a72ff2..c165cdd04 100644 --- a/src/terraform.test.js +++ b/src/terraform.test.js @@ -209,6 +209,71 @@ describe('Terraform tests', () => { `); }); + test('basic settings with kubernetes node selector', async () => { + const output = iterativeCmlRunnerTpl({ + repo: 'https://', + token: 'abc', + driver: 'gitlab', + labels: 'mylabel', + idleTimeout: 300, + name: 'myrunner', + single: true, + cloud: 'aws', + region: 'west', + type: 'mymachinetype', + gpu: 'mygputype', + hddSize: 50, + sshPrivate: 'myprivate', + spot: true, + spotPrice: '0.0001', + kubernetesNodeSelector: { + accelerator: 'infer', + ram: null, + 'disk type': 'hard drives' + } + }); + expect(JSON.stringify(output, null, 2)).toMatchInlineSnapshot(` + "{ + \\"terraform\\": { + \\"required_providers\\": { + \\"iterative\\": { + \\"source\\": \\"iterative/iterative\\" + } + } + }, + \\"provider\\": { + \\"iterative\\": {} + }, + \\"resource\\": { + \\"iterative_cml_runner\\": { + \\"runner\\": { + \\"cloud\\": \\"aws\\", + \\"driver\\": \\"gitlab\\", + \\"instance_gpu\\": \\"mygputype\\", + \\"instance_hdd_size\\": 50, + \\"idle_timeout\\": 300, + \\"labels\\": \\"mylabel\\", + \\"name\\": \\"myrunner\\", + \\"region\\": \\"west\\", + \\"repo\\": \\"https://\\", + \\"single\\": true, + \\"spot\\": true, + \\"spot_price\\": \\"0.0001\\", + \\"ssh_private\\": \\"myprivate\\", + \\"token\\": \\"abc\\", + \\"instance_type\\": \\"mymachinetype\\", + \\"kubernetes_node_selector\\": { + \\"accelerator\\": \\"infer\\", + \\"ram\\": null, + \\"disk type\\": \\"hard drives\\" + } + } + } + } + }" + `); + }); + test('basic settings with docker volumes', async () => { const output = iterativeCmlRunnerTpl({ repo: 'https://',