Skip to content

Commit

Permalink
Merge pull request #2151 from FedML-AI/dimitris/fix_pending_requests_…
Browse files Browse the repository at this point in the history
…counter

Adding hash set for counting the number of pending requests per endpoint.
  • Loading branch information
Raphael-Jin authored Jun 10, 2024
2 parents 6b33065 + c29cf1d commit e667ded
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def set_user_setting_replica_num(self, end_point_id,
"target_queries_per_replica": target_queries_per_replica,
"aggregation_window_size_seconds": aggregation_window_size_seconds,
"scale_down_delay_seconds": scale_down_delay_seconds,
"request_timeout_sec": timeout_s
ServerConstants.INFERENCE_REQUEST_TIMEOUT_KEY: timeout_s
}
try:
self.redis_connection.set(self.get_user_setting_replica_num_key(end_point_id), json.dumps(replica_num_dict))
Expand Down Expand Up @@ -974,20 +974,21 @@ def delete_endpoint_scaling_down_decision_time(self, end_point_id) -> bool:
self.FEDML_MODEL_ENDPOINT_SCALING_DOWN_DECISION_TIME_TAG,
end_point_id))

def get_pending_requests_counter(self) -> int:
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
return int(self.redis_connection.get(self.FEDML_PENDING_REQUESTS_COUNTER))
def get_pending_requests_counter(self, end_point_id) -> int:
# If the endpoint does not exist inside the Hash collection, set its counter to 0.
if self.redis_connection.hexists(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id):
return int(self.redis_connection.hget(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id))
return 0

def update_pending_requests_counter(self, increase=False, decrease=False) -> int:
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
def update_pending_requests_counter(self, end_point_id, increase=False, decrease=False) -> int:
if not self.redis_connection.hexists(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id):
self.redis_connection.hset(self.FEDML_PENDING_REQUESTS_COUNTER, mapping={end_point_id: 0})
if increase:
self.redis_connection.incr(self.FEDML_PENDING_REQUESTS_COUNTER)
self.redis_connection.hincrby(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id, 1)
if decrease:
# Careful on the negative, there is no native function for hash decreases.
self.redis_connection.hincrby(self.FEDML_PENDING_REQUESTS_COUNTER, end_point_id, -1)
# Making sure the counter never becomes negative!
if self.get_pending_requests_counter() < 0:
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
else:
self.redis_connection.decr(self.FEDML_PENDING_REQUESTS_COUNTER)
return self.get_pending_requests_counter()
if self.get_pending_requests_counter(end_point_id) < 0:
self.redis_connection.hset(self.FEDML_PENDING_REQUESTS_COUNTER, mapping={end_point_id: 0})
return self.get_pending_requests_counter(end_point_id)
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ async def auth_middleware(request: Request, call_next):
{"error": True, "message": "Invalid JSON."},
status_code=status.HTTP_400_BAD_REQUEST)

# Get total pending requests.
pending_requests_num = FEDML_MODEL_CACHE.get_pending_requests_counter()
# Get endpoint's total pending requests.
end_point_id = request_json.get("end_point_id", None)
pending_requests_num = FEDML_MODEL_CACHE.get_pending_requests_counter(end_point_id)
if pending_requests_num:
end_point_id = request_json.get("end_point_id", None)
# Fetch metrics of the past k=3 requests.
pask_k_metrics = FEDML_MODEL_CACHE.get_endpoint_metrics(
end_point_id=end_point_id,
k_recent=3)

# Get the request timeout from the endpoint settings.
request_timeout_s = FEDML_MODEL_CACHE.get_endpoint_settings(end_point_id) \
.get("request_timeout_s", ClientConstants.INFERENCE_REQUEST_TIMEOUT)
.get(ServerConstants.INFERENCE_REQUEST_TIMEOUT_KEY, ServerConstants.INFERENCE_REQUEST_TIMEOUT_DEFAULT)

# Only proceed if the past k metrics collection is not empty.
if pask_k_metrics:
Expand All @@ -76,7 +76,8 @@ async def auth_middleware(request: Request, call_next):
mean_latency = sum(past_k_latencies_sec) / len(past_k_latencies_sec)

# If timeout threshold is exceeded then cancel and return time out error.
if (mean_latency * pending_requests_num) > request_timeout_s:
should_block = (mean_latency * pending_requests_num) > request_timeout_s
if should_block:
return JSONResponse(
{"error": True, "message": "Request timed out."},
status_code=status.HTTP_504_GATEWAY_TIMEOUT)
Expand Down Expand Up @@ -173,7 +174,7 @@ async def _predict(
header=None
) -> Union[MutableMapping[str, Any], Response, StreamingResponse]:
# Always increase the pending requests counter on a new incoming request.
FEDML_MODEL_CACHE.update_pending_requests_counter(increase=True)
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, increase=True)
inference_response = {}

try:
Expand Down Expand Up @@ -205,14 +206,14 @@ async def _predict(
if not is_endpoint_activated(in_end_point_id):
inference_response = {"error": True, "message": "endpoint is not activated."}
logging_inference_request(input_json, inference_response)
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
return inference_response

# Found idle inference device
idle_device, end_point_id, model_id, model_name, model_version, inference_host, inference_output_url = \
found_idle_inference_device(in_end_point_id, in_end_point_name, in_model_name, in_model_version)
if idle_device is None or idle_device == "":
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
return {"error": True, "error_code": status.HTTP_404_NOT_FOUND,
"message": "can not found active inference worker for this endpoint."}

Expand Down Expand Up @@ -252,18 +253,18 @@ async def _predict(
pass

logging_inference_request(input_json, inference_response)
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
return inference_response
else:
inference_response = {"error": True, "message": "token is not valid."}
logging_inference_request(input_json, inference_response)
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)
return inference_response

except Exception as e:
logging.error("Inference Exception: {}".format(traceback.format_exc()))
# Need to reduce the pending requests counter in whatever exception that may be raised.
FEDML_MODEL_CACHE.update_pending_requests_counter(decrease=True)
FEDML_MODEL_CACHE.update_pending_requests_counter(end_point_id, decrease=True)


def retrieve_info_by_endpoint_id(end_point_id, in_end_point_name=None, in_model_name=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class ServerConstants(object):
AUTO_DETECT_PUBLIC_IP = "auto_detect_public_ip"
MODEL_INFERENCE_DEFAULT_PORT = 2203
MODEL_CACHE_KEY_EXPIRE_TIME = 1 * 10

INFERENCE_REQUEST_TIMEOUT_KEY = "request_timeout_sec"
INFERENCE_REQUEST_TIMEOUT_DEFAULT = 30
# -----End-----

MODEL_DEPLOYMENT_STAGE1 = {"index": 1, "text": "ReceivedRequest"}
Expand Down

0 comments on commit e667ded

Please sign in to comment.