k-means/K-means2
2025-06-14 19:55:33 +08:00

86 lines
2.4 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from copy import deepcopy
# 设置图形样式
plt.rcParams['figure.figsize'] = (16, 9)
plt.style.use('ggplot')
# 创建示例数据并保存为CSV文件
def create_sample_data():
# 生成三个簇的示例数据
np.random.seed(42)
cluster1 = np.random.normal(loc=[0, 0], scale=1, size=(100, 2))
cluster2 = np.random.normal(loc=[10, 5], scale=1.5, size=(100, 2))
cluster3 = np.random.normal(loc=[5, 10], scale=1.2, size=(100, 2))
data = np.vstack([cluster1, cluster2, cluster3])
# 创建DataFrame并保存为CSV
df = pd.DataFrame(data, columns=['V1', 'V2'])
df.to_csv('xclara.csv', index=False)
# 创建示例CSV文件
create_sample_data()
# 从CSV读取数据
data = pd.read_csv('xclara.csv')
f1 = data['V1'].values
f2 = data['V2'].values
X = np.array(list(zip(f1, f2)))
# 距离计算函数
def dist(a, b, ax=1):
return np.linalg.norm(a - b, axis=ax)
# 设置聚类数
k = 3
# 随机初始化质心(修正:使用数据范围)
C_x = np.random.uniform(np.min(f1), np.max(f1), size=k)
C_y = np.random.uniform(np.min(f2), np.max(f2), size=k)
C = np.array(list(zip(C_x, C_y)), dtype=np.float32)
# 绘制初始数据点和质心(修正颜色拼写错误)
plt.scatter(f1, f2, c='black', s=7) # 修正:'balck' -> 'black'
plt.scatter(C_x, C_y, marker='*', s=200, c='red')
plt.title("Initial Data Points and Centroids")
plt.show()
# ---- 可选添加完整的K-Means算法实现 ----
# 复制原始质心用于后续更新
C_old = np.zeros(C.shape)
clusters = np.zeros(len(X))
error = dist(C, C_old, None)
# K-Means迭代
while error != 0:
# 分配点到最近质心
for i in range(len(X)):
distances = dist(X[i], C)
cluster = np.argmin(distances)
clusters[i] = cluster
# 保存旧质心
C_old = deepcopy(C)
# 计算新质心
for i in range(k):
points = [X[j] for j in range(len(X)) if clusters[j] == i]
if points:
C[i] = np.mean(points, axis=0)
# 计算质心移动距离
error = dist(C, C_old, None)
# 绘制最终聚类结果
colors = ['r', 'g', 'b', 'c', 'm', 'y']
fig, ax = plt.subplots()
for i in range(k):
points = np.array([X[j] for j in range(len(X)) if clusters[j] == i])
ax.scatter(points[:, 0], points[:, 1], s=7, c=colors[i])
ax.scatter(C[:, 0], C[:, 1], marker='*', s=200, c='black')
plt.title("Final Clustering Result")
plt.show()