【OpenAI】Api接口限额方案实现

发布于:2025-06-29 ⋅ 阅读:(19) ⋅ 点赞:(0)

需求

目前对接多个平台的收费接口,需要对每个平台进行每天的限制额度,由于某些平台的计费不是实时的,所以允许有部分出入

问题

  1. 聊天接口基本上都用openAI的包,但查询费用接口不一致
  2. 查询费用接口只能返回当前账号余额,需要记录每天使用的金额数
  3. 每天记录需要定时生成
  4. 聊天接口要根据金额是否达标而判断是否需要阻止访问
  5. 聊天需要流式返回,且判断余额不可以过度影响返回的响应速度
  6. 需要考虑充值问题

实现代码

demo链接

1. 表设计

-- auto-generated definition
create table t_platform
(
    id           int auto_increment
        primary key,
    name         varchar(255)  null comment '平台名称',
    limit_amount double        null comment '当日限制使用金额',
    balance      double        null comment '余额',
    api_key      varchar(255)  null comment '平台授权访问的key',
    uuid         varchar(255)  null comment '平台授权访问的token',
    url          varchar(255)  null comment '平台提供的url',
    is_default   int default 0 null comment '0: 常规  1:默认'
);
create table t_platform_daily_use
(
    id             int auto_increment
        primary key,
    use_date       date   not null comment '使用日期',
    init_amount    double not null comment '当日初始化时账户剩余金额',
    current_amount double not null comment '当前查询接口余额',
    diff_amount    double not null comment '差额',
    platform_id    int    not null comment '平台id'
);

2.技术选型
框架: FastApi
ORM框架:tortoise
缓存: redis

3.项目结构

├── app/
│ ├── api/
│ │ └── plaform_api.py # API 路由定义
│ ├── config/
│ │ └──config.py # 读取配置文件
│ ├── handler/
│ │ ├── init.py # 初始化文件(包含模型扫描逻辑)
│ │ └── platform_handler.py # 平台处理逻辑
│ ├── model/
│ │ └── db/ # 数据库模型目录
│ │ └── models.py # 定义数据库表对应的模型类
│ └── util/ # 工具模块目录
├── config/ # 配置文件目录
└── main.py # 项目入口文件

4.核心具体实现
数据库表模型

class Platform(Model):
    id = fields.IntField(pk=True)
    name = fields.CharField(max_length=255, description='平台名称')
    limit_amount = fields.FloatField(description='当日限制金额')
    balance = fields.FloatField(description='余额')
    api_key = fields.CharField(max_length=255, description='api key')
    uuid = fields.CharField(max_length=255, description='uuid')
    url=fields.CharField(max_length=255, description='api url')
    is_default=fields.BooleanField(description='是否默认 0: 常规 1:默认')

    class Meta:
        table = 't_platform'

    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'limit_amount': self.limit_amount,
            'balance': self.balance,
            'api_key': self.api_key,
            'uuid': self.uuid,
            'url': self.url,
            'is_default': self.is_default
        }


class PlatformDailyUse(Model):
    id = fields.IntField(pk=True)
    use_date = fields.DateField(description='日期')
    init_amount = fields.FloatField(description='当日重置时账户剩余金额')
    current_amount=fields.FloatField(description='当前余额')
    diff_amount = fields.FloatField(description='差额')
    platform_id = fields.IntField( description='平台id')

    class Meta:
        table = 't_platform_daily_use'

    async def init_daily_use(agi_platform: Platform = None):
        return await PlatformDailyUse.create(platform_id=agi_platform.id,
                                             use_date=datetime.now().date(),
                                             init_amount=agi_platform.balance,
                                             current_amount=agi_platform.balance,
                                             diff_amount=0)

定义对外服务的平台基类

 # 定义基类
 class ProviderPlatform(ABC):
    def __init__(self, platform: Platform):
        self.platform = platform

	# 查询余额 子类可以重载
    async def query_account(self):
        pass

	# 检查余额任务实现
    async def check_balance_task(self):
        await self.query_account()
        agi_platform = self.platform
        daily_use = await PlatformDailyUse.filter(platform_id=agi_platform.id, use_date=datetime.now().date()).first()
        if daily_use is None:
            daily_use = await PlatformDailyUse.init_daily_use(agi_platform)
        current_amount=float(daily_use.current_amount)
        balance=float(agi_platform.balance)
        if current_amount < balance:
            # 充值了
            charge_amount = balance - current_amount
            daily_use.init_amount += charge_amount
        current_amount = balance
        daily_use.diff_amount = float(daily_use.init_amount) - current_amount
        await daily_use.save(update_fields=['current_amount', 'diff_amount', 'init_amount'])
        if float(daily_use.diff_amount) > float(agi_platform.limit_amount)  or float(self.platform.balance) <=0:
            REDIS_CONN.set(get_plat_day_forbid_key(self.platform.name), "false")
	
	# 添加异步任务
    def add_check_balance_task(self,background_tasks: BackgroundTasks):
        background_tasks.add_task(self.check_balance_task)
        pass

	# 聊天
    def chat_completions(self,req: ChatRequest, background_tasks: BackgroundTasks):
    	# 查询redis
        if REDIS_CONN.exist(get_plat_day_forbid_key(self.platform.name)):
            return get_json_result(500, "当日付费金额已超额度,请联系管理员!")
        # 添加任务去校验今天使用的钱是否
        self.add_check_balance_task(background_tasks)
        async def generate_stream():
            try:
                client = AsyncOpenAI(
                    api_key=self.platform.api_key,
                    base_url=f"https://{self.platform.url}/v1"
                )
                stream = await client.chat.completions.create(
                    model=req.model,
                    messages=req.messages,
                    stream=True,
                    temperature=req.temperature,
                    max_tokens=req.max_tokens
                )
                async for chunk in stream:
                    chunk_json = chunk.model_dump_json()
                    yield f"{chunk_json}\n\n"
                yield "data: [DONE]\n\n"
            except Exception as e:
                yield f"data: {{\"error\": \"{str(e)}\"}}\n\n"
        return EventSourceResponse(generate_stream(), media_type="text/event-stream")

Agicto的平台服务类


class AgictoPlatform(ProviderPlatform):
    async def query_account(self):
        conn = http.client.HTTPSConnection(self.platform.url)
        payload=json.dumps({"uuid": self.platform.uuid})
        conn.request("POST", "/v1/enterprise/account", payload, {'Content-Type': 'application/json'})
        res = conn.getresponse()
        data = res.read()
        await ApiRecord.create(api_url="/v1/enterprise/account", params=payload, response_content=data,
                               create_time=datetime.now())
        if res.status != 200:
            return res.status, "查询余额失败"
        result = json.loads(data)
        if result["code"] == 0:
            self.platform.balance = result["data"]["account"]
            await self.platform.save(update_fields=["balance"])
        return get_json_result(result["code"],result["data"]["account"])

项目启动时从平台表里加载该平台的服务类

# 平台注册中心
_PLATFORM_REGISTRY: Dict[str, ProviderPlatform] = {}

# 默认平台实例对象
default_platform = None

# 注册平台
def register_platform(name: str, handler: ProviderPlatform):
    _PLATFORM_REGISTRY[name.lower()] = handler

# 根据名称获取平台实例对象
def get_platform_class(name: str) -> ProviderPlatform:
    if name is None:
        name= "default"
    return _PLATFORM_REGISTRY.get(name.lower(), default_platform)

# 根据基类的类型 读取某个包下继承该类的的所有类
def get_all_subclasses_from_package(package_name: str, base_class: Type) -> Dict[str, Type]:
    discovered_classes = {}

    # 获取包模块
    package: ModuleType = importlib.import_module(package_name)
    package_path = package.__path__

    # 遍历包中的所有模块
    for _, module_name, _ in pkgutil.iter_modules(package_path):
        full_module_name = f"{package_name}.{module_name}"
        module = importlib.import_module(full_module_name)

        # 遍历模块中的所有属性
        for attr_name in dir(module):
            cls = getattr(module, attr_name)
            # 判断是否是类、是否继承自 base_class、不是 base_class 自身
            if isinstance(cls, type) and issubclass(cls, base_class) and cls != base_class:
                discovered_classes[cls.__name__] = cls

    return discovered_classes

# 从数据库表中获取平台 并找到对应的平台服务类
async def initialize_platforms_from_db():
    classes = get_all_subclasses_from_package('app.handler', ProviderPlatform)
    db_platforms = await Platform.all()
    for db_platform in db_platforms:
        class_name = f"{db_platform.name.capitalize()}Platform"
        cls = classes.get(class_name)
        # cls = globals().get(class_name)
        if cls:
            instance = cls(db_platform)
            register_platform(db_platform.name, instance)
            if db_platform.is_default:
                register_platform("default", instance)

main.py

# 生命周期管理器

@asynccontextmanager
async def lifespan(fast_app: FastAPI):
	# 初始化平台
    await initialize_platforms_from_db()
    # 启动时执行
    scheduler = AsyncIOScheduler()
    scheduler.add_job(func=init_daily_use, trigger='cron', hour=0, minute=0, timezone="Asia/Shanghai")
    scheduler.start()
    print("⏰ 定时任务已启动")

    yield  # 应用运行期间保持

    # 关闭时执行
    scheduler.shutdown()
    print("⏹️ 定时任务已关闭")

app = FastAPI(lifespan=lifespan)

app.include_router(plaform_api.platform_cto, prefix="/platform", tags=["platform", ])



目前未实现

  1. 未考虑余额为0 然后充值的问题
  2. 未考虑增加每日限额的逻辑