Skip to content

Commit

Permalink
Merge branch 'master' into fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shifucun authored Feb 23, 2023
2 parents c2ce2fc + c1c2425 commit 2e51d11
Show file tree
Hide file tree
Showing 17 changed files with 158 additions and 58 deletions.
2 changes: 1 addition & 1 deletion deploy/overlays/climatetrace/custom_bigtable_info.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
project: datcom-mixer-encode
instance: dc-graph
tables:
-
- climatetrace_2023_02_22_18_50_22
1 change: 1 addition & 0 deletions deploy/overlays/climatetrace/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ configMapGenerator:
literals:
- flaskEnv=custom
- secretProject=datcom-mixer-encode
- enableModel=true
name: website-configmap
- behavior: create
literals:
Expand Down
4 changes: 3 additions & 1 deletion server/lib/nl/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,6 @@
}

QUERY_OK = 'ok'
QUERY_FAILED = 'failed'
QUERY_FAILED = 'failed'

TEST_SESSION_ID = '007_999999999'
7 changes: 4 additions & 3 deletions server/lib/nl/fulfiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
# Compute a new Utterance given the classifications for a user query
# and past utterances.
#
def fulfill(query_detection: Detection,
currentUtterance: Utterance) -> Utterance:
def fulfill(query_detection: Detection, currentUtterance: Utterance,
session_id: str) -> Utterance:

filtered_svs = filter_svs(query_detection.svs_detected.sv_dcids,
query_detection.svs_detected.sv_scores)
Expand All @@ -48,7 +48,8 @@ def fulfill(query_detection: Detection,
svs=filtered_svs,
chartCandidates=[],
rankedCharts=[],
answerPlaces=[])
answerPlaces=[],
session_id=session_id)
uttr.counters['filtered_svs'] = filtered_svs

# Add detected places.
Expand Down
23 changes: 22 additions & 1 deletion server/lib/nl/fulfillment/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import Dict, List

from server.lib.nl import constants
from server.lib.nl.detection import ClassificationAttributes
from server.lib.nl.detection import ClassificationType
from server.lib.nl.detection import Place
Expand Down Expand Up @@ -132,3 +133,23 @@ def classifications_of_type_from_utterance(
uttr: Utterance,
ctype: ClassificationType) -> List[ClassificationAttributes]:
return [cl for cl in uttr.classifications if cl.type == ctype]


# `context_history` contains utterances in a given session.
def get_session_info(context_history: List[Dict]) -> Dict:
session_info = {'items': []}
# The first entry in context_history is the most recent.
# Reverse the order for session_info.
for i in range(len(context_history)):
u = context_history[len(context_history) - 1 - i]
if 'id' not in session_info:
session_info['id'] = u['session_id']
if u['ranked_charts']:
s = constants.QUERY_OK
else:
s = constants.QUERY_FAILED
session_info['items'].append({
'query': u['query'],
'status': s,
})
return session_info
10 changes: 10 additions & 0 deletions server/lib/nl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import logging
import os
import random
import re
from typing import Dict, List, Set, Union

Expand Down Expand Up @@ -798,3 +799,12 @@ def has_map(place_type: any) -> bool:
if isinstance(place_type, str):
place_type = detection.ContainedInPlaceType(place_type)
return place_type in constants.MAP_PLACE_TYPES


def new_session_id() -> str:
# Convert seconds to microseconds
micros = int(datetime.datetime.now().timestamp() * 1000000)
# Add some randomness to avoid clashes
rand = random.randrange(1000)
# Prefix randomness since session_id gets used as BT key
return str(rand) + '_' + str(micros)
9 changes: 7 additions & 2 deletions server/lib/nl/utterance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
from typing import Dict, List

from server.lib.nl import constants
from server.lib.nl.detection import ClassificationType
from server.lib.nl.detection import ContainedInClassificationAttributes
from server.lib.nl.detection import ContainedInPlaceType
Expand All @@ -37,7 +38,7 @@
from server.lib.nl.detection import TimeDeltaType

# How far back does the context go back.
CTX_LOOKBACK_LIMIT = 8
CTX_LOOKBACK_LIMIT = 15


# Forward declaration since Utterance contains a pointer to itself.
Expand Down Expand Up @@ -123,6 +124,8 @@ class Utterance:
answerPlaces: List[str]
# Linked list of past utterances
prev_utterance: Utterance
# A unique ID to identify sessions
session_id: str
# Debug counters that are cleared out before serializing.
# Some of these might be promoted to the main Debug Info display,
# but everything else will appear in the raw output.
Expand Down Expand Up @@ -246,6 +249,7 @@ def save_utterance(uttr: Utterance) -> List[Dict]:
udict['places'] = _place_to_dict(u.places)
udict['classifications'] = _classification_to_dict(u.classifications)
udict['ranked_charts'] = _chart_spec_to_dict(u.rankedCharts)
udict['session_id'] = u.session_id
uttr_dicts.append(udict)
u = u.prev_utterance
cnt += 1
Expand Down Expand Up @@ -273,6 +277,7 @@ def load_utterance(uttr_dicts: List[Dict]) -> Utterance:
rankedCharts=_dict_to_chart_spec(udict['ranked_charts']),
detection=None,
chartCandidates=None,
answerPlaces=None)
answerPlaces=None,
session_id=udict['session_id'])
prev_uttr = uttr
return uttr
16 changes: 12 additions & 4 deletions server/routes/nl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from server.lib.nl.detection import SimpleClassificationAttributes
from server.lib.nl.detection import SVDetection
import server.lib.nl.fulfiller as fulfillment
import server.lib.nl.fulfillment.context as context
import server.lib.nl.page_config_builder as nl_page_config
import server.lib.nl.utils as utils
import server.lib.nl.utterance as nl_utterance
Expand Down Expand Up @@ -421,7 +422,15 @@ def data():

# Generate new utterance.
prev_utterance = nl_utterance.load_utterance(context_history)
utterance = fulfillment.fulfill(query_detection, prev_utterance)
if prev_utterance:
session_id = prev_utterance.session_id
else:
if current_app.config['LOG_QUERY']:
session_id = utils.new_session_id()
else:
session_id = constants.TEST_SESSION_ID

utterance = fulfillment.fulfill(query_detection, prev_utterance, session_id)

if utterance.rankedCharts:
page_config_pb = nl_page_config.build_page_config(utterance,
Expand Down Expand Up @@ -451,9 +460,7 @@ def data():
status_str = "Successful"
if utterance.rankedCharts:
status_str = ""
status = constants.QUERY_OK
else:
status = constants.QUERY_FAILED
if not utterance.places:
status_str += '**No Place Found**.'
if not utterance.svs:
Expand All @@ -462,7 +469,8 @@ def data():
if current_app.config['LOG_QUERY']:
# Asynchronously log as bigtable write takes O(100ms)
loop = asyncio.new_event_loop()
loop.run_until_complete(bt.write_row(original_query, status))
session_info = context.get_session_info(context_history)
loop.run_until_complete(bt.write_row(session_info))

data_dict = _result_with_debug_info(data_dict, status_str, query_detection,
context_history, dbg_counters)
Expand Down
35 changes: 18 additions & 17 deletions server/services/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from datetime import datetime
from datetime import timedelta
import json

from flask import current_app
import google.auth
Expand All @@ -28,8 +29,7 @@
_COLUMN_FAMILY = 'all'

_COL_PROJECT = 'project'
_COL_QUERY = 'query'
_COL_STATUS = 'status'
_COL_SESSION = 'session_info'

_SPAN_IN_DAYS = 3

Expand All @@ -46,15 +46,16 @@ def get_project_id():
return project_id


async def write_row(query, status):
async def write_row(session_info):
if not session_info.get('id', None):
return
project_id = get_project_id()
ts = datetime.utcnow()
# use length of query as prefix to avoid Bigtable hotspot nodes.
row_key = '{}#{}#{}'.format(len(query), project_id, ts.timestamp()).encode()
# The session_id starts with a rand to avoid hotspots.
row_key = '{}#{}'.format(session_info['id'], project_id).encode()
row = table.direct_row(row_key)
row.set_cell(_COLUMN_FAMILY, _COL_PROJECT.encode(), project_id, timestamp=ts)
row.set_cell(_COLUMN_FAMILY, _COL_QUERY.encode(), query, timestamp=ts)
row.set_cell(_COLUMN_FAMILY, _COL_STATUS.encode(), status, timestamp=ts)
# Rely on timestamp in BT server
row.set_cell(_COLUMN_FAMILY, _COL_PROJECT.encode(), project_id)
row.set_cell(_COLUMN_FAMILY, _COL_SESSION.encode(), json.dumps(session_info))
table.mutate_rows([row])


Expand All @@ -67,25 +68,25 @@ def read_success_rows():
result = []
for row in rows:
project = ''
query = ''
status = ''
session_info = {}
timestamp = 0
for _, cols in row.cells.items():
for col, cells in cols.items():
if col.decode('utf-8') == _COL_PROJECT:
project = cells[0].value.decode('utf-8')
if col.decode('utf-8') == _COL_STATUS:
status = cells[0].value.decode('utf-8')
elif col.decode('utf-8') == _COL_QUERY:
query = cells[0].value.decode('utf-8')
if col.decode('utf-8') == _COL_SESSION:
session_info = json.loads(cells[0].value.decode('utf-8'))
timestamp = cells[0].timestamp.timestamp()
if project != project_id:
continue
if status == nl_constants.QUERY_FAILED:
if not session_info or not session_info.get('items', []):
continue
if session_info['items'][0]['status'] == nl_constants.QUERY_FAILED:
continue
query_list = [it['query'] for it in session_info['items'] if 'query' in it]
result.append({
'project': project,
'query': query,
'query_list': query_list,
'timestamp': timestamp,
})
result.sort(key=lambda x: x['timestamp'], reverse=True)
Expand Down
6 changes: 4 additions & 2 deletions server/tests/lib/nl/fulfiller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest
from unittest.mock import patch

from server.lib.nl import constants
from server.lib.nl import fulfiller
from server.lib.nl import utils
from server.lib.nl import utterance
Expand Down Expand Up @@ -424,7 +425,7 @@ def test_counters_simple(self, mock_sv_existence, mock_single_datapoint,
mock_sv_existence.side_effect = [['Count_Person_Male'],
['Count_Person_Female']]

got = fulfiller.fulfill(detection, None).counters
got = fulfiller.fulfill(detection, None, constants.TEST_SESSION_ID).counters

self.maxDiff = None
_COUNTERS = {
Expand Down Expand Up @@ -519,4 +520,5 @@ def _run(detection: Detection, uttr_dict: List[Dict]):
prev_uttr = None
if uttr_dict:
prev_uttr = utterance.load_utterance(uttr_dict)
return utterance.save_utterance(fulfiller.fulfill(detection, prev_uttr))[0]
return utterance.save_utterance(
fulfiller.fulfill(detection, prev_uttr, constants.TEST_SESSION_ID))[0]
Loading

0 comments on commit 2e51d11

Please sign in to comment.