Hadoop 实现朴素贝叶斯 Naive Bayes 文本分类

实验简介

Hadoop 是什么

Hadoop 是一个由 Apache 基金会所开发的分布式系统基础架构。主要解决海量数据的存储和海量数据的分析计算问题。广义上来说 HADOOP 通常是指一个更广泛的概念—— HADOOP 生态圈

Hadoop 的优势

  1. 高可靠性:因为 Hadoop 假设计算元素和存储会出现故障,因为它维护多个工作数据副本,在出现故障时可以对失败的节点重新分布处理。
  2. 高扩展性:在集群间分配任务数据,可方便的扩展数以千计的节点。
  3. 高效性:在 MapReduce 的思想下, Hadoop 是并行工作的,以加快任务处理速度。
  4. 高容错性:自动保存多份副本数据,并且能够自动将失败的任务重新分配。

Hadoop 组成

  1. Hadoop HDFS:一个高可靠、高吞吐量的分布式文件系统。
  2. Hadoop MapReduce: 一个分布式的离线并行计算框架。
  3. Hadoop YARN: 作业调度与集群资源管理的框架。
  4. Hadoop Common:支持其他模块的工具模块(Configuration、RPC、序列化机制、日志操作)。

实验任务说明

本文意在通过 hadoop 集群,完成贝叶斯文本分类的任务。

有关利用 朴素贝叶斯算法 进行文本分类的原理参见上节内容,这里不做过多说明。

本实验是在 Windown 8 操作系统下,通过搭建 VMware Workstation 虚拟机的方式进行 hadoop 集群的部署。本次实验部署了 3hadoop 虚拟机的节点,集群部署规划如下表:

hadoop1 hadoop2 hadoop3
HDFS NameNode
DataNode
DataNode SecondaryNameNode
DataNode
YARN NodeManager ResourceManager
NodeManager
NodeManager

有关于虚拟机安装与 hadoop 集群环境搭建这里略。

贝叶斯分类器任务框架

数据集简介

数据集文件夹为 NBCorpus,里面一共两个子文件夹 CountryIndustry,实验要求从中选取一个完成即可。

CountryIndustry 下每个子目录就是一个文档类别,但有的子目录下文件非常少,因此要选择文件比较多的目录(至少二个)进行训练和测试。

每个文件已经分好词,一行一个单词。

训练集和测试集的选取

我在数据集中选择了 Country 文件夹下的 CHINACANA 作为本次实验的样本,其中 CHINA 类中包含 255 个文本,CANA 类中包含 263 个文本。按照 70%30% 的比例选取训练集和测试集。表格如下:

CHINA CANA
文档总数 255 263
训练集数 178 184
测试集数 77 79

手动随机抽取相应数量的文档放入相应的数据输入路径,训练集与测试集相应的路径如下所示:

  • e:/INPUT/TRAIN/CHINA/,训练集类 CHINA 文档路径
  • e:/INPUT/TRAIN/CANA/,训练集类 CANA 文档路径
  • e:/INPUT/TEST/CHINA/,测试集类 CHINA 文档路径
  • e:/INPUT/TEST/CANA/,测试集类 CANA 文档路径

实验任务分解

Bayes 分类器的 MapReduce 实现分为 训练测试 两个阶段。

其中 训练 阶段需要编写两个 MapReduce 任务,MapReduce 任务一 完成计算文档 d 出现在类 c 中的先验概率的所需数据,我将该阶段称为 训练先验概率 阶段;MapReduce 任务二 完成计算词项 t 出现在类 c 中的条件概率的所需数据,我将该阶段称为 训练条件概率 阶段。

其中 测试 阶段需要编写一个 MapReduce 任务三,用于完成 预测 测试集中文档所属的类别,我将其称为 预测 阶段。另外还要编写一个java程序,用于对 任务三 中预测的结果进行 评估,我将其称为 评估 阶段。

实验任务所有代码均在 eclipse 编辑器中完成。我在 eclipse 创建了一个名为 Naive Bayesprojectsrc 文件夹下创建三个 package,分别取名为 DocCountWordCount 以及 Predition,分别用于完成 MapReduce 任务一MapReduce 任务二 以及 MapReduce 任务三。最后的 评估 阶段代码写在了 Predition 这个包中。项目中的代码结构如下图所示:

01_Project_Architecture

MapReduce 任务一 训练先验概率

任务说明

  1. 需要编写一个单独的 MapReduce Job ,计算结果写入文件;
  2. 实现一个自定义的 InputFormatRecordReader ,每读取一个文件(实际上不需要读取文件内容),输出 <ClassName,1> ,其中 ClassName 为读取的文件所在的类别目录名,<ClassName,1>Map 的输入,Map 不做任何处理,直接输出 <ClassName,1>
  3. Map 的输出交给 Combine 处理,Combine 的输入为<ClassName,{1,1,...,1}>,在 Combine 中计算 1 的个数,所以 Combine 的输出为 <ClassName,Count>Count为属于 ClassName 类别的文档个数,但是局部的;
  4. Combine 的输出交给 ReducerReducer 的输入为 <ClassName,{count1,count2, ..., countn}>,在 Reduce 里对 count1count2,…,countn 求和,就得到了 ClassName的总数 TotalCountReducer 的输出为 <ClassName,TotalCount> 并写到文件;
  5. 该作业主要统计了每种类别文档的总数目,具体概率的计算放在了后面。作业的输出会产生多个文件,取决于 Reducer 的个数,每个文件里一行的格式为: 类名 文档总数

代码目录

DocCount

  • DocCountDriver.java: 主程序入口
  • DocCountMapper.java: 实现 Map 阶段
  • DocCountReducer.java: 实现 Reduce 阶段
  • WholeFileInputFormat.java: 重写的 InputFormat
  • WholeRecordReader.java: 重写的 RecordReader

重写 InputFormat 与 RecordReader

之所以要重写这个类,主要是因为 hadoop 中默认的 MapReduce 程序,每一个 Map 任务的调用输入的 key 为文档中每一行行号,数据类型为 LongWritablevalue 为文档中的一行内容,数据类型为 Text。但是第一任务,我们要求每一个 Map 任务处理一个文档,而不是文档中的每一行记录,所以需要对 InputFormatRecordReader 这两个类进行重写,以符合任务的需求。

重写的 InputFormat 的类文件 WholeFileInputFormat.java 内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
package DocCount;

import java.io.IOException;

import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;

public class WholeFileInputFormat extends FileInputFormat<NullWritable, BytesWritable>{

@Override
protected boolean isSplitable(JobContext context, Path filename) {
return false;
}

@Override
public RecordReader<NullWritable, BytesWritable> createRecordReader(InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException {
// 创建对象
WholeRecordReader recordReader = new WholeRecordReader();
// 初始化
recordReader.initialize(split, context);
// 返回对象
return recordReader;
}
}

重写的 RecordReader 的类文件 WholeRecordReader.java 内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package DocCount;

import java.io.IOException;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;

public class WholeRecordReader extends RecordReader<NullWritable, BytesWritable>{

BytesWritable value = new BytesWritable();
boolean isProcess = false;
FileSplit split;
Configuration configuration;

@Override
public void initialize(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException {
// 初始化
this.split = (FileSplit) split;
configuration = context.getConfiguration();
}

@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
// 读取一个一个的文件
if (!isProcess) {

// 0.缓存区
byte[] buf = new byte[(int) split.getLength()];

FileSystem fs = null;
FSDataInputStream fis = null;
try {
// 1.获取文件系统
Path path = split.getPath();
fs = path.getFileSystem(configuration);
// 2.打开文件输入流
fis = fs.open(path);
// 3.流的拷贝
IOUtils.readFully(fis, buf, 0, buf.length);
// 4.拷贝缓存区的数据到最终输出
value.set(buf, 0, buf.length);
} catch (Exception e) {
}finally {
IOUtils.closeStream(fis);
IOUtils.closeStream(fs);
}
isProcess = true;
return true;
}
return false;
}

@Override
public NullWritable getCurrentKey() throws IOException, InterruptedException {
// 获取当前键
return NullWritable.get();
}

@Override
public BytesWritable getCurrentValue() throws IOException, InterruptedException {
// 获取当前值
return value;
}

@Override
public float getProgress() throws IOException, InterruptedException {
// 获取当前进度
return isProcess? 1:0;
}

@Override
public void close() throws IOException {
}
}

MapReduce 任务

设计的 Map 任务程序 DocCountMapper.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package DocCount;

import java.io.IOException;

import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;

public class DocCountMapper extends Mapper<NullWritable, BytesWritable, Text, IntWritable>{

Text k = new Text();
IntWritable v = new IntWritable(1);

@Override
protected void setup(Mapper<NullWritable, BytesWritable, Text, IntWritable>.Context context)
throws IOException, InterruptedException {

// 获取文件的路径和名称(类名)
FileSplit split = (FileSplit) context.getInputSplit();

Path path = split.getPath();
k.set(path.getParent().getName());
}

@Override
protected void map(NullWritable key, BytesWritable value, Context context)
throws IOException, InterruptedException {

context.write(k, v);
}
}

设计的 Reduce 任务程序 DocCountReducer.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
package DocCount;

import java.io.IOException;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;

public class DocCountReducer extends Reducer<Text, IntWritable, Text, IntWritable>{

IntWritable value = new IntWritable();

@Override
protected void reduce(Text key, Iterable<IntWritable> values,
Context context) throws IOException, InterruptedException {

// 1. 统计文档总个数
int sum = 0;
for (IntWritable count : values) {
sum += count.get();
}

// 2 输出单词总个数
value.set(sum);
context.write(key, value);
}
}

main 函数入口

对于程序的入口,我们专门写了一个类进行 main() 封装,命名为 Driver,程序 DocCountDriver.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package DocCount;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

public class DocCountDriver {

public static void main(String[] args) throws IllegalArgumentException, IOException, ClassNotFoundException, InterruptedException {

args = new String[] {"e:/INPUT/TRAIN", "e:/z_output_doc"};

// 1 获取job信息
Configuration conf = new Configuration();
Job job = Job.getInstance(conf);

// 2 获取jar包位置
job.setJarByClass(DocCountDriver.class);

// 3 关联自定义的mapper和reducer
job.setMapperClass(DocCountMapper.class);
job.setReducerClass(DocCountReducer.class);

// 4 设置自定义的InputFormat类
job.setInputFormatClass(WholeFileInputFormat.class);

// 5 设置map输出数据类型
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(IntWritable.class);

// 6 设置最终输出数据类型
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(IntWritable.class);

// 7 设置输入和输出文件路径
ArrayList<Path> paths = GetPaths(args[0]);
for(int i=0; i < paths.size(); i++) {
FileInputFormat.addInputPath(job, paths.get(i));
}
FileOutputFormat.setOutputPath(job, new Path(args[1]));

// 8 提交代码
boolean result = job.waitForCompletion(true);
System.exit(result?0:1);
}

private static ArrayList<Path> GetPaths(String path) {
// 获取path路径下所有子文件夹路径
ArrayList<Path> paths = new ArrayList<Path>();
File file = new File(path);
// 如果这个路径是文件夹
if (file.isDirectory()) {
// 获取路径下的所有文件
File[] files = file.listFiles();
for (int i=0; i<files.length; i++) {
// 如果还是文件夹
if (files[i].isDirectory()) {
// 将其加入路径列表
paths.add(new Path(files[i].getPath()));
}
else {continue;}
}
}
return paths;
}
}

程序运行结果

程序运行过程如图所示:

02_Doc_1

程序运行结束如图所示:

03_Doc_2

程序输出文件如图所示:

04_Doc_3

程序输出数据内容如下:

05_Doc_4

MapReduce 任务二 训练条件概率

任务说明

  1. 需要编写一个单独的MapReduce Job ,计算结果写入文件;
  2. 实现一个自定义的 Writable 类型,要求 Map 每读取一个文件中的一行(一个单词),输出 <<ClassName,Term>,1>,其中 key<ClassName,Term>ClassName 为读取的文件所在的类别目录名,Term 为单词,1 表示 TermClassName 的类里出现一次;
  3. Map 的输出交给 Combine 处理,Combine 的输入 <<ClassName,Term>,{1,1,...,1}>,在 Combine 中计算 1 的个数,所以 Combine 的输出为 <<ClassName,Term>,Count>CountTermClassName 的类里出现的次数,但是局部的;
  4. Combine 的输出交给 ReducerReducer 的输入为 <<ClassName,Term>,{count1,count2,...,countn}>,在 Reduce 里把 count1count2countn 求和,就得到了 TermClassName 的类里出现的总次数 TotalCount
  5. Reduce 输出 <<ClassName,Term>,TotalCount>
  6. 该作业只统计了每个 <ClassName,Term> 对出现的总次数,具体条件概率计算放在了后面。作业的输出会产生多个文件,取决于 Reducer 的个数,每个文件里一行的格式为:类名 单词 出现次数

代码目录

WordCount

  • WordCountDriver.java: 主程序入口
  • WordCountMapper.java: 实现 Map 阶段
  • WordCountReducer.java: 实现 Reduce 阶段
  • TextPair.java:重写 Wtrtable 类,自定义一个 <Text,Text> 的数据类型

定义新的数据类型

Map 任务需要输出的键值为 <ClassName,Term>,然而我查了下 hadoop 的数据类型,如下表:

Java Hadoop Writable
boolean BooleanWritalbe
byte ByteWritable
int IntWritable
float FloatWritable
long LongWritable
double DoubleWritable
string Text
map MapWritable
array ArrayWritable

上表中,左侧为 java 的数据类型,右侧为 hadoop 默认的数据类型,每行存储的数据是一样的,只不过一个是 java 的类,一个是 hadoop 的类。

看了下数据类型,只有 MapWritable 满足要求,因为要求输出为键值对形式。无奈这个类不太会用,程序各种编译不通过,所幸重新写了一个自定义的数据类型,取名为 TextPair,文件 TextPair.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package WordCount;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;

public class TextPair implements WritableComparable<TextPair>{

private Text dirName;
private Text word;

public TextPair() {
set(new Text(), new Text());
}

public TextPair(String dirName, String word) {
set(new Text(dirName), new Text(word));
}

public TextPair(Text dirName, Text word) {
set(dirName, word);
}

public void set(Text dirName, Text word) {
this.dirName = dirName;
this.word = word;
}

public Text getFirst() {
return dirName;
}

public Text getSecond() {
return word;
}

@Override
public void write(DataOutput out) throws IOException {
dirName.write(out);
word.write(out);
}

@Override
public void readFields(DataInput in) throws IOException {
dirName.readFields(in);
word.readFields(in);
}

@Override
public int hashCode() {
return dirName.hashCode() * 163 + word.hashCode();
}

@Override
public boolean equals(Object o) {
if(o instanceof TextPair) {
TextPair tp = (TextPair) o;
return dirName.equals(tp.dirName) && word.equals(tp.word);
}
return false;
}

@Override
public String toString() {
return dirName + "\t" + word;
}

@Override
public int compareTo(TextPair tp) {
int cmp = dirName.compareTo(tp.dirName);
if (cmp !=0) {
return cmp;
}
return word.compareTo(tp.word);
}
}

MapReduce 任务

设计的 Map 任务程序 WordCountMapper.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package WordCount;

import java.io.IOException;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;

/*
* 输入的key LongWritable 行号
* 输入的value Text 一行内容
* 输出的key Text 单词
* 输入的value IntWritable 单词的个数
*/
public class WordCountMapper extends Mapper<LongWritable, Text, TextPair, IntWritable>{

Text className = new Text();
Text wordName = new Text();
TextPair k = new TextPair();
IntWritable v = new IntWritable(1);

@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {

// 1.获取类名
InputSplit inputSplit = context.getInputSplit();
String dirName = ((FileSplit) inputSplit).getPath().getParent().getName();

// 2. 一行内容转换成string
String line = value.toString();

// 3. 切割
String[] words = line.split(" ");

// 3 循环写出当下一个截断
for (String word : words) {
className.set(dirName);
wordName.set(word);
k.set(className, wordName);
context.write(k, v);
}
}
}

设计的 Reduce 任务程序 WordCountReducer.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
package WordCount;

import java.io.IOException;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Reducer;

public class WordCountReducer extends Reducer<TextPair, IntWritable, TextPair, IntWritable>{

@Override
protected void reduce(TextPair key, Iterable<IntWritable> values,
Context context) throws IOException, InterruptedException {

// 1 统计单词总个数
int sum = 0;
for (IntWritable count : values) {
sum += count.get();
}

// 2 输出单词总个数
context.write(key, new IntWritable(sum));
}
}

main 函数入口

main() 写在 WordCountDriver.java 中,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
package WordCount;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

public class WordCountDriver {

public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {

args = new String[] {"e:/INPUT/TRAIN", "e:/z_output_word"};

// 1 获取job信息
Configuration configuration = new Configuration();
Job job = Job.getInstance(configuration);

// 2 获取jar包位置
job.setJarByClass(WordCountDriver.class);

// 3 关联自定义的mapper和reducer
job.setMapperClass(WordCountMapper.class);
job.setReducerClass(WordCountReducer.class);

// 4 设置map输出数据类型
job.setMapOutputKeyClass(TextPair.class);
job.setMapOutputValueClass(IntWritable.class);

// 5 设置最终输出数据类型
job.setOutputKeyClass(TextPair.class);
job.setOutputValueClass(IntWritable.class);

// 6 设置输入和输出文件路径
ArrayList<Path> paths = GetPaths(args[0]);
for(int i=0; i < paths.size(); i++) {
FileInputFormat.addInputPath(job, paths.get(i));
}
FileOutputFormat.setOutputPath(job, new Path(args[1]));

// 7 提交代码
boolean result = job.waitForCompletion(true);
System.exit(result?0:1);
}

private static ArrayList<Path> GetPaths(String path) {
// 获取path路径下所有子文件夹路径
ArrayList<Path> paths = new ArrayList<Path>();
File file = new File(path);
// 如果这个路径是文件夹
if (file.isDirectory()) {
// 获取路径下的所有文件
File[] files = file.listFiles();
for (int i=0; i<files.length; i++) {
// 如果还是文件夹
if (files[i].isDirectory()) {
// 将其加入路径列表
paths.add(new Path(files[i].getPath()));
}
else {continue;}
}
}
return paths;
}
}

程序运行结果

程序运行过程如图所示:

06_Word_1

程序运行结束如图所示:

07_Word_2

程序输出文件如图所示:

08_Word_3

程序输出数据内容如下:

09_Word_4

MapReduce 任务三 预测

任务说明

  1. 预测前将训练得到文件加载到内存里,计算先验概率和每个类别里单词出现的条件概率,可以交给自定义 Mapper 类和自定义 Reducer 类的包装类 Prediction 来处理,在 Prediction 类里定义成类变量来保存这些学习到的概率,这样 Mapper 类和 Reducer 类都可以访问到这些概率。保存这些概率的数据结构应该用 HashTable,这样可以高效地读取所需的概率值;
  2. Prediction 类实现一个静态方法,计算一个文档属于某类的条件概率 P(class|doc),该方法无需用MapReduce实现,需要计算其中每个单词出现的频率,该方法命名为 conditionalProbabilityForClass
  3. 每读取一个文件,这里需要把文件内容作为一个整体读取成为一个 String,产生 <docId,content> 作为 Map 的输入;
  4. Map 里写一个 for 循环,对于每一个类别 c 在循环中调用 conditionalProbabilityForClass 函数,得到 <docId,<ClassName,Prob>>,作为 Map 的输出。因此 Map 的输入为 <docId,content>Map 的输出为list<docId,<ClassName,Prob>>
  5. Reduce 任务输入为 <docId,list<ClassName,Prob>>,找到最大的 Prob,输出 <docId,最大Prob对应的ClassName>

代码目录

Predition

  • PredictDriver.java: 主程序入口
  • PredictMapper.java: 实现 Map 阶段
  • PredictReducer.java: 实现 Reduce 阶段
  • PredictTestInputFormat.java: 重写的 InputFormat
  • PredictTestRecordReader.java: 重写的 RecordReader
  • Prediction.java: 封装类,用于保存学习到的先验概率和条件概率,并定义一个方法实现 P(c|d) 的计算,并给 Map 调用

封装类说明

Prediction.java 用于在 MapReduce 运行之前,对之前的任务的运行结果进行预处理,之前的运行结果主要存放在 "e:/z_output_doc/""e:/z_output_word/" 下。将其数据取出,分别用于计算该文档 d 属于类 c 的先验概率 P(c),以及在类 c 中单词 t 出现的条件概率 p(t|c),并将计算结果存入相应的哈希表 class_probclass_term_prob 中。并定义一个方法实现 P(c|d) 的计算,传给 Map 调用。封装类 Prediction.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package Predition;

import java.io.*;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;

public class Prediction {

private static Hashtable<String,Double> class_prob = new Hashtable<String, Double>();
private static Hashtable<Map<String,String>,Double> class_term_prob = new Hashtable<Map<String, String>, Double>();
private static Hashtable<String,Double> class_term_total = new Hashtable<String, Double>();
private static Hashtable<String,Double> class_term_num = new Hashtable<String, Double>();
Prediction() throws NumberFormatException, IOException{
// 统计文档总数
BufferedReader reader = new BufferedReader(new FileReader(new File("e://z_output_doc//"+"part-r-00000")));
double file_total = 0;
while(reader.ready()){
String line = reader.readLine();
String[] args = line.split("\t");
file_total += Double.valueOf(args[1]);
}
// 计算先验概率class_prob
reader = new BufferedReader(new FileReader(new File("e://z_output_doc//"+"part-r-00000")));
while(reader.ready()){
String line = reader.readLine();
String[] args = line.split("\t");
class_prob.put(args[0],Double.valueOf(args[1])/file_total);
System.out.println(String.format(("%s:%f"),args[0],Double.valueOf(args[1])/file_total));
}

//计算单词总数
reader = new BufferedReader(new FileReader(new File("e://z_output_word//"+"part-r-00000")));
while(reader.ready()){
String line = reader.readLine();
String[] args = line.split("\t");// 0:类,1:词条,2:词频
double count = Double.valueOf(args[2]);
String classname = args[0];
class_term_total.put(classname,class_term_total.getOrDefault(classname,0.0)+count);
}
//计算单词集合大小
reader = new BufferedReader(new FileReader(new File("e://z_output_word//"+"part-r-00000")));
while(reader.ready()){
String line = reader.readLine();
String[] args = line.split("\t");// 0:类,1:词条,2:词频
String classname = args[0];
class_term_num.put(classname,class_term_num.getOrDefault(classname,0.0)+1.0);
}
System.out.println(String.format(("%f:%f"),class_term_num.get("CANA"),class_term_num.get("CHINA")));
//计算每个类别里面出现的词条概率class-term prob
reader = new BufferedReader(new FileReader(new File("e://z_output_word//"+"part-r-00000")));
while(reader.ready()){
String line = reader.readLine();
String[] args = line.split("\t");// 0:类,1:词条,2:词频
double count = Double.valueOf(args[2]);
String classname = args[0];
String term = args[1];
Map<String,String> map = new HashMap<String, String>();
map.put(classname,term);
class_term_prob.put(map, (count+1)/(class_term_total.get(classname)+class_term_num.get(classname)));
}
}

public static double conditionalProbabilityForClass(String content,String classname){
// 计算一个文档属于某类的条件概率
double result = 0;
String[] words = content.split("\n");
for(String word:words){
Map<String,String> map = new HashMap<String, String>();
map.put(classname, word);
result += Math.log(class_term_prob.getOrDefault(map,1.0/(class_term_total.get(classname)+class_term_num.get(classname))));
}
result += Math.abs(Math.log(class_prob.get(classname)));
return result;
}
}

重写 InputFormat 与 RecordReader

之所以要重写这个类,原因前面做 MapReduce 任务一 说过,默认的类不满足任务需求,我们要求给 Map 的出入为整个文档的内容,并以字符串的形式传递,所以要进行改写。

重写的 InputFormat 的类文件 PredictTestInputFormat.java 内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package Predition;

import java.io.IOException;

import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;

public class PredictTestInputFormat extends FileInputFormat<NullWritable,Text>{

@Override
protected boolean isSplitable(JobContext context, Path filename) {
return false;
}

@Override
public RecordReader<NullWritable, Text> createRecordReader(InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException{
// 创建对象
PredictTestRecordReader recordReader = new PredictTestRecordReader();
// 返回对象
return recordReader;
}
}

重写的 RecordReader 的类文件 PredictTestRecordReader.java 内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package Predition;

import java.io.IOException;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.input.LineRecordReader;

public class PredictTestRecordReader extends RecordReader<NullWritable, Text>{

FileSplit split;
Configuration conf;
Text value = new Text();
boolean isProcess = false;
LineRecordReader reader=new LineRecordReader();

@Override
public void initialize(InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException {
// 初始化
this.split = (FileSplit) split;
conf = context.getConfiguration();
reader.initialize(split, context);
}

@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
// 读取一个一个的文件
if (!isProcess) {
String result = "";
while(reader.nextKeyValue()){
result += reader.getCurrentValue() + "\n";
}
value.set(result);
isProcess = true;
return true;
}
return false;
}

@Override
public NullWritable getCurrentKey() throws IOException, InterruptedException {
return NullWritable.get();
}

@Override
public Text getCurrentValue() throws IOException, InterruptedException {
return value;
}

@Override
public float getProgress() throws IOException, InterruptedException {
// 获取当前进度
return reader.getProgress();
}

@Override
public void close() throws IOException {
}
}

MapReduce 任务

设计的 Map 任务程序 PredictMapper.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package Predition;

import java.io.IOException;

import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;

public class PredictMapper extends Mapper<NullWritable, Text, Text, Text>{
Text k = new Text();
@Override
protected void setup(Mapper<NullWritable, Text, Text, Text>.Context context)
throws IOException, InterruptedException {
// 获取文件的路径和名称(类名)
FileSplit split = (FileSplit) context.getInputSplit();

Path path = split.getPath();
k.set(path.getName()+"&"+path.getParent().getName());
}

@Override
protected void map(NullWritable key, Text value, Context context)
throws IOException, InterruptedException {
Text result = new Text();
String[] CLASS_NAMES = {"CHINA","CANA"};
for(String classname:CLASS_NAMES) {
result.set(classname+"&"+Double.toString(Prediction.conditionalProbabilityForClass(value.toString(),classname)));
context.write(k,result);
}
}
}

设计的 Reduce 任务程序 PredictReducer.java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import java.io.IOException;

import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;

public class PredictReducer extends Reducer<Text, Text, Text, Text>{

@Override
protected void reduce(Text key, Iterable<Text> values, Context context)
throws IOException, InterruptedException {
Text value = new Text();
String max_classname = "";
double max_prob = Double.NEGATIVE_INFINITY;
for(Text text:values) {
String[] args = text.toString().split("&");
if(Double.valueOf(args[1]) > max_prob){
max_prob = Double.valueOf(args[1]);
max_classname = args[0];
}
}
value.set(max_classname);
context.write(key, value);
}
}

main 函数入口

main() 写在 PredictDriver.java 中,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package Predition;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

public class PredictDriver {

public static void main(String[] args) throws NumberFormatException, IOException, ClassNotFoundException, InterruptedException {

args = new String[] {"e:/INPUT/TEST", "e:/z_output_class"};

// 1 获取job信息
Prediction prediction = new Prediction();
Configuration conf = new Configuration();
Job job = Job.getInstance(conf, "prediction");

// 2 获取jar包位置
job.setJarByClass(Prediction.class);

// 3 关联自定义的mapper和reducer
job.setMapperClass(PredictMapper.class);
job.setReducerClass(PredictReducer.class);

// 4 设置自定义的InputFormat类
job.setInputFormatClass(PredictTestInputFormat.class);

// 5 设置map输出数据类型
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(Text.class);

// 6 设置最终输出数据类型
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Text.class);

// 7 设置输入和输出文件路径
ArrayList<Path> paths = GetPaths(args[0]);
for(int i=0; i < paths.size(); i++) {
FileInputFormat.addInputPath(job, paths.get(i));
}
FileOutputFormat.setOutputPath(job, new Path(args[1]));

// 8 提交代码
boolean result = job.waitForCompletion(true);
System.exit(result?0:1);
}

private static ArrayList<Path> GetPaths(String path) {
// 获取path路径下所有子文件夹路径
ArrayList<Path> paths = new ArrayList<Path>();
File file = new File(path);
// 如果这个路径是文件夹
if (file.isDirectory()) {
// 获取路径下的所有文件
File[] files = file.listFiles();
for (int i=0; i<files.length; i++) {
// 如果还是文件夹
if (files[i].isDirectory()) {
// 将其加入路径列表
paths.add(new Path(files[i].getPath()));
}
else {continue;}
}
}
return paths;
}
}

程序运行结果

程序运行过程如图所示:

10_Predict_1

程序运行结束如图所示:

11_Predict_2

程序输出文件如图所示:

12_Predict_3

程序输出数据内容如下:

13_Predict_4

评估

任务说明

  1. 根据 MapReduce 任务三 的预测结果,计算每一个类别的精确率与召回率;
  2. 类别超过一个,评估算法对于所有类别的精确率与召回率的宏平均与微平均指标。

代码目录

Predition

  • Evaluation.java: 该类主要计算预测的评估指标

预测代码实现

Evaluation.java 代码内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package Predition;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.Hashtable;

public class Evaluation {

public static String[] CLASS_NAMES = {"CHINA","CANA"};
//tp,tn,fp,fn,p,r,f1 7个数据结构
public static Hashtable<String,Integer> TP = new Hashtable<String,Integer>();
public static Hashtable<String,Integer> TN = new Hashtable<String,Integer>();
public static Hashtable<String,Integer> FP = new Hashtable<String,Integer>();
public static Hashtable<String,Integer> FN = new Hashtable<String,Integer>();
public static Hashtable<String,Double> P = new Hashtable<String,Double>();
public static Hashtable<String,Double> R = new Hashtable<String,Double>();
public static Hashtable<String,Double> F1 = new Hashtable<String,Double>();

public static void main(String[] args) throws IOException {
calculatePrecision();
for(String classname:CLASS_NAMES) {
double p=0,r=0,f1=0,tp=0,fp=0,fn=0;
tp = TP.getOrDefault(classname,0);
fp = FP.getOrDefault(classname,0);
fn = FN.getOrDefault(classname,0);
System.out.println(tp);
System.out.println(fp);
System.out.println(fn);
p=tp/(tp+fp);
r=tp/(tp+fn);
f1=2*p*r/(p+r);
P.put(classname,p);
R.put(classname,r);
F1.put(classname,f1);
System.out.println(String.format("%s precision: %f----recall: %f----f1:%f "
,classname,p,r,f1));
}
printMicroAverage();
printMacroAverage();
}

private static void printMicroAverage() {
// 计算微平均
double sumP=0,sumR=0,sumF1=0,length = CLASS_NAMES.length;
for(String classname:CLASS_NAMES){
sumP += P.get(classname);
sumR += R.get(classname);
sumF1 += F1.get(classname);
}
System.out.println(String.format(
"all classes micro average P: %f",sumP/length));
System.out.println(String.format(
"all classes micro average R: %f",sumR/length));
System.out.println(String.format(
"all classes micro average F1: %f",sumF1/length));
}

private static void printMacroAverage() {
// 计算宏平均
double tp=0,fp=0,fn=0;
double p=0,r=0,f1=0;
for(String classname:CLASS_NAMES){
tp += TP.get(classname);
fp += FP.getOrDefault(classname,0);
fn += FN.getOrDefault(classname,0);
}
p = tp/(tp+fp);
r = tp/(tp+fn);
f1 = 2*p*r/(p+r);
System.out.println(String.format(
"all classes macro average P: %f",p));
System.out.println(String.format(
"all classes macro average R: %f",r));
System.out.println(String.format(
"all classes macro average F1: %f",f1));
}

public static void calculatePrecision() throws IOException {
// 读取预测结果,计算TP,FP,FN,TN值并存入hash表
BufferedReader reader=new BufferedReader(
new FileReader("e://z_output_class//"+"part-r-00000"));
while(reader.ready()){
String line=reader.readLine();
String[] args = line.split("\t");
String[] args1 = args[0].split("&");
String truth = args1[1];
String predict = args[1];
for(String classname:CLASS_NAMES) {
if(truth.equals(classname) && predict.equals(classname)) {
TP.put(classname,TP.getOrDefault(classname,0)+1);
}else if(truth.equals(classname) && !predict.equals(classname)) {
FN.put(classname,FN.getOrDefault(classname,0)+1);
}else if(!truth.equals(classname) && predict.equals(classname)) {
FP.put(classname,FP.getOrDefault(classname,0)+1);
}else if(!truth.equals(classname) && !predict.equals(classname)) {
TN.put(classname,TN.getOrDefault(classname,0)+1);
}
}
}
}
}

运行结果

运行结果如下图所示:

14_Evalution

对上表进行整理,结果如下:

CHINA Yes(Ground Truth) No(Ground Truth)
Yes(Classified) 77 14
No(Classified) 0 65

CHINA 类的精确率为 0.8461CHINA 类的召回率为 1.00F1值为 0.9167

CANA Yes(Ground Truth) No(Ground Truth)
Yes(Classified) 65 0
No(Classified) 14 77

CANA 类的精确率为 1.00CANA 类的召回率为 0.8228F1值为 0.9028

微平均的计算结果为

P: 0.923077
R: 0.911392
F1: 0.909722

宏平均的计算结果为

P: 0.910256
R: 0.910256
F1: 0.910256

坚持原创技术分享,您的支持将鼓励我继续创作!