朴素贝叶斯的研究

参考

https://blog.csdn.net/fisherming/article/details/79509025
http://www.qlcoder.com/task/760e

公式

贝叶斯定理:p(a|b) = p(a)p(b|a)/p(b)
条件概率:p(a|b)=p(ab)/p(b)

这两个公式我理解了好久,数学基本功太差

下面的代码能把样本计算出81.9%的正确率,没有任何数据修正,很开心能通过考试,希望以后能慢慢把他调的更好。

import pandas as pd
import numpy as np
import re
#完整数据
df = pd.read_csv('adult.txt',header=None)
#待预测数据
df_test = pd.read_csv('adult_test.txt',header=None)

#数据条数
data_count = len(df)

columns_en = ["age","work_type","college","marry","job","family_role","color","sex","in",
"out","work_time","country","rich"]
columns_cn = ["年龄","工作类型","教育程度","婚姻状态","职业","家庭角色","种族","性别","资本收益",
"资本损失","工作时长","原国家","年收入"]
df.columns = columns_en
df_test.solumns = columns_en[:-1]

columns = {
"age":['0-18','19-20','21-22','23-23','24-25','25-26','27-29','30-32','33-35','36-39','40-43','44-47','48-51','52-60','61-100'],
"work_type":['Private','Self-emp-not-inc','Self-emp-inc','Federal-gov','Local-gov','State-gov','Without-pay','Never-worked'],
"college":['Bachelors','Some-college','11th','HS-grad','Prof-school','Assoc-acdm','Assoc-voc','9th','7th-8th','12th','Masters','1st-4th','10th','Doctorate','5th-6th','Preschool'],
"marry":['Married-civ-spouse','Divorced','Never-married','Separated','Widowed','Married-spouse-absent','Married-AF-spouse'],
"job":['Tech-support','Craft-repair','Other-service','Sales','Exec-managerial','Prof-specialty','Handlers-cleaners','Machine-op-inspct','Adm-clerical','Farming-fishing','Transport-moving','Priv-house-serv','Protective-serv','Armed-Forces'],
"family_role":['Wife','Own-child','Husband','Not-in-family','Other-relative','Unmarried'],
"color":['White','Asian-Pac-Islander','Amer-Indian-Eskimo','Other','Black'],
"sex":['Female','Male'],
"in":['==0','1-1000','1001-2000','2001-3000','3001-4000','4001-5000','>5000'],
"out":['==0','1-1000','1001-2000','2001-3000','3001-4000','4001-5000','>5000'],
"work_time":['==0','1-10','11-20','21-30','31-39','==40','41-50','51-60','61-99'],
"country":['United-States','Cambodia','England','Puerto-Rico','Canada','Germany','Outlying-US','Guam-USVI-etc','India','Japan','Greece','South','China','Cuba','Iran','Honduras','Philippines','Italy','Poland','Jamaica','Vietnam','Mexico','Portugal','Ireland','France','Dominican-Republic','Laos','Ecuador','Taiwan','Haiti','Columbia','Hungary','Guatemala','Nicaragua','Scotland','Thailand','Yugoslavia','El-Salvador','Trinadad&Tobago','Peru','Hong','Holand-Netherlands'],
"rich":[1,0],
}

#数字范围的列

columns_sum = {}


# print(df.head(5))
number_range_list = ["age","in","out","work_time"]

reg_renge = '(^>|^<|^==|^\d+)-?(\d+)'

reg_renge_1 = '^(\d+)-(\d+)$'
reg_renge_2 = '^==(\d+)$'
reg_renge_3 = '^>(\d+)$'
reg_renge_4 = '^<(\d+)$'

#统计各列总数
for flag in columns['rich']:
    columns_sum[flag] = {}
    fenmu = len(df[(df['rich'] == flag)])
    for key in  columns:
        if key in number_range_list:
            #数字数据
            for range_text in columns[key]:
                reg_res = re.findall(reg_renge,range_text)[0]
                if reg_res == []:
                    continue

                if '==' == reg_res[0]:
                    columns_sum[flag][key + '_' + range_text] = len(df[(df['rich'] == flag) & (df[key] == int(reg_res[1]))]) / fenmu
                if '>' == reg_res[0]:
                    columns_sum[flag][key + '_' + range_text] = len(df[(df['rich'] == flag) & (df[key] > int(reg_res[1]))]) / fenmu
                if '<' == reg_res[0]:
                    columns_sum[flag][key + '_' + range_text] = len(df[(df['rich'] == flag) & (df[key] < int(reg_res[1]))]) / fenmu
                if re.match('\d+',reg_res[0]):
                    columns_sum[flag][key + '_' + range_text] = len(df[(df['rich'] == flag) & (int(reg_res[0]) <= df[key]) & (df[key] <= int(reg_res[1]))]) / fenmu
        else:
            #枚举数据
            for item in columns[key]:
                columns_sum[flag][item] = len(df[(df[key] == item) & (df['rich'] == flag)]) / fenmu
#计算整体papb
count_a = len(df[(df['rich'] == 1)])
count_b = len(df[(df['rich'] == 0)])
pa_value = count_a/(count_a + count_b)
pb_value = count_b/(count_a + count_b)


print(columns_sum)

print(pa_value,pb_value)

#计算结果
def calc_papb(columns_sum, row_data):
    #[x for x in list(columns_sum[pa_val]) if x != pa_val][0]

    pa_val = 1
    pb_val = 0
    
    col_dict = {
        0:'age',
        8:'in',
        9:'out',
        10:'work_time',
    }
    
    res_a = 1
    res_b = 1
    
    return_arr = []
    for i in range(len(row_data)):
        col = row_data[i]
        # print('\n','*'*3,i,col)
        if col == '?':
            return_arr.append(1)
            continue
        if i in [0,8,9,10]:
            key = col_dict[i]
            for range_text in columns[key]:
                reg_res = re.findall(reg_renge,range_text)[0]
                # print(key, range_text, reg_res)
                if '==' == reg_res[0] and col == int(reg_res[1]):
                    col = key + '_' + range_text
                    break
                else:
                    pass
                    # print('!check','==' == reg_res[0],col == int(reg_res[1]))
                
                if '>' == reg_res[0] and int(col) > int(reg_res[1]):
                    col = key + '_' + range_text
                    break
                else:
                    pass
                    # print('!check','>' == reg_res[0],col > int(reg_res[1]))
                
                if '<' == reg_res[0] and int(col) < int(reg_res[1]):
                    col = key + '_' + range_text
                    break
                else:
                    pass
                    # print('!check','<' == reg_res[0],col < int(reg_res[1]))
                
                if re.match('\d+',reg_res[0]) and int(col) >= int(reg_res[0]) and int(col) <= int(reg_res[1]):
                    col = key + '_' + range_text
                    break
                else:
                    pass
                    # print('!check', reg_res[0], reg_res[1], col)
        else:
            if col not in columns_sum[pa_val]:
                continue

        res_a *= columns_sum[pa_val][col]
        res_b *= columns_sum[pb_val][col]
    
    res_a *= pa_value
    res_b *= pb_value
    
    result_a = res_a/(res_a + res_b)
    # result_b = res_b/(res_a + res_b)
    result_c = 1 - result_a
    # return result_a, result_c
    if result_a > result_c:
        return '1'
    else:
        return '0'

'''
df_list = np.array(df_test[:1]).tolist()[0]
print(df_list)
result = calc_papb(columns_sum, data_count,df_list)
print(result)
'''


'''
result = ''
result2 = ''
for row in np.array(df[:1000]).tolist():
    # result += '{0[0]},{0[1]}\n'.format(calc_papb(columns_sum, row))
    tmp = calc_papb(columns_sum, row[:-1])
    if tmp == str(row[-1:][0]):
        result += '{},成功\n'.format(','.join([str(x) for x in row]))
    else:
        result2 += '{},失败\n'.format(','.join([str(x) for x in row]))

file = open('d:/res.txt','w')
file.write(result+result2)
file.close()
'''
result = ''
for row in np.array(df_test).tolist():
    # result += '{0[0]},{0[1]}\n'.format(calc_papb(columns_sum, row))
    result += calc_papb(columns_sum, row)



# print(result)

file = open('d:/res.txt','w')
file.write(result)
file.close()
最后修改:2019 年 01 月 10 日 04 : 33 PM

发表评论