Apriori算法实现
前言:出自于学校课程数据挖掘与分析布置的实验小作业,案例经典,代码注释较全,供大家参考。题目:文件dataset.txt 中包含某超市的购物篮数据,编程实现关联规则,发现其中的主要规则,并提出提高销售额的方法。实验数据如下:要求:1、自行采用一种语言编程实现算法(注意:生成候选项集、生成频繁项集、产生关联规则等核心算法需自己编程实现)2、用课堂例子进行正确性检验3、用户界面友好,要考虑到输入输出4
·
前言:出自于学校课程数据挖掘与分析布置的实验小作业,案例经典,代码注释较全,供大家参考。
题目:
文件dataset.txt 中包含某超市的购物篮数据,编程实现关联规则,发现其中的主要规则,并提出提高销售额的方
法。实验数据如下:
要求:
1、自行采用一种语言编程实现算法(注意:生成候选项集、生成频繁项集、产生关联规则等核心算法需自己编程实现)
2、用课堂例子进行正确性检验
3、用户界面友好,要考虑到输入输出
4、分析结果,给出合理解释或建议
python实现
import numpy as np
import pandas as pd
# DMA.apriori DMA文件下的apriori是自己写的apriori方法
from DMA.apriori import get_sup, get_conf, str_to_list
import tkinter as tk
from tkinter import filedialog
import easygui as gui
# 开启选择文件窗口
# root = tk.Tk()
# root.withdraw()
# 获取选择好的文件
# select_path = filedialog.askopenfilename()
# print(select_path)
# 第一次扫描
df = pd.read_csv("dataset.txt", encoding="gbk", sep="\t")
# 获取数据列名
columns = df.columns[1:df.shape[1]]
# 处理后的数据,只有数据“项”
goods_set = np.array(df[columns])
# 清理“项”中的数据
item_list = []
for item in goods_set:
# 除去空格,并且将字符串以逗号将各个字符隔开顺序存入数组中
item = sorted(item[0].replace(" ", "").split(","))
# 将处理后的每行数据添加到项集数组中
item_list.append(item)
sup_result_dic = get_sup(item_list, 2)
print(sup_result_dic)
# gui.msgbox(msg=sup_result_dic, title="频繁候选集", ok_button="确认")
sup_result_list = str_to_list(list(sup_result_dic))
conf_result_dic = get_conf(sup_result_list, 0.2)
conf_result = []
if len(conf_result_dic) != 0:
conf_result_list = str_to_list(list(conf_result_dic))
for i in sup_result_list:
for j in conf_result_list:
str_ = str(j) + " => " + \
str(sorted(set(i) - {j})) + \
" 置信度confidence -> " + \
str(conf_result_dic[str(j)])
conf_result.append(str_)
else:
print("出错了")
for res in conf_result:
print(res)
apriori.py
# Author: KamTang
# Date: December 15, 2021
"""
关联规则:Apriori算法
从符合要求的项集中又递归生成符合要求的多项集,当生成的多项集为空,返回上一多项集,
并将上一多项集作为最后剪枝的结果。
核心函数包括disposition_data、get_sup和scanning_item
其作用分别是:
disposition_data():处理数据,并且将有用的数据提取作为全局变量
get_sup():获取符合支持率不低于用户输入的最小支持率的项集集合,结果以字典形式返回;
例如:{'a':0.3, 'b':0.5, 'c':0.3}
scanning_item():扫描项集,从而由get_sup筛选出的项集生成新的项集。
例如:由{'a':0.3, 'b':0.5, 'c':0.3}生成[['a', 'b'], ['a', 'c'], ['b', 'c']]
"""
import ast
# 当前筛选项集下标——0-单项集、1-二项集、2-三项集...
temp_index = 0
# 数据总行数
total_num_g = 0
# 项集临时存放;格式[[单项集], [二项集], [三项集], [四项集], ...]
temp_g = []
# 源数据中包含元素最多的项集长度
max_length_g = 0
# 单项集,第一次调用get_sup函数时需要
single_list_g = []
# 筛选后的项集
result_sup = {}
# 置信率
# min_conf = 0.5
def disposition_data(data_list):
"""
将用户选择的数据经过处理后获取相关数据,并将需要的数据保存为全局变量以便后面使用
Parameters
----------
data_list: list
源数据
Returns
-------
item_set: list
返回下一项集
"""
global total_num_g, temp_g, max_length_g, single_list_g
for it in data_list:
# 获取每行数据中最长长度max_length
max_length_g = max_length_g if len(it) < max_length_g else len(it)
# 数据总行数 10
total_num_g = len(data_list)
i = 0
# 数据中长度最长为4,我们创建4个临时项集格式[[], [], [], []]
while i < max_length_g:
temp_ = []
temp_g.append(temp_)
i = i + 1
for i in data_list:
for j in range(len(i)):
temp_g[j].append(i)
for data in temp_g[0]:
for element in data:
# 将数据中每行的项存入单项集singleSet中
single_list_g.append(element)
"""
set函数详情见:https://docs.python.org/2/tutorial/datastructures.html#sets
"""
# 去除重复元素,并且按照字母顺序排序
single_list_g = sorted(set(single_list_g))
def get_sup(data_list, min_sup=None, item_list=None):
"""
获取各个项的支持率
Parameters
----------
data_list: list
需要进行关联规则的数据,数组格式。
min_sup: int
最低支持率,默认值为0
item_list: list
项集
Returns
-------
以项集为key,支持率为value的字典形式返回
"""
# 存放单项和对应的支持率
item_sup_dic = {}
# 源数据最大项集
global temp_index
# 筛选后的项集
global result_sup
# 这里只需执行一次
if temp_index == 0 and item_list is None:
# 处理数据
disposition_data(data_list)
item_list = single_list_g
data_list = temp_g
for single in item_list:
# 单项在每个项集的个数初始为0
num = 0
for item in data_list[temp_index]:
# 单项集中的元素是字符串格式,直接使用in判断
if isinstance(single, str) and single in item:
num = 1 + num
# 针对多项集(元素是集合)的生成,判断是否具有包含关系
elif set(item).issuperset(set(single)):
num = 1 + num
# 排除支持率在low_sup下的
if min_sup is None or min_sup == 0:
item_sup_dic[str(single)] = num
elif num >= min_sup:
item_sup_dic[str(single)] = num
elif isinstance(min_sup, int) is False or min_sup < 0:
raise ValueError("你应该使用正整数")
# TODO 删除previous_item_sup
# 将有效结果结果保存到previous_item_sup
previous_item_sup = item_sup_dic
print("item_sup_dic => ", item_sup_dic)
print("previous_item_sup => ", previous_item_sup)
# 获取下一项集
if len(item_sup_dic.keys()) != 0:
# 从符合支持率在min_sup之上的项集中递归获取下一项集
item_list = scanning_item(list(item_sup_dic))
# min_sup=0 => IndexError: list index out of range
if len(item_list) > 1 and len(previous_item_sup.keys()) != 0 and temp_index < max_length_g - 1:
# 排除n项集时不需一直对原数据扫描,n项集只需与n项集比较
temp_index += 1
# 递归获取筛选出符合支持率在min_sup之上的项集
item_sup_dic = get_sup(data_list, min_sup, item_list)
else:
result_sup = previous_item_sup
return previous_item_sup if len(item_sup_dic.keys()) == 0 else item_sup_dic
def scanning_item(item_):
"""
根据传入的项集,生成下一项集
Parameters
----------
item_: list
项集,list格式
Returns
-------
item_set: list
返回下一项集
"""
item_from_sup = str_to_list(item_)
# 项集,最后作为结果返回
item_set = []
# 外层循环次数
outside_index = 0
# 获取单项集的笛卡尔:ABC*ABC形式
for single1 in item_from_sup:
# 控制while循环,这里的AB = BA,保留AB
single2 = outside_index
# 由于单项集中的元素格式是字符串格式,所以这里需要与多项集处理方法分开
if isinstance(single1, str):
while single2 < len(item_from_sup):
# 去除AA的情况
if single1 != item_from_sup[single2]:
item_set.append([single1] + [item_from_sup[single2]])
single2 = single2 + 1
# 多项集 -> 多项集,多项集里面的元素是以list形式存放
else:
while single2 < len(item_from_sup) - 1:
# 如果是['a', 'b']遇到['a', 'c']的情况,并且对称差集symmetric_difference要属于上一级项集
# 如果遇到['a', 'b', 'c', 'd']与['a', 'b', 'e', 'f']的情况下面的方法失效
# symmetric_difference = sorted(set(single1).symmetric_difference(item_from_sup[single2 + 1]))
symmetric_difference = sorted(set(single1[1:]).union(set(item_from_sup[single2 + 1][1:])))
# 根据项集中的第一个项是否相同并且两个项集的对称差集属于原项集(item_list)
if single1[0] == item_from_sup[single2 + 1][0] and symmetric_difference in item_from_sup:
# 取两个集合的并集并且按照字母先后排序,如果这里不排序,生成多项集时会出现空或者少项
item_set.append(sorted(set(single1).union(item_from_sup[single2 + 1])))
single2 = single2 + 1
else:
break
outside_index = outside_index + 1
return item_set
def str_to_list(dic):
# 将处理item_数据后的结果放入item_from_sup
item_from_sup = []
# 使用捕获异常来处理单项集中字符串无法转换成list的问题
try:
if isinstance(dic[0], str):
for i in dic:
# 没有异常,将字典中str类型的keys转化成list形式
item_from_sup.append(ast.literal_eval(i))
except ValueError:
# 如果有异常,直接使用传入的参数
item_from_sup = dic
return item_from_sup
def get_conf(lists_, min_conf=0.5):
sup_list = list(result_sup.values())
for l_index in range(len(lists_)):
return get_conf_main(lists_[l_index], min_conf, sup_list[l_index])
def get_conf_main(list_, min_conf, sup):
"""
计算置信度
:param list_:
:param min_conf:
:param sup:
:return:
"""
global temp_index
conf_dic = {}
list_01 = scanning_conf_item(list_)
for t in list_01:
num_c = 0
conf = 0
for tt in temp_g[temp_index - 1]:
if isinstance(t, str) and t in tt:
num_c = num_c + 1
conf = sup / num_c
elif set(tt).issuperset(set(t)):
num_c = num_c + 1
conf = sup / num_c
if conf >= min_conf:
conf_dic[str(t)] = conf
previous_item_conf = conf_dic
if temp_index > 1 and len(conf_dic) > 1:
temp_index = temp_index - 1
conf_dic = get_conf_main(list(conf_dic.keys()), min_conf, sup)
return previous_item_conf if len(conf_dic.keys()) == 0 else conf_dic
def scanning_conf_item(list_from_conf):
"""
获取满足置信度要求的项集
:param list_from_conf:
:return:
"""
list_01 = []
if len([list_from_conf]) > 1:
list_conf = str_to_list(list_from_conf)
else:
list_conf = list_from_conf
if isinstance(list_conf[0], str) and len([list_conf]) > 1:
for i in range(len(list_conf)):
# 筛选出符合置信度的项集
tri = sorted(set(list_conf) - {list_conf[i]})
list_01.append(tri)
elif isinstance(list_conf[0], list) and len([list_conf]) > 1:
for j in list_conf:
for k in range(len(j)):
# 筛选出符合置信度的项集
two = sorted(set(j) - {j[k]})
if two not in list_01 and len(two) != 0:
# 存入list_02数组中
list_01.append(two)
else:
list_01 = list_from_conf
return list_01
执行结果:
↓
↓
↓
↓
↓
↓
↓
↓
↓
↓
↓
↓
↓
↓
↓
别往下滑了,没有界面实现,太懒了。 =.=!
更多推荐
已为社区贡献3条内容
所有评论(0)