在做高精度运算时发现duckdb的decimal类型比postgresql支持的精度少很多,利用duckdb自定义标量和聚合函数功能调用gmp库,就可以实现任意精度的decimal。
标量UDF很容易实现,deepseek直接就写出来了,只是读写duckdb自定义的字符串有些问题。通过向他提示文档中的数据结构和例子源码,他自己编写转换函数解决了。我额外让他实现了一个读取指定的duckdb数据库,执行SQL脚本文件中的多个SQL语句的功能。
具体代码参见引导DeepSeek编写的调用GMP库在DuckDB中进行高精度运算的程序。
聚合UDAF麻烦一些,我先找到了一个外国人的github库,他实现了计算多种统计量的插件,将其中一个函数sum_count(汇总同时计数)的源码发给他参考,结果编出来的程序总是有一个必须平凡可移动错误(/par/duck/src/include/duckdb/function/aggregate_function.hpp:263:82: error: static assertion failed: Aggregate state must be trivially move constructible),解决不了,最后还是从那个sum_count源码出发,参考插件中注册函数步骤,先自己实现了从main函数成功调用,再试着添加一些gmp计算进去,编译不出错,才把它发给他参考,这次神奇地编出来了。和标量UDF一样,有读写字符串问题,用前面同样的转换函数解决了。
具体代码参见DeepSeek参考SumCount编写的duckdb调用gmp库高精度聚合函数
在调试过程中,因为duckdb用的头文件很多,模板用得也很多,导致g++编译源代码的速度慢得出奇,为了节省编译时间,让他顺带编写了DeepSeek写的只有49行的DuckDB SQL Shell, 这样就可以顺心所欲地测试不同的SQL语句,而不用写死在源代码中一遍又一遍地编译。
虽然有了UDF和UDAF,但目前两者是割裂的两个程序,我试着把代码复制粘贴,但改变程序逻辑的工作量有点大,就把整合工作又交给了他。提示词如下:请整合附件duckgmp.cpp的程序,要求一:命令行参数无论是1、2、3都不报错,1打开内存数据库进入交互模式,2打开指定文件进入交互模式,文件不存在则创建,3打开指定文件进入sql脚本执行,如果任一文件不存在则报错。命令行参数大于3个则报错。要求二:任何一种模式都可以调用mpz_add、mpz_mul、mpz_sum三个函数,其中mpz_add、mpz_mul有很多重复代码,请进一步整合,并添加mpz_sub、mpz_div(除法取整)、mpz_root(开n次方)函数。要求三:在交互模式增加读入SQL脚本命令 read 文件路径名。要求四:计时开关timi on/off,输出保留3位小数的秒数
, 改写的程序经过几轮提示终于实现了所有索要的功能。留了一个不算严重的BUG,在交互界面输入多行SQL时,它会删除行尾换行,结果前一行最后一个字符和下一行第一个字符合在一起导致语法错误,这个问题的规避方法是输入每行时在行首或行尾添加空格。本来让他继续解决,好像他也累了,越改问题越多,放弃。
在这个过程中,令我吃惊的有两点,第一、DeepSeek作为一个AI大模型,掌握C++这种通用语言不奇怪,但对于一个单独的数据库项目,对里面的函数也很熟悉就有点吓人了,连能用RunFunctionInTransaction()
函数代替BeginTransaction()
都知道,不知道用了多少素材训练的。第二、DuckDB的动态链接库功能如此全面,除了没有格式化输出(也许有,但我没发现),所有CLI具备的功能都具备,包括内置csv、parquet读写,套用UDAF模板自动支持转换为分析函数,函数动态注册后,在帮助信息里自动提示参数。支持COPY命令,只要很少的努力,就能实现自定义的数据库操作。
最后,附上最终代码,我还不知道怎么把它编译为DuckDB插件,而不在加载时出现The metadata at the end of the file is invalid
错误。有谁知道请告诉我。duckdb 插件机制研究文中提到为了安全,任何 Out-of-Tree Extension,必须通过 duckdb 官方签名后,才能被正确载入
, 不知真假。
#include <iostream>
#include <fstream>
#include <chrono>
#include <duckdb.hpp>
#include <gmp.h>
#include <vector>
#include <memory>
#include <iomanip>
#include "duckdb.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp"
#include "duckdb/function/scalar/nested_functions.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/common/pair.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/common/types/vector.hpp"
using duckdb::string_t;
using duckdb::LogicalType;
using duckdb::vector;
duckdb::AggregateFunction GetMPZSumFunction() ;
class DuckGMPExecutor {
public:
DuckGMPExecutor(const std::string& db_path = ":memory:", bool timing = false)
: db_path_(db_path), timing_enabled_(timing) {
db_ = std::make_unique<duckdb::DuckDB>(db_path);
conn_ = std::make_unique<duckdb::Connection>(*db_);
registerGMPFunctions();
}
void enableTiming(bool enable) { timing_enabled_ = enable; }
// 统一的GMP运算模板
template<int Operation>
static string_t gmp_operation_impl(string_t a, string_t b) {
std::string a_str, b_str;
try {
a_str = GetNumericString(a);
b_str = GetNumericString(b);
} catch (const std::exception& e) {
throw std::runtime_error(std::string("Invalid input: ") + e.what());
}
mpz_t num1, num2, result;
mpz_init(num1);
mpz_init(num2);
mpz_init(result);
if (mpz_set_str(num1, a_str.c_str(), 10) != 0 ||
mpz_set_str(num2, b_str.c_str(), 10) != 0) {
mpz_clear(num1);
mpz_clear(num2);
mpz_clear(result);
throw std::runtime_error("Failed to parse number");
}
// 根据操作类型执行不同的GMP运算
switch(Operation) {
case 0: mpz_add(result, num1, num2); break;
case 1: mpz_sub(result, num1, num2); break;
case 2: mpz_mul(result, num1, num2); break;
case 3: mpz_tdiv_q(result, num1, num2); break; // 整数除法
case 4: { // 开n次方
unsigned long n = mpz_get_ui(num2);
if (n == 0) {
mpz_clear(num1);
mpz_clear(num2);
mpz_clear(result);
throw std::runtime_error("Root degree cannot be zero");
}
mpz_root(result, num1, n);
break;
}
default:
throw std::runtime_error("Unknown operation");
}
char* res_str = mpz_get_str(nullptr, 10, result);
duckdb::string_t ret = StoreString(std::string(res_str));
mpz_clear(num1);
mpz_clear(num2);
mpz_clear(result);
free(res_str);
return ret;
}
void registerGMPFunctions() {
instance = this;
// 注册GMP运算函数
conn_->CreateScalarFunction("mpz_add", &gmp_operation_impl<0>);
conn_->CreateScalarFunction("mpz_sub", &gmp_operation_impl<1>);
conn_->CreateScalarFunction("mpz_mul", &gmp_operation_impl<2>);
conn_->CreateScalarFunction("mpz_div", &gmp_operation_impl<3>);
conn_->CreateScalarFunction("mpz_root", &gmp_operation_impl<4>);
// 注册聚合函数(需要在事务中)
conn_->context->RunFunctionInTransaction([&]() {
auto &context = *conn_->context;
auto &catalog = duckdb::Catalog::GetSystemCatalog(context);
duckdb::AggregateFunctionSet mpz_sum("mpz_sum");
mpz_sum.AddFunction(GetMPZSumFunction());
duckdb::CreateAggregateFunctionInfo info(mpz_sum);
catalog.CreateFunction(context, info);
});
}
bool executeScript(const std::string& script) {
bool all_success = true;
size_t start_pos = 0;
size_t semicolon_pos;
while ((semicolon_pos = script.find(';', start_pos)) != std::string::npos) {
std::string single_sql = script.substr(start_pos, semicolon_pos - start_pos + 1);
single_sql.erase(0, single_sql.find_first_not_of(" \n\r\t"));
single_sql.erase(single_sql.find_last_not_of(" \n\r\t") + 1);
if (!single_sql.empty()) {
// 处理 read 命令(支持带引号和不带引号的文件名)
if (single_sql.rfind("read ", 0) == 0) {
std::string file_path = single_sql.substr(5);
file_path.erase(0, file_path.find_first_not_of(" \t"));
// 移除可能的引号
if (!file_path.empty() && (file_path.front() == '\'' || file_path.front() == '"')) {
file_path.erase(0, 1);
if (!file_path.empty() && (file_path.back() == '\'' || file_path.back() == '"')) {
file_path.pop_back();
}
}
// 移除末尾分号
if (!file_path.empty() && file_path.back() == ';') {
file_path.pop_back();
}
file_path.erase(file_path.find_last_not_of(" \t") + 1);
if (file_path.empty()) {
std::cerr << "Error: No file specified for read command" << std::endl;
all_success = false;
} else if (!executeFile(file_path)) {
all_success = false;
}
start_pos = semicolon_pos + 1;
continue;
}
// 其他SQL命令处理...
}
start_pos = semicolon_pos + 1;
}
return all_success;
}
bool executeFile(const std::string& file_path) {
std::ifstream file(file_path);
if (!file.is_open()) {
std::cerr << "Error opening file: " << file_path << std::endl;
return false;
}
std::string script;
std::string line;
bool all_success = true;
int line_num = 0;
auto executeCurrentScript = [&]() {
if (!script.empty()) {
line_num++;
// 输出将要执行的SQL语句(带行号)
std::cout << "-- [Line " << line_num << "] Executing:\n"
<< script << "\n-- [End Line " << line_num << "]\n";
try {
auto start = std::chrono::high_resolution_clock::now();
auto result = conn_->Query(script);
auto end = std::chrono::high_resolution_clock::now();
if (result->HasError()) {
std::cerr << "Error: " << result->GetError() << std::endl;
all_success = false;
} else if (result->RowCount() > 0) {
result->Print();
}
if (timing_enabled_) {
std::chrono::duration<double> elapsed = end - start;
std::cout << "-- [Time: " << std::fixed << std::setprecision(3)
<< elapsed.count() << "s]\n";
}
} catch (const std::exception& e) {
std::cerr << "Error at line " << line_num << ": " << e.what() << std::endl;
all_success = false;
}
script.clear();
}
};
while (std::getline(file, line)) {
// 移除行尾换行符和多余空格
line.erase(line.find_last_not_of("\r\n") + 1);
// 跳过空行和注释行
if (line.empty() || line.find("--") == 0) {
continue;
}
script += line + " ";
// 检查是否包含完整语句
if (line.find(';') != std::string::npos) {
executeCurrentScript();
}
}
// 执行文件末尾可能存在的最后一条语句
executeCurrentScript();
return all_success;
}
void interactiveShell() {
std::cout << "DuckDB GMP Shell (enter 'exit;' to quit, 'read file.sql' to execute script)\n";
std::cout << "Type 'timi on/off' to enable/disable timing\n";
std::cout << "=============================================\n";
std::string query;
bool in_transaction = false;
while (true) {
if (!in_transaction) {
std::cout << "duckdb> ";
} else {
std::cout << " ...> ";
}
std::string line;
std::getline(std::cin, line);
// 检查是否是多行输入结束
if (line.empty() && query.empty()) {
continue;
}
query += line;
// 检查命令结束(分号)
if (query.find(';') == std::string::npos) {
in_transaction = true;
continue;
}
in_transaction = false;
// 处理特殊命令
if (query.find("timi on") == 0) {
timing_enabled_ = true;
std::cout << "Timing enabled\n";
query.clear();
continue;
} else if (query.find("timi off") == 0) {
timing_enabled_ = false;
std::cout << "Timing disabled\n";
query.clear();
continue;
} else if (query.find("exit;") == 0) {
break;
} else if (query.find("read ") == 0) {
// 提取文件名
size_t start = query.find_first_of(" \t") + 1;
size_t end = query.find_last_of(';');
std::string file_path = query.substr(start, end - start);
// 清理文件名
file_path.erase(0, file_path.find_first_not_of(" \t\"\'"));
file_path.erase(file_path.find_last_not_of(" \t\"\';") + 1);
if (file_path.empty()) {
std::cerr << "Error: No file specified" << std::endl;
} else {
std::cout << "-- Reading file: " << file_path << "\n";
executeFile(file_path);
std::cout << "-- File execution completed\n";
}
query.clear();
continue;
}
// 执行普通SQL
try {
auto start = std::chrono::high_resolution_clock::now();
auto result = conn_->Query(query);
auto end = std::chrono::high_resolution_clock::now();
if (result->HasError()) {
std::cerr << "Error: " << result->GetError() << std::endl;
} else if (result->RowCount() > 0) {
result->Print();
}
if (timing_enabled_) {
std::chrono::duration<double> elapsed = end - start;
std::cout << "[Time: " << std::fixed << std::setprecision(3)
<< elapsed.count() << "s]" << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
query.clear();
}
}
// 字符串处理函数
static std::string GetNumericString(const duckdb::string_t& input) {
const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
const char* data;
uint32_t length;
if (raw->value.inlined.length <= 12) {
data = raw->value.inlined.inlined;
length = raw->value.inlined.length;
} else {
data = raw->value.pointer.ptr;
length = raw->value.pointer.length;
}
for (uint32_t i = 0; i < length; i++) {
if (data[i] < '0' || data[i] > '9') {
throw std::runtime_error("Invalid character in number");
}
}
return std::string(data, length);
}
static duckdb::string_t StoreString(const std::string& input) {
duckdb::string_t result;
auto* raw = reinterpret_cast<duckdb_string_t*>(&result);
if (input.size() <= 12) {
raw->value.inlined.length = input.size();
memcpy(raw->value.inlined.inlined, input.data(), input.size());
} else {
raw->value.pointer.length = input.size();
memcpy(raw->value.pointer.prefix, input.data(), 4);
raw->value.pointer.ptr = (char*)malloc(input.size());
memcpy(raw->value.pointer.ptr, input.data(), input.size());
}
return result;
}
private:
std::string db_path_;
bool timing_enabled_;
std::unique_ptr<duckdb::DuckDB> db_;
std::unique_ptr<duckdb::Connection> conn_;
static DuckGMPExecutor* instance;
};
DuckGMPExecutor* DuckGMPExecutor::instance = nullptr;
// MPZSum聚合函数实现
struct MPZSumState { mpz_t sum; };
struct MPZSumFunction {
template <class STATE> static void Initialize(STATE &state) { mpz_init(state.sum); }
template <class STATE> static void Destroy(STATE &state, duckdb::AggregateInputData &) { mpz_clear(state.sum); }
static bool IgnoreNull() { return true; }
};
static void MPZSumUpdate(duckdb::Vector inputs[], duckdb::AggregateInputData &, idx_t, duckdb::Vector &state_vector, idx_t count) {
auto &input = inputs[0];
duckdb::UnifiedVectorFormat sdata, input_data;
state_vector.ToUnifiedFormat(count, sdata);
input.ToUnifiedFormat(count, input_data);
auto states = (MPZSumState **)sdata.data;
for (idx_t i = 0; i < count; i++) {
if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) {
auto &state = *states[sdata.sel->get_index(i)];
auto str_value = duckdb::UnifiedVectorFormat::GetData<duckdb::string_t>(input_data);
std::string num_str = DuckGMPExecutor::GetNumericString(str_value[input_data.sel->get_index(i)]);
mpz_t tmp;
mpz_init(tmp);
if (mpz_set_str(tmp, num_str.c_str(), 10) != 0) {
mpz_clear(tmp);
throw std::runtime_error("Failed to convert string to GMP number");
}
mpz_add(state.sum, state.sum, tmp);
mpz_clear(tmp);
}
}
}
static void MPZSumFinalize(duckdb::Vector &state_vector, duckdb::AggregateInputData &, duckdb::Vector &result, idx_t count, idx_t offset) {
duckdb::UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto states = (MPZSumState **)sdata.data;
for (idx_t i = 0; i < count; i++) {
const auto rid = i + offset;
auto &state = *states[sdata.sel->get_index(i)];
char *str = mpz_get_str(nullptr, 10, state.sum);
if (!str) throw std::runtime_error("Failed to convert GMP number to string");
duckdb::string_t result_str = DuckGMPExecutor::StoreString(std::string(str));
duckdb::FlatVector::GetData<duckdb::string_t>(result)[rid] =
duckdb::StringVector::AddString(result, result_str);
free(str);
}
}
static void MPZSumCombine(duckdb::Vector &state, duckdb::Vector &combined, duckdb::AggregateInputData &, idx_t count) {
duckdb::UnifiedVectorFormat sdata;
state.ToUnifiedFormat(count, sdata);
auto states_ptr = (MPZSumState **)sdata.data;
auto combined_ptr = duckdb::FlatVector::GetData<MPZSumState *>(combined);
for (idx_t i = 0; i < count; i++) {
mpz_add(combined_ptr[i]->sum, combined_ptr[i]->sum, states_ptr[sdata.sel->get_index(i)]->sum);
}
}
duckdb::unique_ptr<duckdb::FunctionData> MPZSumBind(duckdb::ClientContext &context, duckdb::AggregateFunction &function, duckdb::vector<duckdb::unique_ptr<duckdb::Expression>> &arguments) {
function.return_type = duckdb::LogicalType::VARCHAR; // 返回字符串类型
return nullptr;
}
duckdb::AggregateFunction GetMPZSumFunction() {
using STATE_TYPE = MPZSumState;
return duckdb::AggregateFunction(
"mpz_sum", // 函数名
{duckdb::LogicalType::VARCHAR}, // 参数类型为字符串
duckdb::LogicalType::VARCHAR, // 返回类型为字符串
duckdb::AggregateFunction::StateSize<STATE_TYPE>, // 状态大小
duckdb::AggregateFunction::StateInitialize<STATE_TYPE, MPZSumFunction>, // 初始化
MPZSumUpdate, // 更新
MPZSumCombine, // 合并
MPZSumFinalize, // 最终化
nullptr, // 简单更新
MPZSumBind, // 绑定
duckdb::AggregateFunction::StateDestroy<STATE_TYPE, MPZSumFunction> // 销毁
);
}
int main(int argc, char** argv) {
if (argc > 3) {
std::cerr << "Usage: " << argv[0] << " [database_file] [sql_file]\n"
<< " No args: memory database + interactive mode\n"
<< " 1 arg: specified database + interactive mode\n"
<< " 2 args: specified database + execute SQL file\n";
return 1;
}
try {
bool timing = false;
DuckGMPExecutor executor(argc >= 2 ? argv[1] : ":memory:", timing);
if (argc == 1) {
executor.interactiveShell();
} else if (argc == 2) {
executor.interactiveShell();
} else {
if (!executor.executeFile(argv[2])) {
return 1;
}
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}