本次实验的任务是文本分类,所用数据集为20news_bydate
实验流程主要如下:
一、算法介绍
· 朴素贝叶斯算法是应用最为广泛的分类算法之一。
· 给定训练数据集X,其类别为Y,则有:
在这里插入图片描述
其中,P(Y|X)为后验概率,即测试集文本X取类别Y的概率。对于文本Xi,对每个类别Yi计算概率,得到P(Y1|Xi)、P(Y2|Xi)、……、P(Yn|Xi),概率最大的类别则为预测类别。

上式中:
· P(Y)为每个类别的先验概率,计算方式为:
每个类别单词总数 / 训练集所有单词总数。

· P(X | Y)= P(x1,x2,……,xn | Y),xi为文本的特征。
朴素贝叶斯“朴素”地将各个特征视作相互独立,因此:
P(X | Y) = P(x1 | Y)× P(x2 | Y)× P(x3 | Y)…… × P(xn | Y)。
这里,文本的特征xi即为文本预处理、分词后得到的全部词汇。
计算方式(以 P(xi | Yj)为例):
单词xi在训练集Yj类别下所有文档中出现的总次数 / 训练集Yj类别下所有文档包含的单词总数

· P(X)始终不变,可以忽略不计。

综上,实验主要步骤为:
数据集获取 -》 数据预处理,得到分词后的语料和词袋 -》 遍历计算P(X)、P(X|Y) -》
max( P(X) × P(X|Y) )对应的类别即为预测类别

二、数据集获取
20newsgroups数据集可以直接从sklearn模块导入:

from sklearn.datasets import fetch_20newsgroups #导入模块
news_data = fetch_20newsgroups(subset="all") #读取数据

也可以手动下载:
链接:https://pan.baidu.com/s/1YO-Je1lT_y-MbGRpSwHQSA
提取码:qwer

三、数据预处理

import os
# string nltk 用于文本预处理
import string
import nltk
from nltk.corpus import stopwords
import pickle

class Textprocess():
    def __init__(self):
    	# 存放原始语料库路径
        self.corpus = ''
        # 分词后的路径
        self.segment = ''
        # 存储分词、去重后的结果
        self.word_list = ''
        self.label_list = ''
        # 存储训练集分词结果,处理测试集时相关代码注释
        self.ori_words_list = ''
	
	# 在原始路径下创建train_segment,test_segment两个文件夹
	# 存储预处理、分词后的结果
    def preprocess(self):
        mydir = os.listdir(self.corpus)
        for dir in mydir:
            create_dir = self.corpus + '/' + dir + '_segment'
            os.makedirs(create_dir)
            dir_path = self.corpus + '/' + dir
            news_list = os.listdir(dir_path)
            # 每个类别的文档集
            for news in news_list:
                path = create_dir + '/' + news
                os.makedirs(path)
                news_path = dir_path + '/' + news
                files = os.listdir(news_path)
                # 每个文本文件
                for file in files:
                    file_path = news_path + '/' + file
                    with open(file_path,'r',encoding='utf-8', errors='ignore') as f1:
                        content = f1.read()
                        clean_content = self.data_clean(content)
                        new_file_path = path + '/' + file
                        with open(new_file_path, 'w', encoding = 'utf-8', errors='ignore') as f2:
                            f2.write(clean_content)

    def data_clean(self, data):
        # 大写转换为小写
        data1 = data.lower()
        # 去除标点符号
        remove = str.maketrans('','',string.punctuation)
        data2 = data1.translate(remove)
        # 分词
        data3 = nltk.word_tokenize(data2)
        # 去除停用词和非英文词汇
        data4 = [w for w in data3 if (w not in stopwords.words('english')) and (w.isalpha()) and (len(w) < 15)]
        data_str = ' '.join(data4)
        return data_str

    def create_non_repeated_words:(self):
        self.content_list = []
        self.labels_list = []
        # self.ori_list = []
        mydir = sorted(os.listdir(self.segment))
        label = 0
        for dir in mydir:
            dir_path = self.segment + '/' + dir
            files = sorted(os.listdir(dir_path))
            for file in files:
                file_path = dir_path + '/' + file
                with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                    line = f.read()
                    line1 = line.strip('\n')
                    line2 = line1.split()
                    # 将分词后的结果存入列表,处理测试集时注释
                    # self.ori_list.append(line2)
                    # 列表去重,并去掉仅出现一次的单词
                    line3 = []
                    once_word = []
                    for i in line2:
                        if i not in once_word:
                            once_word.append(i)
                        else:
                            if i not in line3:
                                line3.append(i)
                    self.content_list.append(line3)
                    self.labels_list.append(label)
            label += 1

        self.data_dump(self.word_list, self.content_list)
        self.data_dump(self.label_list, self.labels_list)
        # self.data_dump(self.ori_words_list, self.ori_list)

    def data_dump(self, path, data):
        f = open(path, 'wb')
        pickle.dump(data, f)
        f.close()

    def data_load(self, path):
        f = open(path, 'rb')
        data = pickle.load(f)
        return data
    
text = Textprocess()
text.corpus = r'.\20news-bydate'
text.segment = r'.\20news-bydate\20news-bydate-train_segment'
text.word_list = 'train_words'
text.label_list = 'train_labels'
text.ori_words_list = 'original_bag'
text.preprocess()
text.create_non_repeated_words()


test = Textprocess()
test.corpus = r'.\20news-bydate'
test.segment = r'.\20news-bydate\20news-bydate-test_segment'
test.word_list = 'test_words'
test.label_list = 'test_labels'
test.create_non_repeated_words()

四、朴素贝叶斯算法的实现

#encoding=utf-8
import textprocess_detail as tp
import numpy as np
from sklearn import metrics
ori_words = tp.Textprocess().data_load('original_bag')
train_labels = tp.Textprocess().data_load('train_labels')
test_words = tp.Textprocess().data_load('test_words')
test_labels = tp.Textprocess().data_load('test_labels')

# 计算每个类别包含的单词总数
def words_sum():
    sum = [0 for i in range(20)]
    for i in range(len(ori_words)):
        count = len(ori_words[i])
        sum[train_labels[i]] += count
    return sum

# 计算每个类别的先验概率
def category_probability(list):
    sum = 0
    cp = []
    for i in list:
        sum += i
    for j in list:
        cp.append(j / sum)
    return cp

# p(x1|y) * p(x2|y) * …… * p(y)
def predict(sum, cp):
    precision = []
    for doc in range(len(test_words)):
        p_list = []
        for predict_label in range(20):
            p = 1
            word_sum = sum[predict_label]
            for word in test_words[doc]:
                count = 0
                # 遍历训练集文档
                for i in range(len(ori_words)):
                    if train_labels[i] == predict_label:
                        count += ori_words[i].count(word)
                p *= (count + 1) / (word_sum + 20)
            p *= cp[predict_label]
            p_list.append(p)
        precision.append(p_list)
    tp.Textprocess().data_dump('precision', precision)
    print(precision)
    return precision

count_list = words_sum()
cp = category_probability(count_list)
precision = predict(count_list, cp)
probability = tp.Textprocess().data_load('precision')
a = np.array(probability)
precision = np.argmax(a, axis = 1)

true_label = tp.Textprocess().data_load('test_labels')
m_precision = metrics.accuracy_score(true_label,precision)
print("%.2f"%m_precision)

最终正确率在0.7左右

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐