Skip to content

Commit

Permalink
修复错误
Browse files Browse the repository at this point in the history
  • Loading branch information
kingzeus committed Nov 24, 2024
1 parent ce6f47a commit dfbf87b
Show file tree
Hide file tree
Showing 32 changed files with 450 additions and 487 deletions.
28 changes: 16 additions & 12 deletions backend/apis/account.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid
from flask_restx import Namespace, Resource, fields

from backend.apis.common import create_list_response_model, create_response_model
from models.account import ModelAccount, update_account
from models.database import (
add_account,
delete_account,
get_account,
get_accounts,
update_account,
delete_record,
get_record,
get_record_list,
update_record,
)
from utils.response import format_response

Expand Down Expand Up @@ -44,16 +45,18 @@ class AccountList(Resource):
@api.marshal_with(account_list_response)
def get(self):
"""获取所有账户列表"""
return format_response(data=get_accounts())
return format_response(data=get_record_list(ModelAccount))

@api.doc("创建新账户")
@api.expect(account_input)
@api.marshal_with(account_response)
def post(self):
"""创建新账户"""
data = api.payload
account_id = add_account(data["name"], data.get("description"))
return format_response(data=get_account(account_id), message="账户创建成功")
account_id = update_account(str(uuid.uuid4()), data)
return format_response(
data=get_record(ModelAccount, {"id": account_id}), message="账户创建成功"
)


@api.route("/<string:account_id>")
Expand All @@ -63,7 +66,7 @@ class Account(Resource):
@api.marshal_with(account_response)
def get(self, account_id):
"""获取指定账户的详情"""
account = get_account(account_id)
account = get_record(ModelAccount, {"id": account_id})
if not account:
return format_response(message="账户不存在", code=404)
return format_response(data=account)
Expand All @@ -74,7 +77,7 @@ def get(self, account_id):
def put(self, account_id):
"""更新账户信息"""
data = api.payload
account = update_account(account_id, data)
account = update_record(ModelAccount, {"id": account_id}, data)
if not account:
return format_response(message="账户不存在", code=404)
return format_response(data=account, message="账户更新成功")
Expand All @@ -83,5 +86,6 @@ def put(self, account_id):
@api.marshal_with(account_response)
def delete(self, account_id):
"""删除账户"""
delete_account(account_id)
return format_response(message="账户删除成功")
if delete_record(ModelAccount, {"id": account_id}):
return format_response(message="账户删除成功")
return format_response(message="账户不存在", code=404)
9 changes: 5 additions & 4 deletions backend/apis/fund.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from backend.apis.common import create_list_response_model, create_response_model
from models.database import (
add_fund_position,
delete_fund_position,
delete_record,
get_fund_positions,
get_fund_transactions,
update_fund_position,
update_record,
)
from models.fund import ModelFundPosition
from utils.response import format_response

api = Namespace("funds", description="基金相关操作")
Expand Down Expand Up @@ -94,7 +95,7 @@ class FundPosition(Resource):
def put(self, position_id):
"""更新持仓信息"""
data = api.payload
position = update_fund_position(position_id, data)
position = update_record(ModelFundPosition, {"id": position_id}, data)
if not position:
return format_response(message="持仓不存在", code=404)
return format_response(data=position, message="持仓更新成功")
Expand All @@ -103,7 +104,7 @@ def put(self, position_id):
@api.marshal_with(position_response)
def delete(self, position_id):
"""删除持仓"""
if delete_fund_position(position_id):
if delete_record(ModelFundPosition, {"id": position_id}):
return format_response(message="持仓删除成功")
return format_response(message="持仓不存在", code=404)

Expand Down
35 changes: 18 additions & 17 deletions backend/apis/portfolio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid
from flask_restx import Namespace, Resource, fields

from backend.apis.common import create_list_response_model, create_response_model
from models.account import ModelPortfolio
from models.database import (
add_portfolio,
delete_portfolio,
get_portfolio,
get_portfolios,
update_portfolio,
delete_record,
get_record,
get_record_list,
update_record,
)
from utils.response import format_response

Expand Down Expand Up @@ -54,21 +55,20 @@ def get(self):
account_id = api.payload.get("account_id")
if not account_id:
return format_response(message="缺少账户ID", code=400)
return format_response(data=get_portfolios(account_id))
return format_response(data=get_record_list(ModelPortfolio, {"account_id": account_id}))

@api.doc("创建新投资组合")
@api.expect(portfolio_input)
@api.marshal_with(portfolio_response)
def post(self):
"""创建新投资组合"""
data = api.payload
portfolio_id = add_portfolio(
data["account_id"],
data["name"],
data.get("description"),
data.get("is_default", False),
portfolio_id = str(uuid.uuid4())
update_record(ModelPortfolio, {"id": portfolio_id}, data)

return format_response(
data=get_record(ModelPortfolio, {"id": portfolio_id}), message="组合创建成功"
)
return format_response(data=get_portfolio(portfolio_id), message="组合创建成功")


@api.route("/<string:portfolio_id>")
Expand All @@ -78,7 +78,7 @@ class Portfolio(Resource):
@api.marshal_with(portfolio_response)
def get(self, portfolio_id):
"""获取指定组合的详情"""
portfolio = get_portfolio(portfolio_id)
portfolio = get_record(ModelPortfolio, {"id": portfolio_id})
if not portfolio:
return format_response(message="组合不存在", code=404)
return format_response(data=portfolio)
Expand All @@ -89,14 +89,15 @@ def get(self, portfolio_id):
def put(self, portfolio_id):
"""更新组合信息"""
data = api.payload
portfolio = update_portfolio(portfolio_id, data)
portfolio = update_record(ModelPortfolio, {"id": portfolio_id}, data)
if not portfolio:
return format_response(message="组合不存在", code=404)
return format_response(message="更新失败", code=404)
return format_response(data=portfolio, message="组合更新成功")

@api.doc("删除组合")
@api.marshal_with(portfolio_response)
def delete(self, portfolio_id):
"""删除组合"""
delete_portfolio(portfolio_id)
return format_response(message="组合删除成功")
if delete_record(ModelPortfolio, {"id": portfolio_id}):
return format_response(message="组合删除成功")
return format_response(message="组合不存在", code=404)
10 changes: 6 additions & 4 deletions data_source/implementations/eastmoney.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def get_fund_nav_history_size(self) -> int:
def get_fund_nav_history(
self,
fund_code: str,
start_date: datetime,
end_date: datetime,
start_date: str,
end_date: str,
) -> List[Dict[str, Any]]:
"""获取基金历史净值"""
try:
Expand All @@ -230,8 +230,8 @@ def get_fund_nav_history(
"code": fund_code,
"type": "lsjz", # 历史净值
"page": 1, # 页码
"sdate": start_date.strftime("%Y-%m-%d"),
"edate": end_date.strftime("%Y-%m-%d"),
"sdate": start_date,
"edate": end_date,
"per": 20, # 默认获取最近20条记录
}
logger.debug("请求基金历史净值: %s, params: %s", url, params)
Expand Down Expand Up @@ -293,6 +293,8 @@ def get_fund_nav_history(
"daily_return": daily_return,
"subscription_status": subscription_status,
"redemption_status": redemption_status,
"data_source": self.get_name(),
"data_source_version": self.get_version(),
}
if dividend:
item["dividend"] = dividend
Expand Down
8 changes: 4 additions & 4 deletions data_source/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def get_fund_nav_history_size(self) -> int:
def get_fund_nav_history(
self,
fund_code: str,
start_date: datetime,
end_date: datetime,
start_date: str,
end_date: str,
) -> List[Dict[str, Any]]:
"""
获取基金历史净值
Args:
fund_code: 基金代码
start_date: 开始日期
end_date: 结束日期
start_date: 开始日期,格式: 2004-01-01
end_date: 结束日期,格式: 2004-01-01
Returns:
List[Dict]: [{
"date": "日期",
Expand Down
14 changes: 12 additions & 2 deletions data_source/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def get_fund_detail(self, fund_code: str) -> Dict[str, Any]:
def get_fund_nav_history(
self,
fund_code: str,
start_date: datetime,
end_date: datetime,
start_date: str,
end_date: str,
) -> Dict[str, Any]:
"""获取基金历史净值"""
return self._call_api(
Expand All @@ -110,3 +110,13 @@ def get_fund_nav_history(
end_date=end_date,
is_array=True,
)

def get_fund_nav_history_size(
self,
) -> Dict[str, Any]:
"""获取基金历史净值"""
return self._call_api(
func_name="get_fund_nav_history_size",
api_func=self._data_source.get_fund_nav_history_size,
error_msg="获取基金历史净值失败",
)
40 changes: 39 additions & 1 deletion models/account.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Optional
import uuid
from peewee import BooleanField, CharField, ForeignKeyField

from models.base import BaseModel
Expand Down Expand Up @@ -30,7 +32,7 @@ class ModelPortfolio(BaseModel):
"""投资组合模型"""

id = CharField(primary_key=True)
account = ForeignKeyField(ModelAccount, backref="portfolio")
account = ForeignKeyField(ModelAccount, backref="portfolios")
name = CharField(null=False)
description = CharField(null=True)
is_default = BooleanField(default=False)
Expand All @@ -51,3 +53,39 @@ def to_dict(self) -> dict:
}
)
return result


def update_account(account_id: Optional[str], data: Dict[str, Any]) -> bool:
"""更新账户信息"""
from models.database import update_record

if not account_id:
account_id = str(uuid.uuid4())

def on_created(result):
# 创建默认投资组合
ModelPortfolio.create(
id=str(uuid.uuid4()),
account=account_id,
name=f"{data['name']}-默认组合",
description=f"{data['name']}的默认投资组合",
is_default=True,
)
print(f"账户创建完成: {result.to_dict()}")

return update_record(ModelAccount, {"id": account_id}, data, on_created)


def delete_account(account_id: str) -> bool:
"""更新账户信息"""
from models.database import delete_record

def on_before(result):
# 删除默认投资组合
portfolio_count = (
ModelPortfolio.select().where(ModelPortfolio.account == account_id).count()
)
if portfolio_count > 0:
raise ValueError("账户下存在投资组合,无法删除")

return delete_record(ModelAccount, {"id": account_id}, on_before)
Loading

0 comments on commit dfbf87b

Please sign in to comment.