Skip to content

Commit

Permalink
feat(inventory): support variables mapping (#10)
Browse files Browse the repository at this point in the history
* feat(inventory): support variables mapping

* fix: variables description
  • Loading branch information
quantumsheep authored Jun 6, 2023
1 parent 12255b6 commit 28052c1
Showing 1 changed file with 58 additions and 8 deletions.
66 changes: 58 additions & 8 deletions plugins/inventory/scaleway.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@
- public_ipv6
- hostname
- id
variables:
description:
- "Variables mapping to apply to hosts in the format destination_variable: host_variable."
- "You can use the following host variables:"
- " - C(id): The server id."
- " - C(tags): The server tags."
- " - C(zone): The server zone."
- " - C(state): The server state."
- " - C(hostname): The server hostname."
- " - C(public_ipv4): The server public ipv4."
- " - C(private_ipv4): The server private ipv4."
- " - C(public_ipv6): The server public ipv6."
- " - C(public_dns): The server public dns."
- " - C(private_dns): The server private dns."
- ""
- "If the variable is not found, the host will be ignored."
type: dict
"""

EXAMPLES = r"""
Expand All @@ -68,14 +85,17 @@
secret_key: <your secret key>
api_url: https://api.scaleway.com
regions:
- fr-par-2
- nl-ams-1
- fr-par-2
- nl-ams-1
tags:
- dev
- dev
variables:
ansible_host: public_ipv4
"""


from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, Optional

from ansible.errors import AnsibleError
Expand Down Expand Up @@ -115,14 +135,19 @@ class _Filters:
@dataclass
class _Host:
id: str
hostname: str
tags: List[str]
zone: "Zone"
state: str

hostname: str
public_ipv4: Optional[str]
private_ipv4: Optional[str]
public_ipv6: Optional[str]

# Instances-only
public_dns: Optional[str] = None
private_dns: Optional[str] = None


class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
NAME = "scaleway.scaleway.scaleway"
Expand Down Expand Up @@ -174,6 +199,7 @@ def parse(self, inventory, loader, path: str, cache=True):

def populate(self, results: List[_Host]):
hostnames = self.get_option("hostnames")
variables = self.get_option("variables") or {}

for result in results:
groups = self.get_host_groups(result)
Expand All @@ -184,9 +210,28 @@ def populate(self, results: List[_Host]):
self.display.warning(f"Skipping host {result.id}: {e}")
continue

self.inventory.add_host(host=hostname)

should_skip = False
for variable, source in variables.items():
value = getattr(result, source, None)
if not value:
self.display.warning(
f"Skipping host {result.id}: "
f"Field {source} is not available."
)
self.inventory.remove_host(SimpleNamespace(name=hostname))
should_skip = True
break

self.inventory.set_variable(hostname, variable, value)

if should_skip:
continue

for group in groups:
self.inventory.add_group(group=group)
self.inventory.add_host(group=group, host=hostname)
self.inventory.add_child(group=group, child=hostname)

def get_host_groups(self, host: _Host):
return set(host.tags).union(set([host.zone.replace("-", "_")]))
Expand Down Expand Up @@ -218,12 +263,15 @@ def _get_instances(self, client: "Client", filters: _Filters) -> List[_Host]:
for server in servers:
host = _Host(
id=server.id,
hostname=server.hostname,
tags=["instance", *server.tags],
zone=server.zone,
state=str(server.state),
hostname=server.hostname,
public_ipv4=server.public_ip.address if server.public_ip else None,
private_ipv4=server.private_ip,
public_ipv6=server.ipv6.address if server.ipv6 else None,
public_dns=f"{server.id}.pub.instances.scw.cloud",
private_dns=f"{server.id}.priv.instances.scw.cloud",
)

results.append(host)
Expand Down Expand Up @@ -256,9 +304,10 @@ def _get_elastic_metal(self, client: "Client", filters: _Filters) -> List[_Host]

host = _Host(
id=server.id,
hostname=server.name,
tags=["elastic_metal", *server.tags],
zone=server.zone,
state=str(server.status),
hostname=server.name,
public_ipv4=public_ipv4.address if public_ipv4 else None,
private_ipv4=None,
public_ipv6=public_ipv6.address if public_ipv6 else None,
Expand Down Expand Up @@ -287,9 +336,10 @@ def _get_apple_sillicon(self, client: "Client", filters: _Filters) -> List[_Host
for server in servers:
host = _Host(
id=server.id,
hostname=server.name,
tags=["apple_sillicon"],
zone=server.zone,
state=str(server.status),
hostname=server.name,
public_ipv4=server.ip,
private_ipv4=None,
public_ipv6=None,
Expand Down

0 comments on commit 28052c1

Please sign in to comment.