[手写系列]Go手写db — — 第三版
第一版文章地址:https://blog.csdn.net/weixin_45565886/article/details/147839627
第二版文章地址:https://blog.csdn.net/weixin_45565886/article/details/150869791
🏠整体项目Github地址:https://github.com/ziyifast/ZiyiDB
- 🚀请大家多多支持,也欢迎大家star⭐️和共同维护这个项目~
序言:只要接触过后端开发,必不可少会使用到关系型数据库,比如:MySQL、Oracle等,那么我们经常使用的字段默认值、以及聚合函数底层是如何实现的呢?本文会给大家提供一些思路,实现相关功能。
主要介绍如何在 ZiyiDB之前的基础上,实现更多新功能,给大家提供实现数据库的简单思路,以及数据库底层实现的流程,后续更多功能,大家可以参考着实现。
一、功能列表
- 默认值支持(DEFAULT 关键字)
- 聚合函数支持(COUNT, SUM, AVG, MAX, MIN)
- Group by分组能力
- Order by 排序能力
二、实现细节
1. 默认值实现
设计思路
默认值是数据库中一个重要的数据完整性特性。当插入数据时,如果没有为某列提供值,数据库会自动使用该列的默认值。
在 ZiyiDB 中,默认值的实现需要考虑以下几点:
- 语法解析:在 CREATE TABLE 语句中识别 DEFAULT 关键字和默认值
- 存储:在表结构中保存每列的默认值
- 执行:在 INSERT 语句中应用默认值
1.在lexer/token.go中新增default字符,然后在lexer/lexer.go的lookupIdentifier方法中新增对于default的case语句,用于匹配识别用户输入的SQL
token.go:
lexer.go:
2. internal/ast/ast.go抽象语法树中新增DefaultExpression,同时列定义中新增默认值字段,用于存储列的默认值
3. parser中的parseCreateTableStatement函数新增对create SQL中默认值的读取和封装,解析用户输入SQL中的字段默认值类型和value
4. internal/storage/memory.go 存储引擎处理Insert方法时,新增对默认值的处理。
代码实现
1.语法解析层(Parser)
在 internal/parser/parser.go 中,parseCreateTableStatement 方法被增强以支持默认值:
// parseCreateTableStatement 解析CREATE TABLE语句
func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) {
stmt := &ast.CreateTableStatement{Token: p.curToken}
// ... 其他代码
// 解析列定义
for !p.peekTokenIs(lexer.RPAREN) {
p.nextToken()
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
col := ast.ColumnDefinition{
Name: p.curToken.Literal,
}
if !p.expectPeek(lexer.INT) &&
!p.expectPeek(lexer.TEXT) &&
!p.expectPeek(lexer.FLOAT) &&
!p.expectPeek(lexer.DATETIME) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
col.Type = string(p.curToken.Type)
if p.peekTokenIs(lexer.PRIMARY) {
p.nextToken()
if !p.expectPeek(lexer.KEY) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
col.Primary = true
}
if p.peekTokenIs(lexer.DEFAULT) {
p.nextToken() // 消费 DEFAULT 关键字
p.nextToken() // 移动到默认值表达式开始位置
// 解析复杂默认值表达式(支持函数调用、数学表达式等)
defaultValue, err := p.parseExpression()
if err != nil {
return nil, fmt.Errorf("Invalid default value for column '%s': %v", col.Name, err)
}
// 创建 DefaultExpression 节点
col.Default = &ast.DefaultExpression{
Token: p.curToken,
Value: defaultValue,
}
}
stmt.Columns = append(stmt.Columns, col)
if p.peekTokenIs(lexer.COMMA) {
p.nextToken()
}
}
// ... 其他代码
}
2.AST 定义
在 internal/ast/ast.go 中,我们添加了 DefaultExpression 类型来表示默认值:
// DefaultExpression 表示DEFAULT表达式
type DefaultExpression struct {
Token lexer.Token
Value Expression
}
func (de *DefaultExpression) expressionNode() {}
func (de *DefaultExpression) TokenLiteral() string { return de.Token.Literal }
同时,ColumnDefinition 结构也被更新以包含默认值:
// ColumnDefinition 表示列定义
type ColumnDefinition struct {
Name string
Type string
Primary bool
Nullable bool
Default interface{} //列默认值
}
3.存储引擎实现
在 internal/storage/memory.go 中,Insert 方法被增强以支持默认值:
// Insert 插入数据
func (b *MemoryBackend) Insert(stmt *ast.InsertStatement) error {
table, exists := b.tables[stmt.TableName]
if !exists {
return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
}
// 构建列名到表列索引的映射
colIndexMap := make(map[string]int)
for idx, col := range table.Columns {
colIndexMap[col.Name] = idx
}
// 初始化行数据(长度为表的总列数)
row := make([]ast.Cell, len(table.Columns))
// 处理插入列列表(用户显式指定的列或隐式全列)
var insertCols []*ast.Identifier
//用户SQL需要插入的列名、值的映射
userColMap := make(map[string]ast.Expression)
if len(stmt.Columns) > 0 {
insertCols = stmt.Columns
for i, col := range stmt.Columns {
userColMap[col.Token.Literal] = stmt.Values[i]
}
} else {
// 未指定列时默认使用表的所有列
insertCols = make([]*ast.Identifier, len(table.Columns))
for i, col := range table.Columns {
insertCols[i] = &ast.Identifier{Value: col.Name}
userColMap[col.Name] = stmt.Values[i]
}
}
// 检查值数量与指定列数量是否匹配
if len(stmt.Values) != len(insertCols) {
return fmt.Errorf("Column count doesn't match value count at row 1 (got %d, want %d)", len(stmt.Values), len(insertCols))
}
// 转换值
// 填充行数据(处理用户值或默认值)
for i, tableCol := range table.Columns {
// 优先使用用户提供的值,否则使用默认值
var expr ast.Expression
expr = userColMap[tableCol.Name]
if expr == nil && tableCol.Default != nil {
expr = tableCol.Default.(*ast.DefaultExpression).Value
}
//获取当前列名
colName := table.Columns[i].Name
tableColIdx, ok := colIndexMap[colName]
if !ok {
return fmt.Errorf("Unknown column '%s' in INSERT statement", colName)
}
// 转换值类型
value, err := evaluateExpression(expr)
if err != nil {
return fmt.Errorf("invalid value for column '%s': %v", colName, err)
}
// 类型转换
switch v := value.(type) {
case string:
if tableCol.Type == "INT" {
intVal, err := strconv.ParseInt(v, 10, 32)
if err != nil {
return fmt.Errorf("Incorrect integer value: '%s' for column '%s'", v, tableCol.Name)
}
row[tableColIdx] = ast.Cell{Type: ast.CellTypeInt, IntValue: int32(intVal)}
} else {
row[tableColIdx] = ast.Cell{Type: ast.CellTypeText, TextValue: v}
}
case int32:
row[tableColIdx] = ast.Cell{Type: ast.CellTypeInt, IntValue: v}
case float32:
row[tableColIdx] = ast.Cell{Type: ast.CellTypeFloat, FloatValue: v}
case time.Time:
row[tableColIdx] = ast.Cell{Type: ast.CellTypeDateTime, TimeValue: v.Format("2006-01-02 15:04:05")}
default:
return fmt.Errorf("Unsupported value type: %T for column '%s'", value, tableCol.Name)
}
}
// ... 其他代码
}
测试
测试SQL:
-- 创建带默认值的表
CREATE TABLE users (
id INT PRIMARY KEY,
name TEXT,
age INT DEFAULT 18,
score FLOAT,
ctime DATETIME DEFAULT '2023-07-04 12:00:00'
);
-- 插入部分列数据(未指定的列将使用默认值)
INSERT INTO users (id, name, score) VALUES (1, 'Alice', 90.0);
INSERT INTO users (id, name, age, score) VALUES (2, 'Bob', 25, 85.5);
-- 查询数据验证默认值
SELECT * FROM users;
效果:
2. 聚合函数实现
设计思路
聚合函数是 SQL 中用于对一组值执行计算并返回单个值的函数。在 ZiyiDB 中,我们实现了以下聚合函数:
- COUNT:计算行数
- SUM:计算数值列的总和
- AVG:计算数值列的平均值
- MAX:找出列中的最大值
- MIN:找出列中的最小值
聚合函数的实现需要考虑以下几点:
语法解析:在 SELECT 语句中识别函数调用
执行逻辑:在存储引擎中计算聚合结果
结果返回:以统一的格式返回结果
这里以count聚合函数为例,其他聚合函数同理
- internal/ast/ast.go中新增FunctionCall函数调用类型,用于后续执行函数调用,比如count、max等聚合函数
- internal/parser/parser.go中新增对函数类型的解析和封装
- internal/storage/memory.go存储引擎Select方法中新增对聚合函数的判断
同时memory.go中添加calculateFunctionResults方法,实现对函数的执行和底层实现
代码实现
- 语法解析层(Parser)
在 internal/parser/parser.go 中,我们增强了 parseSelectStatement 方法来支持函数调用:
// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
stmt := &ast.SelectStatement{Token: p.curToken}
// 解析选择列表
for !p.peekTokenIs(lexer.FROM) {
p.nextToken()
if p.curToken.Type == lexer.ASTERISK {
stmt.Fields = append(stmt.Fields, &ast.StarExpression{})
break
}
expr, err := p.parseExpression()
if err != nil {
return nil, err
}
stmt.Fields = append(stmt.Fields, expr)
if p.peekTokenIs(lexer.COMMA) {
p.nextToken()
}
}
// ... 其他代码
}
parseExpression 方法也进行了增强,以支持函数调用的解析:
// parseExpression 解析表达式
func (p *Parser) parseExpression() (ast.Expression, error) {
switch p.curToken.Type {
// ... 其他情况
case lexer.IDENT:
if p.peekTokenIs(lexer.LPAREN) {
return p.parseFunctionCall()
}
return &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}, nil
// ...
}
}
// parseFunctionCall 解析函数调用
func (p *Parser) parseFunctionCall() (ast.Expression, error) {
fn := &ast.FunctionCall{
Token: p.curToken,
Name: p.curToken.Literal,
Params: []ast.Expression{},
}
// 检查下一个token是否为左括号
if !p.expectPeek(lexer.LPAREN) {
return nil, fmt.Errorf("expected ( after function name")
}
// 如果是右括号,说明没有参数
if p.peekTokenIs(lexer.RPAREN) {
p.nextToken()
return fn, nil
}
// 解析参数列表
for !p.peekTokenIs(lexer.RPAREN) {
p.nextToken()
param, err := p.parseExpression()
if err != nil {
return nil, err
}
fn.Params = append(fn.Params, param)
if p.peekTokenIs(lexer.COMMA) {
p.nextToken()
} else if !p.peekTokenIs(lexer.RPAREN) {
return nil, fmt.Errorf("expected comma or closing parenthesis in function call")
}
}
if !p.expectPeek(lexer.RPAREN) {
return nil, fmt.Errorf("Missing closing parenthesis for function call")
}
return fn, nil
}
- AST 定义
在 internal/ast/ast.go 中,我们添加了 FunctionCall 类型来表示函数调用:
// FunctionCall 表示函数调用
type FunctionCall struct {
Token lexer.Token
Name string
Params []Expression
}
func (fc *FunctionCall) expressionNode() {}
func (fc *FunctionCall) TokenLiteral() string { return fc.Token.Literal }
- 存储引擎实现
在 internal/storage/memory.go 中,Select 方法被增强以支持聚合函数:
// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*ast.Results, error) {
table, exists := b.tables[stmt.TableName]
if !exists {
return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
}
results := &ast.Results{
Columns: make([]ast.ResultColumn, 0),
Rows: make([][]ast.Cell, 0),
}
// 检查是否为聚合函数查询
isAggregation := false
var aggregateFunc *ast.FunctionCall
// 处理select列表
if len(stmt.Fields) == 1 {
// 检查是否为 SELECT *
if _, ok := stmt.Fields[0].(*ast.StarExpression); ok {
// SELECT *
for _, col := range table.Columns {
results.Columns = append(results.Columns, ast.ResultColumn{
Name: col.Name,
Type: col.Type,
})
}
} else if fn, ok := stmt.Fields[0].(*ast.FunctionCall); ok {
// 处理函数调用
isAggregation = true
aggregateFunc = fn
results.Columns = append(results.Columns, ast.ResultColumn{
Name: fn.Name,
Type: "FUNCTION",
})
}
// ... 其他情况
}
// ... 其他情况
// 如果是聚合函数查询,直接计算结果
if isAggregation {
// 处理WHERE子句
filteredRows := make([][]ast.Cell, 0)
for _, row := range table.Rows {
if stmt.Where != nil {
match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)
if err != nil {
return nil, err
}
if !match {
continue
}
}
filteredRows = append(filteredRows, row)
}
functionResult := calculateFunctionResults(aggregateFunc, table, filteredRows)
results.Rows = [][]ast.Cell{functionResult}
return results, nil
}
// ... 非聚合函数的处理
}
每个聚合函数都有对应的计算方法:
// calculateFunctionResults 计算函数结果
func calculateFunctionResults(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {
// 根据函数类型计算结果
switch strings.ToUpper(fn.Name) {
case "COUNT":
return calculateCount(fn, table, rows)
case "SUM":
return calculateSum(fn, table, rows)
case "AVG":
return calculateAvg(fn, table, rows)
case "MAX":
return calculateMax(fn, table, rows)
case "MIN":
return calculateMin(fn, table, rows)
default:
return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: Unknown function '%s'", fn.Name)}}
}
}
// calculateCount 计算COUNT函数结果
func calculateCount(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {
return []ast.Cell{{Type: ast.CellTypeInt, IntValue: int32(len(rows))}}
}
// calculateSum 计算SUM函数结果
func calculateSum(fn *ast.FunctionCall, table *Table, rows [][]ast.Cell) []ast.Cell {
// 处理 SUM(column) 情况
if len(fn.Params) != 1 {
return []ast.Cell{{Type: ast.CellTypeText, TextValue: "ERROR: SUM function requires exactly one parameter"}}
}
var columnName string
// 检查参数类型
switch param := fn.Params[0].(type) {
case *ast.Identifier:
columnName = param.Value
default:
return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: SUM function requires a column name, got %T", param)}}
}
// 查找列索引
colIndex := -1
for i, col := range table.Columns {
if col.Name == columnName {
colIndex = i
break
}
}
if colIndex == -1 {
return []ast.Cell{{Type: ast.CellTypeText, TextValue: fmt.Sprintf("ERROR: Unknown column '%s'", columnName)}}
}
// 计算SUM值
var sumInt int32 = 0
var sumFloat float32 = 0.0
hasFloat := false
for _, row := range rows {
cell := row[colIndex]
switch cell.Type {
case ast.CellTypeInt:
sumInt += cell.IntValue
case ast.CellTypeFloat:
// 如果之前有整数,需要转换为浮点数
if !hasFloat {
sumFloat = float32(sumInt)
hasFloat = true
}
sumFloat += cell.FloatValue
}
}
// 返回结果
if hasFloat {
return []ast.Cell{{Type: ast.CellTypeFloat, FloatValue: sumFloat}}
}
return []ast.Cell{{Type: ast.CellTypeInt, IntValue: sumInt}}
}
// ... 其他聚合函数的实现
测试
测试SQL:
-- 创建测试表
CREATE TABLE users (id INT PRIMARY KEY, name TEXT, age INT);
-- 插入测试数据
INSERT INTO users VALUES (1, 'Alice', 20);
INSERT INTO users VALUES (2, 'Bob', 25);
INSERT INTO users VALUES (3, 'Charlie', 30);
-- 使用聚合函数
SELECT COUNT(*) FROM users;
SELECT SUM(age) FROM users;
SELECT AVG(age) FROM users;
SELECT MAX(age) FROM users;
SELECT MIN(age) FROM users;
-- 带WHERE条件的聚合函数
SELECT COUNT(*) FROM users WHERE age > 25;
SELECT SUM(age) FROM users WHERE age >= 25;
效果:
3. group by 实现
设计思路
1.语法解析:
首先在internal/lexer/token.go中新增group by关键字
然后在internal/lexer/lexer.go词法分析器的lookupIdentifier方法中新增对group by关键字的识别
接下来在internal/parser/parser.go词法分析器中的parseSelectStatement方法中添加 GROUP 和 BY 关键字的解析,将其解析并封装为ast的一部分
在 internal/ast/ast.go 中添加 GroupBy 字段到 SelectStatement 结构体
2. 执行引擎:
首先在internal/storage/memory.go存储引擎中的Select方法实现对分组逻辑的调用
接着selectWithGroupBy方法,实现底层分组原理,按指定列对数据进行分组
3. internal/storage/memory.go中的selectWithGroupBy对聚合函数进行处理,确保查询结果列是聚合函数列或者分组列
代码实现
- 在词法分析器中添加新的关键字
// internal/lexer/token.go
const (
// ... 其他关键字
GROUP TokenType = "GROUP"
BY TokenType = "BY"
)
// internal/lexer/lexer.go
func (l *Lexer) lookupIdentifier(ident string) TokenType {
switch strings.ToUpper(ident) {
// ... 其他关键字
case "GROUP":
return GROUP
case "BY":
return BY
default:
return IDENT
}
}
- 在 AST 中添加新的结构体以支持 GROUP BY
// internal/ast/ast.go
// SelectStatement 表示SELECT语句
type SelectStatement struct {
Token lexer.Token
Fields []Expression
TableName string
Where Expression
GroupBy []Expression // 添加 GroupBy 字段
}
- 在语法分析器中添加对 GROUP BY 子句的解析
// internal/parser/parser.go
// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
stmt := &ast.SelectStatement{Token: p.curToken}
// ... 解析选择列表和 FROM 子句 ...
// 解析WHERE子句
if p.peekTokenIs(lexer.WHERE) {
p.nextToken()
whereExpr, err := p.parseWhereClause()
if err != nil {
return nil, err
}
stmt.Where = whereExpr
}
// 解析GROUP BY子句
if p.peekTokenIs(lexer.GROUP) {
p.nextToken() // 跳过 GROUP
if !p.expectPeek(lexer.BY) {
return nil, fmt.Errorf("expected BY after GROUP")
}
// 解析GROUP BY字段列表
for {
p.nextToken()
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("expected identifier in GROUP BY clause")
}
expr := &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}
stmt.GroupBy = append(stmt.GroupBy, expr)
if !p.peekTokenIs(lexer.COMMA) {
break
}
p.nextToken() // 跳过逗号
}
}
return stmt, nil
}
- 在存储引擎中实现 GROUP BY 的执行逻辑
// internal/storage/memory.go
// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*Results, error) {
table, exists := b.tables[stmt.TableName]
if !exists {
return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
}
// 如果有 GROUP BY 子句
if len(stmt.GroupBy) > 0 {
return b.selectWithGroupBy(stmt, table)
}
// ... 原有的查询逻辑 ...
}
// selectWithGroupBy 处理带有 GROUP BY 的查询
func (b *MemoryBackend) selectWithGroupBy(stmt *ast.SelectStatement, table *Table) (*Results, error) {
results := &Results{
Columns: make([]ResultColumn, 0),
Rows: make([][]Cell, 0),
}
// 验证 GROUP BY 字段存在于表中
groupByIndices := make([]int, len(stmt.GroupBy))
for i, expr := range stmt.GroupBy {
if identifier, ok := expr.(*ast.Identifier); ok {
found := false
for j, col := range table.Columns {
if col.Name == identifier.Value {
groupByIndices[i] = j
found = true
break
}
}
if !found {
return nil, fmt.Errorf("Unknown column '%s' in 'group statement'", identifier.Value)
}
} else {
return nil, fmt.Errorf("GROUP BY only supports column names")
}
}
// 构建结果列
for _, expr := range stmt.Fields {
switch e := expr.(type) {
case *ast.Identifier:
found := false
for _, col := range table.Columns {
if col.Name == e.Value {
results.Columns = append(results.Columns, ResultColumn{
Name: col.Name,
Type: col.Type,
})
found = true
break
}
}
if !found {
return nil, fmt.Errorf("Unknown column '%s' in 'field list'", e.Value)
}
case *ast.FunctionCall:
results.Columns = append(results.Columns, ResultColumn{
Name: e.Name,
Type: "FUNCTION",
})
case *ast.StarExpression:
for _, col := range table.Columns {
results.Columns = append(results.Columns, ResultColumn{
Name: col.Name,
Type: col.Type,
})
}
default:
return nil, fmt.Errorf("Unsupported select expression type")
}
}
// 处理WHERE子句
filteredRows := make([][]Cell, 0)
for _, row := range table.Rows {
if stmt.Where != nil {
match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)
if err != nil {
return nil, err
}
if !match {
continue
}
}
filteredRows = append(filteredRows, row)
}
// 按 GROUP BY 字段分组
groups := make(map[string][][]Cell)
for _, row := range filteredRows {
// 构建分组键
groupKey := ""
for _, idx := range groupByIndices {
groupKey += row[idx].String() + "|"
}
// 将行添加到对应的组中
groups[groupKey] = append(groups[groupKey], row)
}
// 为每个组计算结果
for _, groupRows := range groups {
if len(groupRows) == 0 {
continue
}
resultRow := make([]Cell, len(results.Columns))
colIndex := 0
// 处理非聚合字段(GROUP BY 字段)
for _, expr := range stmt.Fields {
if identifier, ok := expr.(*ast.Identifier); ok {
// 检查是否为 GROUP BY 字段
isGroupByField := false
for _, groupByExpr := range stmt.GroupBy {
if groupByIdent, ok := groupByExpr.(*ast.Identifier); ok {
if groupByIdent.Value == identifier.Value {
isGroupByField = true
break
}
}
}
if isGroupByField {
// 对于 GROUP BY 字段,取第一个值(所有行应该相同)
for k, tableCol := range table.Columns {
if tableCol.Name == identifier.Value {
resultRow[colIndex] = groupRows[0][k]
break
}
}
}
colIndex++
}
}
// 处理聚合函数
for i, expr := range stmt.Fields {
if fn, ok := expr.(*ast.FunctionCall); ok {
functionResult := calculateFunctionResults(fn, table, groupRows)
resultRow[i] = functionResult[0]
}
}
results.Rows = append(results.Rows, resultRow)
}
return results, nil
}
测试
测试SQL:
CREATE TABLE sales (id INT PRIMARY KEY, product TEXT, category TEXT, amount FLOAT);
INSERT INTO sales VALUES (1, 'Apple', 'Fruit', 10.5);
INSERT INTO sales VALUES (2, 'Banana', 'Fruit', 8.0);
INSERT INTO sales VALUES (3, 'Carrot', 'Vegetable', 5.2);
INSERT INTO sales VALUES (4, 'Broccoli', 'Vegetable', 7.3);
INSERT INTO sales VALUES (5, 'Orange', 'Fruit', 9.8);
SELECT category, COUNT(*) FROM sales GROUP BY category;
SELECT category, SUM(amount) FROM sales GROUP BY category;
SELECT category, AVG(amount) FROM sales GROUP BY category;
效果:
4. order by 实现
设计思路
与group by实现基本一致
1.语法解析:
在词法分析器中添加 ORDER、BY、ASC 和 DESC 关键字
- internal/lexer/token.go:
- internal/lexer/lexer.go的lookupIdentifier方法:
在语法分析器中解析 ORDER BY 子句:
在 internal/ast/ast.go中添加 OrderBy 字段到 SelectStatement 结构体
2.执行引擎:
在internal/storage/memory.go存储引擎的Select方法中实现对order by的解析调用:
同时实现排序逻辑,使用 Go 标准库的 sort.Slice 进行排序同时实现自定义比较函数以支持不同数据类型的比较:
代码实现
- 在词法分析器中添加新的关键字
// internal/lexer/token.go
const (
// ... 其他关键字
ORDER TokenType = "ORDER"
ASC TokenType = "ASC"
DESC TokenType = "DESC"
)
// internal/lexer/lexer.go
func (l *Lexer) lookupIdentifier(ident string) TokenType {
switch strings.ToUpper(ident) {
// ... 其他关键字
case "ORDER":
return ORDER
case "ASC":
return ASC
case "DESC":
return DESC
default:
return IDENT
}
}
- 在 AST 中添加新的结构体以支持 ORDER BY
// internal/ast/ast.go
// SelectStatement 表示SELECT语句
type SelectStatement struct {
Token lexer.Token
Fields []Expression
TableName string
Where Expression
OrderBy []OrderByClause // 添加 OrderBy 字段
}
// OrderByClause 表示 ORDER BY 子句中的排序项
type OrderByClause struct {
Expression Expression
Direction string // "ASC" 或 "DESC"
}
- 在语法分析器中添加对 ORDER BY 子句的解析
// internal/parser/parser.go
// parseSelectStatement 解析SELECT语句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
stmt := &ast.SelectStatement{Token: p.curToken}
// ... 解析选择列表、FROM 子句和 WHERE 子句 ...
// 解析GROUP BY子句(如果有的话)
if p.peekTokenIs(lexer.GROUP) {
// ... GROUP BY 解析逻辑 ...
}
// 解析ORDER BY子句
if p.peekTokenIs(lexer.ORDER) {
orderExprs, err := p.parseOrderByClause()
if err != nil {
return nil, err
}
stmt.OrderBy = orderExprs
}
return stmt, nil
}
// parseOrderByClause 解析ORDER BY子句
func (p *Parser) parseOrderByClause() ([]ast.OrderByClause, error) {
// 跳过 ORDER 关键字
if !p.expectPeek(lexer.ORDER) {
return nil, fmt.Errorf("expected ORDER keyword")
}
// 跳过 BY 关键字
if !p.expectPeek(lexer.BY) {
return nil, fmt.Errorf("expected BY keyword")
}
var orderExprs []ast.OrderByClause
for {
p.nextToken()
// 解析表达式(列名)
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("expected identifier in ORDER BY clause")
}
expr := &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}
orderClause := ast.OrderByClause{
Expression: expr,
Direction: "ASC", // 默认升序
}
// 检查是否有 ASC 或 DESC
if p.peekTokenIs(lexer.ASC) || p.peekTokenIs(lexer.DESC) {
p.nextToken()
orderClause.Direction = p.curToken.Literal
}
orderExprs = append(orderExprs, orderClause)
// 如果没有逗号,说明结束了
if !p.peekTokenIs(lexer.COMMA) {
break
}
p.nextToken() // 跳过逗号
}
return orderExprs, nil
}
- 在存储引擎中实现 ORDER BY 的执行逻辑
// internal/storage/memory.go
// Select 查询数据
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*Results, error) {
// ... 原有的查询逻辑 ...
// 处理 ORDER BY
if len(stmt.OrderBy) > 0 {
var err error
results.Rows, err = b.orderBy(results.Rows, results.Columns, stmt.OrderBy, table.Columns)
if err != nil {
return nil, err
}
}
return results, nil
}
// orderBy 根据 ORDER BY 子句对结果进行排序
func (b *MemoryBackend) orderBy(rows [][]Cell, resultCols []ResultColumn, orderBy []ast.OrderByClause, tableCols []ast.ColumnDefinition) ([][]Cell, error) {
// 创建列名到索引的映射
colIndexMap := make(map[string]int)
for i, col := range resultCols {
colIndexMap[col.Name] = i
}
// 创建排序键的索引和方向
type sortKey struct {
index int
direction string
}
var sortKeys []sortKey
for _, ob := range orderBy {
identifier, ok := ob.Expression.(*ast.Identifier)
if !ok {
return nil, fmt.Errorf("ORDER BY only supports column names")
}
index, exists := colIndexMap[identifier.Value]
if !exists {
return nil, fmt.Errorf("Unknown column '%s' in 'order clause'", identifier.Value)
}
sortKeys = append(sortKeys, sortKey{
index: index,
direction: ob.Direction,
})
}
// 使用 sort.Slice 进行排序
sort.Slice(rows, func(i, j int) bool {
for _, key := range sortKeys {
left := rows[i][key.index]
right := rows[j][key.index]
// 比较两个值
result, err := compareValues(left, right, "<")
if err != nil {
// 如果比较出错,保持原有顺序
return false
}
if result {
// 如果是升序,返回 true
// 如果是降序,返回 false
return key.direction == "ASC"
} else {
// 检查是否相等
equal, _ := compareValues(left, right, "=")
if !equal {
// 如果是降序,返回 true
// 如果是升序,返回 false
return key.direction == "DESC"
}
// 如果相等,继续比较下一个排序键
}
}
// 所有键都相等,保持原有顺序
return false
})
return rows, nil
}
测试
测试SQL:
CREATE TABLE sales (id INT PRIMARY KEY, product TEXT, category TEXT, amount FLOAT);
INSERT INTO sales VALUES (1, 'Apple', 'Fruit', 10.5);
INSERT INTO sales VALUES (2, 'Banana', 'Fruit', 8.0);
INSERT INTO sales VALUES (3, 'Carrot', 'Vegetable', 5.2);
INSERT INTO sales VALUES (4, 'Broccoli', 'Vegetable', 7.3);
INSERT INTO sales VALUES (5, 'Orange', 'Fruit', 9.8);
SELECT * FROM sales ORDER BY amount;
SELECT * FROM sales ORDER BY amount DESC;
SELECT * FROM sales ORDER BY category, amount DESC;
效果: