Skip to content

Commit

Permalink
Merge pull request #108 from NASA-AMMOS/download-upload-plan-tags
Browse files Browse the repository at this point in the history
Download/Upload Plan Tags
  • Loading branch information
cartermak authored Jan 26, 2024
2 parents 5a5916b + b0787f0 commit d1896bb
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 5 deletions.
71 changes: 71 additions & 0 deletions src/aerie_cli/aerie_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def get_activity_plan_by_id(self, plan_id: int, full_args: str = None) -> Activi
simulations{
id
}
tags {
tag {
id
name
}
}
activity_directives(order_by: { start_offset: asc }) {
id
name
Expand Down Expand Up @@ -92,6 +98,12 @@ def list_all_activity_plans(self) -> List[ActivityPlanRead]:
simulations{
id
}
tags {
tag {
id
name
}
}
}
}
"""
Expand Down Expand Up @@ -147,6 +159,60 @@ def get_plan_id_by_sim_id(self, simulation_dataset_id: int) -> int:
simulation_dataset_id=simulation_dataset_id
)
return resp['simulation']['plan']['id']

def get_tag_id_by_name(self, tag_name: str):
get_tags_by_name_query = """
query GetTagByName($name: String) {
tags(where: {name: {_eq: $name}}) {
id
}
}
"""

#make default color of tag white
create_new_tag = """
mutation CreateNewTag($name: String, $color: String = "#FFFFFF") {
insert_tags_one(object: {name: $name, color: $color}) {
id
}
}
"""

resp = self.aerie_host.post_to_graphql(
get_tags_by_name_query,
name=tag_name
)

#if a tag with the specified name exists then returns the ID, else creates a new tag with this name
if len(resp) > 0:
return resp[0]["id"]
else:
new_tag_resp = self.aerie_host.post_to_graphql(
create_new_tag,
name=tag_name
)

return new_tag_resp["id"]

def add_plan_tag(self, plan_id: int, tag_name: str):
add_tag_to_plan = """
mutation AddTagToPlan($plan_id: Int, $tag_id: Int) {
insert_plan_tags(objects: {plan_id: $plan_id, tag_id: $tag_id}) {
returning {
tag_id
}
}
}
"""

#add tag to plan
resp = self.aerie_host.post_to_graphql(
add_tag_to_plan,
plan_id=plan_id,
tag_id=self.get_tag_id_by_name(tag_name)
)

return resp['returning'][0]

def create_activity_plan(
self, model_id: int, plan_to_create: ActivityPlanCreate
Expand All @@ -167,6 +233,11 @@ def create_activity_plan(
)
plan_id = plan_resp["id"]
plan_revision = plan_resp["revision"]

#add plan tags if exists from plan_to_create
for tag in plan_to_create.tags:
self.add_plan_tag(plan_id, tag["tag"]["name"])

# This loop exists to make sure all anchor IDs are updated as necessary

# Deep copy activities so we can augment and pop from the list
Expand Down
6 changes: 6 additions & 0 deletions src/aerie_cli/schemas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ class ApiActivityPlanRead(ApiActivityPlanBase):
converter=converters.optional(
lambda listOfDicts: [ApiActivityRead.from_dict(d) if isinstance(d, dict) else d for d in listOfDicts])
)
tags: Optional[List[Dict]] = field(
default = [],
converter=converters.optional(
lambda listOfDicts: [d for d in listOfDicts]
)
)


@define
Expand Down
16 changes: 15 additions & 1 deletion src/aerie_cli/schemas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class EmptyActivityPlan(ClientSerialize):
)
end_time: Arrow = field(
converter = arrow.get)
tags: Optional[List[Dict]]

def duration(self) -> timedelta:
return self.end_time - self.start_time
Expand Down Expand Up @@ -145,7 +146,12 @@ class ActivityPlanCreate(EmptyActivityPlan):
sim_id: Optional[int] = field(
default=None
)

tags: Optional[List[Dict]] = field(
default=[],
converter=converters.optional(
lambda listOfDicts: [d for d in listOfDicts]
)
)

@classmethod
def from_plan_read(cls, plan_read: "ActivityPlanRead") -> "ActivityPlanCreate":
Expand All @@ -154,6 +160,7 @@ def from_plan_read(cls, plan_read: "ActivityPlanRead") -> "ActivityPlanCreate":
start_time=plan_read.start_time,
end_time=plan_read.end_time,
activities=plan_read.activities,
tags=plan_read.tags
)

def to_api_create(self, model_id: int) -> "ApiActivityPlanCreate":
Expand All @@ -170,6 +177,12 @@ class ActivityPlanRead(EmptyActivityPlan):
id: int
model_id: int
sim_id: int
tags: Optional[List[Dict]] = field(
default = [],
converter=converters.optional(
lambda listOfDicts: [d for d in listOfDicts]
)
)
activities: Optional[List[Activity]] = field(
default = None,
converter=converters.optional(
Expand Down Expand Up @@ -220,6 +233,7 @@ def from_api_read(cls, api_plan_read: ApiActivityPlanRead) -> "ActivityPlanRead"
sim_id=api_plan_read.simulations[0]["id"],
start_time=plan_start,
end_time=plan_start + api_plan_read.duration,
tags=api_plan_read.tags,
activities= None if api_plan_read.activity_directives is None else [
Activity.from_api_read(api_activity)
for api_activity in api_plan_read.activity_directives
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"id": 1,
"model_id": 1,
"sim_id": 1,
"tags": [],
"activities": [
{
"type": "ACT_One",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"request": {
"query": "query get_plans ($plan_id: Int!) { plan_by_pk(id: $plan_id) { id model_id name start_time duration simulations{ id } activity_directives(order_by: { start_offset: asc }) { id name type start_offset arguments metadata anchor_id anchored_to_start } } }",
"query": "query get_plans ($plan_id: Int!) { plan_by_pk(id: $plan_id) { id model_id name start_time duration simulations{ id } tags { tag { id name } } activity_directives(order_by: { start_offset: asc }) { id name type start_offset arguments metadata anchor_id anchored_to_start } } }",
"variables": {
"plan_id": 1
}
Expand Down
13 changes: 11 additions & 2 deletions tests/unit_tests/files/mock_responses/list_all_activity_plans.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"request": {
"query": "\n query list_all_plans {\n plan(order_by: { id: asc }) {\n id\n model_id\n name\n start_time\n duration\n simulations{\n id\n }\n }\n }",
"query": "\n query list_all_plans {\n plan(order_by: { id: asc }) {\n id\n model_id\n name\n start_time\n duration\n simulations{\n id\n }\n tags {\n tag {\n id\n name\n }\n }\n }\n }",
"variables": {}
},
"response": [
Expand All @@ -15,7 +15,8 @@
{
"id": 1
}
]
],
"tags": []
},
{
"id": 2,
Expand All @@ -27,6 +28,14 @@
{
"id": 2
}
],
"tags": [
{
"tag": {
"id": 1,
"name": "Test"
}
}
]
}
]
Expand Down
11 changes: 10 additions & 1 deletion tests/unit_tests/test_aerie_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def test_list_all_activity_plans():
{
"id": 1
}
]
],
"tags": []
},
{
"id": 2,
Expand All @@ -100,6 +101,14 @@ def test_list_all_activity_plans():
{
"id": 2
}
],
"tags": [
{
"tag": {
"id": 1,
"name": "Test"
}
}
]
}
]""")
Expand Down

0 comments on commit d1896bb

Please sign in to comment.