mimajingsai_2/server.py

563 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import socket
import time
import pickle
import wx # pylint: disable=e0401 # type: ignore
import CoreAlgorithm
import threading
import DataSearch
import server_demo
sleep_time = 0.2
# 审批回执类消息head01001001
# sql密文类消息head01001010
# 元数据类消息head01001100
# 用于生成本地公钥和私钥信息
def generate_key_pair_data():
public_key_data, private_key_data = CoreAlgorithm.generate_paillier_keypair()
return private_key_data, public_key_data
# 从公钥证书中提取公钥信息
def get_public_key_data(message):
message = eval(message.replace("\n", ""))
public_key = (
message["public_key"]
.replace("-----BEGIN PUBLIC KEY-----", "")
.replace("-----END PUBLIC KEY-----", "")
) # 分割得到公钥
public_key_bytes = bytes.fromhex(public_key)
public_key_data = pickle.loads(public_key_bytes)
return public_key_data
# 用公钥为本地生成数字证书
def get_certificate(temp_public_key, cert_name):
public_key_str = "\n".join(
temp_public_key[i : i + 60] for i in range(0, len(temp_public_key), 60)
)
pack_public_key = (
"-----BEGIN PUBLIC KEY-----\n" + public_key_str + "\n-----END PUBLIC KEY-----\n"
)
cert = {"public_key": pack_public_key, "name": cert_name}
return cert
# 加密字符串
def str_to_encrypt(message, public_data):
# str 转 int
if message.isdigit():
int_message = int(message)
else:
int_message = int.from_bytes(message.encode(), "big")
enc_message = public_data.encrypt(int_message)
print("int_message", int_message)
return enc_message
class MyServer(server_demo.MyFrame):
def __init__(self, parent):
server_demo.MyFrame.__init__(self, parent)
# 生成私钥和公钥信息
self.private_key_data, self.public_key_data = generate_key_pair_data()
# 生成私钥和公钥字符串
self.private_key, self.public_key = self.generate_key_pair()
# 获取数字证书
self.certificate = get_certificate(self.public_key, "安全多方服务器")
# 初始化当前sql
self.sql = ""
# 初始化sql拆分对象
self.divide_sqls = []
# 初始化sql拆分数据源
self.divide_providers = []
# 记录属于同一个请求的元数据密文
self.datas = []
# 初始化数据查询方的公钥证书为str在发来过后为其赋值
self.search_cert = ""
# 初始化socket
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 绑定IP地址和端口
self.server.bind(("localhost", 1111))
# 设置最大监听数
self.server.listen(5)
# 设置一个字典,用来保存每一个客户端的连接和身份信息
self.socket_mapping = {} # temp_socket: [addr, 公钥信息]
# 设置接收的最大字节数
self.maxSize = 4096
# 记录调用方地址
self.source = None
# 记录收集信息的数量
self.flag = 0
# 记录需要收集的信息总量
self.total = 0
# 保存安全多方计算结果
self.result = 0
self.out_look()
# 等待客户端连接
message_p = "等待客户端连接..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# 添加消息记录列
def out_look(self):
self.m_listCtrl1.InsertColumn(0, "消息记录") # 为聊天框添加‘消息记录’列
self.m_listCtrl1.SetColumnWidth(0, 1000)
self.m_listCtrl2.InsertColumn(0, "消息记录") # 为聊天框添加‘消息记录’列
self.m_listCtrl2.SetColumnWidth(0, 1000)
self.m_listCtrl3.InsertColumn(0, "消息记录") # 为聊天框添加‘消息记录’列
self.m_listCtrl3.SetColumnWidth(0, 1000)
# 建立连接,设置线程
def run(self):
while True:
client_socket, addr = self.server.accept()
# 发送信息,提示客户端已成功连接
message_p = "{0}连接成功!".format(addr)
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# 将客户端socket等信息存入字典
self.socket_mapping[client_socket] = []
self.socket_mapping[client_socket].append(addr)
message_p = "安全多方计算平台:单向授权成功!"
client_socket.send(message_p.encode())
self.m_listCtrl2.Append([message_p])
time.sleep(sleep_time)
# 创建线程,负责接收客户端信息并转发给其他客户端
threading.Thread(
target=self.recv_from_client, args=(client_socket,)
).start()
# 产生公私钥字符串
def generate_key_pair(self):
# Paillier 转 bytes
public_key_bytes = pickle.dumps(self.public_key_data)
private_key_bytes = pickle.dumps(self.private_key_data)
# bytes 转 str
public_key = public_key_bytes.hex()
private_key = private_key_bytes.hex()
return private_key, public_key
# 接收客户端消息并转发
def recv_from_client(self, client_socket): # client_socket指的是连接到的端口socket
while True:
message = client_socket.recv(self.maxSize).decode("utf-8")
message_p = "接收到来自{0}的message:\n{1}".format(
self.socket_mapping[client_socket][0], message
)
print(message_p)
self.m_listCtrl1.Append([message_p]) # 准备文件传输
if message.startswith("01001001"):
message_p = "正在解析消息内容..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# 去掉审批回执的头部
message = message.split("01001001", 1)[
1
] # [0]是空字符串,已测试。只切一次是防止密文出现和头部一样的信息被误切。
# 认证发送者的合法性
sender = message.split("||")[0]
context = message.split("||")[1]
message_p = "接收到来自{0}的message:\n{1}".format(sender, context)
print(message_p)
self.m_listCtrl1.Append([message_p])
# time.sleep(sleep_time)
self.flag += 1
if context == "想得美哦!不给(*^▽^*)":
self.flag = -9999
elif message.startswith("01001010"):
message_p = "正在解析消息内容..."
print(message_p)
self.m_listCtrl3.Append([message_p])
# time.sleep(sleep_time)
# 认证发送者的合法性
if len(self.socket_mapping[client_socket]) > 1: # 如果发送者的公钥信息已被收集
# 识别明文中一起发送的发送目标 明文应是发送者||发送内容(||时间戳等),对象用socket表示吧...
message = message.split("01001010", 1)[
1
] # [0]是空字符串,已测试。只切一次是防止密文出现和头部一样的信息被误切。
# 用发送目标之前给的公钥加密明文,得到密文。
# 去掉sql密文的头部
# 使用平台私钥解密消息,获得sql明文
message_p = "接收到的sql密文:" + message
print(message_p)
self.m_listCtrl1.Append([message_p])
time.sleep(sleep_time)
int_message = int(message)
# Ciphertext转EncryptedNumber
Encrpt_message = CoreAlgorithm.EncryptedNumber(
self.public_key_data, int_message, exponent=0
)
dec_int_message = self.private_key_data.decrypt(Encrpt_message)
dec_message = dec_int_message.to_bytes(
(dec_int_message.bit_length() + 7) // 8, "big"
).decode()
message_p = "解密后的消息为:" + dec_message
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
self.source = dec_message.split("||")[0]
self.sql = dec_message.split("||")[1]
message_p = "收到已授权方的sql:" + self.sql
print(message_p)
self.m_listCtrl1.Append([message_p])
time.sleep(sleep_time)
message_p = "待处理的sql语句为:" + self.sql
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
message_p = "正在拆分sql..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
self.divide_providers = DataSearch.extract_tables(self.sql)
message_p = "涉及到的数据持有对象有:" + str(self.divide_providers)
print(message_p)
self.m_listCtrl3.Append(([message_p]))
self.divide_sqls = DataSearch.divide_sql(self.sql)
message_p = "正在分别加密和封装sql..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
self.pack_sqls()
message_p = "发送成功!"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
else:
message_p = "非授权对象,禁止访问!"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# show_list()
elif message.startswith("01001100"):
message_p = "正在解析消息内容..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# 去掉元数据头部
message = message.split("01001100", 1)[
1
] # [0]是空字符串,已测试。只切一次是防止密文出现和头部一样的信息被误切。
# 把元数据存入本地数据库的临时表里格式provider + encode_data
# print("message:", message)
int_message = int(message)
message_p = "收到元数据为:{}".format(int_message)
print(message_p)
self.m_listCtrl1.Append([message_p])
# 根据证书找公钥信息
search_public_key = (
self.search_cert["public_key"]
.split("-----END PUBLIC KEY-----")[0]
.split("-----BEGIN PUBLIC KEY-----")[1]
)
search_public_key = search_public_key.replace("\n", "")
print("提取到的search_public_key:")
print(search_public_key)
# str转bytes
byte = bytes.fromhex(search_public_key)
# bytes转PaillierPublicKey
search_public_key_data = pickle.loads(byte)
print("对应的search_public_key_data:")
print(search_public_key_data)
# int密 -- EncryptedNumber密
Encrpt_message = CoreAlgorithm.EncryptedNumber(
search_public_key_data, int_message, exponent=0
)
self.datas.append(Encrpt_message)
# Ciphertext转EncryptedNumber
# print("int_message:", int_message)
# Encrpt_message = CoreAlgorithm.EncryptedNumber(self.public_key_data, int_message, exponent=0)
# print("Enc:", Encrpt_message)
# dec_int_message = self.private_key_data.decrypt(Encrpt_message)
# print("dec:", dec_int_message)
# dec_message = dec_int_message.to_bytes((dec_int_message.bit_length() + 7) // 8, 'big').decode()
# print("解密后的消息为:", dec_message) # 已测试,说明平台不可解密元数据
self.safe_calculate()
elif message == "请发送平台的完整公钥证书":
message_p = "正在发送证书..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
message_p = str(self.certificate)
client_socket.send(message_p.encode()) # 二进制传输
self.m_listCtrl2.Append([message_p])
time.sleep(sleep_time)
message_p = "发送完成!"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
elif message.startswith("{'public_key':"):
print(message)
print(self.socket_mapping)
print(get_public_key_data(message))
self.socket_mapping[client_socket].append(
get_public_key_data(message)
) # 绑定端口与公钥信息的关系
print(self.socket_mapping)
cert = eval(message)
print(cert, type(cert)) # 字典型
if cert["name"].startswith("<数据查询方>"):
self.search_cert = cert # 字典型
self.socket_mapping[client_socket].append(cert["name"]) # 绑定端口与用户身份的关系
message_p = "接收到一则公钥证书:"
print(message_p)
self.m_listCtrl1.Append([message_p])
time.sleep(sleep_time)
message_p = ""
for key, value in cert.items():
if isinstance(value, bytes):
value = value.decode()
print(key, ":", value)
message_p = key + ":\n" + value
print(message_p)
self.m_listCtrl1.Append([message_p])
time.sleep(sleep_time)
message_p = "发送完成!"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
message_p = "安全多方计算平台:双向授权成功!"
print(message_p)
self.m_listCtrl2.Append([message_p])
client_socket.send(message_p.encode("utf-8")) # 二进制传输
time.sleep(sleep_time)
message_p = "使用对象表已更新:"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
print(self.socket_mapping)
self.m_listCtrl3.Append([self.socket_mapping])
elif message == "请求查询方信息!":
if str(self.search_cert) != "":
message_p = "正在发送提供方的证书..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
message_p = str(self.search_cert)
client_socket.send(message_p.encode()) # 二进制传输
self.m_listCtrl2.Append([message_p])
time.sleep(sleep_time)
else:
client_socket.send(" ".encode()) # 二进制传输
for key, value in self.socket_mapping.items():
message_p = str(key) + ":" + str(value) + "\n"
print(message_p)
self.m_listCtrl3.Append([message_p])
elif message == "":
pass
elif message == "认证成功!":
pass
else:
message_p = "Message:" + message
print(message_p)
self.m_listCtrl1.Append([message_p])
time.sleep(sleep_time)
message_p = "错误的消息格式,丢弃!"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# 打印公钥
def print_public_key(self, event):
message_p = "本地公钥信息为:"
print(message_p)
self.m_listCtrl3.Append([message_p])
message_p = self.public_key_data
print(message_p)
self.m_listCtrl3.Append([message_p])
message_p = "打印公钥如下:"
print(message_p)
self.m_listCtrl3.Append([message_p])
public_key_str = "\n".join(
self.public_key[i : i + 60] for i in range(0, len(self.public_key), 60)
)
pack_public_key = (
"-----BEGIN PUBLIC KEY-----\n"
+ public_key_str
+ "\n-----END PUBLIC KEY-----\n"
)
message_p = pack_public_key
print(message_p)
message = message_p.split("\n") # 设置打印格式,因为显示窗打印不了\n
for i in range(len(message)):
self.m_listCtrl3.Append([message[i]])
# show_list()
# 打印私钥
def print_private_key(self, event):
message_p = "本地私钥信息为:"
print(message_p)
self.m_listCtrl3.Append([message_p])
message_p = self.private_key_data
print(message_p)
self.m_listCtrl3.Append([message_p])
private_key_str = "\n".join(
self.private_key[i : i + 60] for i in range(0, len(self.public_key), 60)
)
pack_private_key = (
"-----BEGIN PRIVATE KEY-----\n"
+ private_key_str
+ "\n-----END PRIVATE KEY-----\n"
)
message_p = "打印私钥如下:"
print(message_p)
self.m_listCtrl3.Append([message_p])
message_p = pack_private_key
print(message_p)
message = message_p.split("\n") # 设置打印格式,因为显示窗打印不了\n
for i in range(len(message)):
self.m_listCtrl3.Append([message[i]])
# 打印证书
def print_certificate(self, event):
message_p = "本地公钥证书为:"
print(message_p)
self.m_listCtrl3.Append([message_p])
for key, value in self.certificate.items():
message_p = key + ":"
print(message_p)
self.m_listCtrl3.Append([message_p])
if key == "public_key":
value = value.split("\n")
for i in range(len(value)):
self.m_listCtrl3.Append([value[i]])
else:
self.m_listCtrl3.Append([value])
# show_list()
# 加密封装sqls并发送
def pack_sqls(self):
for key, value in self.socket_mapping.items():
for i in range(len(self.divide_providers)):
if (
self.divide_providers[i] in value[2]
): # eg: value[2] == "<数据提供方>风舱医院"
for j in range(len(self.divide_sqls)):
if (
self.divide_providers[i] in self.divide_sqls[j]
): # 如果发送目标和信息匹配)
sql = (
str(self.source)
+ "||"
+ str(key.getsockname())
+ "||"
+ self.divide_sqls[i]
)
print(sql)
int_enc_sql = str_to_encrypt(
sql, value[1]
).ciphertext() # 用接收者的公钥加密消息
message_p = "01001010" + str(int_enc_sql)
key.send(message_p.encode())
self.m_listCtrl2.Append([message_p])
message_p = "已将消息{0}发送给{1},其地址为{2}".format(
self.divide_sqls[j],
self.divide_providers[i],
key.getsockname(),
)
print(message_p)
self.m_listCtrl3.Append([message_p])
# 安全算法
def safe_calculate(self):
self.total = len(self.divide_providers)
if self.flag == self.total:
message_p = "正在进行安全多方计算分析..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# #################安全多方计算分析过程##############
if "select count(*) from xx阳光社区诊所, xx大学附属医院" in self.sql:
print(self.datas)
for x in self.datas:
self.result = self.result + x
# EncryptedNumber密 -- int密
self.result = self.result.ciphertext()
# int密 -- str密
self.result = str(self.result)
message = "分析成功!"
print(message)
self.m_listCtrl3.Append([message])
# EncryptedNumber 转 int
# self.result = self.result.ciphertext()
message = "结果:" + str(self.result)
print(message)
self.m_listCtrl3.Append([message])
time.sleep(sleep_time)
message_p = "正在发送结果给申请人..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
for key, value in self.socket_mapping.items():
if value[2].startswith("<数据查询方>"):
key.send(message.encode())
# 重置参数
self.total = 0
self.flag = 0
self.result = 0
self.datas = []
elif self.flag < 0:
message_p = "结果:已有数据持有方拒绝了提供消息的请求,安全分析无法进行,分析失败!"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
for key, value in self.socket_mapping.items():
if value[2].startswith("<数据查询方>"):
key.send(message_p.encode())
# 重置参数
self.total = 0
self.flag = 0
else:
message_p = "结果:正在等待接收其他信息..."
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
for key, value in self.socket_mapping.items():
if value[2].startswith("<数据查询方>"):
key.send(message_p.encode())
message_p = "发送完成!"
print(message_p)
self.m_listCtrl3.Append([message_p])
time.sleep(sleep_time)
# ----------------------------------------------------主程序----------------------------------------------------
app = wx.App()
frame = MyServer(None)
frame.Show(True) # 展示登录页面
threading.Thread(target=frame.run).start() # 在新线程中运行服务器
app.MainLoop()