基于跳跃表的zset实现解析(lua版)

发布于:2025-09-14 ⋅ 阅读:(18) ⋅ 点赞:(0)

zset是指有序集合,这是redis中的一种数据结构,可以存储成员和对应的分数,并且按照分数排序,原理是基于跳跃表实现

核心源码

skiplist.h

#include <stdlib.h>

#define SKIPLIST_MAXLEVEL 32
#define SKIPLIST_P 0.25

typedef struct slobj {
    char *ptr;
    size_t length;
} slobj;

typedef struct skiplistNode {
    slobj* obj;
    double score;
    struct skiplistNode *backward;
    struct skiplistLevel {
        struct skiplistNode *forward;
        unsigned int span;
    }level[];
} skiplistNode;

typedef struct skiplist {
    struct skiplistNode *header, *tail;
    unsigned long length;
    int level;
} skiplist;

typedef void (*slDeleteCb) (void *ud, slobj *obj);
slobj* slCreateObj(const char* ptr, size_t length);
void slFreeObj(slobj *obj);

skiplist *slCreate(void);
void slFree(skiplist *sl);
void slDump(skiplist *sl);

void slInsert(skiplist *sl, double score, slobj *obj);
int slDelete(skiplist *sl, double score, slobj *obj);
unsigned long slDeleteByRank(skiplist *sl, unsigned int start, unsigned int end, slDeleteCb cb, void* ud);

unsigned long slGetRank(skiplist *sl, double score, slobj *o);
skiplistNode* slGetNodeByRank(skiplist *sl, unsigned long rank);

skiplistNode *slFirstInRange(skiplist *sl, double min, double max);
skiplistNode *slLastInRange(skiplist *sl, double min, double max);

skiplist.c

/*
 *  author: xjdrew
 *  date: 2014-06-03 20:38
 */

// skiplist similar with the version in redis
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#include "skiplist.h"

skiplistNode *slCreateNode(int level, double score, slobj *obj) {
    skiplistNode *n = malloc(sizeof(*n) + level * sizeof(struct skiplistLevel));
    n->score = score;
    n->obj   = obj;
    return n;
}

skiplist *slCreate(void) {
    int j;
    skiplist *sl;

    sl = malloc(sizeof(*sl));
    sl->level = 1;
    sl->length = 0;
    sl->header = slCreateNode(SKIPLIST_MAXLEVEL, 0, NULL);
    for (j=0; j < SKIPLIST_MAXLEVEL; j++) {
        sl->header->level[j].forward = NULL;
        sl->header->level[j].span = 0;
    }
    sl->header->backward = NULL;
    sl->tail = NULL;
    return sl;
}

slobj* slCreateObj(const char* ptr, size_t length) {
    slobj *obj = malloc(sizeof(*obj));
    obj->ptr    = malloc(length + 1);

    if(ptr) {
        memcpy(obj->ptr, ptr, length);
    }
    obj->ptr[length] = '\0';

    obj->length = length;
    return obj;
}

void slFreeObj(slobj *obj) {
    free(obj->ptr);
    free(obj);
}

void slFreeNode(skiplistNode *node) {
    slFreeObj(node->obj);
    free(node);
}

void slFree(skiplist *sl) {
    skiplistNode *node = sl->header->level[0].forward, *next;

    free(sl->header);
    while(node) {
        next = node->level[0].forward;
        slFreeNode(node);
        node = next;
    }
    free(sl);
}

int slRandomLevel(void) {
    int level = 1;
    while((random() & 0xffff) < (SKIPLIST_P * 0xffff))
        level += 1;
    return (level < SKIPLIST_MAXLEVEL) ? level : SKIPLIST_MAXLEVEL;
}

int compareslObj(slobj *a, slobj *b) {
    int cmp = memcmp(a->ptr, b->ptr, a->length <= b->length ? a->length : b->length);
    if(cmp == 0) return a->length - b->length;
    return cmp;
}

int equalslObj(slobj *a, slobj *b) {
    return compareslObj(a, b) == 0;
}

void slInsert(skiplist *sl, double score, slobj *obj) {
    skiplistNode *update[SKIPLIST_MAXLEVEL], *x;
    unsigned int rank[SKIPLIST_MAXLEVEL];
    int i, level;

    x = sl->header;
    for (i = sl->level-1; i >= 0; i--) {
        /* store rank that is crossed to reach the insert position */
        rank[i] = i == (sl->level-1) ? 0 : rank[i+1];
        while (x->level[i].forward &&
            (x->level[i].forward->score < score ||
                (x->level[i].forward->score == score &&
                compareslObj(x->level[i].forward->obj,obj) < 0))) {
            rank[i] += x->level[i].span;
            x = x->level[i].forward;
        }
        update[i] = x;
    }
    /* we assume the key is not already inside, since we allow duplicated
     * scores, and the re-insertion of score and redis object should never
     * happen since the caller of slInsert() should test in the hash table
     * if the element is already inside or not. */
    level = slRandomLevel();
    if (level > sl->level) {
        for (i = sl->level; i < level; i++) {
            rank[i] = 0;
            update[i] = sl->header;
            update[i]->level[i].span = sl->length;
        }
        sl->level = level;
    }
    x = slCreateNode(level,score,obj);
    for (i = 0; i < level; i++) {
        x->level[i].forward = update[i]->level[i].forward;
        update[i]->level[i].forward = x;

        /* update span covered by update[i] as x is inserted here */
        x->level[i].span = update[i]->level[i].span - (rank[0] - rank[i]);
        update[i]->level[i].span = (rank[0] - rank[i]) + 1;
    }

    /* increment span for untouched levels */
    for (i = level; i < sl->level; i++) {
        update[i]->level[i].span++;
    }

    x->backward = (update[0] == sl->header) ? NULL : update[0];
    if (x->level[0].forward)
        x->level[0].forward->backward = x;
    else
        sl->tail = x;
    sl->length++;
}

/* Internal function used by slDelete, slDeleteByScore */
void slDeleteNode(skiplist *sl, skiplistNode *x, skiplistNode **update) {
    int i;
    for (i = 0; i < sl->level; i++) {
        if (update[i]->level[i].forward == x) {
            update[i]->level[i].span += x->level[i].span - 1;
            update[i]->level[i].forward = x->level[i].forward;
        } else {
            update[i]->level[i].span -= 1;
        }
    }
    if (x->level[0].forward) {
        x->level[0].forward->backward = x->backward;
    } else {
        sl->tail = x->backward;
    }
    while(sl->level > 1 && sl->header->level[sl->level-1].forward == NULL)
        sl->level--;
    sl->length--;
}

/* Delete an element with matching score/object from the skiplist. */
int slDelete(skiplist *sl, double score, slobj *obj) {
    skiplistNode *update[SKIPLIST_MAXLEVEL], *x;
    int i;

    x = sl->header;
    for (i = sl->level-1; i >= 0; i--) {
        while (x->level[i].forward &&
            (x->level[i].forward->score < score ||
                (x->level[i].forward->score == score &&
                compareslObj(x->level[i].forward->obj,obj) < 0)))
            x = x->level[i].forward;
        update[i] = x;
    }
    /* We may have multiple elements with the same score, what we need
     * is to find the element with both the right score and object. */
    x = x->level[0].forward;
    if (x && score == x->score && equalslObj(x->obj,obj)) {
        slDeleteNode(sl, x, update);
        slFreeNode(x);
        return 1;
    } else {
        return 0; /* not found */
    }
    return 0; /* not found */
}

/* Delete all the elements with rank between start and end from the skiplist.
 * Start and end are inclusive. Note that start and end need to be 1-based */
unsigned long slDeleteByRank(skiplist *sl, unsigned int start, unsigned int end, slDeleteCb cb, void* ud) {
    skiplistNode *update[SKIPLIST_MAXLEVEL], *x;
    unsigned long traversed = 0, removed = 0;
    int i;

    x = sl->header;
    for (i = sl->level-1; i >= 0; i--) {
        while (x->level[i].forward && (traversed + x->level[i].span) < start) {
            traversed += x->level[i].span;
            x = x->level[i].forward;
        }
        update[i] = x;
    }

    traversed++;
    x = x->level[0].forward;
    while (x && traversed <= end) {
        skiplistNode *next = x->level[0].forward;
        slDeleteNode(sl,x,update);
        cb(ud, x->obj);
        slFreeNode(x);
        removed++;
        traversed++;
        x = next;
    }
    return removed;
}

/* Find the rank for an element by both score and key.
 * Returns 0 when the element cannot be found, rank otherwise.
 * Note that the rank is 1-based due to the span of sl->header to the
 * first element. */
unsigned long slGetRank(skiplist *sl, double score, slobj *o) {
    skiplistNode *x;
    unsigned long rank = 0;
    int i;

    x = sl->header;
    for (i = sl->level-1; i >= 0; i--) {
        while (x->level[i].forward &&
            (x->level[i].forward->score < score ||
                (x->level[i].forward->score == score &&
                compareslObj(x->level[i].forward->obj,o) <= 0))) {
            rank += x->level[i].span;
            x = x->level[i].forward;
        }

        /* x might be equal to sl->header, so test if obj is non-NULL */
        if (x->obj && equalslObj(x->obj, o)) {
            return rank;
        }
    }
    return 0;
}

/* Finds an element by its rank. The rank argument needs to be 1-based. */
skiplistNode* slGetNodeByRank(skiplist *sl, unsigned long rank) {
    if(rank == 0 || rank > sl->length) {
        return NULL;
    }

    skiplistNode *x;
    unsigned long traversed = 0;
    int i;

    x = sl->header;
    for (i = sl->level-1; i >= 0; i--) {
        while (x->level[i].forward && (traversed + x->level[i].span) <= rank)
        {
            traversed += x->level[i].span;
            x = x->level[i].forward;
        }
        if (traversed == rank) {
            return x;
        }
    }

    return NULL;
}

/* range [min, max], left & right both include */
/* Returns if there is a part of the zset is in range. */
int slIsInRange(skiplist *sl, double min, double max) {
    skiplistNode *x;

    /* Test for ranges that will always be empty. */
    if(min > max) {
        return 0;
    }
    x = sl->tail;
    if (x == NULL || x->score < min)
        return 0;

    x = sl->header->level[0].forward;
    if (x == NULL || x->score > max)
        return 0;
    return 1;
}

/* Find the first node that is contained in the specified range.
 * Returns NULL when no element is contained in the range. */
skiplistNode *slFirstInRange(skiplist *sl, double min, double max) {
    skiplistNode *x;
    int i;

    /* If everything is out of range, return early. */
    if (!slIsInRange(sl,min, max)) return NULL;

    x = sl->header;
    for (i = sl->level-1; i >= 0; i--) {
        /* Go forward while *OUT* of range. */
        while (x->level[i].forward && x->level[i].forward->score < min)
                x = x->level[i].forward;
    }

    /* This is an inner range, so the next node cannot be NULL. */
    x = x->level[0].forward;
    return x;
}

/* Find the last node that is contained in the specified range.
 * Returns NULL when no element is contained in the range. */
skiplistNode *slLastInRange(skiplist *sl, double min, double max) {
    skiplistNode *x;
    int i;

    /* If everything is out of range, return early. */
    if (!slIsInRange(sl, min, max)) return NULL;

    x = sl->header;
    for (i = sl->level-1; i >= 0; i--) {
        /* Go forward while *IN* range. */
        while (x->level[i].forward &&
            x->level[i].forward->score <= max)
                x = x->level[i].forward;
    }

    /* This is an inner range, so this node cannot be NULL. */
    return x;
}

void slDump(skiplist *sl) {
    skiplistNode *x;
    int i;

    x = sl->header;
    i = 0;
    while(x->level[0].forward) {
        x = x->level[0].forward;
        i++;
        printf("node %d: score:%f, member:%s\n", i, x->score, x->obj->ptr);
    }
}

zset.lua

local skiplist = require "skiplist.c"
local mt = {}
mt.__index = mt

function mt:add(score, member)
    local old = self.tbl[member]
    if old then
        if old == score then
            return
        end
        self.sl:delete(old, member)
    end

    self.sl:insert(score, member)
    self.tbl[member] = score
end

function mt:rem(member)
    local score = self.tbl[member]
    if score then
        self.sl:delete(score, member)
        self.tbl[member] = nil
    end
end

function mt:count()
    return self.sl:get_count()
end

function mt:_reverse_rank(r)
    return self.sl:get_count() - r + 1
end

function mt:limit(count, delete_handler)
    local total = self.sl:get_count()
    if total <= count then
        return 0
    end

    local delete_function = function(member)
        self.tbl[member] = nil
        if delete_handler then
            delete_handler(member)
        end
    end

    return self.sl:delete_by_rank(count+1, total, delete_function)
end

function mt:rev_limit(count, delete_handler)
    local total = self.sl:get_count()
    if total <= count then
        return 0
    end
    local from = self:_reverse_rank(count+1)
    local to   = self:_reverse_rank(total)

    local delete_function = function(member)
        self.tbl[member] = nil
        if delete_handler then
            delete_handler(member)
        end
    end

    return self.sl:delete_by_rank(from, to, delete_function)
end

function mt:rev_range(r1, r2)
    r1 = self:_reverse_rank(r1)
    r2 = self:_reverse_rank(r2)
    return self:range(r1, r2)
end

function mt:range(r1, r2)
    if r1 < 1 then
        r1 = 1
    end

    if r2 < 1 then
        r2 = 1
    end
    return self.sl:get_rank_range(r1, r2)
end

function mt:rev_rank(member)
    local r = self:rank(member)
    if r then
        return self:_reverse_rank(r)
    end
    return r
end

function mt:rank(member)
    local score = self.tbl[member]
    if not score then
        return nil
    end
    return self.sl:get_rank(score, member)
end

function mt:range_by_score(s1, s2)
    return self.sl:get_score_range(s1, s2)
end

function mt:score(member)
    return self.tbl[member]
end

function mt:member_by_rank(r)
    return self.sl:get_member_by_rank(r)
end

function mt:member_by_rev_rank(r)
    r = self:_reverse_rank(r)
    if r > 0 then
        return self.sl:get_member_by_rank(r)
    end
end

function mt:dump()
    self.sl:dump()
end

local M = {}
function M.new()
    local obj = {}
    obj.sl = skiplist()
    obj.tbl = {}
    return setmetatable(obj, mt)
end
return M

原理图示

在这里插入图片描述

一、数据结构概览

1. slobj(字符串对象)

typedef struct slobj {
    char *ptr;
    size_t length;
} slobj;
  • 用于存储字符串类型的成员(member)
  • ptr 指向字符串内容,length 是字符串长度(不含结尾 \0)

2. skiplistNode(跳跃表节点)

typedef struct skiplistNode {
    slobj* obj;
    double score;
    struct skiplistNode *backward;
    struct skiplistLevel {
        struct skiplistNode *forward;
        unsigned int span;
    } level[];
} skiplistNode;
  • obj:成员对象
  • score:分数,用于排序
  • backward:后退指针,用于反向遍历
  • level[]:柔性数组,表示多层 forward 指针和跨度(span)

3. skiplist(跳跃表)

typedef struct skiplist {
    struct skiplistNode *header, *tail;
    unsigned long length;
    int level;
} skiplist;
  • header:头节点,不存储数据,用于管理多层链表
  • tail:尾节点,用于快速反向遍历
  • length:节点数量
  • level:当前最大层数

二、C 语言层接口(skiplist.c / skiplist.h)

1. 创建与销毁

skiplist *slCreate(void);
void slFree(skiplist *sl);
slobj* slCreateObj(const char* ptr, size_t length);
void slFreeObj(slobj *obj);
  • slCreate:创建跳跃表
  • slFree:释放整个跳跃表及其节点
  • slCreateObj / slFreeObj:创建和释放字符串对象

2. 插入与删除

void slInsert(skiplist *sl, double score, slobj *obj);
int slDelete(skiplist *sl, double score, slobj *obj);
unsigned long slDeleteByRank(skiplist *sl, unsigned int start, unsigned int end, slDeleteCb cb, void* ud);
  • slInsert:插入一个带分数的成员
  • slDelete:删除指定分数和成员的节点
  • slDeleteByRank:按排名范围删除节点,并回调处理每个被删除的成员

3. 查询

unsigned long slGetRank(skiplist *sl, double score, slobj *o);
skiplistNode* slGetNodeByRank(skiplist *sl, unsigned long rank);
skiplistNode *slFirstInRange(skiplist *sl, double min, double max);
skiplistNode *slLastInRange(skiplist *sl, double min, double max);
  • slGetRank:获取成员的排名(1-based)
  • slGetNodeByRank:根据排名获取节点
  • slFirstInRange / slLastInRange:返回分数在 [min, max] 范围内的第一个/最后一个节点

三、Lua 层接口(zset.lua)

1. 创建 ZSet

local zset = require "zset"
local zs = zset.new()

2. 增删改查

方法 说明
zs:add(score, member) 添加或更新成员
zs:rem(member) 删除成员
zs:score(member) 获取成员分数
zs:count() 返回成员总数

3. 排名相关

方法 说明
zs:rank(member) 获取正向排名(1-based)
zs:rev_rank(member) 获取反向排名(从高到低)
zs:range(r1, r2) 获取排名在 [r1, r2] 的成员
zs:rev_range(r1, r2) 获取反向排名在 [r1, r2] 的成员

4. 分数范围查询

zs:range_by_score(s1, s2)
  • 返回分数在 [s1, s2] 范围内的所有成员

5. 限制大小(常用于排行榜)

zs:limit(count, delete_handler)
zs:rev_limit(count, delete_handler)
  • 保留前 count 名,删除后面的成员
  • delete_handler 是可选的删除回调函数

6. 按排名获取成员

zs:member_by_rank(r)
zs:member_by_rev_rank(r)
  • 根据正向/反向排名获取成员

7. 调试输出

zs:dump()
  • 打印所有成员及其分数(按分数升序)

四、实现特点与注意事项

特点:

  • 支持相同分数不同成员(按字典序排序)
  • 支持正向和反向排名查询
  • 支持按排名和分数范围删除
  • 使用跳跃表,插入、删除、查询的平均时间复杂度为 O(log N)

注意事项:

  • 排名是从 1 开始(1-based)
  • 反向排名 = 总人数 - 正向排名 + 1
  • slDeleteByRank 中的 start 和 end 是包含的(inclusive)
  • Lua 层使用了一个 tbl 表来存储 member -> score 的映射,用于快速查找分数

五、使用示例(Lua)

local zset = require "zset"
local zs = zset.new()

zs:add(100, "AAA")
zs:add(200, "BBB")
zs:add(150, "CCC")

print(zs:rank("BBB"))          --> 3
print(zs:rev_rank("BBB"))      --> 1
print(zs:score("BBB"))         --> 200

zs:limit(2, function(member)
    print("Deleted:", member)  --> Deleted: BBB
end)

zs:range(1, 2)                 --> {"AAA", "CCC"}