Skip to content

Commit

Permalink
fix fleet api
Browse files Browse the repository at this point in the history
  • Loading branch information
gurevichdmitry committed Nov 26, 2024
1 parent ceb0048 commit d6fedb3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
8 changes: 6 additions & 2 deletions tests/fleet_api/agent_policy_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

from typing import Optional

from fleet_api.base_call_api import APICallException, perform_api_call
from fleet_api.base_call_api import (
APICallException,
perform_api_call,
uses_new_fleet_api_response,
)
from loguru import logger
from munch import Munch, munchify

Expand Down Expand Up @@ -152,7 +156,7 @@ def get_agents(cfg: Munch) -> list:
url=url,
auth=cfg.auth,
)
if cfg.stack_version.startswith("9."):
if uses_new_fleet_api_response(cfg.stack_version):
return munchify(response.get("items", []))
return munchify(response.get("list", []))
except APICallException as api_ex:
Expand Down
13 changes: 13 additions & 0 deletions tests/fleet_api/base_call_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,16 @@ def perform_api_call(method, url, return_json=True, headers=None, auth=None, par
if not return_json:
return response.content
return response.json()


def uses_new_fleet_api_response(version: str) -> bool:
"""
Determine if the specified version uses the new Fleet API response format.
Args:
version (str): Elastic stack version.
Returns:
bool: True if the version uses the new Fleet API response format, False otherwise.
"""
return version.startswith("9.") or version.startswith("8.17")
12 changes: 8 additions & 4 deletions tests/fleet_api/common_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import time
from typing import Any, Dict, List

from fleet_api.base_call_api import APICallException, perform_api_call
from fleet_api.base_call_api import (
APICallException,
perform_api_call,
uses_new_fleet_api_response,
)
from fleet_api.utils import add_capabilities, add_tags, replace_image_field
from loguru import logger
from munch import Munch, munchify
Expand Down Expand Up @@ -39,7 +43,7 @@ def get_enrollment_token(cfg: Munch, policy_id: str) -> str:
auth=cfg.auth,
)
api_keys = munchify(response.get("list", []))
if cfg.stack_version.startswith("9."):
if uses_new_fleet_api_response(cfg.stack_version):
api_keys = munchify(response.get("items", []))
api_key = ""
for item in api_keys:
Expand Down Expand Up @@ -320,7 +324,7 @@ def get_package_version(

cloud_security_posture_version = None
packages = response.get("response", [])
if cfg.stack_version.startswith("9."):
if uses_new_fleet_api_response(cfg.stack_version):
packages = response.get("items", [])
for package in packages:
if package.get("name", "") == package_name:
Expand Down Expand Up @@ -370,7 +374,7 @@ def get_package(
params={"params": request_params},
)
package_data = response.get("response", {})
if cfg.stack_version.startswith("9."):
if uses_new_fleet_api_response(cfg.stack_version):
package_data = response.get("item", {})
return package_data
except APICallException as api_ex:
Expand Down
8 changes: 6 additions & 2 deletions tests/fleet_api/package_policy_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
This module contains API calls related to the package policy API.
"""

from fleet_api.base_call_api import APICallException, perform_api_call
from fleet_api.base_call_api import (
APICallException,
perform_api_call,
uses_new_fleet_api_response,
)
from fleet_api.utils import delete_key, update_key
from loguru import logger
from munch import Munch, munchify
Expand Down Expand Up @@ -129,7 +133,7 @@ def create_integration(cfg: Munch, pkg_policy: dict, agent_policy_id: str, data:
params={"json": pkg_policy},
)
policy_data = response.get("response", {}).get("item", {})
if cfg.stack_version.startswith("9."):
if uses_new_fleet_api_response(cfg.stack_version):
policy_data = response.get("item", {})
package_policy_id = policy_data.get("id", "")
logger.info(f"Package policy '{package_policy_id}' created successfully")
Expand Down

0 comments on commit d6fedb3

Please sign in to comment.