K-Means 算法的 Hadoop 实现
K-Means 算法简介
k-Means是一种聚类分析算法,它是一种无监督学习算法。它主要用来计算数据的聚集,将数据相近的点归到同一数据蔟。学习聚类时我们需要了解聚类与分类的区别,分类的类别是我们实现设定好的,而聚类的类别是通过计算得到的。
算法原理
维基百科的算法描述如下:
已知观测集 (x1,x2,x3,...,xn) ,其中每个观测都是一个d-维实向量,k-平均聚类要把这n个观测划分到k个集合中 (k≤n) ,使得组内平方和(WCSS within-cluster sum of squares)最小。换句话说,它的目标是找到使得下式满足的聚类 Si
argminS=∑i=1k∑x∈Si||x?μi||2
其中 μi 是 Si 中所有点的均值。
简单描述就是:不断迭代计算各个数据簇的中心点,直到该中心点趋于稳定。
该算法的优点是实现非常简单,主要缺点有如下:
- 对异常数据敏感。当单独几个数据远离数据簇时会影响聚类效果。
- 由于 K 值是事先给定的,所以 K 值的选择难以估计。也就是我们事先并不知道需要分多少个类别。
- ISODATA 算法可用于解决此问题,得到较为合理的类型数目K
- 初始的数据簇的中心点需要事先给定,初始种子点很大程度上会影响聚类的结果。
- K-Means++ 算法可以用来解决这个问题,其可以有效地选择初始点
步骤
- 创建k个数据簇的中心点。
- 计算所有数据点到这 k 个中心点的距离,将其划归到距离自己最近的中心点。
- 根据上次聚类结果,计算各个数据簇的算数平均值作为新的数据簇中心点。
- 将所有数据在新的中心点上重新聚类。
- 重复第4步,直到中心点趋于稳定。
中心点距离算法
求某一数据点到中心点的距离可以采用欧几里得距离公式:
distance=∑k=1n(xik?xjk)2 ̄ ̄ ̄ ̄ ̄ ̄ ̄ ̄ ̄ ̄ ̄ ̄ ̄?
可以参考 K-Means 算法(CoolShell) 里面的
求点群中心的算法
这一节,有三种距离公式。
Hadoop 环境简介
参考该篇文章(基于Docker搭建Hadoop集群之升级版)搭建了基于 Docker 的 Hadoop 环境。
Hadoop Version
[email protected]:/myjob/kmeans# hadoop version
Hadoop 2.7.2
Subversion Unknown -r Unknown
Compiled by root on 2016-05-27T18:05Z
Compiled with protoc 2.5.0
From source with checksum d0fda26633fa762bff87ec759ebe689c
This command was run using /usr/local/hadoop/share/hadoop/common/hadoop-common-2.7.2.jar
Hadoop 参数设置
core-site.xml
<?xml version="1.0"?>
<configuration>
<property>
<name>fs.defaultFS</name>
<value>hdfs://hadoop-master:9000/</value>
</property>
</configuration>
hdfs-site.xml
<?xml version="1.0"?>
<configuration>
<property>
<name>dfs.namenode.name.dir</name>
<value>file:///root/hdfs/namenode</value>
<description>NameNode directory for namespace and transaction logs storage.</description>
</property>
<property>
<name>dfs.datanode.data.dir</name>
<value>file:///root/hdfs/datanode</value>
<description>DataNode directory</description>
</property>
<property>
<name>dfs.replication</name>
<value>3</value>
</property>
</configuration>
mapred-site.xml
<?xml version="1.0"?>
<configuration>
<property>
<name>mapreduce.framework.name</name>
<value>yarn</value>
</property>
</configuration>
yarn-site.xml
<?xml version="1.0"?>
<configuration>
<property>
<name>yarn.nodemanager.aux-services</name>
<value>mapreduce_shuffle</value>
</property>
<property>
<name>yarn.nodemanager.aux-services.mapreduce_shuffle.class</name>
<value>org.apache.hadoop.mapred.ShuffleHandler</value>
</property>
<property>
<name>yarn.resourcemanager.hostname</name>
<value>hadoop-master</value>
</property>
</configuration>
NameNode Information
在本机通过 Docker 启动 Hadoop 集群后共有一个 master 节点和四个 slave 节点,如下图:
从管理后台截得下图:
Datanode Information
共有四个计算节点,在 hdfs-site.xml 文件中设定每个数据都备份三份,如下图:
MapReduce 编程
整体思路
整体设计思路如下图:
每次迭代过程都是一个 Hadoop Job ,通过不断迭代计算得到新的中心点文件,然后跟旧的中心点文件进行比较,直到新的中心点与旧的中心点误差小于给定的阙值,此时迭代结束,最后一次得到的中心点为计算结果。
项目的 GitHub 地址: https://github.com/CHAAAAA/hadoop-kmeans
KMeansData (辅助 K-Means 算法的单个数据点类)
为辅助计算,设计了类 KMeansData ,其用于保存 K-Means 计算过程中的数据,并实现了 WritableComparable
接口使其支持在 Hadoop 计算过程中向下传递。
成员变量
Text kMeansData
- 格式:
1 2 ... 6 7
,多维数据用空格隔开, 使用Double
解析 - 含义: 一行多维数据;或者在
Combiner
和Reducer
过程中累加的多维数据
- 格式:
IntWritable dataSize
- 含义:该对象的
kMeansData
字段是有几行数据累加的
- 含义:该对象的
主要成员函数
public void add(KMeansData data, int dimension)
在当前 KMeansData 对象上累加一个 KMeansData 对象,更新 kMeansData 和 dataSize 的值。该函数主要在
Combiner
和Reducer
计算过程中用到。@param data 一个 KMeansData 数据对象
@param dimension 数据对象的维度
@throws KMeansCentroidFormatException
public String getNewCentroids()
根据当前 KMeansData 对象生成新的中心点,也就是将 kMeansData 中各维的数据除以 dataSize。该函数主要用于
Reducer
过程中的最后一步,生成新中心点。@return 新中心点,数据以空格隔开
代码
public class KMeansData implements WritableComparable<KMeansData> {
private Text kMeansData;
private IntWritable dataSize;
public KMeansData() {
set(new Text(), new IntWritable());
}
public KMeansData(Text text, IntWritable intWritable) {
this.kMeansData = text;
this.dataSize = intWritable;
}
private void set(Text textWritable, IntWritable intWritable) {
this.kMeansData = textWritable;
this.dataSize = intWritable;
}
public Text getkMeansData() {
return kMeansData;
}
public IntWritable getDataSize() {
return dataSize;
}
/**
* isEqual
*
* @param o
* @return
*/
public int compareTo(KMeansData o) {
int flag = 0;
if (kMeansData.compareTo(o.kMeansData) != 0 || dataSize.compareTo(o.dataSize) != 0) {
flag = 1;
}
return flag;
}
public void write(DataOutput dataOutput) throws IOException {
kMeansData.write(dataOutput);
dataSize.write(dataOutput);
}
public void readFields(DataInput dataInput) throws IOException {
kMeansData.readFields(dataInput);
dataSize.readFields(dataInput);
}
/**
* add a KMeansData to this
* @param data a KMeansData
* @param dimension dimension
* @throws KMeansCentroidFormatException
*/
public void add(KMeansData data, int dimension) throws KMeansCentroidFormatException {
Text newData = data.kMeansData;
String[] newStrings = newData.toString().trim().split(" ");
String[] strings = kMeansData.toString().trim().split(" ");
if (newStrings.length != dimension || strings.length != dimension) {
throw new KMeansCentroidFormatException("Dimension Error");
}
StringBuffer result = new StringBuffer();
for (int i = 0; i < dimension; i++) {
double a = Double.parseDouble(newStrings[i]) + Double.parseDouble(strings[i]);
DecimalFormat df = new DecimalFormat("0.0");
result.append(df.format(a)).append(" ");
}
String r = result.toString().trim();
this.kMeansData.set(r.substring(0, r.length() - 1));
this.dataSize.set(this.dataSize.get() + data.dataSize.get());
}
/**
* get the new Centroids
* callback by Reducer
* @return data
*/
public String getNewCentroids() {
StringBuffer r = new StringBuffer();
String[] strings = kMeansData.toString().trim().split(" ");
for (String s : strings) {
double d = Double.parseDouble(s) / dataSize.get();
DecimalFormat df = new DecimalFormat("0.0");
r.append(df.format(d)).append(" ");
}
return r.toString().trim();
}
/**
* return an ArrayList<Double> about this data
* @return arrayList
*/
public ArrayList<Double> getArray() {
ArrayList<Double> arrayList = new ArrayList<Double>();
String[] data = this.kMeansData.toString().trim().split(" ");
for (String s : data) {
arrayList.add(Double.parseDouble(s));
}
return arrayList;
}
}
KMeansCentroids (辅助 K-Means 算法的中心点类)
该类用于实例化中心点文件,将中心点文件中的所有中心点保存在该类中用于计算和比较。
成员变量
Map<Integer, ArrayList<Double>> centroids
- 格式:
centroids
为一个 Map 。 Map 的 Key 为一个中心点的行号;Value 为一个 ArrayList ,存储了该中心点的数据。
- 格式:
int centroidDimension
- 含义:该字段保存了中心点文件的数据维度
主要成员函数
private void initCentroid(String centroidPath)
在初始化一个 KMeansCentroids 对象时会调用此函数,函数根据传入的文件路径,读取该文件并将数据存储在
centroids
和centroidDimension
中。@param centroidPath 中心点文件的路径
@throws KMeansCentroidFormatException
public int getCentroid(KMeansData point)
根据传入的一个 KMeansData 对象,找到距离该数据最近的中心点并返回。使用欧几里得距离公式求距离。
@param point 一个 KMeansData 数据对象
@return 距离该数据最近的中心点的行号
@throws KMeansCentroidFormatException
public boolean isEquals(KMeansCentroids o, Double error)
计算该对象是否与传入的 KMeansCentroids 对象相等,当另个对象之间的数据差小于 error 时,我们也认为其相等
@param o 一个 KMeansCentroids 数据对象
@param error 两个 KMeansCentroids 对象之间允许的误差
@return 距离该数据最近的中心点的行号
@return 另个中心点文件是否相等
代码
public class KMeansCentroids {
private static final Logger LOG = LogManager.getLogger(KMeansCentroids.class);
private Map<Integer, ArrayList<Double>> centroids = new HashMap<Integer, ArrayList<Double>>();
// Dimension
private int centroidDimension = 3;
public KMeansCentroids(String centroidPath, int centroidDimension) {
this.centroidDimension = centroidDimension;
initCentroid(centroidPath);
}
public KMeansCentroids(String centroidPath) {
initCentroid(centroidPath);
}
/**
* init centroids
*
* @param centroidPath centroids file uri
*/
private void initCentroid(String centroidPath) {
try {
Configuration configuration = new Configuration();
FileSystem fileSystem = FileSystem.get(URI.create(centroidPath), configuration);
FSDataInputStream inputStream = null;
try {
LOG.debug("Start read centroids file,URI: " + centroidPath);
inputStream = fileSystem.open(new Path(centroidPath));
BufferedReader d = new BufferedReader(new InputStreamReader(inputStream));
String line;
while ((line = d.readLine()) != null) {
line = line.replace("\t", " ").trim();
String[] points = line.split(" ");
if (points.length != centroidDimension + 1) {
throw new KMeansCentroidFormatException("Centroid Dimension Error");
}
int index = Integer.valueOf(points[0]);
ArrayList<Double> oneCentroid = new ArrayList<Double>();
for (int i = 1; i <= centroidDimension; i++) {
oneCentroid.add(Double.valueOf(points[i]));
}
centroids.put(index, oneCentroid);
}
LOG.debug("Read centroids file success. Centroids: \n" + readCentroids());
} catch (Exception e) {
e.printStackTrace();
} finally {
IOUtils.closeStream(inputStream);
}
} catch (Exception e1) {
e1.printStackTrace();
}
}
public Map<Integer, ArrayList<Double>> getCentroids() {
return centroids;
}
public int getCentroidDimension() {
return centroidDimension;
}
/**
* Get the Shortest Path Centroid index
*
* @param point data point
* @return centroid
* @throws KMeansCentroidFormatException
*/
public int getCentroid(KMeansData point) throws KMeansCentroidFormatException {
double distance = Double.MAX_VALUE;
int r = 0;
ArrayList<Double> pointData = point.getArray();
for (Integer i : centroids.keySet()) {
double temp = getEnumDistance(centroids.get(i), pointData);
if (temp < distance) {
distance = temp;
r = i;
}
}
return r;
}
/**
* Get the Enum Distance
*
* @param centroid centroid
* @param point data point
* @return Enum Distance
*/
private double getEnumDistance(ArrayList<Double> centroid, ArrayList<Double> point) {
double distance = 0.0;
for (int i = 0; i < centroidDimension; i++) {
distance += ((centroid.get(i) - point.get(i)) * (centroid.get(i) - point.get(i)));
}
distance = Math.sqrt(distance);
return distance;
}
/**
* Show the Centroids
*/
public String readCentroids() {
StringBuffer sb = new StringBuffer();
for (Integer i : centroids.keySet()) {
sb.append(i + "");
for (Double j : centroids.get(i)) {
sb.append(j + " ");
}
sb.append("\n");
}
return sb.toString();
}
/**
* two KMeansCentroids is equal
*
* @param o compare object
* @param error allowable error
* @return flag
*/
public boolean isEquals(KMeansCentroids o, Double error) {
boolean flag = true;
for (Integer i : centroids.keySet()) {
if (!arrayEquals(centroids.get(i), o.getCentroids().get(i), error)) {
flag = false;
break;
}
}
return flag;
}
private boolean arrayEquals(ArrayList<Double> a, ArrayList<Double> b, Double error) {
boolean flag = true;
for (int i = 0; i < centroidDimension; i++) {
if (Math.abs(a.get(i) - b.get(i)) > error) {
flag = false;
break;
}
}
return flag;
}
}
KMeansMapper
Mapper 计算过程:
<Object key, Text value> -> <IntWritable index, KMeansData data>
代码
protected void map(Object key, Text value, Context context) throws IOException, InterruptedException {
//get data
String s = value.toString().trim();
String[] fields = s.split(" ");
if (fields.length == dimension) {
KMeansData kMeansData = new KMeansData(new Text(s), new IntWritable(1));
try {
int index = centroids.getCentroid(kMeansData);
context.write(new IntWritable(index), kMeansData);
} catch (KMeansCentroidFormatException e) {
throw new IOException();
}
}
}
KMeansCombiner
Combiner 计算过程:
<IntWritable index, KMeansData data> -> <IntWritable index, KMeansData data>
代码
protected void reduce(IntWritable key, Iterable<KMeansData> values, Context context) throws IOException, InterruptedException {
StringBuilder stringBuilder = new StringBuilder();
for (int i = 0; i < dimension; i++) {
stringBuilder.append(0).append(" ");
}
KMeansData data = new KMeansData(new Text(stringBuilder.toString().trim()), new IntWritable(dimension));
for (KMeansData val : values) {
try {
data.add(val, dimension);
} catch (KMeansCentroidFormatException e) {
throw new IOException();
}
}
context.write(key, data);
}
KMeansReducer
Reducer 计算过程:
<IntWritable index, KMeansData data> -> <IntWritable index, Text centroid>
代码
protected void reduce(IntWritable key, Iterable<KMeansData> values, Context context) throws IOException, InterruptedException {
StringBuilder stringBuilder = new StringBuilder();
for (int i = 0; i < dimension; i++) {
stringBuilder.append(0).append(" ");
}
KMeansData data = new KMeansData(new Text(stringBuilder.toString().trim()), new IntWritable(dimension));
for (KMeansData val : values) {
try {
data.add(val, dimension);
} catch (KMeansCentroidFormatException e) {
throw new IOException();
}
}
KMeansRun
这是 Hadoop 作业的执行入口,每次迭代计算后都判断一下是否进行下次迭代。程序启动时需要 1 个或者 3 个参数,说明如下:
- 1 个参数:参数需要表示本次计算的数据维度
- 3 个参数:
数据维度 输入文件夹 初始数据中心点路径
其中初始数据中心点的文件名称必须为 centroid。
代码
public class KMeansRun {
private static final Logger LOG = LogManager.getLogger(KMeansRun.class);
//default centroid path /centroid$times
private static String centroidPath = "kmeans/";
private static String inputPath = "input/kmeans/";
private static Integer dimension = 3;
public static void main(String[] args) throws Exception {
//iteration times
int iterations = 0;
Configuration conf = new Configuration();
GenericOptionsParser optionParser = new GenericOptionsParser(conf, args);
String[] remainingArgs = optionParser.getRemainingArgs();
if (remainingArgs.length != 1 && remainingArgs.length != 3) {
System.err.println("Usage: K-Means <dimension> [in] [centroidRootPath(without filename , default filename is centroid)]");
System.exit(2);
}
dimension = Integer.valueOf(args[0]);
if (remainingArgs.length == 3) {
inputPath = args[1];
centroidPath = args[2];
if (!args[2].endsWith(File.separator)) {
centroidPath = args[2] + File.separator;
}
}
//set dimension
conf.set("dimension", dimension.toString());
String oldCentroidPath = centroidPath + "centroid";
String currentCentroidPath = centroidPath + "centroid";
do {
conf.set("centroid.path", currentCentroidPath);
Job job = Job.getInstance(conf, "K-Means" + iterations);
job.setJarByClass(KMeansRun.class);
job.setMapperClass(KMeansMapper.class);
job.setCombinerClass(KMeansCombiner.class);
job.setReducerClass(KMeansReducer.class);
job.setMapOutputKeyClass(IntWritable.class);
job.setMapOutputValueClass(KMeansData.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(Text.class);
FileInputFormat.addInputPath(job, new Path(inputPath));
FileOutputFormat.setOutputPath(job, new Path(centroidPath + iterations + "/"));
if (!job.waitForCompletion(true)) {
System.exit(1);
}
oldCentroidPath = currentCentroidPath;
currentCentroidPath = centroidPath + iterations + "/part-r-00000";
iterations++;
} while (isContinue(oldCentroidPath, currentCentroidPath));
}
private static boolean isContinue(String oldPath, String newPath) {
boolean flag = false;
KMeansCentroids oldCentroids = new KMeansCentroids(oldPath, dimension);
KMeansCentroids newCentroids = new KMeansCentroids(newPath, dimension);
if (!oldCentroids.isEquals(newCentroids, 0.1)) {
flag = true;
}
return flag;
}
}