1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
| import numpy as np import math import random
from matplotlib import pyplot as plt
class EM: def __init__(self, phi_1, phi_2, miu1, miu2, sigma1, sigma2, dataSize, max_iter): """ 参数初始化 :param phi_1: 隐变量取Gauss1的概率 :param phi_2: 隐变量取Gauss2的概率 :param miu1: Gauss1的伪均值 :param miu2: Gauss2的伪均值 :param sigma1: Gauss1的方差 :param sigma2: Gauss2的方差 :param dataSize: 样本数据长度 """ self.phi_1 = phi_1 self.phi_2 = phi_2 self.miu1 = miu1 self.miu2 = miu2 self.sigma1 = sigma1 self.sigma2 = sigma2 self.dataSize = dataSize self.max_iter = max_iter
self.phi_1set = [] self.phi_2set = []
self.miu1set = [] self.miu2set = []
self.sigma1set = [] self.sigma2set = []
def creat_gauss_dist(self): """ 构造一个高斯混合样本集 :return: """ data1 = np.random.normal(self.miu1, self.sigma1, int(self.dataSize * self.phi_1)) data2 = np.random.normal(self.miu2, self.sigma2, int(self.dataSize * self.phi_2)) dataset = [] dataset.extend(data1) dataset.extend(data2) random.shuffle(dataset)
return dataset
def calculate_gauss(self, dataset, miu, sigma): """ 计算高斯核函数 :param miu: 高斯核伪均值 :param sigma: 高斯核方差 :return: 高斯分布概率值 """ gauss = (1 / (math.sqrt(2 * math.pi) * sigma)) * \ np.exp(-1 * (dataset - miu) * (dataset - miu) / (2 * sigma ** 2))
return gauss
def E_step(self, dataset, phi_1, phi_2, miu1, miu2, sigma1, sigma2): """ E步: 计算Q函数 :return: Q_k(z), k=1, 2 """
q1_numerator = phi_1 * self.calculate_gauss(dataset, miu1, sigma1) q2_numerator = phi_2 * self.calculate_gauss(dataset, miu2, sigma2)
q_denominator = q1_numerator + q2_numerator
q1 = q1_numerator / q_denominator q2 = q2_numerator / q_denominator
return q1, q2
def M_step(self, dataset, miu1, miu2, q1, q2): """ M步: 计算参数的最大似然估计 """
nk1 = np.sum(q1) nk2 = np.sum(q2)
phi_new_1 = np.sum(q1) / len(q1) phi_new_2 = np.sum(q2) / len(q2)
miu_new_1 = np.dot(q1, dataset) / nk1 miu_new_2 = np.dot(q2, dataset) / nk2
sigma_new_1 = math.sqrt(np.dot(q1, (dataset - miu1) ** 2) / nk1) sigma_new_2 = math.sqrt(np.dot(q2, (dataset - miu2) ** 2) / nk2)
return miu_new_1, miu_new_2, sigma_new_1, sigma_new_2, phi_new_1, phi_new_2
def train(self): dataset = self.creat_gauss_dist() dataset = np.array(dataset)
step = 0
phi_1 = self.phi_1 phi_2 = self.phi_2
miu1 = self.miu2 miu2 = self.miu1
sigma1 = self.sigma2 sigma2 = self.sigma1
while step < self.max_iter:
self.phi_1set.append(phi_1) self.phi_2set.append(phi_2)
self.miu1set.append(miu1) self.miu2set.append(miu2)
self.sigma1set.append(sigma1) self.sigma2set.append(sigma2)
q1, q2 = self.E_step(dataset, phi_1=phi_1, phi_2=phi_2, miu1=miu1, miu2=miu2, sigma1=sigma1, sigma2=sigma2) miu1, miu2, sigma1, sigma2, phi_1, phi_2 = self.M_step(dataset, miu1, miu2, q1, q2) step += 1
return miu1, miu2, sigma1, sigma2, phi_1, phi_2
def draw(self): x_data = np.arange(self.max_iter)
plt.figure() plt.plot(x_data, self.miu1set, color="r", label='miu1', linestyle="solid") plt.plot(x_data, self.miu2set, color="b", label='miu2', linestyle="solid") plt.title("miu Curve", fontsize=10) plt.xlabel('Iteration') plt.ylabel('miu') plt.legend()
plt.figure() plt.plot(x_data, self.sigma1set, color="r", label='sigma1', linestyle="solid") plt.plot(x_data, self.sigma2set, color="b", label='sigma1', linestyle="solid") plt.title("sigma Curve", fontsize=10) plt.xlabel('Iteration') plt.ylabel('sigma') plt.legend()
plt.figure() plt.plot(x_data, self.phi_1set, color="r", label='phi_1', linestyle="solid") plt.plot(x_data, self.phi_2set, color="b", label='phi_2', linestyle="solid") plt.title("phi Curve", fontsize=10) plt.xlabel('Iteration') plt.ylabel('phi') plt.legend()
plt.show()
if __name__ == '__main__':
phi_1 = 0.25 phi_2 = 0.75 miu1 = 100 miu2 = 200 sigma1 = 8 sigma2 = 5 print('数据集参数:phi_1:%.2f, miu1:%.1f, sigma1:%.1f, phi_2:%.2f, miu2:%.1f, sigma2:%.1f' % ( phi_1, miu1, sigma1, phi_2, miu2, sigma2 ))
em = EM( phi_1=0.25, phi_2=0.75, miu1=100, miu2=200, sigma1=8, sigma2=5, dataSize=1000, max_iter=20 )
miu1, miu2, sigma1, sigma2, phi_1, phi_2 = em.train() print('拟合参数:phi_1:%.2f, miu1:%.1f, sigma1:%.1f, phi_2:%.2f, miu2:%.1f, sigma2:%.1f' % ( phi_1, miu1, sigma1, phi_2, miu2, sigma2 ))
em.draw()
|