diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 5c0e2ef695d..26a8f399a5d 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -75,10 +75,12 @@ def __init__(self): connPool = ConnectionPool(infinity_uri) inf_conn = connPool.get_conn() res = inf_conn.show_current_node() - connPool.release_conn(inf_conn) - self.connPool = connPool if res.error_code == ErrorCode.OK and res.server_status=="started": + self._migrate_db(inf_conn) + self.connPool = connPool + connPool.release_conn(inf_conn) break + connPool.release_conn(inf_conn) logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.") time.sleep(5) except Exception as e: @@ -90,6 +92,45 @@ def __init__(self): raise Exception(msg) logger.info(f"Infinity {infinity_uri} is healthy.") + def _migrate_db(self, inf_conn): + inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) + fp_mapping = os.path.join( + get_project_base_directory(), "conf", "infinity_mapping.json" + ) + if not os.path.exists(fp_mapping): + raise Exception(f"Mapping file not found at {fp_mapping}") + schema = json.load(open(fp_mapping)) + table_names = inf_db.list_tables().table_names + for table_name in table_names: + inf_table = inf_db.get_table(table_name) + index_names = inf_table.list_indexes().index_names + if "q_vec_idx" not in index_names: + # Skip tables not created by me + continue + column_names = inf_table.show_columns()["name"] + column_names = set(column_names) + text_suffix = ["_tks", "_ltks", "_kwd"] + for field_name, field_info in schema.items(): + if field_name in column_names: + continue + res = inf_table.add_columns({field_name: field_info}) + assert res.error_code == infinity.ErrorCode.OK + logger.info( + f"INFINITY added following column to table {table_name}: {field_name} {field_info}" + ) + if field_info["type"] != "varchar": + continue + for suffix in text_suffix: + if field_name.endswith(suffix): + inf_table.create_index( + f"text_idx_{field_name}", + IndexInfo( + field_name, IndexType.FullText, {"ANALYZER": "standard"} + ), + ConflictType.Ignore, + ) + break + """ Database operations """