【Pytorch基础教程34】EGES召回模型(更新ing)

发布于:2022-11-01 ⋅ 阅读:(975) ⋅ 点赞:(0)

note

一、EGES图算法

在这里插入图片描述

  • 某宝在19年提出的EGES模型,是加入side information的graph embedding方法,解决冷启动问题。核心任务在于基于用户行为计算所有项目之间的成对相似性。大致步骤为基于用户历史行为构造一个图,然后利用 Node2Vec 的方法来学习 Item 的 Embedding 向量。这样便可以根据向量的内积计算节点间的相似度来生成候选集。
  • 为了解决冷启动,阿里的GNN迭代了三次:BGE、GES 和 EGES。
  • 推荐系统中存在很多的图结构,如二部图,序列图,社交关系图,知识语义图等。GNN比random walk等算法效果更好。

在这里插入图片描述

1.1 数据预处理

利用滑动窗口选取用户历史行为序列,同时也有降噪处理,如点击少于1s的大概率为无意点击(需要剔除)、过度活跃用户(短时间购买几千件商品可能为刷的,需要剔除)等等。

在这里插入图片描述

1.2 GNN with side info

在这里插入图片描述

1.3 Framework of EGES

在这里插入图片描述

二、代码实现

在这里插入图片描述

import torch as th

class EGES(th.nn.Module):
    def __init__(self, dim, num_nodes, num_brands, num_shops, num_cates):
        super(EGES, self).__init__()
        self.dim = dim
        # embeddings for nodes
        base_embeds = th.nn.Embedding(num_nodes, dim)
        brand_embeds = th.nn.Embedding(num_brands, dim)
        shop_embeds = th.nn.Embedding(num_shops, dim)
        cate_embeds = th.nn.Embedding(num_cates, dim)
        # concat four embedding
        self.embeds = [base_embeds, brand_embeds, shop_embeds, cate_embeds]
        # weights for each node's side information
        self.side_info_weights = th.nn.Embedding(num_nodes, 4)

    #
    def forward(self, srcs, dsts):
        # srcs: sku_id, brand_id, shop_id, cate_id
        srcs = self.query_node_embed(srcs)
        dsts = self.query_node_embed(dsts)
        
        return srcs, dsts
    
    def query_node_embed(self, nodes):
        """
            @nodes: tensor of shape (batch_size, num_side_info)
        """
        batch_size = nodes.shape[0]
        # query side info weights, (batch_size, 4)
        side_info_weights = th.exp(self.side_info_weights(nodes[:, 0]))
        # merge all embeddings
        side_info_weighted_embeds_sum = []
        side_info_weights_sum = []
        # four embeddings
        for i in range(4):
            # weights for i-th side info, (batch_size, ) -> (batch_size, 1)
            i_th_side_info_weights = side_info_weights[:, i].view((batch_size, 1))
            # batch of i-th side info embedding * its weight, (batch_size, dim)
            side_info_weighted_embeds_sum.append(i_th_side_info_weights * self.embeds[i](nodes[:, i]))
            side_info_weights_sum.append(i_th_side_info_weights)
        # stack: (batch_size, 4, dim), sum: (batch_size, dim)
        side_info_weighted_embeds_sum = th.sum(th.stack(side_info_weighted_embeds_sum, axis=1), axis=1)
        # stack: (batch_size, 4), sum: (batch_size, )
        side_info_weights_sum = th.sum(th.stack(side_info_weights_sum, axis=1), axis=1)
        # (batch_size, dim)
        H = side_info_weighted_embeds_sum / side_info_weights_sum

        return H       

    def loss(self, srcs, dsts, labels):
        dots = th.sigmoid(th.sum(srcs * dsts, axis=1))
        dots = th.clamp(dots, min=1e-7, max=1 - 1e-7)

        return th.mean(- (labels * th.log(dots) + (1 - labels) * th.log(1 - dots)))

Reference

[1] EGES模型