需求
目前对接多个平台的收费接口,需要对每个平台进行每天的限制额度,由于某些平台的计费不是实时的,所以允许有部分出入
问题
- 聊天接口基本上都用openAI的包,但查询费用接口不一致
- 查询费用接口只能返回当前账号余额,需要记录每天使用的金额数
- 每天记录需要定时生成
- 聊天接口要根据金额是否达标而判断是否需要阻止访问
- 聊天需要流式返回,且判断余额不可以过度影响返回的响应速度
- 需要考虑充值问题
实现代码
1. 表设计
-- auto-generated definition
create table t_platform
(
id int auto_increment
primary key,
name varchar(255) null comment '平台名称',
limit_amount double null comment '当日限制使用金额',
balance double null comment '余额',
api_key varchar(255) null comment '平台授权访问的key',
uuid varchar(255) null comment '平台授权访问的token',
url varchar(255) null comment '平台提供的url',
is_default int default 0 null comment '0: 常规 1:默认'
);
create table t_platform_daily_use
(
id int auto_increment
primary key,
use_date date not null comment '使用日期',
init_amount double not null comment '当日初始化时账户剩余金额',
current_amount double not null comment '当前查询接口余额',
diff_amount double not null comment '差额',
platform_id int not null comment '平台id'
);
2.技术选型
框架: FastApi
ORM框架:tortoise
缓存: redis
3.项目结构
├── app/
│ ├── api/
│ │ └── plaform_api.py # API 路由定义
│ ├── config/
│ │ └──config.py # 读取配置文件
│ ├── handler/
│ │ ├── init.py # 初始化文件(包含模型扫描逻辑)
│ │ └── platform_handler.py # 平台处理逻辑
│ ├── model/
│ │ └── db/ # 数据库模型目录
│ │ └── models.py # 定义数据库表对应的模型类
│ └── util/ # 工具模块目录
├── config/ # 配置文件目录
└── main.py # 项目入口文件
4.核心具体实现
数据库表模型
class Platform(Model):
id = fields.IntField(pk=True)
name = fields.CharField(max_length=255, description='平台名称')
limit_amount = fields.FloatField(description='当日限制金额')
balance = fields.FloatField(description='余额')
api_key = fields.CharField(max_length=255, description='api key')
uuid = fields.CharField(max_length=255, description='uuid')
url=fields.CharField(max_length=255, description='api url')
is_default=fields.BooleanField(description='是否默认 0: 常规 1:默认')
class Meta:
table = 't_platform'
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'limit_amount': self.limit_amount,
'balance': self.balance,
'api_key': self.api_key,
'uuid': self.uuid,
'url': self.url,
'is_default': self.is_default
}
class PlatformDailyUse(Model):
id = fields.IntField(pk=True)
use_date = fields.DateField(description='日期')
init_amount = fields.FloatField(description='当日重置时账户剩余金额')
current_amount=fields.FloatField(description='当前余额')
diff_amount = fields.FloatField(description='差额')
platform_id = fields.IntField( description='平台id')
class Meta:
table = 't_platform_daily_use'
async def init_daily_use(agi_platform: Platform = None):
return await PlatformDailyUse.create(platform_id=agi_platform.id,
use_date=datetime.now().date(),
init_amount=agi_platform.balance,
current_amount=agi_platform.balance,
diff_amount=0)
定义对外服务的平台基类
# 定义基类
class ProviderPlatform(ABC):
def __init__(self, platform: Platform):
self.platform = platform
# 查询余额 子类可以重载
async def query_account(self):
pass
# 检查余额任务实现
async def check_balance_task(self):
await self.query_account()
agi_platform = self.platform
daily_use = await PlatformDailyUse.filter(platform_id=agi_platform.id, use_date=datetime.now().date()).first()
if daily_use is None:
daily_use = await PlatformDailyUse.init_daily_use(agi_platform)
current_amount=float(daily_use.current_amount)
balance=float(agi_platform.balance)
if current_amount < balance:
# 充值了
charge_amount = balance - current_amount
daily_use.init_amount += charge_amount
current_amount = balance
daily_use.diff_amount = float(daily_use.init_amount) - current_amount
await daily_use.save(update_fields=['current_amount', 'diff_amount', 'init_amount'])
if float(daily_use.diff_amount) > float(agi_platform.limit_amount) or float(self.platform.balance) <=0:
REDIS_CONN.set(get_plat_day_forbid_key(self.platform.name), "false")
# 添加异步任务
def add_check_balance_task(self,background_tasks: BackgroundTasks):
background_tasks.add_task(self.check_balance_task)
pass
# 聊天
def chat_completions(self,req: ChatRequest, background_tasks: BackgroundTasks):
# 查询redis
if REDIS_CONN.exist(get_plat_day_forbid_key(self.platform.name)):
return get_json_result(500, "当日付费金额已超额度,请联系管理员!")
# 添加任务去校验今天使用的钱是否
self.add_check_balance_task(background_tasks)
async def generate_stream():
try:
client = AsyncOpenAI(
api_key=self.platform.api_key,
base_url=f"https://{self.platform.url}/v1"
)
stream = await client.chat.completions.create(
model=req.model,
messages=req.messages,
stream=True,
temperature=req.temperature,
max_tokens=req.max_tokens
)
async for chunk in stream:
chunk_json = chunk.model_dump_json()
yield f"{chunk_json}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield f"data: {{\"error\": \"{str(e)}\"}}\n\n"
return EventSourceResponse(generate_stream(), media_type="text/event-stream")
Agicto的平台服务类
class AgictoPlatform(ProviderPlatform):
async def query_account(self):
conn = http.client.HTTPSConnection(self.platform.url)
payload=json.dumps({"uuid": self.platform.uuid})
conn.request("POST", "/v1/enterprise/account", payload, {'Content-Type': 'application/json'})
res = conn.getresponse()
data = res.read()
await ApiRecord.create(api_url="/v1/enterprise/account", params=payload, response_content=data,
create_time=datetime.now())
if res.status != 200:
return res.status, "查询余额失败"
result = json.loads(data)
if result["code"] == 0:
self.platform.balance = result["data"]["account"]
await self.platform.save(update_fields=["balance"])
return get_json_result(result["code"],result["data"]["account"])
项目启动时从平台表里加载该平台的服务类
# 平台注册中心
_PLATFORM_REGISTRY: Dict[str, ProviderPlatform] = {}
# 默认平台实例对象
default_platform = None
# 注册平台
def register_platform(name: str, handler: ProviderPlatform):
_PLATFORM_REGISTRY[name.lower()] = handler
# 根据名称获取平台实例对象
def get_platform_class(name: str) -> ProviderPlatform:
if name is None:
name= "default"
return _PLATFORM_REGISTRY.get(name.lower(), default_platform)
# 根据基类的类型 读取某个包下继承该类的的所有类
def get_all_subclasses_from_package(package_name: str, base_class: Type) -> Dict[str, Type]:
discovered_classes = {}
# 获取包模块
package: ModuleType = importlib.import_module(package_name)
package_path = package.__path__
# 遍历包中的所有模块
for _, module_name, _ in pkgutil.iter_modules(package_path):
full_module_name = f"{package_name}.{module_name}"
module = importlib.import_module(full_module_name)
# 遍历模块中的所有属性
for attr_name in dir(module):
cls = getattr(module, attr_name)
# 判断是否是类、是否继承自 base_class、不是 base_class 自身
if isinstance(cls, type) and issubclass(cls, base_class) and cls != base_class:
discovered_classes[cls.__name__] = cls
return discovered_classes
# 从数据库表中获取平台 并找到对应的平台服务类
async def initialize_platforms_from_db():
classes = get_all_subclasses_from_package('app.handler', ProviderPlatform)
db_platforms = await Platform.all()
for db_platform in db_platforms:
class_name = f"{db_platform.name.capitalize()}Platform"
cls = classes.get(class_name)
# cls = globals().get(class_name)
if cls:
instance = cls(db_platform)
register_platform(db_platform.name, instance)
if db_platform.is_default:
register_platform("default", instance)
main.py
# 生命周期管理器
@asynccontextmanager
async def lifespan(fast_app: FastAPI):
# 初始化平台
await initialize_platforms_from_db()
# 启动时执行
scheduler = AsyncIOScheduler()
scheduler.add_job(func=init_daily_use, trigger='cron', hour=0, minute=0, timezone="Asia/Shanghai")
scheduler.start()
print("⏰ 定时任务已启动")
yield # 应用运行期间保持
# 关闭时执行
scheduler.shutdown()
print("⏹️ 定时任务已关闭")
app = FastAPI(lifespan=lifespan)
app.include_router(plaform_api.platform_cto, prefix="/platform", tags=["platform", ])
目前未实现
- 未考虑余额为0 然后充值的问题
- 未考虑增加每日限额的逻辑