563 lines
24 KiB
Python
563 lines
24 KiB
Python
import socket
|
||
import time
|
||
import pickle
|
||
import wx # pylint: disable=e0401 # type: ignore
|
||
import CoreAlgorithm
|
||
import threading
|
||
|
||
import DataSearch
|
||
import server_demo
|
||
|
||
sleep_time = 0.2
|
||
|
||
|
||
# 审批回执类消息head:01001001
|
||
# sql密文类消息head:01001010
|
||
# 元数据类消息head:01001100
|
||
|
||
|
||
# 用于生成本地公钥和私钥信息
|
||
def generate_key_pair_data():
|
||
public_key_data, private_key_data = CoreAlgorithm.generate_paillier_keypair()
|
||
return private_key_data, public_key_data
|
||
|
||
|
||
# 从公钥证书中提取公钥信息
|
||
def get_public_key_data(message):
|
||
message = eval(message.replace("\n", ""))
|
||
public_key = (
|
||
message["public_key"]
|
||
.replace("-----BEGIN PUBLIC KEY-----", "")
|
||
.replace("-----END PUBLIC KEY-----", "")
|
||
) # 分割得到公钥
|
||
public_key_bytes = bytes.fromhex(public_key)
|
||
public_key_data = pickle.loads(public_key_bytes)
|
||
return public_key_data
|
||
|
||
|
||
# 用公钥为本地生成数字证书
|
||
def get_certificate(temp_public_key, cert_name):
|
||
public_key_str = "\n".join(
|
||
temp_public_key[i : i + 60] for i in range(0, len(temp_public_key), 60)
|
||
)
|
||
pack_public_key = (
|
||
"-----BEGIN PUBLIC KEY-----\n" + public_key_str + "\n-----END PUBLIC KEY-----\n"
|
||
)
|
||
cert = {"public_key": pack_public_key, "name": cert_name}
|
||
return cert
|
||
|
||
|
||
# 加密字符串
|
||
def str_to_encrypt(message, public_data):
|
||
# str 转 int
|
||
if message.isdigit():
|
||
int_message = int(message)
|
||
else:
|
||
int_message = int.from_bytes(message.encode(), "big")
|
||
enc_message = public_data.encrypt(int_message)
|
||
print("int_message", int_message)
|
||
return enc_message
|
||
|
||
|
||
class MyServer(server_demo.MyFrame):
|
||
def __init__(self, parent):
|
||
server_demo.MyFrame.__init__(self, parent)
|
||
# 生成私钥和公钥信息
|
||
self.private_key_data, self.public_key_data = generate_key_pair_data()
|
||
# 生成私钥和公钥字符串
|
||
self.private_key, self.public_key = self.generate_key_pair()
|
||
# 获取数字证书
|
||
self.certificate = get_certificate(self.public_key, "安全多方服务器")
|
||
# 初始化当前sql
|
||
self.sql = ""
|
||
# 初始化sql拆分对象
|
||
self.divide_sqls = []
|
||
# 初始化sql拆分数据源
|
||
self.divide_providers = []
|
||
# 记录属于同一个请求的元数据密文
|
||
self.datas = []
|
||
# 初始化数据查询方的公钥证书为str,在发来过后为其赋值
|
||
self.search_cert = ""
|
||
# 初始化socket
|
||
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
# 绑定IP地址和端口
|
||
self.server.bind(("localhost", 1111))
|
||
# 设置最大监听数
|
||
self.server.listen(5)
|
||
# 设置一个字典,用来保存每一个客户端的连接和身份信息
|
||
self.socket_mapping = {} # temp_socket: [addr, 公钥信息]
|
||
# 设置接收的最大字节数
|
||
self.maxSize = 4096
|
||
# 记录调用方地址
|
||
self.source = None
|
||
# 记录收集信息的数量
|
||
self.flag = 0
|
||
# 记录需要收集的信息总量
|
||
self.total = 0
|
||
# 保存安全多方计算结果
|
||
self.result = 0
|
||
self.out_look()
|
||
# 等待客户端连接
|
||
message_p = "等待客户端连接..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
# 添加消息记录列
|
||
def out_look(self):
|
||
self.m_listCtrl1.InsertColumn(0, "消息记录") # 为聊天框添加‘消息记录’列
|
||
self.m_listCtrl1.SetColumnWidth(0, 1000)
|
||
self.m_listCtrl2.InsertColumn(0, "消息记录") # 为聊天框添加‘消息记录’列
|
||
self.m_listCtrl2.SetColumnWidth(0, 1000)
|
||
self.m_listCtrl3.InsertColumn(0, "消息记录") # 为聊天框添加‘消息记录’列
|
||
self.m_listCtrl3.SetColumnWidth(0, 1000)
|
||
|
||
# 建立连接,设置线程
|
||
def run(self):
|
||
while True:
|
||
client_socket, addr = self.server.accept()
|
||
# 发送信息,提示客户端已成功连接
|
||
message_p = "与{0}连接成功!".format(addr)
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
# 将客户端socket等信息存入字典
|
||
self.socket_mapping[client_socket] = []
|
||
self.socket_mapping[client_socket].append(addr)
|
||
message_p = "安全多方计算平台:单向授权成功!"
|
||
client_socket.send(message_p.encode())
|
||
self.m_listCtrl2.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
# 创建线程,负责接收客户端信息并转发给其他客户端
|
||
threading.Thread(
|
||
target=self.recv_from_client, args=(client_socket,)
|
||
).start()
|
||
|
||
# 产生公私钥字符串
|
||
def generate_key_pair(self):
|
||
# Paillier 转 bytes
|
||
public_key_bytes = pickle.dumps(self.public_key_data)
|
||
private_key_bytes = pickle.dumps(self.private_key_data)
|
||
# bytes 转 str
|
||
public_key = public_key_bytes.hex()
|
||
private_key = private_key_bytes.hex()
|
||
return private_key, public_key
|
||
|
||
# 接收客户端消息并转发
|
||
def recv_from_client(self, client_socket): # client_socket指的是连接到的端口socket
|
||
while True:
|
||
message = client_socket.recv(self.maxSize).decode("utf-8")
|
||
message_p = "接收到来自{0}的message:\n{1}".format(
|
||
self.socket_mapping[client_socket][0], message
|
||
)
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p]) # 准备文件传输
|
||
if message.startswith("01001001"):
|
||
message_p = "正在解析消息内容..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
# 去掉审批回执的头部
|
||
message = message.split("01001001", 1)[
|
||
1
|
||
] # [0]是空字符串,已测试。只切一次是防止密文出现和头部一样的信息被误切。
|
||
# 认证发送者的合法性
|
||
sender = message.split("||")[0]
|
||
context = message.split("||")[1]
|
||
message_p = "接收到来自{0}的message:\n{1}".format(sender, context)
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p])
|
||
# time.sleep(sleep_time)
|
||
self.flag += 1
|
||
if context == "想得美哦!不给(*^▽^*)":
|
||
self.flag = -9999
|
||
elif message.startswith("01001010"):
|
||
message_p = "正在解析消息内容..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
# time.sleep(sleep_time)
|
||
# 认证发送者的合法性
|
||
if len(self.socket_mapping[client_socket]) > 1: # 如果发送者的公钥信息已被收集
|
||
# 识别明文中一起发送的发送目标 明文应是发送者||发送内容(||时间戳等),对象用socket表示吧...
|
||
message = message.split("01001010", 1)[
|
||
1
|
||
] # [0]是空字符串,已测试。只切一次是防止密文出现和头部一样的信息被误切。
|
||
# 用发送目标之前给的公钥加密明文,得到密文。
|
||
# 去掉sql密文的头部
|
||
# 使用平台私钥解密消息,获得sql明文
|
||
message_p = "接收到的sql密文:" + message
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
int_message = int(message)
|
||
# Ciphertext转EncryptedNumber
|
||
Encrpt_message = CoreAlgorithm.EncryptedNumber(
|
||
self.public_key_data, int_message, exponent=0
|
||
)
|
||
dec_int_message = self.private_key_data.decrypt(Encrpt_message)
|
||
dec_message = dec_int_message.to_bytes(
|
||
(dec_int_message.bit_length() + 7) // 8, "big"
|
||
).decode()
|
||
message_p = "解密后的消息为:" + dec_message
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
self.source = dec_message.split("||")[0]
|
||
self.sql = dec_message.split("||")[1]
|
||
message_p = "收到已授权方的sql:" + self.sql
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
message_p = "待处理的sql语句为:" + self.sql
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
message_p = "正在拆分sql..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
self.divide_providers = DataSearch.extract_tables(self.sql)
|
||
message_p = "涉及到的数据持有对象有:" + str(self.divide_providers)
|
||
print(message_p)
|
||
self.m_listCtrl3.Append(([message_p]))
|
||
self.divide_sqls = DataSearch.divide_sql(self.sql)
|
||
|
||
message_p = "正在分别加密和封装sql..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
self.pack_sqls()
|
||
message_p = "发送成功!"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
else:
|
||
message_p = "非授权对象,禁止访问!"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
# show_list()
|
||
|
||
elif message.startswith("01001100"):
|
||
message_p = "正在解析消息内容..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
# 去掉元数据头部
|
||
message = message.split("01001100", 1)[
|
||
1
|
||
] # [0]是空字符串,已测试。只切一次是防止密文出现和头部一样的信息被误切。
|
||
# 把元数据存入本地数据库的临时表里,格式:provider + encode_data
|
||
# print("message:", message)
|
||
int_message = int(message)
|
||
message_p = "收到元数据为:{}".format(int_message)
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p])
|
||
# 根据证书找公钥信息
|
||
search_public_key = (
|
||
self.search_cert["public_key"]
|
||
.split("-----END PUBLIC KEY-----")[0]
|
||
.split("-----BEGIN PUBLIC KEY-----")[1]
|
||
)
|
||
search_public_key = search_public_key.replace("\n", "")
|
||
print("提取到的search_public_key:")
|
||
print(search_public_key)
|
||
|
||
# str转bytes
|
||
byte = bytes.fromhex(search_public_key)
|
||
# bytes转PaillierPublicKey
|
||
search_public_key_data = pickle.loads(byte)
|
||
print("对应的search_public_key_data:")
|
||
print(search_public_key_data)
|
||
# int密 -- EncryptedNumber密
|
||
Encrpt_message = CoreAlgorithm.EncryptedNumber(
|
||
search_public_key_data, int_message, exponent=0
|
||
)
|
||
self.datas.append(Encrpt_message)
|
||
# Ciphertext转EncryptedNumber
|
||
# print("int_message:", int_message)
|
||
# Encrpt_message = CoreAlgorithm.EncryptedNumber(self.public_key_data, int_message, exponent=0)
|
||
# print("Enc:", Encrpt_message)
|
||
# dec_int_message = self.private_key_data.decrypt(Encrpt_message)
|
||
# print("dec:", dec_int_message)
|
||
# dec_message = dec_int_message.to_bytes((dec_int_message.bit_length() + 7) // 8, 'big').decode()
|
||
# print("解密后的消息为:", dec_message) # 已测试,说明平台不可解密元数据
|
||
self.safe_calculate()
|
||
|
||
elif message == "请发送平台的完整公钥证书":
|
||
message_p = "正在发送证书..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
message_p = str(self.certificate)
|
||
client_socket.send(message_p.encode()) # 二进制传输
|
||
self.m_listCtrl2.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
message_p = "发送完成!"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
elif message.startswith("{'public_key':"):
|
||
print(message)
|
||
print(self.socket_mapping)
|
||
print(get_public_key_data(message))
|
||
self.socket_mapping[client_socket].append(
|
||
get_public_key_data(message)
|
||
) # 绑定端口与公钥信息的关系
|
||
print(self.socket_mapping)
|
||
cert = eval(message)
|
||
print(cert, type(cert)) # 字典型
|
||
if cert["name"].startswith("<数据查询方>"):
|
||
self.search_cert = cert # 字典型
|
||
self.socket_mapping[client_socket].append(cert["name"]) # 绑定端口与用户身份的关系
|
||
message_p = "接收到一则公钥证书:"
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
message_p = ""
|
||
for key, value in cert.items():
|
||
if isinstance(value, bytes):
|
||
value = value.decode()
|
||
print(key, ":", value)
|
||
message_p = key + ":\n" + value
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
message_p = "发送完成!"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
message_p = "安全多方计算平台:双向授权成功!"
|
||
print(message_p)
|
||
self.m_listCtrl2.Append([message_p])
|
||
client_socket.send(message_p.encode("utf-8")) # 二进制传输
|
||
time.sleep(sleep_time)
|
||
|
||
message_p = "使用对象表已更新:"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
print(self.socket_mapping)
|
||
self.m_listCtrl3.Append([self.socket_mapping])
|
||
elif message == "请求查询方信息!":
|
||
if str(self.search_cert) != "":
|
||
message_p = "正在发送提供方的证书..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
message_p = str(self.search_cert)
|
||
client_socket.send(message_p.encode()) # 二进制传输
|
||
self.m_listCtrl2.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
else:
|
||
client_socket.send(" ".encode()) # 二进制传输
|
||
|
||
for key, value in self.socket_mapping.items():
|
||
message_p = str(key) + ":" + str(value) + "\n"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
elif message == "":
|
||
pass
|
||
elif message == "认证成功!":
|
||
pass
|
||
else:
|
||
message_p = "Message:" + message
|
||
print(message_p)
|
||
self.m_listCtrl1.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
message_p = "错误的消息格式,丢弃!"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
# 打印公钥
|
||
def print_public_key(self, event):
|
||
message_p = "本地公钥信息为:"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
message_p = self.public_key_data
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
message_p = "打印公钥如下:"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
public_key_str = "\n".join(
|
||
self.public_key[i : i + 60] for i in range(0, len(self.public_key), 60)
|
||
)
|
||
pack_public_key = (
|
||
"-----BEGIN PUBLIC KEY-----\n"
|
||
+ public_key_str
|
||
+ "\n-----END PUBLIC KEY-----\n"
|
||
)
|
||
message_p = pack_public_key
|
||
print(message_p)
|
||
message = message_p.split("\n") # 设置打印格式,因为显示窗打印不了\n
|
||
for i in range(len(message)):
|
||
self.m_listCtrl3.Append([message[i]])
|
||
# show_list()
|
||
|
||
# 打印私钥
|
||
def print_private_key(self, event):
|
||
message_p = "本地私钥信息为:"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
message_p = self.private_key_data
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
private_key_str = "\n".join(
|
||
self.private_key[i : i + 60] for i in range(0, len(self.public_key), 60)
|
||
)
|
||
pack_private_key = (
|
||
"-----BEGIN PRIVATE KEY-----\n"
|
||
+ private_key_str
|
||
+ "\n-----END PRIVATE KEY-----\n"
|
||
)
|
||
|
||
message_p = "打印私钥如下:"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
message_p = pack_private_key
|
||
print(message_p)
|
||
message = message_p.split("\n") # 设置打印格式,因为显示窗打印不了\n
|
||
for i in range(len(message)):
|
||
self.m_listCtrl3.Append([message[i]])
|
||
|
||
# 打印证书
|
||
def print_certificate(self, event):
|
||
message_p = "本地公钥证书为:"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
for key, value in self.certificate.items():
|
||
message_p = key + ":"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
if key == "public_key":
|
||
value = value.split("\n")
|
||
for i in range(len(value)):
|
||
self.m_listCtrl3.Append([value[i]])
|
||
else:
|
||
self.m_listCtrl3.Append([value])
|
||
|
||
# show_list()
|
||
|
||
# 加密封装sqls并发送
|
||
def pack_sqls(self):
|
||
for key, value in self.socket_mapping.items():
|
||
for i in range(len(self.divide_providers)):
|
||
if (
|
||
self.divide_providers[i] in value[2]
|
||
): # eg: value[2] == "<数据提供方>风舱医院"
|
||
for j in range(len(self.divide_sqls)):
|
||
if (
|
||
self.divide_providers[i] in self.divide_sqls[j]
|
||
): # 如果发送目标和信息匹配)
|
||
sql = (
|
||
str(self.source)
|
||
+ "||"
|
||
+ str(key.getsockname())
|
||
+ "||"
|
||
+ self.divide_sqls[i]
|
||
)
|
||
print(sql)
|
||
int_enc_sql = str_to_encrypt(
|
||
sql, value[1]
|
||
).ciphertext() # 用接收者的公钥加密消息
|
||
message_p = "01001010" + str(int_enc_sql)
|
||
key.send(message_p.encode())
|
||
self.m_listCtrl2.Append([message_p])
|
||
message_p = "已将消息{0}发送给{1},其地址为{2}".format(
|
||
self.divide_sqls[j],
|
||
self.divide_providers[i],
|
||
key.getsockname(),
|
||
)
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
|
||
# 安全算法
|
||
def safe_calculate(self):
|
||
self.total = len(self.divide_providers)
|
||
if self.flag == self.total:
|
||
message_p = "正在进行安全多方计算分析..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
# #################安全多方计算分析过程##############
|
||
if "select count(*) from xx阳光社区诊所, xx大学附属医院" in self.sql:
|
||
print(self.datas)
|
||
for x in self.datas:
|
||
self.result = self.result + x
|
||
# EncryptedNumber密 -- int密
|
||
self.result = self.result.ciphertext()
|
||
# int密 -- str密
|
||
self.result = str(self.result)
|
||
message = "分析成功!"
|
||
print(message)
|
||
self.m_listCtrl3.Append([message])
|
||
# EncryptedNumber 转 int
|
||
# self.result = self.result.ciphertext()
|
||
message = "结果:" + str(self.result)
|
||
print(message)
|
||
self.m_listCtrl3.Append([message])
|
||
time.sleep(sleep_time)
|
||
|
||
message_p = "正在发送结果给申请人..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
for key, value in self.socket_mapping.items():
|
||
if value[2].startswith("<数据查询方>"):
|
||
key.send(message.encode())
|
||
# 重置参数
|
||
self.total = 0
|
||
self.flag = 0
|
||
self.result = 0
|
||
self.datas = []
|
||
elif self.flag < 0:
|
||
message_p = "结果:已有数据持有方拒绝了提供消息的请求,安全分析无法进行,分析失败!"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
for key, value in self.socket_mapping.items():
|
||
if value[2].startswith("<数据查询方>"):
|
||
key.send(message_p.encode())
|
||
# 重置参数
|
||
self.total = 0
|
||
self.flag = 0
|
||
else:
|
||
message_p = "结果:正在等待接收其他信息..."
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
for key, value in self.socket_mapping.items():
|
||
if value[2].startswith("<数据查询方>"):
|
||
key.send(message_p.encode())
|
||
message_p = "发送完成!"
|
||
print(message_p)
|
||
self.m_listCtrl3.Append([message_p])
|
||
time.sleep(sleep_time)
|
||
|
||
|
||
# ----------------------------------------------------主程序----------------------------------------------------
|
||
app = wx.App()
|
||
frame = MyServer(None)
|
||
frame.Show(True) # 展示登录页面
|
||
threading.Thread(target=frame.run).start() # 在新线程中运行服务器
|
||
app.MainLoop()
|