Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fuzzy matching when searching dbgpts #2110

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions dbgpt/serve/dbgpts/hub/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy import Column, DateTime, Index, Integer, String, UniqueConstraint, desc

from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util.pagination_utils import PaginationResult

from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
Expand Down Expand Up @@ -109,3 +110,42 @@ def to_response(self, entity: ServeEntity) -> ServerResponse:
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

def dbgpts_list(
self,
query_request: ServeRequest,
page: int,
page_size: int,
desc_order_column: Optional[str] = None,
) -> PaginationResult[ServerResponse]:
"""Get a page of dbgpts.

Args:
query_request (ServeRequest): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
desc_order_column(Optional[str]): The column for descending order.
Returns:
PaginationResult: The pagination result.
"""
session = self.get_raw_session()
try:
query = session.query(ServeEntity)
if query_request.name:
query = query.filter(ServeEntity.name.like(f"%{query_request.name}%"))
if desc_order_column:
query = query.order_by(desc(getattr(ServeEntity, desc_order_column)))
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
res_items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
finally:
session.close()

return PaginationResult(
items=res_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
4 changes: 2 additions & 2 deletions dbgpt/serve/dbgpts/hub/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def init_app(self, system_app: SystemApp) -> None:
self._system_app = system_app

@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
def dao(self) -> ServeDao:
"""Returns the internal DAO."""
return self._dao

Expand Down Expand Up @@ -130,7 +130,7 @@ def get_list_by_page(
installed=request.installed,
)

return self.dao.get_list_page(query_request, page, page_size)
return self.dao.dbgpts_list(query_request, page, page_size)

def refresh_hub_from_git(
self,
Expand Down
Loading