实体嵌入(embedding)目的将表格数据中的分类属性(一个至多个)向量化。
1.实体嵌入简介:
- 实体嵌入是主要应用于深度学习中处理表格分类数据的一种技术,或者更确切地说NLP领域最为火爆,word2vec就是在做word的embedding。
- 神经网络相比于当下的流行的xgboost、LGBM等树模型并不能很好地直接处理大量分类水平的分类特征。因为神经网络要求输入的分类数据进行one-hot处理。当分类特征的水平很高的时候,one-hot经常带来维度爆炸问题,紧接着就是参数爆炸,局部极小值点更多,更容易产生过拟合等等一系列问题。
- 实体嵌入是从降维这一角度来考虑改善这些问题的。它通过将分类属性放入网络的全连接层的输入单元中,后接几个单元数较输入层更少的隐藏层(连续型变量直接接入第一个隐藏层),经过神经网络训练后输出第一个隐藏层中分类变量关联的隐层单元,作为提取的特征,用于各种模型的输入。
2.用实体嵌入做特征提取的网络结构(示例)
(暂略)
3.实体嵌入代码(示例)
1 2 3 import math 4 import random as rn 5 import numpy as np 6 from keras.models import Model 7 from keras.layers import Input, Dense, Concatenate, Reshape, Dropout 8 from keras.layers.embeddings import Embedding 9 10 11 #传入分类变量数据(没有one-hot)和连续变量数据,返回一个model对象(它根据数据的组成定义好了网络框架类,这个类之后可以调用,喂入数据训练) 12 def build_embedding_network(category_data, continus_data): 13 14 #获取分类变量的名称 15 cat_cols = [x for x in category_data.columns] 16 17 #以字典形式存储分类变量的{名称:维度} 18 category_origin_dimension = [math.ceil(category_data[cat_col].drop_duplicates().size) for cat_col in cat_cols] 19 20 #以字典形式存储分类变量 计划 在embed后的{名称:维度} 21 category_embedding_dimension = [math.ceil(math.sqrt(category_data[cat_col].drop_duplicates().size)) for cat_col in cat_cols] 22 23 # 以网络结构embeddding层在前,dense层(全连接)在后;2.训练集的X必须以分类特征在前,连续特征在后。 24 inputs = [] 25 embeddings = [] #嵌入层用list组织数据 26 27 #对于每一个分类变量,执行以下操作 28 for cat_val, cat_origin_dim, cat_embed_dim in list(zip(cat_cols, category_origin_dimension, category_embedding_dimension)) : 29 30 # 实例化输入数据的张量对象 31 input_cate_feature = Input(shape=(1,)) #batch_size, 空(维度) 32 33 #实例化embed层的节点:定义输入的维度(某个分类变量的onehot维度),输出的维度(其embed后的维度),同时传入input数据对象 34 embedding = Embedding(input_dim = cat_origin_dim, 35 output_dim = cat_embed_dim, 36 input_length=1)(input_cate_feature) 37 38 # 将embed层输出的形状调整 39 embedding = Reshape(target_shape=(cat_embed_dim,))(embedding) 40 41 #每遍历一个分类变量就将输入数据添加在inputs对象中 42 inputs.append(input_cate_feature) 43 44 #每遍历一个分类变量就扩展一个embedding层的节点 45 embeddings.append(embedding) 46 47 #dda于连续特征执行以下操作: 48 49 #获取连续特征的变量个数 50 cnt_val_num = continus_data.shape[1] 51 #连续特征原封原样迭代地加入与分类变量embedding后的输出组合在一起(没有embedding操作) 52 for cnt_val_num in range(cnt_val_num) : 53 input_numeric_features = Input(shape=(1,)) 54 55 #输入的连续型数据不经过任何一种激活函数处理直接 与embed + reshape后的分类变量进行组合 56 embedding_numeric_features = Dense(units = 16)(input_numeric_features) 57 inputs.append(input_numeric_features) 58 embeddings.append(embedding_numeric_features) 59 60 #把有不同输入的embedding 层的数据链接在一起 61 x = Concatenate()(embeddings) #这种写法表明concatenate返回的是一个函数,第二个括号是传入这个返回的函数的参数 62 63 x = Dense(units = 16, activation=‘relu‘)(x) 64 65 #丢弃15%的节点 66 x = Dropout( 0.15 )(x) 67 68 #输出层包含一个节点,激活函数是relu 69 output = Dense(1, activation=‘relu‘)(x) 70 model = Model(inputs, output) 71 model.compile(loss=‘mean_squared_error‘, optimizer=‘adam‘) 72 73 return model 74 75 76 # 训练 77 NN = build_embedding_network(category_features_in_trainset, continus_features_in_trainset ) 78 NN.fit(X, y_train, epochs=3, batch_size=40, verbose=0) 79 80 #读取embedding层数据 81 cate_feature_num = category_features_in_trainset.columns.size 82 83 model = NN # 创建原始模型 84 for i in range(cate_feature_num): 85 # 如果把类别特征放前,连续特征放后,cate_feature_num+i就是所有embedding层 86 layer_name = NN.get_config()[‘layers‘][cate_feature_num+i][‘name‘] 87 88 intermediate_layer_model = Model(inputs=NN.input, 89 outputs=model.get_layer(layer_name).output) 90 91 # numpy.array 92 intermediate_output = intermediate_layer_model.predict(X) 93 94 intermediate_output.resize([train_data.shape[0],cate_embedding_dimension[i][1]]) 95 96 if i == 0: #将第一个输出赋值给X_embedding_trans,后续叠加在该对象后面 97 X_embedding_trans = intermediate_output 98 else: 99 X_embedding_trans = np.hstack((X_embedding_trans,intermediate_output)) #水平拼接 100 101 102 #显示分类变量向量化后的数据 103 X_embedding_trans
4.个人观点
高levels的分类变量而言,多是用户ID、地址、名字等。变量如‘地址’可以通过概念分层的办法获取省、市之类的维度较少的特征;‘商品’也可以通过概念分层获取其更泛化的类别,但是这种操作实质是在平滑数据,很可能淹没掉有用的信息。我可能不会选择这种方式来建立预测模型(做细分统计、透视图的时候会经常使用到)。也有的人会不用这些分类变量,但是我觉得会损失一些有用的影响因素,所以最好物尽其用。
实体嵌入跟word2vector一样会让output相似的对象在隐藏层中的数值也相近。使用低纬度、连续型变量的数据训练模型,效果也会比用高维、稀疏的数据训练好得多。
这种针对分类变量特征提取的方法对于提升预测准确率来说是很有效的,跟模型堆叠一样有效。
原文地址:https://www.cnblogs.com/mx0813/p/12635378.html
时间: 2024-10-10 09:06:02