Skip to content

Commit

Permalink
fix(clients): fix metadata overwrite (#26)
Browse files Browse the repository at this point in the history
Because

- metadata will be overwritten if switch `instance`

This commit

- fix metada overwrite
  • Loading branch information
heiruwu authored Sep 28, 2023
1 parent 8ba2260 commit e332cb0
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 63 deletions.
24 changes: 12 additions & 12 deletions instill/clients/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@

class ConnectorClient(Client):
def __init__(self, namespace: str) -> None:
self.hosts = defaultdict(dict)
self.hosts: defaultdict = defaultdict(dict)
self.instance = "default"
self.namespace = namespace
self.metadata: str = ""

if global_config.hosts is not None:
for instance, config in global_config.hosts.items():
if not config.secure:
self.metadata = (
channel = grpc.insecure_channel(config.url)
self.hosts[instance]["metadata"] = (
(
"authorization",
f"Bearer {config.token}",
),
)
channel = grpc.insecure_channel(config.url)
else:
ssl_creds = grpc.ssl_channel_credentials()
call_creds = grpc.access_token_call_credentials(config.token)
Expand All @@ -43,6 +42,7 @@ def __init__(self, namespace: str) -> None:
target=config.url,
credentials=creds,
)
self.hosts[instance]["metadata"] = ""
self.hosts[instance]["token"] = config.token
self.hosts[instance]["channel"] = channel
self.hosts[instance][
Expand Down Expand Up @@ -107,7 +107,7 @@ def create_connector(
request=connector_interface.CreateUserConnectorResourceRequest(
connector_resource=connector, parent=self.namespace
),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)

return resp.connector_resource
Expand All @@ -120,7 +120,7 @@ def get_connector(self, name: str) -> connector_interface.ConnectorResource:
request=connector_interface.GetUserConnectorResourceRequest(
name=f"{self.namespace}/connector-resources/{name}"
),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
.connector_resource
)
Expand All @@ -133,7 +133,7 @@ def test_connector(self, name: str) -> connector_interface.ConnectorResource.Sta
request=connector_interface.TestUserConnectorResourceRequest(
name=f"{self.namespace}/connector-resources/{name}"
),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
.state
)
Expand All @@ -146,7 +146,7 @@ def execute_connector(self, name: str, inputs: list) -> list:
request=connector_interface.ExecuteUserConnectorResourceRequest(
name=f"{self.namespace}/connector-resources/{name}", inputs=inputs
),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
.outputs
)
Expand All @@ -159,7 +159,7 @@ def watch_connector(self, name: str) -> connector_interface.ConnectorResource.St
request=connector_interface.WatchUserConnectorResourceRequest(
name=f"{self.namespace}/connector-resources/{name}"
),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
.state
)
Expand All @@ -170,7 +170,7 @@ def delete_connector(self, name: str):
request=connector_interface.DeleteUserConnectorResourceRequest(
name=f"{self.namespace}/connector-resources/{name}"
),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)

@grpc_handler
Expand All @@ -180,12 +180,12 @@ def list_connectors(self, public=False) -> Tuple[list, str, int]:
request=connector_interface.ListUserConnectorResourcesRequest(
parent=self.namespace
),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
else:
resp = self.hosts[self.instance]["client"].ListConnectorResources(
request=connector_interface.ListConnectorResourcesRequest(),
metadata=(self.metadata,),
metadata=(self.hosts[self.instance]["metadata"],),
)

return resp.connector_resources, resp.next_page_token, resp.total_size
24 changes: 12 additions & 12 deletions instill/clients/mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@

class MgmtClient(Client):
def __init__(self) -> None:
self.hosts = defaultdict(dict)
self.hosts: defaultdict = defaultdict(dict)
self.instance: str = "default"
self.metadata: str = ""

if global_config.hosts is not None:
for instance, config in global_config.hosts.items():
if not config.secure:
self.metadata = (
channel = grpc.insecure_channel(config.url)
self.hosts[instance]["metadata"] = (
(
"authorization",
f"Bearer {config.token}",
),
)
channel = grpc.insecure_channel(config.url)
else:
ssl_creds = grpc.ssl_channel_credentials()
call_creds = grpc.access_token_call_credentials(config.token)
Expand All @@ -38,6 +37,7 @@ def __init__(self) -> None:
target=config.url,
credentials=creds,
)
self.hosts[instance]["metadata"] = ""
self.hosts[instance]["token"] = config.token
self.hosts[instance]["channel"] = channel
self.hosts[instance]["client"] = mgmt_service.MgmtPublicServiceStub(
Expand Down Expand Up @@ -103,15 +103,15 @@ def login(self, username="admin", password="password") -> str:
def get_token(self, name: str) -> mgmt_interface.ApiToken:
response = self.hosts[self.instance]["client"].GetToken(
request=mgmt_interface.GetTokenRequest(name=name),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
return response.token

@grpc_handler
def get_user(self) -> mgmt_interface.User:
response = self.hosts[self.instance]["client"].QueryAuthenticatedUser(
request=mgmt_interface.QueryAuthenticatedUserRequest(),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
return response.user

Expand All @@ -121,7 +121,7 @@ def list_pipeline_trigger_records(
) -> metric_interface.ListPipelineTriggerRecordsResponse:
return self.hosts[self.instance]["client"].ListPipelineTriggerRecords(
request=metric_interface.ListPipelineTriggerChartRecordsRequest(),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)

@grpc_handler
Expand All @@ -130,7 +130,7 @@ def list_pipeline_trigger_table_records(
) -> metric_interface.ListPipelineTriggerTableRecordsRequest:
return self.hosts[self.instance]["client"].ListPipelineTriggerRecords(
request=metric_interface.ListPipelineTriggerTableRecordsResponse(),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)

@grpc_handler
Expand All @@ -139,7 +139,7 @@ def list_pipeline_trigger_chart_records(
) -> metric_interface.ListPipelineTriggerChartRecordsResponse:
return self.hosts[self.instance]["client"].ListPipelineTriggerRecords(
request=metric_interface.ListPipelineTriggerChartRecordsRequest(),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)

@grpc_handler
Expand All @@ -148,7 +148,7 @@ def list_connector_execute_records(
) -> metric_interface.ListConnectorExecuteRecordsResponse:
return self.hosts[self.instance]["client"].ListPipelineTriggerRecords(
request=metric_interface.ListConnectorExecuteRecordsRequest(),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)

@grpc_handler
Expand All @@ -157,7 +157,7 @@ def list_connector_execute_table_records(
) -> metric_interface.ListConnectorExecuteTableRecordsResponse:
return self.hosts[self.instance]["client"].ListPipelineTriggerRecords(
request=metric_interface.ListConnectorExecuteTableRecordsRequest(),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)

@grpc_handler
Expand All @@ -166,5 +166,5 @@ def list_connector_execute_chart_records(
) -> metric_interface.ListConnectorExecuteChartRecordsResponse:
return self.hosts[self.instance]["client"].ListPipelineTriggerRecords(
request=metric_interface.ListConnectorExecuteChartRecordsRequest(),
metadata=self.metadata,
metadata=self.hosts[self.instance]["metadata"],
)
Loading

0 comments on commit e332cb0

Please sign in to comment.