添加 k-means2
This commit is contained in:
parent
fc676b494d
commit
1eaa6c38b0
48
k-means2
Normal file
48
k-means2
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
plt.rcParams['figure.figsize'] = (16,9)
|
||||||
|
plt.style.use('ggplot')
|
||||||
|
data = pd.read_csv('xclara.csv')
|
||||||
|
f1 = data['V1'].values
|
||||||
|
f2 = data['V2'].values
|
||||||
|
X = np.array(list(zip(f1, f2)))
|
||||||
|
# 距离计算函数
|
||||||
|
def dist(a, b, ax=1):
|
||||||
|
return np.linalg.norm(a - b, axis=ax)
|
||||||
|
|
||||||
|
# 设置聚类数
|
||||||
|
k = 3
|
||||||
|
|
||||||
|
# 随机初始化质心(修正:使用数据范围)
|
||||||
|
C_x = np.random.randint(0,np.max(X)-20, size=k)
|
||||||
|
C_y = np.random.randint(0,np.max(X)-20, size=k)
|
||||||
|
C = np.array(list(zip(C_x, C_y)), dtype=np.float32)
|
||||||
|
|
||||||
|
C_old = np.zeros(C.shape)
|
||||||
|
print(C)
|
||||||
|
clusters = np.zeros(len(X))
|
||||||
|
iteration_flag = dist(C,C_old,1)
|
||||||
|
tmp = 1
|
||||||
|
while iteration_flag.any() != 0 and tmp<20:
|
||||||
|
for i in range(len(X)):
|
||||||
|
distances = dist(X[i],C,1)
|
||||||
|
clusters[i] = clusters
|
||||||
|
C_old = deepcopy(C)
|
||||||
|
for i in range(C):
|
||||||
|
points = [X[j] for j in range(len(X)) if clusters[j] == i]
|
||||||
|
C[i] = np.mean(points,axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
print('%d'%tmp)
|
||||||
|
tmp = tmp + 1
|
||||||
|
iteraction_flag = dist(C,C_old,1)
|
||||||
|
print('distance:',iteraction_flag)
|
||||||
|
colors = ['r','g','b','y','c','m']
|
||||||
|
fig,ax = plt.subplots()
|
||||||
|
for i in range(k):
|
||||||
|
points = np.array([X[j] for j in range(len(X) if clusters[j] == i)])
|
||||||
|
ax.scatter(points[:,0],points[:,1],s=7,c=colors[i])
|
||||||
|
ax.scatter(C[:,0],C[:,1],marker="*",s=200,c='black')
|
||||||
|
plt.show()
|
Loading…
x
Reference in New Issue
Block a user