Skip to content

Commit

Permalink
Fix issues #289 and #280 to have year and cohorts applied to multivar…
Browse files Browse the repository at this point in the history
…iate analysis endpoint (#293)

* fixed issue #289

* make multivariate analysis endpoint to honor cohorts
  • Loading branch information
hyi authored Oct 19, 2023
1 parent 0faa160 commit 3eac09b
Showing 1 changed file with 40 additions and 19 deletions.
59 changes: 40 additions & 19 deletions icees_api/features/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,16 +536,7 @@ def count_unique(conn, table_name, year, *columns):
).fetchall()]


def select_feature_matrix(
conn,
table_name,
year,
cohort_features,
cohort_year,
feature_a,
feature_b,
):
"""Select feature matrix."""
def create_cohort_view(conn, table_name, cohort_features):
if not isinstance(cohort_features, list):
cohort_features = [
{
Expand Down Expand Up @@ -576,7 +567,28 @@ def select_feature_matrix(
)
conn.execute("DROP VIEW if exists tmp")
conn.execute(view_query)
table_name = "tmp"
# return the view table name 'tmp' for subsequent operations
return 'tmp'
else:
return table_name


def drop_cohort_view(conn, cohort_features):
if cohort_features:
conn.execute("DROP VIEW tmp")


def select_feature_matrix(
conn,
table_name,
year,
cohort_features,
cohort_year,
feature_a,
feature_b,
):
"""Select feature matrix."""
table_name = create_cohort_view(conn, table_name, cohort_features)

# start_time = time.time()
# cohort_features_norm = normalize_features(cohort_year, cohort_features)
Expand Down Expand Up @@ -735,8 +747,7 @@ def select_feature_matrix(

# association_json = json.dumps(association, sort_keys=True)

if cohort_features:
conn.execute("DROP VIEW tmp")
drop_cohort_view(conn, cohort_features)

# start_time = time.time()
# conn.execute(cache.insert().values(digest=digest, association=association_json, table=table_name, cohort_features=cohort_features_json, feature_a=feature_a_json, feature_b=feature_b_json, access_time=timestamp))
Expand Down Expand Up @@ -791,9 +802,13 @@ def select_feature_count_all_values(
return count


def get_feature_levels(feature):
def get_feature_levels(feature, year=None):
"""Get feature levels."""
return get_value_sets().get(feature, [])
feat_levs = get_value_sets().get(feature, [])
if year and feature == 'year' and int(year) in feat_levs:
# only include the pass-in year in the corresponding year feature level list
feat_levs = [int(year)]
return feat_levs


def apply_correction(ret, correction=None):
Expand Down Expand Up @@ -990,7 +1005,7 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
associations = []
# get feature_constraint list from the first feature variable
feat_constraint_list = []
levels0 = get_feature_levels(feature_variables[0])
levels0 = get_feature_levels(feature_variables[0], year=year)
for level in levels0:
non_op_idx = 0
for lev in level:
Expand All @@ -1015,7 +1030,7 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
"feature_name": feature_variables[index],
"feature_qualifiers": list(map(
lambda level: {"operator": "=", "value": level},
get_feature_levels(feature_variables[index]),
get_feature_levels(feature_variables[index], year=year),
))
}
]
Expand All @@ -1024,7 +1039,7 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
"feature_name": feature_variables[index + 1],
"feature_qualifiers": list(map(
lambda level: {"operator": "=", "value": level},
get_feature_levels(feature_variables[index + 1]),
get_feature_levels(feature_variables[index + 1], year=year),
))
}
]
Expand Down Expand Up @@ -1053,7 +1068,7 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
if index < feat_len:
feature_qualifiers = list(map(
lambda level: {"operator": "=", "value": level},
get_feature_levels(feature_variables[index])
get_feature_levels(feature_variables[index], year=year)
))
more_constraint_list = []
for feature_constraint in feat_constraint_list:
Expand All @@ -1066,6 +1081,10 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
more_constraint_list.append(base_dict)
feat_constraint_list = more_constraint_list
# compute frequency for each feature constraint
if cohort_features:
# compute frequency on the cohort view
table_name = create_cohort_view(conn, table_name, cohort_features)

if len(feat_constraint_list) > 0:
columns = list(feat_constraint_list[0].keys())
result = count_unique(conn, table_name, year, *columns)
Expand All @@ -1078,4 +1097,6 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
for fc in feat_constraint_list:
fc['frequency'] = get_count(result, **fc)

if cohort_features:
drop_cohort_view(conn, cohort_features)
return feat_constraint_list

0 comments on commit 3eac09b

Please sign in to comment.