fm_model是libFM生成的模型
model.ckpt是可以tensorflow serving的模型结构
代码:
import tensorflow as tf def load_fm_model(file_name): state = ‘‘ fid = 0 max_fid = 0 w0 = 0.0 wj = {} v = {} k = 0 with open(file_name) as f: for line in f: line = line.rstrip() if ‘global bias W0‘ in line: state = ‘w0‘ fid = 0 continue elif ‘unary interactions Wj‘ in line: state = ‘wj‘ fid = 0 continue elif ‘pairwise interactions Vj,f‘ in line: state = ‘v‘ fid = 0 continue if state == ‘w0‘: fv = float(line) w0 = fv elif state == ‘wj‘: fv = float(line) if fv != 0: wj[fid] = fv fid += 1 max_fid = max(max_fid, fid) elif state == ‘v‘: fv = [float(_v) for _v in line.split(‘ ‘)] k = len(fv) if any([_v!=0 for _v in fv]): v[fid] = fv fid += 1 max_fid = max(max_fid, fid) return w0, wj, v, k, max_fid _w0, _wj, _v, _k, _max_fid = load_fm_model(‘fm_model‘) n=_max_fid print ‘n‘, n k=_k print ‘k‘, k #write fm algorithm w0=tf.Variable(_w0) w1=tf.Variable(tf.truncated_normal([n])) print ‘w1‘, w1 w1_st = tf.SparseTensor(indices=[[a] for a in _wj.keys()], values=_wj.values(), dense_shape=[n]) tf.assign(w1, tf.sparse_tensor_to_dense(w1_st)) print ‘w1‘, w1 w2=tf.Variable(tf.truncated_normal([n,k])) print ‘w2‘, w2 inds = [] vals = [] for fid, fv in _v.items(): for i, v in enumerate(fv): if v != 0: inds.append([fid, i]) vals.append(v) w2_st = tf.SparseTensor(indices=inds, values=vals, dense_shape=[n,k]) tf.assign(w2, tf.sparse_tensor_to_dense(w2_st)) print ‘w2‘, w2 x_=tf.placeholder(tf.float32,[None,n]) #y_=tf.placeholder(tf.float32,[None]) batch=tf.placeholder(tf.int32) w2_new=tf.reshape(tf.tile(w2,[batch,1]),[-1,n,k]) print ‘w2_new‘, w2_new board_x=tf.reshape(tf.tile(x_,[1,k]),[-1,n,k]) print ‘board_x‘, board_x board_x2=tf.square(board_x) #print tf.multiply(w2_new,board_x) #print tf.reduce_sum(tf.multiply(w2_new,board_x),axis=1) q=tf.square(tf.reduce_sum(tf.multiply(w2_new,board_x),axis=1)) h=tf.reduce_sum(tf.multiply(tf.square(w2_new),board_x2),axis=1) y_fm=w0+tf.reduce_sum(tf.multiply(x_,w1),axis=1)+ 1/2*tf.reduce_sum(q-h,axis=1) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #a = sess.run(y_fm, feed_dict={x_:x_train,y_:y_train,batch:70}) #print a save_path = "./model.ckpt" #saver.save(sess, save_path) tf.saved_model.simple_save(sess, save_path, inputs={"x": x_, "batch":batch}, outputs={"y_fm": y_fm})
参考:
https://www.tensorflow.org/guide/saved_model
原文地址:https://www.cnblogs.com/yaoyaohust/p/10472780.html
时间: 2024-11-13 10:05:18