From dfbf87b6ecaf91db9f239de1341a3d867ff490de Mon Sep 17 00:00:00 2001 From: kingzeus Date: Sat, 23 Nov 2024 16:02:26 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/apis/account.py | 28 +- backend/apis/fund.py | 9 +- backend/apis/portfolio.py | 35 +- data_source/implementations/eastmoney.py | 10 +- data_source/interface.py | 8 +- data_source/proxy.py | 14 +- models/account.py | 40 ++- models/database.py | 395 ++++++++--------------- models/fund.py | 6 +- pages/account/__init__.py | 2 +- pages/account/account_modal.py | 12 +- pages/account/delete_modal.py | 5 +- pages/account/portfolio_modal.py | 33 +- pages/account/table.py | 43 +-- pages/home/__init__.py | 2 +- pages/task/__init__.py | 4 +- pages/task/{page.py => task_page.py} | 2 +- pages/task/{table.py => task_table.py} | 0 pages/transaction/__init__.py | 2 +- pages/transaction/delete_modal.py | 5 +- pages/transaction/modal.py | 74 ++--- pages/transaction/table.py | 9 +- pages/transaction/utils.py | 15 +- pyproject.toml | 1 + scheduler/__init__.py | 11 +- scheduler/tasks/__init__.py | 14 +- scheduler/tasks/data_sync.py | 28 +- scheduler/tasks/fund_detail.py | 10 +- scheduler/tasks/fund_info.py | 28 +- scheduler/tasks/fund_nav.py | 27 +- start.sh | 24 +- utils/datetime_helper.py | 41 ++- 32 files changed, 450 insertions(+), 487 deletions(-) rename pages/task/{page.py => task_page.py} (98%) rename pages/task/{table.py => task_table.py} (100%) diff --git a/backend/apis/account.py b/backend/apis/account.py index 958b275..ed762c3 100644 --- a/backend/apis/account.py +++ b/backend/apis/account.py @@ -1,12 +1,13 @@ +import uuid from flask_restx import Namespace, Resource, fields from backend.apis.common import create_list_response_model, create_response_model +from models.account import ModelAccount, update_account from models.database import ( - add_account, - delete_account, - get_account, - get_accounts, - update_account, + delete_record, + get_record, + get_record_list, + update_record, ) from utils.response import format_response @@ -44,7 +45,7 @@ class AccountList(Resource): @api.marshal_with(account_list_response) def get(self): """获取所有账户列表""" - return format_response(data=get_accounts()) + return format_response(data=get_record_list(ModelAccount)) @api.doc("创建新账户") @api.expect(account_input) @@ -52,8 +53,10 @@ def get(self): def post(self): """创建新账户""" data = api.payload - account_id = add_account(data["name"], data.get("description")) - return format_response(data=get_account(account_id), message="账户创建成功") + account_id = update_account(str(uuid.uuid4()), data) + return format_response( + data=get_record(ModelAccount, {"id": account_id}), message="账户创建成功" + ) @api.route("/") @@ -63,7 +66,7 @@ class Account(Resource): @api.marshal_with(account_response) def get(self, account_id): """获取指定账户的详情""" - account = get_account(account_id) + account = get_record(ModelAccount, {"id": account_id}) if not account: return format_response(message="账户不存在", code=404) return format_response(data=account) @@ -74,7 +77,7 @@ def get(self, account_id): def put(self, account_id): """更新账户信息""" data = api.payload - account = update_account(account_id, data) + account = update_record(ModelAccount, {"id": account_id}, data) if not account: return format_response(message="账户不存在", code=404) return format_response(data=account, message="账户更新成功") @@ -83,5 +86,6 @@ def put(self, account_id): @api.marshal_with(account_response) def delete(self, account_id): """删除账户""" - delete_account(account_id) - return format_response(message="账户删除成功") + if delete_record(ModelAccount, {"id": account_id}): + return format_response(message="账户删除成功") + return format_response(message="账户不存在", code=404) diff --git a/backend/apis/fund.py b/backend/apis/fund.py index f20b241..f3aab34 100644 --- a/backend/apis/fund.py +++ b/backend/apis/fund.py @@ -3,11 +3,12 @@ from backend.apis.common import create_list_response_model, create_response_model from models.database import ( add_fund_position, - delete_fund_position, + delete_record, get_fund_positions, get_fund_transactions, - update_fund_position, + update_record, ) +from models.fund import ModelFundPosition from utils.response import format_response api = Namespace("funds", description="基金相关操作") @@ -94,7 +95,7 @@ class FundPosition(Resource): def put(self, position_id): """更新持仓信息""" data = api.payload - position = update_fund_position(position_id, data) + position = update_record(ModelFundPosition, {"id": position_id}, data) if not position: return format_response(message="持仓不存在", code=404) return format_response(data=position, message="持仓更新成功") @@ -103,7 +104,7 @@ def put(self, position_id): @api.marshal_with(position_response) def delete(self, position_id): """删除持仓""" - if delete_fund_position(position_id): + if delete_record(ModelFundPosition, {"id": position_id}): return format_response(message="持仓删除成功") return format_response(message="持仓不存在", code=404) diff --git a/backend/apis/portfolio.py b/backend/apis/portfolio.py index 7a2686e..3fc5116 100644 --- a/backend/apis/portfolio.py +++ b/backend/apis/portfolio.py @@ -1,12 +1,13 @@ +import uuid from flask_restx import Namespace, Resource, fields from backend.apis.common import create_list_response_model, create_response_model +from models.account import ModelPortfolio from models.database import ( - add_portfolio, - delete_portfolio, - get_portfolio, - get_portfolios, - update_portfolio, + delete_record, + get_record, + get_record_list, + update_record, ) from utils.response import format_response @@ -54,7 +55,7 @@ def get(self): account_id = api.payload.get("account_id") if not account_id: return format_response(message="缺少账户ID", code=400) - return format_response(data=get_portfolios(account_id)) + return format_response(data=get_record_list(ModelPortfolio, {"account_id": account_id})) @api.doc("创建新投资组合") @api.expect(portfolio_input) @@ -62,13 +63,12 @@ def get(self): def post(self): """创建新投资组合""" data = api.payload - portfolio_id = add_portfolio( - data["account_id"], - data["name"], - data.get("description"), - data.get("is_default", False), + portfolio_id = str(uuid.uuid4()) + update_record(ModelPortfolio, {"id": portfolio_id}, data) + + return format_response( + data=get_record(ModelPortfolio, {"id": portfolio_id}), message="组合创建成功" ) - return format_response(data=get_portfolio(portfolio_id), message="组合创建成功") @api.route("/") @@ -78,7 +78,7 @@ class Portfolio(Resource): @api.marshal_with(portfolio_response) def get(self, portfolio_id): """获取指定组合的详情""" - portfolio = get_portfolio(portfolio_id) + portfolio = get_record(ModelPortfolio, {"id": portfolio_id}) if not portfolio: return format_response(message="组合不存在", code=404) return format_response(data=portfolio) @@ -89,14 +89,15 @@ def get(self, portfolio_id): def put(self, portfolio_id): """更新组合信息""" data = api.payload - portfolio = update_portfolio(portfolio_id, data) + portfolio = update_record(ModelPortfolio, {"id": portfolio_id}, data) if not portfolio: - return format_response(message="组合不存在", code=404) + return format_response(message="更新失败", code=404) return format_response(data=portfolio, message="组合更新成功") @api.doc("删除组合") @api.marshal_with(portfolio_response) def delete(self, portfolio_id): """删除组合""" - delete_portfolio(portfolio_id) - return format_response(message="组合删除成功") + if delete_record(ModelPortfolio, {"id": portfolio_id}): + return format_response(message="组合删除成功") + return format_response(message="组合不存在", code=404) diff --git a/data_source/implementations/eastmoney.py b/data_source/implementations/eastmoney.py index 42721e5..3804e8b 100644 --- a/data_source/implementations/eastmoney.py +++ b/data_source/implementations/eastmoney.py @@ -220,8 +220,8 @@ def get_fund_nav_history_size(self) -> int: def get_fund_nav_history( self, fund_code: str, - start_date: datetime, - end_date: datetime, + start_date: str, + end_date: str, ) -> List[Dict[str, Any]]: """获取基金历史净值""" try: @@ -230,8 +230,8 @@ def get_fund_nav_history( "code": fund_code, "type": "lsjz", # 历史净值 "page": 1, # 页码 - "sdate": start_date.strftime("%Y-%m-%d"), - "edate": end_date.strftime("%Y-%m-%d"), + "sdate": start_date, + "edate": end_date, "per": 20, # 默认获取最近20条记录 } logger.debug("请求基金历史净值: %s, params: %s", url, params) @@ -293,6 +293,8 @@ def get_fund_nav_history( "daily_return": daily_return, "subscription_status": subscription_status, "redemption_status": redemption_status, + "data_source": self.get_name(), + "data_source_version": self.get_version(), } if dividend: item["dividend"] = dividend diff --git a/data_source/interface.py b/data_source/interface.py index fe49c42..34c7dbd 100644 --- a/data_source/interface.py +++ b/data_source/interface.py @@ -64,15 +64,15 @@ def get_fund_nav_history_size(self) -> int: def get_fund_nav_history( self, fund_code: str, - start_date: datetime, - end_date: datetime, + start_date: str, + end_date: str, ) -> List[Dict[str, Any]]: """ 获取基金历史净值 Args: fund_code: 基金代码 - start_date: 开始日期 - end_date: 结束日期 + start_date: 开始日期,格式: 2004-01-01 + end_date: 结束日期,格式: 2004-01-01 Returns: List[Dict]: [{ "date": "日期", diff --git a/data_source/proxy.py b/data_source/proxy.py index 2177597..8b03ee9 100644 --- a/data_source/proxy.py +++ b/data_source/proxy.py @@ -97,8 +97,8 @@ def get_fund_detail(self, fund_code: str) -> Dict[str, Any]: def get_fund_nav_history( self, fund_code: str, - start_date: datetime, - end_date: datetime, + start_date: str, + end_date: str, ) -> Dict[str, Any]: """获取基金历史净值""" return self._call_api( @@ -110,3 +110,13 @@ def get_fund_nav_history( end_date=end_date, is_array=True, ) + + def get_fund_nav_history_size( + self, + ) -> Dict[str, Any]: + """获取基金历史净值""" + return self._call_api( + func_name="get_fund_nav_history_size", + api_func=self._data_source.get_fund_nav_history_size, + error_msg="获取基金历史净值失败", + ) diff --git a/models/account.py b/models/account.py index 32255e1..d79e2de 100644 --- a/models/account.py +++ b/models/account.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Optional +import uuid from peewee import BooleanField, CharField, ForeignKeyField from models.base import BaseModel @@ -30,7 +32,7 @@ class ModelPortfolio(BaseModel): """投资组合模型""" id = CharField(primary_key=True) - account = ForeignKeyField(ModelAccount, backref="portfolio") + account = ForeignKeyField(ModelAccount, backref="portfolios") name = CharField(null=False) description = CharField(null=True) is_default = BooleanField(default=False) @@ -51,3 +53,39 @@ def to_dict(self) -> dict: } ) return result + + +def update_account(account_id: Optional[str], data: Dict[str, Any]) -> bool: + """更新账户信息""" + from models.database import update_record + + if not account_id: + account_id = str(uuid.uuid4()) + + def on_created(result): + # 创建默认投资组合 + ModelPortfolio.create( + id=str(uuid.uuid4()), + account=account_id, + name=f"{data['name']}-默认组合", + description=f"{data['name']}的默认投资组合", + is_default=True, + ) + print(f"账户创建完成: {result.to_dict()}") + + return update_record(ModelAccount, {"id": account_id}, data, on_created) + + +def delete_account(account_id: str) -> bool: + """更新账户信息""" + from models.database import delete_record + + def on_before(result): + # 删除默认投资组合 + portfolio_count = ( + ModelPortfolio.select().where(ModelPortfolio.account == account_id).count() + ) + if portfolio_count > 0: + raise ValueError("账户下存在投资组合,无法删除") + + return delete_record(ModelAccount, {"id": account_id}, on_before) diff --git a/models/database.py b/models/database.py index e7c8427..4b4ccd2 100644 --- a/models/database.py +++ b/models/database.py @@ -1,15 +1,17 @@ +from collections.abc import Callable import logging import uuid from datetime import datetime from typing import Any, Dict, List, Optional from peewee import JOIN, fn +from playhouse.shortcuts import update_model_from_dict + -from scheduler.job_manager import JobManager from utils.datetime_helper import format_datetime from .account import ModelAccount, ModelPortfolio -from .base import Database, db_connection +from .base import BaseModel, Database, db_connection from .fund import ModelFund, ModelFundNav, ModelFundPosition, ModelFundTransaction from .task import ModelTask @@ -32,202 +34,6 @@ def init_database(): ) -# 账户相关操作 -def get_accounts() -> List[Dict[str, Any]]: - """获取所有账户""" - with db_connection(): - accounts = ModelAccount.select() - # 如果没有数据,返回空列表 - if not accounts: - return [] - - return [ - { - "id": str(account.id), - "name": account.name, - "description": account.description, - "create_time": account.created_at, - "update_time": account.updated_at, - } - for account in accounts - ] - - -def add_account(name: str, description: Optional[str] = None) -> str: - """添加账户并创建默认投资组合""" - with db_connection(): - # 创建账户 - account_id = str(uuid.uuid4()) - ModelAccount.create( - id=account_id, - name=name, - description=description, - ) - - # 创建默认投资组合 - ModelPortfolio.create( - id=str(uuid.uuid4()), - account=account_id, - name=f"{name}-默认组合", - description=f"{name}的默认投资组合", - is_default=True, - ) - - return account_id - - -def get_account(account_id: str) -> Optional[Dict[str, Any]]: - """获取账户详情""" - with db_connection(): - try: - account = ModelAccount.get_by_id(account_id) - return { - "id": str(account.id), - "name": account.name, - "description": account.description, - "create_time": account.created_at, - "update_time": account.updated_at, - } - except ModelAccount.DoesNotExist: # pylint: disable=E1101 - return None - - -def update_account(account_id: str, name: str, description: str = None) -> bool: - """更新账户信息""" - try: - with db_connection(): - account = ModelAccount.get_by_id(account_id) - account.name = name - account.description = description - account.save() - return True - except Exception as e: - logger.error("更新账户失败: %s", str(e)) - return False - - -def delete_account(account_id: str) -> bool: - """删除账户""" - try: - with db_connection(): - # 首先检查是否有关联的组合 - portfolio_count = ( - ModelPortfolio.select().where(ModelPortfolio.account == account_id).count() - ) - if portfolio_count > 0: - return False - - account = ModelAccount.get_by_id(account_id) - account.delete_instance() - return True - except Exception as e: - logger.error("删除账户失败: %s", str(e)) - return False - - -# 投资组合相关操作 -def get_portfolios(account_id: str) -> List[Dict[str, Any]]: - """获取账户下的所有投资组合""" - with db_connection(): - portfolios = ( - ModelPortfolio.select( - ModelPortfolio, - fn.COUNT(ModelFundPosition.id).alias("fund_count"), - fn.COALESCE(fn.SUM(ModelFundPosition.market_value), 0).alias("total_market_value"), - ) - .join(ModelFundPosition, JOIN.LEFT_OUTER) - .where(ModelPortfolio.account == account_id) - .group_by(ModelPortfolio) - .order_by(ModelPortfolio.created_at.asc()) - ) - - if not portfolios: - return [] - - return [ - { - "id": str(p.id), - "name": p.name, - "description": p.description, - "is_default": p.is_default, - "create_time": p.created_at.isoformat(), - "update_time": p.updated_at.isoformat(), - "fund_count": p.fund_count, - "total_market_value": float(p.total_market_value), - } - for p in portfolios - ] - - -def add_portfolio( - account_id: str, - name: str, - description: Optional[str] = None, - is_default: bool = False, -) -> str: - """添加投资组合""" - with db_connection(): - portfolio_id = str(uuid.uuid4()) - ModelPortfolio.create( - id=portfolio_id, - account=account_id, - name=name, - description=description, - is_default=is_default, - ) - return portfolio_id - - -def get_portfolio(portfolio_id: str) -> Optional[Dict[str, Any]]: - """获取投资组合详情""" - with db_connection(): - try: - portfolio = ModelPortfolio.get_by_id(portfolio_id) - return { - "id": str(portfolio.id), - "account_id": str(portfolio.account.id), - "name": portfolio.name, - "description": portfolio.description, - "is_default": portfolio.is_default, - "create_time": portfolio.created_at, - "update_time": portfolio.updated_at, - } - except ModelPortfolio.DoesNotExist: # pylint: disable=E1101 - return None - - -def update_portfolio(portfolio_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """更新投资组合信息""" - with db_connection(): - try: - portfolio = ModelPortfolio.get_by_id(portfolio_id) - portfolio.name = data.get("name", portfolio.name) - portfolio.description = data.get("description", portfolio.description) - portfolio.is_default = data.get("is_default", portfolio.is_default) - portfolio.save() - return get_portfolio(portfolio_id) - except ModelPortfolio.DoesNotExist: # pylint: disable=E1101 - return None - - -def delete_portfolio(portfolio_id: str) -> bool: - """删除投资组合""" - try: - with db_connection(): - portfolio = ModelPortfolio.get_by_id(portfolio_id) - - # 检查是否为默认组合 - if portfolio.is_default: - return False - - # 删除组合及其关联的基金持仓 - portfolio.delete_instance(recursive=True) - return True - except Exception as e: - logger.error("删除组合失败: %s", str(e)) - return False - - # 基金持仓相关操作 def get_fund_positions(portfolio_id: str) -> List[Dict[str, Any]]: """获取组合的基金持仓""" @@ -278,38 +84,6 @@ def add_fund_position(data: Dict[str, Any]) -> str: return position_id -def update_fund_position(position_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """更新基金持仓信息""" - with db_connection(): - try: - position = ModelFundPosition.get_by_id(position_id) - for key, value in data.items(): - if hasattr(position, key): - setattr(position, key, value) - - # 更新市值和收益率 - position.market_value = float(position.shares) * float(position.nav) - position.return_rate = (position.market_value - float(position.cost)) / float( - position.cost - ) - position.save() - - return get_fund_positions(str(position.portfolio.id)) - except ModelFundPosition.DoesNotExist: # pylint: disable=E1101 - return None - - -def delete_fund_position(position_id: str) -> bool: - """删除基金持仓""" - with db_connection(): - try: - position = ModelFundPosition.get_by_id(position_id) - position.delete_instance() - return True - except ModelFundPosition.DoesNotExist: # pylint: disable=E1101 - return False - - def get_fund_transactions(portfolio_id: str) -> List[Dict[str, Any]]: """获取基金交易记录""" with db_connection(): @@ -435,6 +209,8 @@ def add_transaction( fund = ModelFund.get_or_none(ModelFund.code == fund_code) if not fund: # 如果基金不存在,触发更新基金详情任务 + from scheduler.job_manager import JobManager + task_id = JobManager().add_task( "fund_info", fund_code=fund_code, @@ -520,27 +296,6 @@ def update_transaction( return False -def delete_transaction(transaction_id: str) -> bool: - """删除交易记录""" - try: - with db_connection(): - # 获取交易记录 - transaction = ModelFundTransaction.get_by_id(transaction_id) - portfolio_id = transaction.portfolio.id - fund_code = transaction.code - - # 删除交易记录 - transaction.delete_instance() - - # 重新计算持仓 - recalculate_position(portfolio_id, fund_code) - - return True - except Exception as e: - logger.error("删除交易记录失败: %s", str(e)) - return False - - def update_position_after_transaction( portfolio_id: str, fund: ModelFund, @@ -693,30 +448,138 @@ def check_database_content(): ) -def get_fund_nav(fund_code: str, nav_date: datetime) -> Optional[ModelFundNav]: - """获取指定日期的基金净值 +######################################## +# 通用函数 +######################################## +def get_record(model_class, search_fields: Dict[str, Any]) -> Optional[BaseModel]: + """通用的获取记录函数 Args: - fund_code: 基金代码 - nav_date: 净值日期 + model_class: Peewee 模型类 + search_fields: 用于查找记录的字段和值的字典 + + Example: + get_record( + ModelFundNav, + {"fund_code": "000001", "nav_date": "2024-01-01"} + ) + """ + try: + with db_connection(): + return model_class.get_or_none(**search_fields) + except Exception as e: + logger.error("获取记录失败 - 模型: %s, 错误: %s", model_class.__name__, str(e)) + return None + + +def get_record_list(model_class, search_fields: Optional[Dict[str, Any]] = None) -> List[BaseModel]: + """通用的获取记录列表函数 + + Args: + model_class: Peewee 模型类 + search_fields: 用于过滤记录的字段和值的字典,为空时返回所有记录 Returns: - ModelFundNav | None: 基金净值记录,如果不存在则返回 None + List[BaseModel]: 记录列表 + + Example: + get_record_list( + ModelFundNav, + {"fund_code": "000001"} + ) """ try: with db_connection(): - nav_record = ( - ModelFundNav.select() - .where( - (ModelFundNav.fund == fund_code) & (ModelFundNav.nav_date == nav_date.date()) + query = model_class.select() + if search_fields: + query = query.where( + *[ + getattr(model_class, field) == value + for field, value in search_fields.items() + ] + ) + return list(query) + except Exception as e: + logger.error("获取记录列表失败 - 模型: %s, 错误: %s", model_class.__name__, str(e)) + return [] + + +def delete_record( + model_class, + search_fields: Dict[str, Any], + callback_before: Optional[Callable[[BaseModel], None]] = None, +) -> bool: + """通用的删除记录函数 + + Args: + model_class: Peewee 模型类 + search_fields: 用于查找记录的字段和值的字典 + callback_before: 删除记录前的回调函数 + Returns: + bool: 删除是否成功 + + Example: + delete_record( + ModelFundNav, + {"fund_code": "000001", "nav_date": "2024-01-01"} + ) + """ + try: + with db_connection(): + record = model_class.get_or_none(**search_fields) + if record: + if callback_before: + callback_before(record) + record.delete_instance() + logger.info( + "成功删除记录 - 模型: %s, 条件: %s", model_class.__name__, str(search_fields) ) - .first() + return True + logger.warning( + "未找到要删除的记录 - 模型: %s, 条件: %s", model_class.__name__, str(search_fields) ) + return False + except Exception as e: + logger.error( + "删除记录失败 - 模型: %s, 条件: %s, 错误: %s", + model_class.__name__, + str(search_fields), + str(e), + ) + return False + - if nav_record: - return nav_record - return None +def update_record( + model_class, + search_fields: Dict[str, Any], + update_data: Dict[str, Any], + callback_created: Optional[Callable[[BaseModel], None]] = None, +) -> bool: + """通用的更新或创建记录函数 + Args: + model_class: Peewee 模型类 + search_fields: 用于查找记录的字段和值的字典 + update_data: 需要更新的数据字典 + callback_created: 创建记录后的回调函数 + Example: + _update_record( + ModelFundNav, + {"fund_code": "000001", "nav_date": "2024-01-01"}, + {"nav": 1.234, "acc_nav": 2.345} + ) + """ + try: + with db_connection(): + existing_record, created = model_class.get_or_create( + **search_fields, defaults=update_data + ) + if not created: + update_model_from_dict(existing_record, update_data) + existing_record.save() + if created and callback_created: + callback_created(existing_record) + return True except Exception as e: - logger.error("获取基金净值失败: %s", str(e)) - return None + logger.error("更新记录失败 - 模型: %s, 错误: %s", model_class.__name__, str(e)) + return False diff --git a/models/fund.py b/models/fund.py index 598a9a4..15df6b0 100644 --- a/models/fund.py +++ b/models/fund.py @@ -216,7 +216,7 @@ class ModelFundNav(BaseModel): # 累计净值 acc_nav = DecimalField(max_digits=10, decimal_places=4) # 日收益率 - daily_return = DecimalField(max_digits=10, decimal_places=4) + daily_return = DecimalField(max_digits=10, decimal_places=4, auto_round=True) # 申购状态 subscription_status = CharField(max_length=20) # 赎回状态 @@ -230,7 +230,7 @@ class ModelFundNav(BaseModel): class Meta: table_name = "fund_nav_history" - primary_key = CompositeKey("fund_code", "nav_date") + primary_key = CompositeKey("fund", "nav_date") def to_dict(self) -> dict: """将基金净值历史实例转换为可JSON序列化的字典""" @@ -238,7 +238,7 @@ def to_dict(self) -> dict: result.update( { "fund_code": self.fund_code, - "nav_date": self.nav_date.isoformat() if self.nav_date else None, + "nav_date": self.nav_date.strftime("%Y-%m-%d") if self.nav_date else None, "nav": float(self.nav) if self.nav else None, "acc_nav": float(self.acc_nav) if self.acc_nav else None, "daily_return": float(self.daily_return) if self.daily_return else None, diff --git a/pages/account/__init__.py b/pages/account/__init__.py index 04dc78d..d6d67de 100644 --- a/pages/account/__init__.py +++ b/pages/account/__init__.py @@ -6,7 +6,7 @@ - 嵌套表格展示账户和组合的层级关系 文件结构: -- page.py: 页面主渲染函数 +- task_page.py: 页面主渲染函数 - table.py: 账户表格相关组件和回调 - account_modal.py: 账户编辑弹窗相关 - portfolio_modal.py: 组合编辑弹窗相关 diff --git a/pages/account/account_modal.py b/pages/account/account_modal.py index 7384529..3992822 100644 --- a/pages/account/account_modal.py +++ b/pages/account/account_modal.py @@ -20,7 +20,7 @@ from dash import Input, Output, State, callback # Local imports -from models.database import add_account, update_account +from models.account import update_account from pages.account.table import get_account_table_data from pages.account.utils import validate_name @@ -195,9 +195,11 @@ def handle_account_create_or_edit( - 账户描述输入框值 """ if ok_counts and name and validate_status == "success": - if editing_id: - update_account(editing_id, name, description) - else: - add_account(name, description) + + update_account( + editing_id, + {"name": name, "description": description}, + ) + return get_account_table_data(), False, "", "" return get_account_table_data(), dash.no_update, dash.no_update, dash.no_update diff --git a/pages/account/delete_modal.py b/pages/account/delete_modal.py index c2e11a7..f63c0d2 100644 --- a/pages/account/delete_modal.py +++ b/pages/account/delete_modal.py @@ -3,7 +3,8 @@ from dash import Input, Output, State, callback from dash.exceptions import PreventUpdate -from models.database import delete_account, delete_portfolio +from models.account import ModelPortfolio, delete_account +from models.database import delete_record from pages.account.table import get_account_table_data @@ -60,7 +61,7 @@ def handle_delete_confirm(ok_counts, object_id, custom_info): success = False if custom_info and custom_info.get("type") == "portfolio": - success = delete_portfolio(object_id) + success = delete_record(ModelPortfolio, {"id": object_id}) else: success = delete_account(object_id) diff --git a/pages/account/portfolio_modal.py b/pages/account/portfolio_modal.py index 8d2a2ac..474b712 100644 --- a/pages/account/portfolio_modal.py +++ b/pages/account/portfolio_modal.py @@ -1,10 +1,12 @@ from typing import Optional, Tuple +import uuid import feffery_antd_components as fac from dash import Input, Output, State, callback, dcc from dash.exceptions import PreventUpdate -from models.database import add_portfolio, update_portfolio +from models.account import ModelPortfolio +from models.database import update_record from pages.account.table import get_account_table_data from pages.account.utils import validate_name @@ -143,22 +145,17 @@ def handle_portfolio_create_or_edit( if not ok_counts or not name: raise PreventUpdate - if is_edit_mode and editing_id: - update_portfolio( - editing_id, - { - "name": name, - "description": description, - }, - ) - else: - if not account_id: - raise PreventUpdate - add_portfolio( - account_id=account_id, - name=name, - description=description, - is_default=False, - ) + update_record( + ModelPortfolio, + { + "id": str(uuid.uuid4()) if not editing_id else editing_id, + "account_id": account_id, + }, + { + "name": name, + "description": description, + "is_default": False, + }, + ) return get_account_table_data(), False, None, "", "" diff --git a/pages/account/table.py b/pages/account/table.py index 98f7775..5f59cbc 100644 --- a/pages/account/table.py +++ b/pages/account/table.py @@ -5,7 +5,8 @@ from dash import Input, Output, State, callback from dash.exceptions import PreventUpdate -from models.database import get_accounts, get_portfolio, get_portfolios +from models.account import ModelAccount, ModelPortfolio +from models.database import get_record, get_record_list from utils.datetime_helper import format_datetime from .utils import create_operation_buttons @@ -27,43 +28,41 @@ def get_account_table_data() -> List[Dict[str, Any]]: - 包含嵌套的组合数据 - 包含操作按钮配置 """ - accounts = get_accounts() + accounts = get_record_list(ModelAccount) table_data = [] for account in accounts: - portfolios = get_portfolios(account["id"]) + portfolios = account.portfolios portfolio_data = [] for p in portfolios: operation_buttons = [] - if not p["is_default"]: + if not p.is_default: operation_buttons = create_operation_buttons( - p["id"], "portfolio", account["id"], is_danger=True + p.id, "portfolio", account.id, is_danger=True ) portfolio_data.append( { - "key": p["id"], - "id": p["id"], - "name": p["name"], - "description": p.get("description", ""), - "create_time": format_datetime(p["create_time"]), - "market_value": ( - f"¥ {p['total_market_value']:,.2f}" if p["total_market_value"] else "¥ 0.00" - ), - "fund_count": p["fund_count"] or 0, + "key": p.id, + "id": p.id, + "name": p.name, + "description": p.description, + "create_time": format_datetime(p.created_at), + "market_value": "¥ 0.00", + "fund_count": 0, "operation": operation_buttons, } ) table_data.append( { - "key": account["id"], - "id": account["id"], - "name": account["name"], - "description": account.get("description", ""), - "create_time": format_datetime(account["create_time"]), - "operation": create_operation_buttons(account["id"], "account", is_danger=True), + "key": account.id, + "id": account.id, + "name": account.name, + "description": account.description, + "create_time": format_datetime(account.created_at), + "operation": create_operation_buttons(account.id, "account", is_danger=True), "children": portfolio_data, } ) @@ -269,10 +268,12 @@ def handle_button_click(nClicksButton, custom_info, accounts_data): # 处理组合操作 elif object_type == "portfolio": account_id = custom_info.get("accountId") - portfolio = get_portfolio(object_id) + portfolio = get_record(ModelPortfolio, {"id": object_id}) if not portfolio: raise PreventUpdate + portfolio = portfolio.to_dict() + if action == "edit": return ( False, # account modal visible diff --git a/pages/home/__init__.py b/pages/home/__init__.py index a39c460..80f12c2 100644 --- a/pages/home/__init__.py +++ b/pages/home/__init__.py @@ -6,7 +6,7 @@ - 收益走势图表 文件结构: -- page.py: 页面主渲染函数 +- task_page.py: 页面主渲染函数 - overview.py: 数据概览卡片相关 - charts.py: 图表相关组件 - utils.py: 通用工具函数 diff --git a/pages/task/__init__.py b/pages/task/__init__.py index ef97a0e..6a60d8d 100644 --- a/pages/task/__init__.py +++ b/pages/task/__init__.py @@ -7,13 +7,13 @@ - 任务详情查看 文件结构: -- page.py: 页面主渲染函数 +- task_page.py: 页面主渲染函数 - table.py: 任务列表表格相关 - modal.py: 任务创建弹窗及其回调 - detail_modal.py: 任务详情弹窗相关 - utils.py: 通用工具函数和常量 """ -from pages.task.page import render_task_page +from pages.task.task_page import render_task_page __all__ = ["render_task_page"] diff --git a/pages/task/page.py b/pages/task/task_page.py similarity index 98% rename from pages/task/page.py rename to pages/task/task_page.py index 3b0d96b..e2de01f 100644 --- a/pages/task/page.py +++ b/pages/task/task_page.py @@ -23,7 +23,7 @@ from pages.task.detail_modal import render_task_detail_modal from pages.task.modal import render_task_modal -from pages.task.table import render_task_table +from pages.task.task_table import render_task_table from pages.task.utils import ICON_STYLES, PAGE_PADDING from scheduler.job_manager import JobManager diff --git a/pages/task/table.py b/pages/task/task_table.py similarity index 100% rename from pages/task/table.py rename to pages/task/task_table.py diff --git a/pages/transaction/__init__.py b/pages/transaction/__init__.py index 66e5a89..5f3c926 100644 --- a/pages/transaction/__init__.py +++ b/pages/transaction/__init__.py @@ -6,7 +6,7 @@ - 表格展示所有交易记录 文件结构: -- page.py: 页面主渲染函数 +- task_page.py: 页面主渲染函数 - table.py: 交易记录表格相关组件和回调 - modal.py: 交易记录编辑弹窗及其回调 - delete_modal.py: 删除确认弹窗相关 diff --git a/pages/transaction/delete_modal.py b/pages/transaction/delete_modal.py index 2b944f0..37baa94 100644 --- a/pages/transaction/delete_modal.py +++ b/pages/transaction/delete_modal.py @@ -10,7 +10,8 @@ from dash import Input, Output, State, callback from dash.exceptions import PreventUpdate -from models.database import delete_transaction, get_transactions +from models.database import delete_record, get_transactions +from models.fund import ModelFundTransaction def render_delete_confirm_modal() -> fac.AntdModal: @@ -41,6 +42,6 @@ def handle_delete_confirm(ok_counts, transaction_id): if not ok_counts or not transaction_id: raise PreventUpdate - if delete_transaction(transaction_id): + if delete_record(ModelFundTransaction, {"id": transaction_id}): return get_transactions(), False return dash.no_update, False diff --git a/pages/transaction/modal.py b/pages/transaction/modal.py index 150f2f9..0ffc05d 100644 --- a/pages/transaction/modal.py +++ b/pages/transaction/modal.py @@ -7,6 +7,7 @@ """ import logging +import uuid from datetime import datetime from typing import Any, Dict, List, Optional, Tuple @@ -16,14 +17,16 @@ from dash.exceptions import PreventUpdate from components.fund_code_aio import FundCodeAIO -from data_source.proxy import DataSourceProxy from models.database import ( add_transaction, - get_fund_nav, + get_record, get_transactions, + update_record, update_transaction, ) -from models.fund import ModelFundTransaction +from models.fund import ModelFundNav, ModelFundTransaction +from scheduler import job_manager +from scheduler.job_manager import JobManager from utils.fac_helper import show_message from .utils import build_cascader_options @@ -380,44 +383,28 @@ def handle_transaction_save( raise PreventUpdate # 处理日期时间 - trade_datetime = None - if isinstance(trade_time, str): - if len(trade_time) == 10: - trade_datetime = datetime.strptime(trade_time, "%Y-%m-%d") - else: - trade_datetime = datetime.strptime(trade_time, "%Y-%m-%d %H:%M:%S") - - if not trade_datetime: + + if not trade_time: logger.error("无效的交易时间格式") raise PreventUpdate # 保存交易记录 - success = False - if editing_id: - success = update_transaction( - transaction_id=editing_id, - portfolio_id=portfolio_id, - fund_code=fund_code, - transaction_type=transaction_type, - amount=amount, - shares=shares, - nav=nav, - fee=fee, - fee_type=fee_type, - trade_time=trade_datetime, - ) - else: - success = add_transaction( - portfolio_id=portfolio_id, - fund_code=fund_code, - transaction_type=transaction_type, - amount=amount, - shares=shares, - nav=nav, - fee=fee, - fee_type=fee_type, - trade_time=trade_datetime, - ) + + # 交易记录数据 + transaction_data = { + "portfolio": portfolio_id, + "fund_code": fund_code, + "type": transaction_type, + "amount": amount, + "shares": shares or 0.0, + "nav": nav or 0.0, + "fee": fee or 0.0, + "fee_type": fee_type, + "transaction_date": trade_time, + } + + condition = {"id": editing_id if editing_id else str(uuid.uuid4())} + success = update_record(ModelFundTransaction, condition, transaction_data) if success: return get_transactions(), False @@ -489,7 +476,7 @@ def calculate_fee( @callback( - Output("nav-input", "value"), + Output("nav-input", "value", allow_duplicate=True), [ Input(FundCodeAIO.ids.select("fund-code-aio"), "value"), Input("trade-time-picker", "value"), @@ -506,13 +493,18 @@ def update_nav_value(fund_code: Optional[str], trade_time: Optional[str]) -> Opt try: # 1. 从数据库中获取基金净值 - trade_date = datetime.strptime(trade_time, "%Y-%m-%d") - nav = get_fund_nav(fund_code, trade_date) + nav: Optional[ModelFundNav] = get_record( + ModelFundNav, {"fund_code": fund_code, "nav_date": trade_time} + ) if nav: return str(nav.nav) show_message("未找到基金净值, 请手动输入", "info") # 2. 触发净值更新任务 - 异步更新数据库中的净值 - update_fund_nav.delay(fund_code, trade_date, nav) + JobManager().add_task( + "fund_nav", + fund_code=fund_code, + start_date=trade_time, + ) return dash.no_update diff --git a/pages/transaction/table.py b/pages/transaction/table.py index 5ac1536..f03e7bf 100644 --- a/pages/transaction/table.py +++ b/pages/transaction/table.py @@ -13,7 +13,8 @@ from dash.exceptions import PreventUpdate from components.fund_code_aio import FundCodeAIO -from models.database import get_portfolio +from models.account import ModelPortfolio +from models.database import get_record from .utils import build_cascader_options @@ -146,11 +147,11 @@ def update_transaction_table(store_data: List[Dict[str, Any]]) -> List[Dict[str, Output("transaction-type-select", "value"), Output("amount-input", "value"), Output("shares-input", "value"), - Output("nav-input", "value"), + Output("nav-input", "value", allow_duplicate=True), Output("fee-input", "value"), Output("trade-time-picker", "value"), Output("transaction-delete-confirm-modal", "visible"), - Output("editing-transaction-id", "data"), + Output("editing-transaction-id", "data", allow_duplicate=True), ], Input("transaction-list", "nClicksButton"), State("transaction-list", "clickedCustom"), @@ -182,7 +183,7 @@ def handle_button_click(nClicksButton, custom_info, store_data): # 找到当前交易记录对应的组合路径 portfolio_id = transaction["portfolio_id"] # 从组合信息中获取账户ID - portfolio = get_portfolio(portfolio_id) + portfolio = get_record(ModelPortfolio, {"id": portfolio_id}) if not portfolio: raise PreventUpdate diff --git a/pages/transaction/utils.py b/pages/transaction/utils.py index 1d580d9..7efe597 100644 --- a/pages/transaction/utils.py +++ b/pages/transaction/utils.py @@ -7,7 +7,8 @@ from typing import Any, Dict, List -from models.database import get_accounts, get_portfolios +from models.account import ModelAccount +from models.database import get_record_list def create_operation_buttons(transaction_id: str) -> List[Dict[str, Any]]: @@ -55,22 +56,22 @@ def build_cascader_options() -> List[Dict[str, Any]]: - value: 选项值 - children: 子选项列表(组合) """ - accounts = get_accounts() + accounts = get_record_list(ModelAccount) cascader_options = [] for account in accounts: - portfolios = get_portfolios(account["id"]) + portfolios = account.portfolios portfolio_children = [ { - "label": p["name"], - "value": p["id"], + "label": p.name, + "value": p.id, } for p in portfolios ] cascader_options.append( { - "label": account["name"], - "value": account["id"], + "label": account.name, + "value": account.id, "children": portfolio_children, } ) diff --git a/pyproject.toml b/pyproject.toml index 0aef00a..7c0fb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ numpy = "^1.26.2" # 数据库 peewee = "^3.17.0" +playhouse = "^0.1.0" [tool.poetry.group.dev.dependencies] # 开发工具 diff --git a/scheduler/__init__.py b/scheduler/__init__.py index ee69295..4dbe2c5 100644 --- a/scheduler/__init__.py +++ b/scheduler/__init__.py @@ -3,20 +3,11 @@ from flask_apscheduler import APScheduler -from config import SCHEDULER_CONFIG +from scheduler.tasks import TaskStatus scheduler = APScheduler() -class TaskStatus: - PENDING = "等待中" - RUNNING = "运行中" - COMPLETED = "已完成" - FAILED = "失败" - TIMEOUT = "超时" - PAUSED = "已暂停" - - class TaskResult: def __init__(self): self.status = TaskStatus.PENDING diff --git a/scheduler/tasks/__init__.py b/scheduler/tasks/__init__.py index e8520e8..2bc35f7 100644 --- a/scheduler/tasks/__init__.py +++ b/scheduler/tasks/__init__.py @@ -1,18 +1,19 @@ +from scheduler.tasks.fund_nav import FundNavTask from .data_sync import DataSyncTask from .fund_detail import FundDetailTask from .fund_info import FundInfoTask -from .portfolio import PortfolioUpdateTask from .task_factory import TaskFactory class TaskStatus: """任务状态常量""" - PENDING = "pending" # 等待执行 - RUNNING = "running" # 正在执行 - COMPLETED = "completed" # 执行完成 - FAILED = "failed" # 执行失败 - PAUSED = "paused" # 已暂停 + PENDING = "等待中" # 等待执行 + RUNNING = "运行中" # 正在执行 + COMPLETED = "已完成" # 执行完成 + FAILED = "失败" # 执行失败 + TIMEOUT = "超时" # 超时 + PAUSED = "已暂停" # 已暂停 def init_tasks(): @@ -23,6 +24,7 @@ def init_tasks(): # factory.register(PortfolioUpdateTask) factory.register(FundInfoTask) factory.register(FundDetailTask) + factory.register(FundNavTask) __all__ = ["TaskFactory", "TaskStatus", "init_tasks"] diff --git a/scheduler/tasks/data_sync.py b/scheduler/tasks/data_sync.py index d179d88..1f5e4da 100644 --- a/scheduler/tasks/data_sync.py +++ b/scheduler/tasks/data_sync.py @@ -19,20 +19,20 @@ def get_config(cls) -> Dict[str, Any]: "timeout": 7200, # 2小时 "priority": 2, "params": [ - { - "name": "同步类型", - "key": "sync_type", - "type": "select", - "required": True, - "description": "选择要同步的数据类型", - "default": "all", - "options": [ - {"label": "全部数据", "value": "all"}, - {"label": "基金信息", "value": "info"}, - {"label": "净值数据", "value": "nav"}, - {"label": "持仓数据", "value": "position"}, - ], - }, + # { + # "name": "同步类型", + # "key": "sync_type", + # "type": "select", + # "required": True, + # "description": "选择要同步的数据类型", + # "default": "all", + # "options": [ + # {"label": "全部数据", "value": "all"}, + # {"label": "基金信息", "value": "info"}, + # {"label": "净值数据", "value": "nav"}, + # {"label": "持仓数据", "value": "position"}, + # ], + # }, { "name": "开始日期", "key": "start_date", diff --git a/scheduler/tasks/fund_detail.py b/scheduler/tasks/fund_detail.py index 5695c4a..eb159fa 100644 --- a/scheduler/tasks/fund_detail.py +++ b/scheduler/tasks/fund_detail.py @@ -3,6 +3,7 @@ from typing import Any, Dict from data_source.proxy import DataSourceProxy +from models.database import update_record from models.fund import ModelFund from .base import BaseTask @@ -90,14 +91,7 @@ def execute(self, **kwargs) -> Dict[str, Any]: self.update_progress(80) logger.info("正在更新数据库...") - # 更新或创建基金记录 - fund, created = ModelFund.get_or_create(code=fund_code, defaults=fund_data) - - if not created: - # 更新现有记录 - for key, value in fund_data.items(): - setattr(fund, key, value) - fund.save() + update_record(ModelFund, {"code": fund_code}, fund_data) self.update_progress(100) logger.info("基金 %s 信息更新完成", fund_code) diff --git a/scheduler/tasks/fund_info.py b/scheduler/tasks/fund_info.py index a52072e..05235b8 100644 --- a/scheduler/tasks/fund_info.py +++ b/scheduler/tasks/fund_info.py @@ -3,7 +3,6 @@ from typing import Any, Dict from data_source.proxy import DataSourceProxy -from models.fund import ModelFund from .base import BaseTask @@ -66,20 +65,13 @@ def execute(self, **kwargs) -> Dict[str, Any]: "description": fund_info.get("description", ""), } - self.update_progress(50) - logger.info("正在更新数据库...") + # self.update_progress(50) + # logger.info("正在更新数据库...") - # 更新或创建基金记录 - fund, created = ModelFund.get_or_create(code=fund_code, defaults=fund_data) + # # update_record(ModelFund, {"code": fund_code}, fund_data) - if not created: - # 更新现有记录 - for key, value in fund_data.items(): - setattr(fund, key, value) - fund.save() - - self.update_progress(80) - logger.info("正在获取最新净值...") + # self.update_progress(80) + # logger.info("正在获取最新净值...") # # 获取最新净值 # nav_history = data_source.get_fund_nav_history( @@ -100,15 +92,7 @@ def execute(self, **kwargs) -> Dict[str, Any]: self.update_progress(100) logger.info("基金 %s 信息更新完成", fund_code) - return { - "message": "Fund info update completed", - "task_id": self.task_id, - "fund_code": fund_code, - "fund_name": fund_data["name"], - "created": created, - "nav": str(fund.nav) if fund.nav else None, - "nav_date": (fund.nav_date.strftime("%Y-%m-%d") if fund.nav_date else None), - } + return fund_data except Exception as e: logger.error("更新基金信息失败: %s", str(e), exc_info=True) # 添加完整的错误堆栈 diff --git a/scheduler/tasks/fund_nav.py b/scheduler/tasks/fund_nav.py index cb709fb..dffaf39 100644 --- a/scheduler/tasks/fund_nav.py +++ b/scheduler/tasks/fund_nav.py @@ -2,8 +2,11 @@ from datetime import datetime from typing import Any, Dict + from data_source.proxy import DataSourceProxy +from models.database import update_record from models.fund import ModelFundNav +from utils.datetime_helper import get_date_str_after_days from .base import BaseTask @@ -63,27 +66,41 @@ def execute(self, **kwargs) -> Dict[str, Any]: if not start_date: raise ValueError("start_date 不能为空") end_date = kwargs.get("end_date") - if not end_date: - raise ValueError("end_date 不能为空") try: # 初始化数据源 data_source = DataSourceProxy() + # 获取基金历史净值数据大小 + nav_history_size_response = data_source.get_fund_nav_history_size() + if nav_history_size_response["code"] != 200: + raise ValueError(nav_history_size_response["message"]) + self.update_progress(30) + + if not end_date: + # 默认结束日期是 开始日期 + end_date = get_date_str_after_days(start_date, nav_history_size_response["data"]) + # 更新进度 - self.update_progress(20) + self.update_progress(50) logger.info("正在获取基金 %s [%s-%s] 的净值...", fund_code, start_date, end_date) # 获取基金信息 nav_history_response = data_source.get_fund_nav_history(fund_code, start_date, end_date) - self.update_progress(50) + self.update_progress(70) if nav_history_response["code"] != 200: raise ValueError(nav_history_response["message"]) logger.info("正在更新数据库...") # 批量保存净值到数据库 - ModelFundNav.bulk_create(nav_history_response["data"]) + nav_data = nav_history_response["data"] + for nav_item in nav_data: + update_record( + ModelFundNav, + {"fund_code": fund_code, "nav_date": nav_item["nav_date"]}, + nav_item, + ) self.update_progress(100) logger.info("基金 %s 信息更新完成", fund_code) diff --git a/start.sh b/start.sh index 55c6963..a72ad8f 100755 --- a/start.sh +++ b/start.sh @@ -347,6 +347,7 @@ show_menu() { echo "4) 启动应用" echo "5) 完整安装(1-4步骤)" echo "------------" + echo "8) 清理日志" echo "9) 测试" echo "0) 退出" echo "------------------------" @@ -369,6 +370,23 @@ do_full_install() { return 0 } +# 添加清理日志文件的函数 +clean_logs() { + echo -e "${YELLOW}开始清理日志文件...${NC}" + if [ -d "logs" ]; then + rm -f logs/* + if [ $? -eq 0 ]; then + echo -e "${GREEN}日志文件清理完成${NC}" + else + echo -e "${RED}日志文件清理失败${NC}" + return 1 + fi + else + echo -e "${YELLOW}logs目录不存在,无需清理${NC}" + fi + return 0 +} + # 主循环 while true; do show_menu @@ -406,11 +424,15 @@ while true; do echo -e "${YELLOW}开始启动应用...${NC}" ensure_directories || continue activate_conda_env || continue - start_app || continue + clean_logs + start_app ;; 5) do_full_install || continue ;; + 8) + clean_logs + ;; 9) start_test ;; diff --git a/utils/datetime_helper.py b/utils/datetime_helper.py index ac9cdfa..356d063 100644 --- a/utils/datetime_helper.py +++ b/utils/datetime_helper.py @@ -1,6 +1,7 @@ -from datetime import date, datetime +from datetime import date, datetime, timedelta +import logging from typing import Optional, Union - +logger = logging.getLogger(__name__) def format_datetime( dt: Union[str, datetime, None], @@ -81,3 +82,39 @@ def get_timestamp() -> int: 1709251200 # 2024-03-01 00:00:00 UTC """ return int(datetime.now().timestamp()) + + +def get_date_str_after_days(start_date: Union[str, date], days: int) -> str: + """获取开始日期后几天的日期字符串""" + return format_date(get_date_after_days(start_date, days)) + + +def get_date_after_days(start_date: Union[str, date], days: int) -> date: + """获取开始日期后几天的日期 + + Args: + start_date: 开始日期,可以是date对象或ISO格式字符串 + days: 天数,正数表示往后,负数表示往前 + + Returns: + date: 计算后的日期 + + Examples: + >>> get_date_after_days('2024-03-01', 1) + datetime.date(2024, 3, 2) + >>> get_date_after_days(date(2024, 3, 1), -1) + datetime.date(2024, 2, 29) + """ + try: + if isinstance(start_date, str): + start_date = date.fromisoformat(start_date.strip()) + + if not isinstance(start_date, date): + logger.error("无效的日期格式: %s", start_date) + raise ValueError("无效的日期格式") + + return start_date + timedelta(days=days) + + except (ValueError, TypeError) as e: + logger.error("计算日期失败: %s", str(e)) + raise