Compare commits

5 Commits

Author SHA1 Message Date
af81cb6c00 new file: .gitignore 2023-06-11 17:18:54 +08:00
c09c326bde modified: README.md 2023-06-11 16:57:41 +08:00
b5a88fa9b5 modified: README.md
new file:   api_delete.png
	new file:   database_delete.png
2023-06-11 16:52:54 +08:00
6cd279e88c modified: README.md
new file:   api_add.png
	modified:   database/initdb.py
	new file:   database_add.png
	new file:   dns1.png
	new file:   dns2.png
	new file:   requirements.txt
	modified:   server/main.py
	modified:   server/serverconf.yaml
	modified:   server/xiaomiandns.py
2023-06-09 17:44:51 +08:00
647c92c8dc new file: .vscode/settings.json
deleted:    client/clientconf.yaml
	deleted:    client/dnssender.py
	deleted:    client/main.py
	modified:   database/initdb.py
	deleted:    database/test.py
	deleted:    node/main.py
	deleted:    node/nodeconf.yaml
	deleted:    node/proxy.py
	modified:   server/xiaomiandns.py
2023-06-07 17:06:37 +08:00
53 changed files with 438 additions and 3126 deletions

View File

@@ -1 +0,0 @@
src/__pycache

View File

@@ -1,23 +0,0 @@
name: Test CI
on:
push:
paths:
- "src/**"
jobs:
test:
name: test speed
runs-on: ubuntu-latest
container:
image: catthehacker/ubuntu:act-latest
steps:
- name: Checkout repository
uses: https://git.mamahaha.work/actions/checkout@v3
# - name: Run script in Docker container
# run: |
# ls $PWD/src
# docker run --rm -v .:/app git.mamahaha.work/sangge/tpre:base ls

21
.gitignore vendored
View File

@@ -1,19 +1,2 @@
.devcontainer
__pycache__
test.py
example.py
ReEncrypt.py
src/temp_message_file
src/temp_key_file
src/client.db
src/server.db
build
src/tpre.cpython-311-x86_64-linux-gnu.so
.vscode
venv
lib
include
database/dns.db
.vscode

6
.gitmodules vendored
View File

@@ -1,6 +0,0 @@
[submodule "gmssl"]
path = gmssl
url = https://github.com/guanzhi/GmSSL.git
[submodule "ecc_rs"]
path = ecc_rs
url = https://git.mamahaha.work/sangge/ecc_rs.git

6
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,6 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

62
FQA.md
View File

@@ -1,62 +0,0 @@
# 答辩问题准备
1. **Q:数据安全与隐私**:您的系统在分布式环境中如何确保数据安全和隐私,尤其是在密钥管理方面的挑战?
**A:** 通过代理服务器来转换密文,使其从一个密钥加密变为另一个密钥加密,而不需要将密文解密为明文;客户端每次启动时会生成新的公私密钥对,在随后的通信中,随机选取一次性对称密钥。
2. **Q:国家安全标准的遵循**:您能详细说明您的系统如何符合中国的国家安全标准,特别是对于政府和金融等敏感领域?
**A:** 使用国产的SM算法减小对外国算法的依赖增强自主控制的能力可以有效防止潜在的后门和安全漏洞。
3. **Q:系统架构**:您的分布式架构如何提高系统的性能和可伸缩性?
**A:** 性能方面,我们采用了门限方案,在解密时不需要接收所有的密文片段;密文重加密和解密使用对称密钥,加快加解密速度。可以认为每次都可以"挑选"较高性能的节点进行计算。可伸缩性方面,动态进行节点的添加与退出,客户端也可以动态更新节点。
4. **Q:算法效率**:您能讨论一下您系统中使用的门限代理重加密技术的效率和计算要求吗?
**A:** 性能开销最大的是重加密密钥的生成和解密其余步骤开销极小计算要求64位CPU512MB内存可以运行Docker
5. **Q:兼容性与鲁棒性**:您的系统如何确保与各种环境(如分布式隐私计算)的兼容性,并且如何处理节点故障或恶意节点?
**A:** 兼容性使用容器化部署保证运行环境的一致性兼容多个操作系统和CPU架构鲁棒性在异常和危险情况下系统生存的能力。比如说计算机软件在输入错误、磁盘故障、网络过载或有意攻击情况下能否不死机、不崩溃就是该软件的鲁棒性。体现对恶意数据进行过滤使用python异常处理机制手动捕获错误并防止程序崩溃退出使用心跳包技术检测节点存活状态。对于故障节点使用心跳包检测节点存活状态对于恶意节点我们尚未实现相关功能但是我们计划使用信誉系统对节点进行评分客户端可向中心服务器反馈恶意节点中心服务器通过相关算法降低恶意节点被分配到的权重如果确定是恶意节点则进行黑名单限制。
6. **Q:实际应用场景**:您的系统在哪些实际场景中可以最有效地使用,这些场景中可能的局限性或挑战是什么?
**A:** 区块链中的分布式计算和数据安全共享场景、数据安全授权、分布式密钥管理;计算资源受限,加密算法缺乏软硬件层面的深度定制优化。
7. **Q:代码结构与模块化**:您的项目代码是如何组织的?您如何确保代码的模块化和可读性?
**A:** 我们先实现了算法部份利用函数调用来实现算法的每个步骤。然后我们按照系统角色功能分别编写了客户端、节点、中心服务器的代码。使用前后端分离。可读性方面我们采用了下划线命名法REST编程风格使用类型提示来确保我们在开发过程中有着良好的可读性。
8. **Q:算法实现**:在代码中,您如何实现门限代理重加密算法?有哪些关键的算法优化或创新点?
**A:** 融合了代理重加密和shamir秘密共享以及混合加密机制使用雅各比坐标系加快椭圆曲线上的点的计算。使用fastapi框架实现原生高性能api以及异步编程。使用sqlite减少计算资源开销免去复杂的数据库配置。
9. **Q:性能测试**:您是否对代码进行了性能测试?测试结果显示的主要瓶颈和优化机会是什么?
**A:** 是的我们使用了line_profiler这个工具进行性能测试。测试结果显示的主要瓶颈在于重加密密钥生成和解密这两个步骤。在这两个步骤中涉及到了椭圆曲线上的计算。我们使用了雅各比坐标系加快了计算速度。(雅各比坐标系通过减少在有限域中进行的昂贵的模逆运算,以及优化乘法操作的数量,使得椭圆曲线上的点加和点倍运算更加高效。这些优化对于加密和解密操作的速度至关重要,特别是在资源受限的环境)。再有sqlite不支持高并发场景在写入时存在资源抢占但是我们使用了消息队列将并发的数据转换为串行提高了写入效率。
10. **Q:安全性考虑**:在代码实现中,您如何处理潜在的安全漏洞,特别是与加密和数据传输相关的?
**A:** 我们使用了参数化查询参数化查询避免了将用户输入直接拼接到SQL语句中。可以使用https加密传输数据。
11. **Q:错误处理和日志记录**:您的代码中是否包含了错误处理和日志记录机制?这些机制如何帮助监控和调试系统?
**A:** 我们使用了python内置的异常处理来处理错误。日志使用的是print函数来输出。我们为代码添加了不同的返回值所以在调试的时候如果出错我们可以很方便的定位错误。同时异常处理保证了程序不会异常退出。
12. **Q:项目是如何应用在区块链中的:** 如题
**A:** 加密数据:将使用私钥加密的密文上传到区块链;授权访问:数据请求方发送访问请求给数据持有方;代理重加密:代理节点重加密密文,数据请求方使用自己的私钥进行解密;所有的访问请求和重加密操作都在区块链上记录,提供了完整的审计追踪功能
13. **Q:简单介绍一下代理重加密的数学原理:** 如题
**A:** 代理重加密为一种三方协议可以简单理解为三个人的DHKE。A将秘密拆分为两个部分一部分给B一部分给代理节点。在重加密时代理节点将这部分密码附加进密文此时B收到可以通过另一部分秘密来恢复完整的秘密。
14. **Q:选择密文攻击,选择明文攻击** 如题
**A:**
15. **Q:国密SM2 SM3 SM4** 如题
**A:**
----------------------------------------------------------------------

128
README.md
View File

@@ -1,99 +1,67 @@
# tpre-python
# xiaomian DNS
本项目是全国密码学竞赛设计项目,该项目是使用python实现tpre算法
本项目是计算机网络的课程设计项目,该项目是一个私有DNS的简单实现。
## 项目原理
使用国密算法实现分布式代理重加密tpre
## 项目结构
.
├── basedockerfile (用于构建base镜像)
├── dockerfile (用于构建应用镜像)
├── doc (开发文档)
├── include (gmssl 的头文件)
├── lib (gmssl 的共享库)
├── LICENSE
├── README_en.md
├── README.md
├── requirements.txt
└── src (程序源码)
在server/xiaomiandns.py中实现了DNSserver和APIserver两个类。通过server/main.py启动实例化的server。配置文件在server/serverconf.yaml
本项目选择sqlite作为数据库存储节点信息等数据。
## 环境依赖
### 直接在实体机安装(未测试)
系统要求:
- Linux
- Windows (需要自行安装gmssl的共享库)
该项目依赖以下软件:
python 3.12
gmssl v3.1.1
gmssl-python 2.2.2
### Docker 版本安装
```bash
apt update && apt install mosh -y
chmod +x install_docker.sh
./install_docker.sh
```
### 开发环境docker版本信息
docker 版本:
- 版本: 24.0.5
- API 版本: 1.43
- Go 版本: go1.20.6
python 3
## 安装步骤
```console
# 安装依赖
pip install -r requirements.txt
### 安装前的准备
本项目依赖gmssl所以请提前安装好。访问 [GmSSL](https://github.com/guanzhi/GmSSL) 可以看到如何安装。
本项目也提供了submodule的方式可以直接使用
```bash
git clone --recurse-submodules https://git.mamahaha.work/sangge/tpre-python.git
chmod +x install_gmssl.sh
./install_docker.sh
```
然后安装必要的python库
```bash
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## Docker 安装
### 使用准备好的base镜像然后自己部署应用
```bash
docker build . -f basedockerfile -t git.mamahaha.work/sangge/tpre:base
(或者 docker pull git.mamahaha.work/sangge/tpre:base)
docker build . -t your_image_name
```
### 使用完整版的docker镜像
```bash
docker pull git.mamahaha.work/sangge/tpre:latest
#初始化数据库
python3 database/initdb.py
```
## 使用说明
```python
python3 server/main.py
```
dns默认端口为53域名api默认端口为81
需要添加数据可以通过post方法向API提交数据
usage: use POST method
/add
data: domian=xxxx&ip=xx.xx.xx.xx
/delete
data: domian=xxxx&ip=xx.xx.xx.xx
注意域名只能以xiaomian作为根域名
详细说明查看开发文档 [docs](doc/README_app_en.md)
# 测试DNS功能
```console
# Linux
sudo apt install dnsutils
dig @DNS_SERVER -p PORT DOMAIN
## 参考文献
# Windows
nslookup DOMAIN DNS_SERVER
```
![DNS1](./dns1.png)
![DNS2](./dns2.png)
# 测试API功能
```console
# Linux
# 添加解析
curl -d "domain=qqqwwweee.xiaomian&ip=123.12.23.34" -X POST http://10.20.117.208:81/add -i
[TPRE Algorithm Blog Post](https://www.cnblogs.com/pam-sh/p/17364656.html#tprelib%E7%AE%97%E6%B3%95)
[Gmssl-python library](https://github.com/GmSSL/GmSSL-Python)
# 删除解析
curl -d "domain=qqqwwweee.xiaomian&ip=123.12.23.34" -X POST http://10.20.117.208:81/delete -i
# Windows
Invoke-WebRequest工具一直收不到post的body不知道问题出在哪里
```
![API add](./api_add.png)
![database add](./database_add.png)
![API delete](./api_delete.png)
![database delete](./database_delete.png)
## 未实现的功能
目前只能把解析记录作为a记录返回还未实现添加其他解析记录。
## 许可证

View File

@@ -1,101 +0,0 @@
# tpre-python
This project is designed for the National Cryptography Competition and is implemented in Python to execute the TPRE algorithm.
## Project Principle
The project uses the Chinese national standard cryptography algorithm to implement distributed proxy re-encryption (TPRE).
## Project Structure
.
├── basedockerfile (being used to build base iamge)
├── dockerfile (being used to build application)
├── doc (development documents)
├── gmssl (gmssl source code)
├── LICENSE
├── README_en.md
├── README.md
├── requirements.txt
└── src (application source code)
## Environment Dependencies
### Bare mental version(UNTESTED)
System requirements:
- Linux
- Windows(may need to complie and install gmssl yourself)
The project relies on the following software:
- Python 3.12
- gmssl v3.1.1
- gmssl-python 2.2.2
### Docker installer
```bash
apt update && apt install mosh -y
chmod +x install_docker.sh
./install_docker.sh
```
### Docker version
docker version:
- Version: 24.0.5
- API version: 1.43
- Go version: go1.20.6
## Installation Steps
### Pre-installation
This project depends on gmssl, so you need to compile it from source first.
Visit [GmSSL](https://github.com/guanzhi/GmSSL) to learn how to install.
Or we use git submodule to reference repo of gmssl.
```bash
git clone --recursive https://git.mamahaha.work/sangge/tpre-python.git
chmod +x install_gmssl.sh
./install_gmssl.sh
```
Then install essential python libs
```bash
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## Docker Installation
### Use base image and build yourself
```bash
docker build . -f basedockerfile -t git.mamahaha.work/sangge/tpre:base
(or docker pull git.mamahaha.work/sangge/tpre:base)
docker build . -t your_image_name
```
### Use pre-build image
```bash
docker pull git.mamahaha.work/sangge/tpre:latest
```
## Usage Instructions
details in [docs](doc/README_app_en.md)
## References
[TPRE Algorithm Blog Post](https://www.cnblogs.com/pam-sh/p/17364656.html#tprelib%E7%AE%97%E6%B3%95)
[Gmssl-python library](https://github.com/GmSSL/GmSSL-Python)
## License
GNU GENERAL PUBLIC LICENSE v3

BIN
api_add.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB

BIN
api_delete.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

View File

@@ -1,18 +0,0 @@
FROM python:3.12-slim
COPY requirements.txt /app/
# 设置目标平台参数
#ARG TARGETPLATFORM
# 根据目标平台复制相应架构的库文件
#COPY lib/${TARGETPLATFORM}/* /lib/
COPY lib/* /lib/
WORKDIR /app
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip install --index-url https://git.mamahaha.work/api/packages/sangge/pypi/simple/ ecc-rs

Binary file not shown.

33
database/initdb.py Normal file
View File

@@ -0,0 +1,33 @@
import sqlite3
import argparse
# 用于创建
db_file = 'database/dns.db'
if __name__ == '__main__':
conn = sqlite3.connect(db_file)
cursor = conn.cursor()
parser = argparse.ArgumentParser()
parser.add_argument('--add', action='store_true', help='add test data')
args = parser.parse_args()
try:
cursor.execute(
'''CREATE TABLE xiaomiandns(domain TEXT PRIMARY KEY, ip TEXT,timestamp DATETIME)''')
except sqlite3.OperationalError:
print("table xiaomiandns already exists")
conn.commit()
if args.add:
test_data = [
('example.xiaomian', '192.168.1.1'),
('google.xiaomian', '8.8.8.8'),
('yahoo.xiaomian', '98.138.219.231')
]
for data in test_data:
domain, ip = data
cursor.execute("INSERT INTO xiaomiandns (domain, ip, timestamp) VALUES (?, ?, DATETIME('now'))", (domain, ip))
# 提交更改到数据库
conn.commit()
cursor.close()
conn.close()

BIN
database_add.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

BIN
database_delete.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.7 KiB

BIN
dns1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 365 KiB

BIN
dns2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

View File

View File

@@ -1,78 +0,0 @@
# APP Doc
## Run docker
```bash
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=60.204.193.58 git.mamahaha.work/sangge/tpre:base bash
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=119.3.125.234 git.mamahaha.work/sangge/tpre:base bash
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=124.70.165.73 git.mamahaha.work/sangge/tpre:base bash
```
```bash
tpre3: docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=60.204.233.103 git.mamahaha.work/sangge/tpre:base bash
```
## Deploy contract
You should deploy the contract yourself in src/logger.sol using remix or any CLI-tools and replace the contract address in src/node.py with your actual address.
[Deployment document](https://remix-ide.readthedocs.io/zh-cn/latest/create_deploy.html)
## Start application
You should replace the wallet address/privateKey in src/node.py with your own wallet address/privateKey.
```bash
nohup python server.py &
nohup python node.py &
nohup python client.py &
cat nohup.out
```
## Cloud server ip
**tpre1**: 110.41.155.96
**tpre2**: 110.41.130.197
**tpre3**: 110.41.21.35
## Agent re-encryption process
### Client request message
```bash
python client_cli.py 124.70.165.73 name
python client_cli.py 124.70.165.73 environment
```
## Client router
**/receive_messages**
post method
**/request_message**
post method
**/receive_request**
post method
**/recieve_pk**
post method
## Central server router
**/server/show_nodes**
get method
**/server/get_node**
get method
**/server/delete_node**
get method
**/server/heartbeat**
get method
**/server/send_nodes_list**
get method
## Node router
**/user_src**
post method

View File

View File

View File

@@ -1,31 +0,0 @@
version: "3"
services:
server:
image: git.mamahaha.work/sangge/tpre:base
volumes:
- ./src:/app
environment:
- server_address=http://server:8000
entrypoint:
- python
- server.py
node:
image: git.mamahaha.work/sangge/tpre:base
volumes:
- ./src:/app
environment:
- server_address=http://server:8000
entrypoint:
- python
- node.py
client:
image: git.mamahaha.work/sangge/tpre:base
volumes:
- ./src:/app
environment:
- server_address=http://server:8000
entrypoint:
- python
- client.py

View File

@@ -1,10 +0,0 @@
FROM git.mamahaha.work/sangge/tpre:base
COPY src /app
COPY requirements.txt /app/requirements.txt
WORKDIR /app
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

1
ecc_rs

Submodule ecc_rs deleted from 880c34ce03

1
gmssl

Submodule gmssl deleted from d655c06b3a

View File

@@ -1,8 +0,0 @@
for pkg in docker.io docker-doc docker-compose podman-docker containerd runc; do apt-get remove $pkg; done
apt update
apt install apt-transport-https ca-certificates curl gnupg lsb-release
curl -fsSL https://mirrors.tuna.tsinghua.edu.cn/docker-ce/linux/debian/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg
echo "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://mirrors.tuna.tsinghua.edu.cn/docker-ce/linux/debian $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null
apt update
apt install docker-ce docker-ce-cli containerd.io

View File

@@ -1,15 +0,0 @@
#!/bin/bash
mkdir lib
mkdir include
cp -r gmssl/include include
mkdir gmssl/build
cd gmssl/build || exit
cmake ..
cd bin || exit
cp libgmssl.so libgmssl.so.3 libgmssl.so.3.1 ../../../lib
cp libsdf_dummy.so libsdf_dummy.so.3 libsdf_dummy.so.3.1 ../../../lib
cp libskf_dummy.so libskf_dummy.so.3 libskf_dummy.so.3.1 ../../../lib
sudo make install

View File

@@ -1,5 +1,2 @@
gmssl-python>=2.2.2,<3.0.0
fastapi
uvicorn
requests
web3
dnspython==2.3.0
PyYAML==6.0

20
server/main.py Normal file
View File

@@ -0,0 +1,20 @@
import xiaomiandns
import yaml
if __name__ == '__main__':
with open('server/serverconf.yaml', 'r') as f:
config = yaml.safe_load(f)
db_file = config['database']['db_file']
DNS_port = config['DNS']['port']
DNS_listen_host = config['DNS']['listen_host']
API_port = config['API']['port']
API_listen_host = config['API']['listen_host']
# start dns server
server = xiaomiandns.DNSServer(DNS_listen_host, DNS_port, db_file)
server.run()
# start dns api server
APIserver = xiaomiandns.DNSAPI(API_listen_host, API_port, db_file)
APIserver.run()

8
server/serverconf.yaml Normal file
View File

@@ -0,0 +1,8 @@
database:
db_file : 'database/dns.db'
DNS:
port : 53
listen_host : "0.0.0.0"
API:
port : 81
listen_host : "0.0.0.0"

319
server/xiaomiandns.py Normal file
View File

@@ -0,0 +1,319 @@
import socket
import threading
import dns.resolver
import dns.message
import dns.rdataclass
import dns.rdatatype
import dns.flags
import dns.rcode
import dns.rrset
import sqlite3
import re
import yaml
class DNSServer:
def __init__(self, hostname, port, db_file):
self.hostname = hostname
self.port = port
self.db_file = db_file
def run(self):
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.udp_socket.bind((self.hostname, self.port))
self.tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.tcp_socket.bind((self.hostname, self.port))
self.tcp_socket.listen(1)
print(f"DNS server running on {self.hostname}:{self.port}")
for i in range(3):
udp_thread = threading.Thread(target=self.handle_udp_request)
udp_thread.start()
tcp_thread = threading.Thread(target=self.handle_tcp_request)
tcp_thread.start()
def handle_udp_request(self):
data, address = self.udp_socket.recvfrom(1024)
response = self.handle_request(data)
self.udp_socket.sendto(response, address)
udp_thread = threading.Thread(target=self.handle_udp_request)
udp_thread.start()
def handle_tcp_request(self):
connection, address = self.tcp_socket.accept()
data = connection.recv(1024)
response = self.handle_request(data)
connection.send(response)
connection.close()
tcp_thread = threading.Thread(target=self.handle_tcp_request)
tcp_thread.start()
def handle_request(self, data):
conn = sqlite3.connect(self.db_file)
cur = conn.cursor()
question = dns.message.from_wire(data)
response = self.build_response(question, cur)
return response
def build_response(self, question, dbcursor, rcode=dns.rcode.NOERROR, answer=None):
# Create a new DNS message object
response = dns.message.Message()
# Set the message header fields
response.id = question.id
response.flags = dns.flags.QR | dns.flags.RA
# Add the question to the message
response.question = question.question
name = question.question[0].name
# search domain in database
dbcursor.execute(
"SELECT ip FROM xiaomiandns WHERE domain = ?", (str(name)[:-1],))
result = dbcursor.fetchone()
# Create a new RRset for the answer
if result is not None:
answer = dns.rrset.RRset(name, dns.rdataclass.IN, dns.rdatatype.A)
rdata = dns.rdata.from_text(
dns.rdataclass.IN, dns.rdatatype.A, result[0])
answer.add(rdata)
response.answer.append(answer)
# Set the response code
response.set_rcode(rcode)
else:
response.set_rcode(dns.rcode.NXDOMAIN)
return response.to_wire()
class DNSAPI:
"""
usage: use POST method
/add
data: domian=xxxx&ip=xx.xx.xx.xx
/delete
data: domian=xxxx&ip=xx.xx.xx.xx
"""
def __init__(self, hostname, port, db_file):
self.hostname = hostname
self.port = port
self.db_file = db_file
def run(self):
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 绑定 IP 地址和端口号
server_socket.bind((self.hostname, self.port))
# 监听连接
server_socket.listen(5)
print(f"API server running on {self.hostname}:{self.port}")
while True:
# 接受连接
conn, addr = server_socket.accept()
# 处理请求
t = threading.Thread(target=self.handle_tcp_request, args=(conn,))
t.start()
def handle_tcp_request(self, conn):
request = conn.recv(1024).decode('utf-8')
response = self.handle_http_request(request)
conn.send(response)
conn.close()
def handle_http_request(self, request):
request_line, headers = request.split('\r\n\r\n', 2)
method, url, version = request_line.split(' ', 2)
if method == 'GET':
response = self.handle_get_request(url)
elif method == 'POST':
data = request.split('\r\n')[-1]
response = self.handle_post_request(url, data)
else:
response = self.handle_error_request()
return response
def handle_get_request(self, url):
status_code = 400
reason_phrase = 'unsupport method, please use POST method'
response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase)
return response.encode("utf-8")
def handle_post_request(self, url:str, data:str)->str:
"""处理 POST 请求
Args:
url (str): url前缀
data (str): POST 方法提交的数据
Returns:
str: http response
"""
# check url start with /add
if re.match(r'^/add', url):
status_code = self.add_data(data)
if status_code == 200:
reason_phrase = 'Add data successful'
else:
reason_phrase = 'Add data unsuccessful'
# check url start with /delete
elif re.match(r'^/delete', url):
status_code = self.delete_data(data)
if status_code == 200:
reason_phrase = 'Delete data successful'
else:
reason_phrase = 'Delete data unsuccessful'
else:
status_code = 400
reason_phrase = 'unsupport api'
response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase)
return response.encode("utf-8")
def handle_error_request(self):
status_code = 400
reason_phrase = "unsupport method"
headers = {
'Content-Type': 'text/html',
'Connection': 'close',
}
response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase)
for key, value in headers.items():
response += '{}: {}\r\n'.format(key, value)
response += '\r\n'
return response.encode("utf-8")
def add_data(self, data:str)->int:
"""add data to database
Args:
data (str): domain and ip
Returns:
int: status code
"""
# parse and check validation
domain, ip = self.parse_data(data)
if not self.check_data(domain,ip):
return 400
# connect db
conn = sqlite3.connect(self.db_file)
c = conn.cursor()
# Check if the domain already exists
c.execute(
"SELECT * FROM xiaomiandns WHERE domain = ?", [domain])
existing_domain = c.fetchall()
if existing_domain:
c.execute("UPDATE xiaomiandns SET ip = ?, timestamp = DATETIME('now') WHERE domain = ?",(ip,domain))
else:
# Insert the new data
c.execute(
"INSERT INTO xiaomiandns (domain,ip,timestamp) VALUES (?,?,DATETIME('now'))", (domain, ip))
conn.commit()
c.close()
conn.close()
status_code = 200
return status_code
def delete_data(self, data:str) -> int:
"""delete record in database
Args:
data (str): domain and ip
Returns:
int: status code
"""
# parse and check validation
domain, ip = self.parse_data(data)
if not self.check_data(domain, ip):
return 400
# connect db
conn = sqlite3.connect(self.db_file)
c = conn.cursor()
c.execute(
"DELETE FROM xiaomiandns WHERE domain = ? ", [domain])
deleted_rows = c.rowcount
if deleted_rows == 0:
# unmatch
status_code = 400
else:
# deleted
status_code = 200
conn.commit()
c.close()
conn.close()
return status_code
def parse_data(self, data:str) -> str :
"""parse data form post data
Args:
data (str): post data
Returns:
domain: domain name
ip: ip
"""
domain = re.search(r'domain=([^&]+)', data)
ip = re.search(r'ip=([^&]+)', data)
if domain and ip:
domain = domain.group(1)
ip = ip.group(1)
return domain, ip
def check_data(self, domain, ip):
"""_summary_
Args:
domain (_type_): _description_
ip (_type_): _description_
Returns:
_type_: _description_
"""
# check domain
domain_pattern = r'^[a-z0-9]+\.xiaomian$'
# check ip
ip_pattern = r'^(\d{1,3}\.){3}\d{1,3}$'
if re.match(domain_pattern, domain) and re.match(ip_pattern, ip) :
octets = ip.split('.')
if all(int(octet) < 256 for octet in octets):
return True
else:
return False
if __name__ == '__main__':
with open('serverconf.yaml', 'r') as f:
config = yaml.safe_load(f)
db_file = config['database']['db_file']
DNS_port = config['DNS']['port']
DNS_listen_host = config['DNS']['listen_host']
API_port = config['API']['port']
API_listen_host = config['API']['listen_host']
# start dns server
server = DNSServer(API_listen_host, DNS_port, db_file)
server.run()
# start dns api server
APIserver = DNSAPI(API_listen_host, API_port, db_file)
APIserver.run()

View File

@@ -1,3 +0,0 @@
[settings]
server_address = 60.204.236.38:8000
version = 1.0

View File

@@ -1,475 +0,0 @@
from fastapi import FastAPI, HTTPException
import requests
import os
from typing import Tuple
from tpre import (
GenerateKeyPair,
Encrypt,
DecryptFrags,
GenerateReKey,
MergeCFrag,
point,
)
import sqlite3
from contextlib import asynccontextmanager
from pydantic import BaseModel
import socket
import random
import json
import pickle
from fastapi.responses import JSONResponse
import asyncio
# 测试文本
test_msessgaes = {
"name": b"proxy re-encryption",
"environment": b"distributed environment",
}
@asynccontextmanager
async def lifespan(_: FastAPI):
init()
yield
clean_env()
app = FastAPI(lifespan=lifespan)
def init():
global pk, sk, server_address
init_db()
pk, sk = GenerateKeyPair()
# load config from config file
init_config()
get_node_list(2, server_address) # type: ignore
def init_db():
with sqlite3.connect("client.db") as db:
# message table
db.execute(
"""
CREATE TABLE IF NOT EXISTS message (
id INTEGER PRIMARY KEY,
capsule BLOB,
ct TEXT,
senderip TEXT
);
"""
)
# node ip table
db.execute(
"""
CREATE TABLE IF NOT EXISTS node (
id INTEGER PRIMARY KEY,
nodeip TEXT
);
"""
)
# sender info table
db.execute(
"""
CREATE TABLE IF NOT EXISTS senderinfo (
id INTEGER PRIMARY KEY,
ip TEXT,
pkx TEXT,
pky TEXT,
threshold INTEGER
)
"""
)
db.commit()
print("Init Database Successful")
# load config from config file
def init_config():
global server_address
server_address = os.environ.get("server_address")
# execute on exit
def clean_env():
global message, node_response
message = b""
node_response = False
with sqlite3.connect("client.db") as db:
db.execute("DELETE FROM node")
db.execute("DELETE FROM message")
db.execute("DELETE FROM senderinfo")
db.commit()
print("Exit app")
# main page
@app.get("/")
async def read_root():
return {"message": "Hello, World!"}
class C(BaseModel):
Tuple: Tuple[Tuple[Tuple[int, int], Tuple[int, int], int, Tuple[int, int]], int]
ip: str
# receive messages from nodes
@app.post("/receive_messages")
async def receive_messages(message: C):
"""
receive capsule and ip from nodes
params:
Tuple: capsule and ct
ip: sender ip
return:
status_code
"""
print(f"Received message: {message}")
if not message.Tuple or not message.ip:
print("Invalid input data received.")
raise HTTPException(status_code=400, detail="Invalid input data")
C_capsule, C_ct = message.Tuple
ip = message.ip
# Serialization
bin_C_capsule = pickle.dumps(C_capsule)
# insert record into database
with sqlite3.connect("client.db") as db:
try:
db.execute(
"""
INSERT INTO message
(capsule, ct, senderip)
VALUES
(?, ?, ?)
""",
(bin_C_capsule, str(C_ct), ip),
)
db.commit()
print("Data inserted successfully into database.")
check_merge(C_ct, ip)
return HTTPException(status_code=200, detail="Message received")
except Exception as e:
print(f"Error occurred: {e}")
db.rollback()
return HTTPException(status_code=400, detail="Database error")
# check record count
def check_merge(ct: int, ip: str):
global sk, pk, node_response, message
"""
CREATE TABLE IF NOT EXISTS senderinfo (
id INTEGER PRIMARY KEY,
ip TEXT,
pkx TEXT,
pky TEXT,
threshold INTEGER
)
"""
with sqlite3.connect("client.db") as db:
# Check if the combination of ct_column and ip_column appears more than once.
cursor = db.execute(
"""
SELECT capsule, ct
FROM message
WHERE ct = ? AND senderip = ?
""",
(str(ct), ip),
)
# [(capsule, ct), ...]
cfrag_cts = cursor.fetchall()
# get _sender_pk
cursor = db.execute(
"""
SELECT pkx, pky
FROM senderinfo
WHERE ip = ?
""",
(ip,),
)
result = cursor.fetchall()
try:
pkx, pky = result[0] # result[0] = (pkx, pky)
pk_sender = (int(pkx), int(pky))
except IndexError:
pk_sender, T = 0, -1
T = 2
if len(cfrag_cts) >= T:
# Deserialization
temp_cfrag_cts = []
for i in cfrag_cts:
capsule = pickle.loads(i[0])
byte_length = (ct.bit_length() + 7) // 8
temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(byte_length)))
cfrags = MergeCFrag(temp_cfrag_cts)
print("sk:", type(sk))
print("pk:", type(pk))
print("pk_sender:", type(pk_sender))
print("cfrags:", type(cfrags))
message = DecryptFrags(sk, pk, pk_sender, cfrags) # type: ignore
print("merge success", message)
node_response = True
print("merge:", node_response)
# send message to node
async def send_messages(
node_ips: tuple[str, ...], message: bytes, dest_ip: str, pk_B: point, shreshold: int
):
global pk, sk
id_list = []
# calculate id of nodes
for node_ip in node_ips:
node_ip = node_ip[0]
ip_parts = node_ip.split(".")
id = 0
for i in range(4):
id += int(ip_parts[i]) << (24 - (8 * i))
id_list.append(id)
print(f"Calculated IDs: {id_list}")
# generate rk
rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore
print(f"Generated ReKey list: {rk_list}")
capsule, ct = Encrypt(pk, message) # type: ignore
# capsule_ct = (capsule, int.from_bytes(ct))
print(f"Encrypted message to capsule={capsule}, ct={ct}")
for i in range(len(node_ips)):
url = "http://" + node_ips[i][0] + ":8001" + "/user_src"
payload = {
"source_ip": local_ip,
"dest_ip": dest_ip,
"capsule": capsule,
"ct": int.from_bytes(ct),
"rk": rk_list[i],
}
print(f"Sending payload to {url}: {json.dumps(payload)}")
response = requests.post(url, json=payload)
if response.status_code == 200:
print(f"send to {node_ips[i]} successful")
else:
print(
f"Failed to send to {node_ips[i]}. Response code: {response.status_code}, Response text: {response.text}"
)
return 0
class IP_Message(BaseModel):
dest_ip: str
message_name: str
source_ip: str
pk: Tuple[int, int]
class Request_Message(BaseModel):
dest_ip: str
message_name: str
# request message from others
@app.post("/request_message")
async def request_message(i_m: Request_Message):
global message, node_response, pk
print(
f"Function 'request_message' called with: dest_ip={i_m.dest_ip}, message_name={i_m.message_name}"
)
dest_ip = i_m.dest_ip
# dest_ip = dest_ip.split(":")[0]
message_name = i_m.message_name
source_ip = get_own_ip()
dest_port = "8002"
url = "http://" + dest_ip + ":" + dest_port + "/receive_request"
payload = {
"dest_ip": dest_ip,
"message_name": message_name,
"source_ip": source_ip,
"pk": pk,
}
print(f"Sending request to {url} with payload: {payload}")
try:
response = requests.post(url, json=payload, timeout=1)
print(f"Response received from {url}: {response.text}")
# print("menxian and pk", response.text)
except requests.Timeout:
print("Timeout error: can't post to the destination.")
# print("can't post")
# content = {"message": "post timeout", "error": str(e)}
# return JSONResponse(content, status_code=400)
# wait 3s to receive message from nodes
for _ in range(10):
print(f"Waiting for node_response... Current value: {node_response}")
# print("wait:", node_response)
if node_response:
data = message
print(f"Node response received with message: {data}")
# reset message and node_response
message = b""
node_response = False
# return message to frontend
return {"message": str(data)}
await asyncio.sleep(0.2)
print("Timeout while waiting for node_response.")
content = {"message": "receive timeout"}
return JSONResponse(content, status_code=400)
# receive request from others
@app.post("/receive_request")
async def receive_request(i_m: IP_Message):
global pk
print(
f"Function 'receive_request' called with: dest_ip={i_m.dest_ip}, source_ip={i_m.source_ip}, pk={i_m.pk}"
)
source_ip = get_own_ip()
print(f"Own IP: {source_ip}")
if source_ip != i_m.dest_ip:
print("Mismatch in destination IP.")
return HTTPException(status_code=400, detail="Wrong ip")
dest_ip = i_m.source_ip
# threshold = random.randrange(1, 2)
threshold = 2
own_public_key = pk
pk_B = i_m.pk
print(f"Using own public key: {own_public_key} and received public key: {pk_B}")
with sqlite3.connect("client.db") as db:
cursor = db.execute(
"""
SELECT nodeip
FROM node
LIMIT ?
""",
(threshold,),
)
node_ips = cursor.fetchall()
print(f"Selected node IPs from database: {node_ips}")
# message name
# message_name = i_m.message_name
# message = xxxxx
# 根据message name到测试文本查找对应值
try:
message = test_msessgaes[i_m.message_name]
except IndexError:
message = b"hello world" + random.randbytes(8)
print(f"Message to be send: {message}")
# send message to nodes
await send_messages(tuple(node_ips), message, dest_ip, pk_B, threshold)
response = {"threshold": threshold, "public_key": own_public_key}
print(f"Sending response: {response}")
return response
def get_own_ip() -> str:
ip = os.environ.get("HOST_IP")
if not ip: # 如果环境变量中没有IP
try:
# 从网卡获取IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
ip = s.getsockname()[0]
s.close()
except IndexError:
raise ValueError("Unable to get IP")
return str(ip)
# get node list from central server
def get_node_list(count: int, server_addr: str):
url = "http://" + server_addr + "/server/send_nodes_list?count=" + str(count)
response = requests.get(url, timeout=3)
# Checking the response
if response.status_code == 200:
print("Success get node list")
node_ip = response.text
node_ip = eval(node_ip)
print(node_ip)
# insert node ip to database
with sqlite3.connect("client.db") as db:
db.executemany(
"""
INSERT INTO node
(nodeip)
VALUES (?)
""",
[(ip,) for ip in node_ip],
)
db.commit()
print("Success add node ip")
else:
print("Failed:", response.status_code, response.text)
# send pk to others
@app.get("/get_pk")
async def get_pk():
global pk, sk
print(sk)
return {"pkx": str(pk[0]), "pky": str(pk[1])}
class pk_model(BaseModel):
pkx: str
pky: str
ip: str
# recieve pk from frontend
@app.post("/recieve_pk")
async def recieve_pk(pk: pk_model):
pkx = pk.pkx
pky = pk.pky
dest_ip = pk.ip
try:
threshold = 2
with sqlite3.connect("client.db") as db:
db.execute(
"""
INSERT INTO senderinfo
(ip, pkx, pky, threshold)
VALUES
(?, ?, ?, ?)
""",
(str(dest_ip), pkx, pky, threshold),
)
except Exception as e:
# raise error
print("Database error")
content = {"message": "Database Error", "error": str(e)}
return JSONResponse(content, status_code=400)
return {"message": "save pk in database"}
pk = (0, 0)
sk = 0
server_address = os.environ.get("server_address")
node_response = False
message = bytes
local_ip = get_own_ip()
if __name__ == "__main__":
import uvicorn # pylint: disable=e0401
uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True, log_level="debug")

View File

@@ -1,40 +0,0 @@
import argparse
import requests
import json
def send_post_request(ip_addr, message_name):
url = "http://localhost:8002/request_message"
data = {"dest_ip": ip_addr, "message_name": message_name}
response = requests.post(url, json=data)
return response.text
def get_pk(ip_addr):
url = "http://" + ip_addr + ":8002/get_pk"
response = requests.get(url, timeout=1)
print(response.text)
json_pk = json.loads(response.text)
payload = {"pkx": json_pk["pkx"], "pky": json_pk["pky"], "ip": ip_addr}
response = requests.post("http://localhost:8002/recieve_pk", json=payload)
return response.text
def main():
parser = argparse.ArgumentParser(description="Send POST request to a specified IP.")
parser.add_argument("ip_addr", help="IP address to send request to.")
parser.add_argument("message_name", help="Message name to send.")
args = parser.parse_args()
response = get_pk(args.ip_addr)
print(response)
response = send_post_request(args.ip_addr, args.message_name)
print(response)
if __name__ == "__main__":
main()

View File

@@ -1,93 +0,0 @@
from tpre import *
import time
import openpyxl
# 初始化Excel工作簿和工作表
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "算法性能结果"
headers = [
"门限值 N",
"门限值 T",
"密钥生成运行时间",
"加密算法运行时间",
"重加密密钥生成算法运行时间",
"重加密算法运行时间",
"解密算法运行时间",
"算法总运行时间",
]
ws.append(headers)
for N in range(4, 21, 2):
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
start_total_time = time.time()
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world"
end_time = time.time()
elapsed_time_key_gen = end_time - start_time
print(f"密钥生成运行时间:{elapsed_time_key_gen}")
# ... [中间代码不变]
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time_enc = end_time - start_time
print(f"加密算法运行时间:{elapsed_time_enc}")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time_rekey_gen = end_time - start_time
print(f"重加密密钥生成算法运行时间:{elapsed_time_rekey_gen}")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
re_elapsed_time = (end_time - start_time) / len(rekeys)
print(f"重加密算法运行时间:{re_elapsed_time}")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time_dec = end_time - start_time
end_total_time = time.time()
total_time = end_total_time - start_total_time - re_elapsed_time * len(rekeys)
print(f"解密算法运行时间:{elapsed_time_dec}")
print("成功解密:", m)
print(f"算法总运行时间:{total_time}")
print()
# 将结果保存到Excel
ws.append(
[
N,
T,
elapsed_time_key_gen,
elapsed_time_enc,
elapsed_time_rekey_gen,
re_elapsed_time,
elapsed_time_dec,
total_time,
]
)
# 保存Excel文件
wb.save("结果.xlsx")

View File

@@ -1,63 +0,0 @@
from web3 import Web3
import json
rpc_url = "https://ethereum-holesky-rpc.publicnode.com"
chain = Web3(Web3.HTTPProvider(rpc_url))
contract_address = "0x642C23F91bf8339044A00251BC09d1D98110C433"
contract_abi = json.loads(
"""[
{
"anonymous": false,
"inputs": [
{
"indexed": false,
"internalType": "string",
"name": "msg",
"type": "string"
}
],
"name": "messageLog",
"type": "event"
},
{
"inputs": [
{
"internalType": "string",
"name": "text",
"type": "string"
}
],
"name": "logmessage",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "nonpayable",
"type": "function"
}
]"""
)
contract = chain.eth.contract(address=contract_address, abi=contract_abi)
wallet_address = "0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6"
pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
def call_eth_logger(wallet_address, pk, message: str):
transaction = contract.functions.logmessage(message).build_transaction(
{
"chainId": 17000,
"gas": 30000,
"gasPrice": chain.to_wei("10", "gwei"),
"nonce": chain.eth.get_transaction_count(wallet_address, "pending"),
}
)
signed_tx = chain.eth.account.sign_transaction(transaction, private_key=pk)
tx_hash = chain.eth.send_raw_transaction(signed_tx.raw_transaction)
print(tx_hash)
receipt = chain.eth.wait_for_transaction_receipt(tx_hash)
transfer_event = contract.events.messageLog().process_receipt(receipt)
for event in transfer_event:
print(event["args"]["msg"])

View File

@@ -1,8 +0,0 @@
from eth_logger import call_eth_logger
wallet_address = (
"0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6" # 修改成要使用的钱包地址/私钥
)
wallet_pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
call_eth_logger(wallet_address, wallet_pk, "hello World")

View File

@@ -1,10 +0,0 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract logger{
event messageLog(string msg);
function logmessage(string memory text) public returns (bool){
emit messageLog(text);
return true;
}
}

View File

@@ -1,196 +0,0 @@
import asyncio
import json
import logging
import os
import socket
import threading
import time
from contextlib import asynccontextmanager
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from eth_logger import call_eth_logger
from tpre import ReEncrypt, capsule
@asynccontextmanager
async def lifespan(_: FastAPI):
init()
yield
clear()
message_list = []
app = FastAPI(lifespan=lifespan)
server_address = os.environ.get("server_address")
id = 0
ip = ""
client_ip_src = "" # 发送信息用户的ip
client_ip_des = "" # 接收信息用户的ip
processed_message = () # 重加密后的数据
logger = logging.getLogger("uvicorn")
# class C(BaseModel):
# Tuple: Tuple[capsule, int]
# ip_src: str
# 向中心服务器发送自己的IP地址,并获取自己的id
def send_ip():
url = server_address + "/get_node?ip=" + ip # type: ignore
# ip = get_local_ip() # type: ignore
global id
id = requests.get(url, timeout=3)
logger.info(f"中心服务器返回节点ID为: {id}")
print("中心服务器返回节点ID为: ", id)
# 用环境变量获取本机ip
def get_local_ip() -> str | None:
ip = os.environ.get("HOST_IP")
if not ip: # 如果环境变量中没有IP
try:
# 从网卡获取IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
ip = str(s.getsockname()[0])
s.close()
return ip
except IndexError:
raise ValueError("Unable to get IP")
else:
return ip
def init():
global ip
ip = get_local_ip()
send_ip()
asyncio.create_task(send_heartbeat_internal())
print("Finish init")
def clear():
print("exit node")
# 接收用户发来的消息,经过处理之后,再将消息发送给其他用户
async def send_heartbeat_internal() -> None:
timeout = 30
global ip
url = server_address + "/heartbeat?ip=" + ip # type: ignore
while True:
# print('successful send my_heart')
try:
requests.get(url, timeout=3)
except requests.exceptions.RequestException:
logger.error("Central server error")
print("Central server error")
# 删除超时的节点
await asyncio.sleep(timeout)
class Req(BaseModel):
source_ip: str
dest_ip: str
capsule: capsule
ct: int
rk: list
@app.post("/user_src") # 接收用户1发送的信息
async def user_src(message: Req):
global client_ip_src, client_ip_des
print(
f"Function 'user_src' called with: source_ip={message.source_ip}, dest_ip={message.dest_ip}, capsule={message.capsule}, ct={message.ct}, rk={message.rk}"
)
# kfrag , capsule_ct ,client_ip_src , client_ip_des = json_data[]
"""
payload = {
"source_ip": local_ip,
"dest_ip": dest_ip,
"capsule_ct": capsule_ct,
"rk": rk_list[i],
}
"""
logger.info(f"node: {message}")
print("node: ", message)
source_ip = message.source_ip
dest_ip = message.dest_ip
capsule = message.capsule
ct = message.ct
payload = {
"source_ip": source_ip,
"dest_ip": dest_ip,
"capsule": capsule,
"ct": ct,
"rk": message.rk,
}
# 将消息详情记录到区块链
global message_list
message_list.append(payload)
byte_length = (ct.bit_length() + 7) // 8
capsule_ct = (capsule, ct.to_bytes(byte_length))
rk = message.rk
logger.info(f"Computed capsule_ct: {capsule_ct}")
print(f"Computed capsule_ct: {capsule_ct}")
a, b = ReEncrypt(rk, capsule_ct) # type: ignore
processed_message = (a, int.from_bytes(b))
logger.info(f"Re-encrypted message: {processed_message}")
print(f"Re-encrypted message: {processed_message}")
await send_user_des_message(source_ip, dest_ip, processed_message)
logger.info("Message sent to destination user.")
print("Message sent to destination user.")
return HTTPException(status_code=200, detail="message recieved")
async def send_user_des_message(
source_ip: str, dest_ip: str, re_message
): # 发送消息给用户2
data = {"Tuple": re_message, "ip": source_ip} # 类型不匹配
# 发送 HTTP POST 请求
response = requests.post(
"http://" + dest_ip + ":8002" + "/receive_messages", json=data
)
logger.info(f"send stauts: {response.text}")
print("send stauts:", response.text)
def log_message():
while True:
global message_list
payload = json.dumps(message_list)
message_list = []
call_eth_logger(wallet_address, wallet_pk, payload)
time.sleep(2)
wallet_address = (
"0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6" # 修改成要使用的钱包地址/私钥
)
wallet_pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
if __name__ == "__main__":
import uvicorn
threading.Thread(target=log_message).start()
uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True, log_level="debug")

View File

@@ -1,148 +0,0 @@
12th Gen Intel(R) Core(TM) i5-12490F
0.01 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:0.9953160285949707秒
加密算法运行时间:0.006381511688232422秒
重加密密钥生成算法运行时间:12.903082609176636秒
重加密算法运行时间:0.989858603477478秒
解密算法运行时间:19.69758915901184秒
成功解密: b'hello world'
算法总运行时间:34.59222791194916秒
0.1 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:0.004191160202026367秒
加密算法运行时间:0.09498381614685059秒
重加密密钥生成算法运行时间:0.7050104141235352秒
重加密算法运行时间:0.09499671459197997秒
解密算法运行时间:1.5005874633789062秒
成功解密: b'hello world'
算法总运行时间:2.3997695684432983秒
1 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:0.0030488967895507812秒
加密算法运行时间:0.005570888519287109秒
重加密密钥生成算法运行时间:0.07791781425476074秒
重加密算法运行时间:0.006881630420684815秒
解密算法运行时间:0.08786344528198242秒
成功解密: b'hello world'
算法总运行时间:0.18128267526626587秒
4 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:0.0026373863220214844秒
加密算法运行时间:0.004965305328369141秒
重加密密钥生成算法运行时间:0.07313323020935059秒
重加密算法运行时间:0.006896591186523438秒
解密算法运行时间:0.08880448341369629秒
成功解密: b'hello world'
算法总运行时间:0.17643699645996094秒
rk3399
0.01 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:3.9984750747680664秒
加密算法运行时间:9.599598169326782秒
重加密密钥生成算法运行时间:132.99906015396118秒
重加密算法运行时间:12.120013177394867秒
解密算法运行时间:153.29800581932068秒
成功解密: b'hello world'
算法总运行时间:312.0151523947716秒
0.1 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:0.09907650947570801秒
加密算法运行时间:0.205247163772583秒
重加密密钥生成算法运行时间:7.498294830322266秒
重加密算法运行时间:0.7300507187843323秒
解密算法运行时间:8.998314619064331秒
成功解密: b'hello world'
算法总运行时间:17.53098384141922秒
1 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:0.008650541305541992秒
加密算法运行时间:0.02130866050720215秒
重加密密钥生成算法运行时间:0.30187034606933594秒
重加密算法运行时间:0.0274674654006958秒
解密算法运行时间:0.3521096706390381秒
成功解密: b'hello world'
算法总运行时间:0.7114066839218139秒
4 cores
当前门限值: N = 20, T = 10
密钥生成运行时间:0.00883340835571289秒
加密算法运行时间:0.021309614181518555秒
重加密密钥生成算法运行时间:0.3036940097808838秒
重加密算法运行时间:0.0277299165725708秒
解密算法运行时间:0.3464491367340088秒
成功解密: b'hello world'
算法总运行时间:0.7080160856246949秒
12th Gen Intel(R) Core(TM) i5-12490F
当前门限值: N = 20, T = 10
明文长度:110
密钥生成运行时间:0.0033216476440429688秒
加密算法运行时间:0.00811624526977539秒
重加密密钥生成算法运行时间:0.11786699295043945秒
重加密算法运行时间:0.009650790691375732秒
解密算法运行时间:0.12125396728515625秒
成功解密:
算法总运行时间:0.2602096438407898秒
明文长度:1100
密钥生成运行时间:0.0034990310668945312秒
加密算法运行时间:0.008537054061889648秒
重加密密钥生成算法运行时间:0.10165071487426758秒
重加密算法运行时间:0.009756004810333252秒
解密算法运行时间:0.13438773155212402秒
成功解密:
算法总运行时间:0.257830536365509秒
明文长度:11000
密钥生成运行时间:0.0035142898559570312秒
加密算法运行时间:0.005819797515869141秒
重加密密钥生成算法运行时间:0.1130058765411377秒
重加密算法运行时间:0.010429942607879638秒
解密算法运行时间:0.1242990493774414秒
成功解密:
算法总运行时间:0.25706895589828493秒
明文长度:110000
密钥生成运行时间:0.002706289291381836秒
加密算法运行时间:0.00833749771118164秒
重加密密钥生成算法运行时间:0.11022734642028809秒
重加密算法运行时间:0.010864639282226562秒
解密算法运行时间:0.13867974281311035秒
成功解密:
算法总运行时间:0.2708155155181885秒
明文长度:1100000
密钥生成运行时间:0.003710031509399414秒
加密算法运行时间:0.04558920860290527秒
重加密密钥生成算法运行时间:0.10261368751525879秒
重加密算法运行时间:0.009720635414123536秒
解密算法运行时间:0.15311646461486816秒
成功解密:
算法总运行时间:0.3147500276565552秒
明文长度:11000000
密钥生成运行时间:0.008045673370361328秒
加密算法运行时间:0.3575568199157715秒
重加密密钥生成算法运行时间:0.09267783164978027秒
重加密算法运行时间:0.009347784519195556秒
解密算法运行时间:0.4754812717437744秒
成功解密:
算法总运行时间:0.9431093811988831秒
当前门限值: N = 94, T = 47
算法总运行时间:0.967951292687274秒
当前门限值: N = 95, T = 47
算法总运行时间:0.9765587304767809秒
当前门限值: N = 96, T = 48
算法总运行时间:1.019304744899273秒

View File

@@ -1,2 +0,0 @@
[pytest]
asyncio_default_fixture_loop_scope = function

View File

@@ -1,228 +0,0 @@
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
import sqlite3
import asyncio
import time
import ipaddress
import logging
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(_: FastAPI):
init()
yield
clean_env()
app = FastAPI(lifespan=lifespan)
def init():
asyncio.create_task(receive_heartbeat_internal())
init_db()
def init_db():
conn = sqlite3.connect("server.db")
cursor = conn.cursor()
# init table: id: int; ip: TEXT
cursor.execute(
"""CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
last_heartbeat INTEGER
)"""
)
def clean_env():
clear_database()
# -----------------------------------------------------------------------------------------------
@app.get("/")
async def home():
return {"message": "Hello, World!"}
@app.get("/server/show_nodes")
async def show_nodes() -> list:
nodes_list = []
with sqlite3.connect("server.db") as db:
# 查询数据
cursor = db.execute("SELECT * FROM nodes")
rows = cursor.fetchall()
for row in rows:
nodes_list.append(row)
# TODO: use JSONResponse
return nodes_list
def validate_ip(ip: str) -> bool:
"""
Validate an IP address.
This function checks if the provided string is a valid IP address.
Both IPv4 and IPv6 are considered valid.
Args:
ip (str): The IP address to validate.
Returns:
bool: True if the IP address is valid, False otherwise.
"""
try:
ipaddress.ip_address(ip)
return True
except ValueError:
return False
@app.get("/server/get_node")
async def get_node(ip: str) -> int:
"""
中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点
params:
ip: node ip
return:
id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id
"""
if not validate_ip(ip):
content = {"message": "invalid ip "}
return JSONResponse(content, status_code=400) # type: ignore
ip_parts = ip.split(".")
ip_int = 0
for i in range(4):
ip_int += int(ip_parts[i]) << (24 - (8 * i))
# TODO: replace print with logger
print("IP", ip, "对应的ID为", ip_int)
# 获取当前时间
current_time = int(time.time())
# TODO: replace print with logger
print("当前时间: ", current_time)
with sqlite3.connect("server.db") as db:
# 插入数据
db.execute(
"INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)",
(ip_int, ip, current_time),
)
db.commit()
# TODO: use JSONResponse
return ip_int
# TODO: try to use @app.delete("/node")
@app.get("/server/delete_node")
async def delete_node(ip: str):
"""
Delete a node by ip.
Args:
ip (str): The ip of the node to be deleted.
"""
with sqlite3.connect("server.db") as db:
# 查询要删除的节点
cursor = db.execute("SELECT * FROM nodes WHERE ip=?", (ip,))
row = cursor.fetchone()
if row is not None:
with sqlite3.connect("server.db") as db:
# 执行删除操作
db.execute("DELETE FROM nodes WHERE ip=?", (ip,))
db.commit()
# TODO: replace print with logger
print(f"Node with IP {ip} deleted successfully.")
return {"message", f"Node with IP {ip} deleted successfully."}
else:
print(f"Node with IP {ip} not found.")
raise HTTPException(status_code=404, detail=f"Node with IP {ip} not found.")
# 接收节点心跳包
@app.get("/server/heartbeat")
async def receive_heartbeat(ip: str):
"""
Receive a heartbeat from a node.
Args:
ip (str): The IP address of the node.
Returns:
JSONResponse: A message indicating the result of the operation.
"""
if not validate_ip(ip):
content = {"message": "invalid ip format"}
return JSONResponse(content, status_code=400)
print("收到来自", ip, "的心跳包")
logger.info("收到来自", ip, "的心跳包")
with sqlite3.connect("server.db") as db:
db.execute(
"UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)
)
content = {"status": "received"}
return JSONResponse(content, status_code=200)
async def receive_heartbeat_internal():
timeout = 70
while 1:
with sqlite3.connect("server.db") as db:
# 删除超时的节点
db.execute(
"DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,)
)
db.commit()
await asyncio.sleep(timeout)
@app.get("/server/send_nodes_list")
async def send_nodes_list(count: int) -> list:
"""
中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成list格式返回给客户端
params:
count: 所需节点个数
return:
nodes_list: list
"""
nodes_list = []
with sqlite3.connect("server.db") as db:
# 查询数据库中的节点数据
cursor = db.execute("SELECT * FROM nodes LIMIT ?", (count,))
rows = cursor.fetchall()
for row in rows:
# id, ip, last_heartbeat = row
_, ip, _ = row
nodes_list.append(ip)
print("收到来自客户端的节点列表请求...")
print(nodes_list)
# TODO: use JSONResponse
return nodes_list
def clear_database() -> None:
with sqlite3.connect("server.db") as db:
db.execute("DELETE FROM nodes")
db.commit()
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)

View File

@@ -1,464 +0,0 @@
from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT
from typing import Tuple
import random
import ecc_rs
point = Tuple[int, int]
capsule = Tuple[point, point, int]
# 生成密钥对模块
class CurveFp:
def __init__(self, A, B, P, N, Gx, Gy, name):
self.A = A
self.B = B
self.P = P
self.N = N
self.Gx = Gx
self.Gy = Gy
self.name = name
sm2p256v1 = CurveFp(
name="sm2p256v1",
A=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
B=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
P=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
N=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
Gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
def multiply(a: point, n: int, flag: int = 0) -> point:
if flag == 1:
result = ecc_rs.multiply(a, n)
return result
else:
N = sm2p256v1.N
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
def add(a: point, b: point, flag: int = 0) -> point:
if flag == 1:
result = ecc_rs.add(a, b)
return result
else:
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
def inv(a: int, n: int) -> int:
if a == 0:
return 0
lm, hm = 1, 0
low, high = a % n, n
while low > 1:
r = high // low
nm, new = hm - lm * r, high - low * r
lm, low, hm, high = nm, new, lm, low
return lm % n
def toJacobian(Xp_Yp: point) -> Tuple[int, int, int]:
Xp, Yp = Xp_Yp
return (Xp, Yp, 1)
def fromJacobian(Xp_Yp_Zp: Tuple[int, int, int], P: int) -> point:
Xp, Yp, Zp = Xp_Yp_Zp
z = inv(Zp, P)
return ((Xp * z**2) % P, (Yp * z**3) % P)
def jacobianDouble(
Xp_Yp_Zp: Tuple[int, int, int], A: int, P: int
) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp
if not Yp:
return (0, 0, 0)
ysq = (Yp**2) % P
S = (4 * Xp * ysq) % P
M = (3 * Xp**2 + A * Zp**4) % P
nx = (M**2 - 2 * S) % P
ny = (M * (S - nx) - 8 * ysq**2) % P
nz = (2 * Yp * Zp) % P
return (nx, ny, nz)
def jacobianAdd(
Xp_Yp_Zp: Tuple[int, int, int], Xq_Yq_Zq: Tuple[int, int, int], A: int, P: int
) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp
Xq, Yq, Zq = Xq_Yq_Zq
if not Yp:
return (Xq, Yq, Zq)
if not Yq:
return (Xp, Yp, Zp)
U1 = (Xp * Zq**2) % P
U2 = (Xq * Zp**2) % P
S1 = (Yp * Zq**3) % P
S2 = (Yq * Zp**3) % P
if U1 == U2:
if S1 != S2:
return (0, 0, 1)
return jacobianDouble((Xp, Yp, Zp), A, P)
H = U2 - U1
R = S2 - S1
H2 = (H * H) % P
H3 = (H * H2) % P
U1H2 = (U1 * H2) % P
nx = (R**2 - H3 - 2 * U1H2) % P
ny = (R * (U1H2 - nx) - S1 * H3) % P
nz = (H * Zp * Zq) % P
return (nx, ny, nz)
def jacobianMultiply(
Xp_Yp_Zp: Tuple[int, int, int], n: int, N: int, A: int, P: int
) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp
if Yp == 0 or n == 0:
return (0, 0, 1)
if n == 1:
return (Xp, Yp, Zp)
if n < 0 or n >= N:
return jacobianMultiply((Xp, Yp, Zp), n % N, N, A, P)
if (n % 2) == 0:
return jacobianDouble(jacobianMultiply((Xp, Yp, Zp), n // 2, N, A, P), A, P)
if (n % 2) == 1:
return jacobianAdd(
jacobianDouble(jacobianMultiply((Xp, Yp, Zp), n // 2, N, A, P), A, P),
(Xp, Yp, Zp),
A,
P,
)
raise ValueError("jacobian Multiply error")
# 生成元
U = multiply(g, random.randint(0, sm2p256v1.N - 1))
def hash2(double_G: Tuple[point, point]) -> int:
sm3 = Sm3() # pylint: disable=e0602
for i in double_G:
for j in i:
sm3.update(j.to_bytes(32))
digest = sm3.digest()
digest = int.from_bytes(digest, "big") % sm2p256v1.N
return digest
def hash3(triple_G: Tuple[point, point, point]) -> int:
sm3 = Sm3() # pylint: disable=e0602
for i in triple_G:
for j in i:
sm3.update(j.to_bytes(32))
digest = sm3.digest()
digest = int.from_bytes(digest, "big") % sm2p256v1.N
return digest
def hash4(triple_G: Tuple[point, point, point], Zp: int) -> int:
sm3 = Sm3() # pylint: disable=e0602
for i in triple_G:
for j in i:
sm3.update(j.to_bytes(32))
sm3.update(Zp.to_bytes(32))
digest = sm3.digest()
digest = int.from_bytes(digest, "big") % sm2p256v1.N
return digest
def KDF(G: point) -> int:
sm3 = Sm3() # pylint: disable=e0602
for i in G:
sm3.update(i.to_bytes(32))
digest = sm3.digest()
digest = int.from_bytes(digest, "big") % sm2p256v1.N
mask_128bit = (1 << 128) - 1
digest = digest & mask_128bit
return digest
def GenerateKeyPair() -> Tuple[point, int]:
"""
return:
public_key, secret_key
"""
sm2 = Sm2Key() # pylint: disable=e0602
sm2.generate_key()
public_key_x = int.from_bytes(bytes(sm2.public_key.x), "big")
public_key_y = int.from_bytes(bytes(sm2.public_key.y), "big")
public_key = (public_key_x, public_key_y)
secret_key = int.from_bytes(bytes(sm2.private_key), "big")
return public_key, secret_key
def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]:
enca = Encapsulate(pk)
K = enca[0].to_bytes(16)
capsule = enca[1]
if len(K) != 16:
raise ValueError("invalid key length")
iv = b"tpretpretpretpre"
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602
enc_Data = sm4_enc.update(m)
enc_Data += sm4_enc.finish()
enc_message = (capsule, bytes(enc_Data))
return enc_message
def Decapsulate(ska: int, capsule: capsule) -> int:
# E, V, s = capsule
E, V, _ = capsule
EVa = multiply(add(E, V), ska) # (E*V)^ska
K = KDF(EVa)
return K
def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
"""
params:
sk_A: secret key
C: (capsule, enc_data)
"""
capsule, enc_Data = C
K = Decapsulate(sk_A, capsule)
iv = b"tpretpretpretpre"
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
dec_Data = sm4_dec.update(enc_Data)
dec_Data += sm4_dec.finish()
return bytes(dec_Data)
# GenerateRekey
def hash5(id: int, D: int) -> int:
sm3 = Sm3() # pylint: disable=e0602
sm3.update(id.to_bytes(32))
sm3.update(D.to_bytes(32))
hash = sm3.digest()
hash = int.from_bytes(hash, "big") % sm2p256v1.N
return hash
def hash6(triple_G: Tuple[point, point, point]) -> int:
sm3 = Sm3() # pylint: disable=e0602
for i in triple_G:
for j in i:
sm3.update(j.to_bytes(32))
hash = sm3.digest()
hash = int.from_bytes(hash, "big") % sm2p256v1.N
return hash
def f(x: int, f_modulus: list, T: int) -> int:
"""
功能: 通过多项式插值来实现信息的分散和重构
例如: 随机生成一个多项式f(x)=4x+5,质数P=11,其中f(0)=5,将多项式的系数分别分配给两个人,例如第一个人得到(1, 9),第二个人得到(2, 2).如果两个人都收集到了这两个点,那么可以使用拉格朗日插值法恢复原始的多项式,进而得到秘密信息"5"
param:
x, f_modulus(多项式系数列表), T(门限)
return:
res
"""
res = 0
for i in range(T):
res += f_modulus[i] * pow(x, i)
res = res % sm2p256v1.N
return res
def GenerateReKey(
sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...]
) -> list:
"""
param:
skA, pkB, N(节点总数), T(阈值)
return:
rki(0 <= i <= N-1)
"""
# 计算临时密钥对(x_A, X_A)
x_A = random.randint(0, sm2p256v1.N - 1)
X_A = multiply(g, x_A)
pk_A = multiply(g, sk_A)
# d是Bob的密钥对与临时密钥对的非交互式Diffie-Hellman密钥交换的结果
d = hash3((X_A, pk_B, multiply(pk_B, x_A)))
# 计算多项式系数
f_modulus = []
# 计算f0
# f0 = (sk_A * inv(d, G.P)) % G.P
f0 = (sk_A * inv(d, sm2p256v1.N)) % sm2p256v1.N
f_modulus.append(f0)
# 计算fi(1 <= i <= T - 1)
for i in range(1, T):
f_modulus.append(random.randint(0, sm2p256v1.N - 1))
# 计算D
D = hash6((pk_A, pk_B, multiply(pk_B, sk_A)))
# 计算KF
KF = []
for i in range(N):
# seems unused?
# y = random.randint(0, sm2p256v1.N - 1)
# Y = multiply(g, y)
s_x = hash5(id_tuple[i], D) # id需要设置
r_k = f(s_x, f_modulus, T)
U1 = multiply(U, r_k)
kFrag = (id_tuple[i], r_k, X_A, U1)
KF.append(kFrag)
return KF
def Encapsulate(pk_A: point) -> Tuple[int, capsule]:
r = random.randint(0, sm2p256v1.N - 1)
u = random.randint(0, sm2p256v1.N - 1)
E = multiply(g, r)
V = multiply(g, u)
s = (u + r * hash2((E, V))) % sm2p256v1.N
pk_A_ru = multiply(pk_A, r + u)
K = KDF(pk_A_ru)
capsule = (E, V, s)
return (K, capsule)
def Checkcapsule(capsule: capsule) -> bool: # 验证胶囊的有效性
E, V, s = capsule
h2 = hash2((E, V))
g = (sm2p256v1.Gx, sm2p256v1.Gy)
result1 = multiply(g, s)
temp = multiply(E, h2) # 中间变量
result2 = add(V, temp) # result2=V*E^H2(E,V)
if result1 == result2:
flag = True
else:
flag = False
return flag
def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, point]:
# id, rk, Xa, U1 = kFrag
id, rk, Xa, _ = kFrag
# E, V, s = capsule
E, V, _ = capsule
if not Checkcapsule(capsule):
raise ValueError("Invalid capsule")
E1 = multiply(E, rk)
V1 = multiply(V, rk)
cfrag = E1, V1, id, Xa
return cfrag # cfrag=(E1,V1,id,Xa) E1= E^rk V1=V^rk
# 重加密函数
def ReEncrypt(
kFrag: tuple, C: Tuple[capsule, bytes]
) -> Tuple[Tuple[point, point, int, point], bytes]:
capsule, enc_Data = C
cFrag = ReEncapsulate(kFrag, capsule)
return (cFrag, enc_Data) # 输出密文
# capsule, enc_Data = C
# 将加密节点加密后产生的t个capsule,ct合并在一起,产生cfrags = {{capsule1,capsule2,...},ct}
def MergeCFrag(cfrag_cts: list) -> list:
ct_list = []
cfrags_list = []
cfrags = []
for cfrag_ct in cfrag_cts:
cfrags_list.append(cfrag_ct[0])
ct_list.append(cfrag_ct[1])
cfrags.append(cfrags_list)
cfrags.append(ct_list[0])
return cfrags
def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
"""
return:
K: sm4 key
"""
Elist = []
Vlist = []
idlist = []
X_Alist = []
for cfrag in cFrags: # Ei,Vi,id,Xa = cFrag
Elist.append(cfrag[0])
Vlist.append(cfrag[1])
idlist.append(cfrag[2])
X_Alist.append(cfrag[3])
pkab = multiply(pk_A, sk_B) # pka^b
D = hash6((pk_A, pk_B, pkab))
Sx = []
for id in idlist: # 从1到t
sxi = hash5(id, D) # id 节点的编号
Sx.append(sxi)
bis = [] # b ==> λ
bi = 1
for i in range(len(cFrags)):
bi = 1
for j in range(len(cFrags)):
if j != i:
Sxj_sub_Sxi = (Sx[j] - Sx[i]) % sm2p256v1.N
Sxj_sub_Sxi_inv = inv(Sxj_sub_Sxi, sm2p256v1.N)
bi = (bi * Sx[j] * Sxj_sub_Sxi_inv) % sm2p256v1.N
bis.append(bi)
E2 = multiply(Elist[0], bis[0]) # E^ 便于计算
V2 = multiply(Vlist[0], bis[0]) # V^
for k in range(1, len(cFrags)):
Ek = multiply(Elist[k], bis[k]) # EK/Vk 是个列表
Vk = multiply(Vlist[k], bis[k])
E2 = add(Ek, E2)
V2 = add(Vk, V2)
X_Ab = multiply(
X_Alist[0], sk_B
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
d = hash3((X_Alist[0], pk_B, X_Ab))
EV = add(E2, V2) # E2 + V2
EVd = multiply(EV, d) # (E2 + V2)^d
K = KDF(EVd)
return K
# M = IAEAM(K,enc_Data)
# cfrags = {{capsule1,capsule2,...},ct} ,ct->en_Data
def DecryptFrags(sk_B: int, pk_B: point, pk_A: point, cfrags: list) -> bytes:
capsules, enc_Data = cfrags # 加密后的密文
K = DecapsulateFrags(sk_B, pk_B, pk_A, capsules)
K = K.to_bytes(16)
iv = b"tpretpretpretpre"
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
try:
dec_Data = sm4_dec.update(enc_Data)
dec_Data += sm4_dec.finish()
except Exception as e:
print(e)
print("key error")
dec_Data = b""
return bytes(dec_Data)

View File

@@ -1,37 +0,0 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import add, multiply, sm2p256v1
import time
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust add 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python add 执行时间: {elapsed_time:.6f}")

View File

@@ -1,74 +0,0 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
for i in range(1, 10):
total_time = 0
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world" * pow(10, i)
print(f"明文长度:{len(m)}")
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"密钥生成运行时间:{elapsed_time}")
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"加密算法运行时间:{elapsed_time}")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"重加密密钥生成算法运行时间:{elapsed_time}")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
print(f"重加密算法运行时间:{elapsed_time}")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"解密算法运行时间:{elapsed_time}")
print("成功解密:")
print(f"算法总运行时间:{total_time}")
print()

View File

@@ -1,70 +0,0 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import *
import time
N = 80
total_time = 0
while total_time < 1:
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
total_time = 0
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world"
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"密钥生成运行时间:{elapsed_time}秒")
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"加密算法运行时间:{elapsed_time}秒")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"重加密密钥生成算法运行时间:{elapsed_time}秒")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
# print(f"重加密算法运行时间:{elapsed_time}秒")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"解密算法运行时间:{elapsed_time}秒")
# print("成功解密:", m)
print(f"算法总运行时间:{total_time}")
print()
N += 1

View File

@@ -1,88 +0,0 @@
# 测试 node.py 中的函数
import os
import unittest
from unittest.mock import patch, MagicMock, Mock
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
import node
class TestGetLocalIP(unittest.TestCase):
@patch.dict("os.environ", {"HOST_IP": "60.204.193.58"}) # 模拟设置 HOST_IP 环境变量
def test_get_ip_from_env(self):
# 调用被测函数
node.get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(node.ip, "60.204.193.58")
@patch("socket.socket") # Mock socket 连接行为
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
def test_get_ip_from_socket(self, mock_socket):
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
node.get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(node.ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@patch("requests.get") # Mock requests.get 调用
def test_send_ip(self, mock_get):
# 设置模拟返回的 HTTP 响应
mock_response = Mock()
mock_response.text = "node123" # 模拟返回的节点ID
mock_response.status_code = 200
mock_get.return_value = (
mock_response # 设置 requests.get() 的返回值为 mock_response
)
# 保存原始的全局 id 值
original_id = node.id
# 调用待测函数
node.send_ip()
# 确保 requests.get 被正确调用
expected_url = f"{node.server_address}/get_node?ip={node.ip}"
mock_get.assert_called_once_with(expected_url, timeout=3)
# 检查 id 是否被正确更新
self.assertIs(node.id, mock_response) # 检查 id 是否被修改
self.assertEqual(
node.id.text, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
node.init()
# 验证 get_local_ip 和 send_ip 被调用
mock_get_local_ip.assert_called_once()
mock_send_ip.assert_called_once()
# 确保 create_task 被调用来启动心跳包
mock_create_task.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -1,195 +0,0 @@
# node_test剩下部分(有问题)
import os
import unittest
import pytest
from unittest.mock import patch, MagicMock, Mock, AsyncMock
import requests
import asyncio
import httpx
import respx
from fastapi.testclient import TestClient
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from node import (
app,
send_heartbeat_internal,
Req,
send_ip,
get_local_ip,
init,
clear,
send_user_des_message,
id,
)
client = TestClient(app)
server_address = "http://60.204.236.38:8000/server"
# ip = None # 初始化全局变量 ip
# id = None # 初始化全局变量 id
class TestGetLocalIP(unittest.TestCase):
def test_get_ip_from_env(self):
os.environ["HOST_IP"] = "60.204.193.58" # 模拟设置 HOST_IP 环境变量
ip = get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(ip, "60.204.193.58")
@patch("socket.socket") # Mock socket 连接行为
def test_get_ip_from_socket(self, mock_socket):
os.environ.pop("HOST_IP", None)
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
ip = get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@respx.mock
def test_send_ip(self):
global ip, id
ip = "60.204.193.58"
mock_url = f"{server_address}/get_node?ip={ip}"
respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123"))
# 调用待测函数
send_ip()
# 确保 requests.get 被正确调用
self.assertEqual(
id, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
init()
# 验证 get_local_ip 和 send_ip 被调用
mock_get_local_ip.assert_called_once()
mock_send_ip.assert_called_once()
# 确保 create_task 被调用来启动心跳包
mock_create_task.assert_called_once()
def test_clear(self):
# 调用 clear 函数
clear()
# 检查输出
self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_success():
global ip
ip = "60.204.193.58"
# 模拟心跳请求
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
return_value=httpx.Response(200)
)
# 模拟 requests.get 以避免实际请求
with patch("requests.get", return_value=httpx.Response(200)) as mock_get:
# 模拟 asyncio.sleep 以避免实际延迟
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
task = asyncio.create_task(send_heartbeat_internal())
await asyncio.sleep(0.1) # 允许任务运行一段时间
task.cancel() # 取消任务以停止无限循环
try:
await task # 确保任务被等待
except asyncio.CancelledError:
pass # 捕获取消错误
assert mock_get.called
assert mock_get.call_count > 0
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_failure():
global ip
ip = "60.204.193.58"
# 模拟心跳请求以引发异常
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
side_effect=httpx.RequestError("Central server error")
)
# 模拟 requests.get 以避免实际请求
with patch(
"requests.get", side_effect=httpx.RequestError("Central server error")
) as mock_get:
# 模拟 asyncio.sleep 以避免实际延迟
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
task = asyncio.create_task(send_heartbeat_internal())
await asyncio.sleep(0.1) # 允许任务运行一段时间
task.cancel() # 取消任务以停止无限循环
try:
await task # 确保任务被等待
except asyncio.CancelledError:
pass # 捕获取消错误
assert mock_get.called
assert mock_get.call_count > 0
def test_user_src():
# 模拟 ReEncrypt 函数
with patch(
"node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")
):
# 模拟 send_user_des_message 函数
with patch(
"node.send_user_des_message", new_callable=AsyncMock
) as mock_send_user_des_message:
message = {
"source_ip": "60.204.193.58",
"dest_ip": "60.204.193.59",
"capsule": (("x1", "y1"), ("x2", "y2"), 123),
"ct": 456,
"rk": ["rk1", "rk2"],
}
response = client.post("/user_src", json=message)
assert response.status_code == 200
assert response.json() == {"detail": "message received"}
mock_send_user_des_message.assert_called_once()
def test_send_user_des_message():
with respx.mock:
dest_ip = "60.204.193.59"
re_message = (("a", "b", "c", "d"), 123)
respx.post(f"http://{dest_ip}:8002/receive_messages").mock(
return_value=httpx.Response(200, json={"status": "success"})
)
response = requests.post(
f"http://{dest_ip}:8002/receive_messages",
json={"Tuple": re_message, "ip": "60.204.193.58"},
)
assert response.status_code == 200
assert response.json() == {"status": "success"}
if __name__ == "__main__":
unittest.main()

View File

@@ -1,140 +0,0 @@
import sqlite3
import pytest
from fastapi.testclient import TestClient
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from server import app, validate_ip
# 创建 TestClient 实例
client = TestClient(app)
# 准备测试数据库数据
def setup_db():
# 创建数据库并插入测试数据
with sqlite3.connect("server.db") as db:
db.execute(
"""
CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY,
ip TEXT NOT NULL,
last_heartbeat INTEGER NOT NULL
)
"""
)
db.execute(
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)"
)
db.execute(
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)"
)
db.commit()
# 清空数据库
def clear_db():
with sqlite3.connect("server.db") as db:
db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表
db.commit()
# 测试 IP 验证功能
def test_validate_ip():
assert validate_ip("192.168.0.1") is True
assert validate_ip("256.256.256.256") is False
assert validate_ip("::1") is True
assert validate_ip("invalid_ip") is False
# 测试首页路由
def test_home():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
# 测试 show_nodes 路由
def test_show_nodes():
setup_db()
response = client.get("/server/show_nodes")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0][1] == "192.168.0.1"
assert data[1][1] == "192.168.0.2"
# 测试 get_node 路由
def test_get_node():
# 确保数据库和表的存在
setup_db()
valid_ip = "192.168.0.3"
invalid_ip = "256.256.256.256"
# 测试有效的 IP 地址
response = client.get(f"/server/get_node?ip={valid_ip}")
assert response.status_code == 200
# 测试无效的 IP 地址
response = client.get(f"/server/get_node?ip={invalid_ip}")
assert response.status_code == 400
# 测试 delete_node 路由
def test_delete_node():
setup_db()
valid_ip = "192.168.0.1"
invalid_ip = "192.168.0.255"
response = client.get(f"/server/delete_node?ip={valid_ip}")
assert response.status_code == 200
assert "Node with IP 192.168.0.1 deleted successfully." in response.text
response = client.get(f"/server/delete_node?ip={invalid_ip}")
assert response.status_code == 404
# 测试 heartbeat 路由
def test_receive_heartbeat():
setup_db()
valid_ip = "192.168.0.2"
invalid_ip = "256.256.256.256"
response = client.get(f"/server/heartbeat?ip={valid_ip}")
assert response.status_code == 200
assert response.json() == {"status": "received"}
response = client.get(f"/server/heartbeat?ip={invalid_ip}")
assert response.status_code == 400
assert response.json() == {"message": "invalid ip format"}
# 测试 send_nodes_list 路由
def test_send_nodes_list():
setup_db()
response = client.get("/server/send_nodes_list?count=1")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0] == "192.168.0.1"
response = client.get("/server/send_nodes_list?count=2")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
# 运行完测试后清理数据库
@pytest.fixture(autouse=True)
def run_around_tests():
clear_db()
yield
clear_db()

View File

@@ -1,72 +0,0 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
total_time = 0
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world"
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"密钥生成运行时间:{elapsed_time}")
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"加密算法运行时间:{elapsed_time}")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"重加密密钥生成算法运行时间:{elapsed_time}")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
print(f"重加密算法运行时间:{elapsed_time}")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"解密算法运行时间:{elapsed_time}")
print("成功解密:", m)
print(f"算法总运行时间:{total_time}")
print()

View File

@@ -1,76 +0,0 @@
import os
import pytest
from fastapi.testclient import TestClient
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from client import app, init_db, clean_env, get_own_ip
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
# 设置测试环境
init_db()
yield
# 清理测试环境
clean_env()
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
def test_receive_messages():
message = {"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), "ip": "127.0.0.1"}
response = client.post("/receive_messages", json=message)
assert response.status_code == 200
assert response.json().get("detail") == "Message received"
# @respx.mock
# def test_request_message():
# request_message = {
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
# "message_name": "name"
# }
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
# response = client.post("/request_message", json=request_message)
# assert response.status_code == 200
# assert "threshold" in response.json()
# assert "public_key" in response.json()
# @respx.mock
# def test_receive_request():
# ip_message = {
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
# "message_name": "name",
# "source_ip": "124.70.165.73", # 使用不同的 IP 地址
# "pk": (123, 456)
# }
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
# response = client.post("/receive_request", json=ip_message)
# assert response.status_code == 200
# assert "threshold" in response.json()
# assert "public_key" in response.json()
def test_get_pk():
response = client.get("/get_pk")
assert response.status_code == 200
assert "pkx" in response.json()
assert "pky" in response.json()
def test_recieve_pk():
pk_data = {"pkx": "123", "pky": "456", "ip": "127.0.0.1"}
response = client.post("/recieve_pk", json=pk_data)
assert response.status_code == 200
assert response.json() == {"message": "save pk in database"}
if __name__ == "__main__":
pytest.main()

View File

@@ -1,158 +0,0 @@
import sys
import os
import hashlib
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
hash2,
hash3,
hash4,
multiply,
g,
sm2p256v1,
GenerateKeyPair,
Encrypt,
Decrypt,
GenerateReKey,
Encapsulate,
ReEncrypt,
DecryptFrags,
MergeCFrag,
)
import random
import unittest
class TestHash2(unittest.TestCase):
def setUp(self):
self.double_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash2(self.double_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash2(self.double_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash3(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash3(self.triple_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash3(self.triple_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash4(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
self.Zp = random.randint(0, sm2p256v1.N - 1)
def test_digest_type(self):
digest = hash4(self.triple_G, self.Zp)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash4(self.triple_G, self.Zp)
self.assertLess(digest, sm2p256v1.N)
class TestGenerateKeyPair(unittest.TestCase):
def test_key_pair(self):
public_key, secret_key = GenerateKeyPair()
self.assertIsInstance(public_key, tuple)
self.assertEqual(len(public_key), 2)
self.assertIsInstance(secret_key, int)
self.assertLess(secret_key, sm2p256v1.N)
self.assertGreater(secret_key, 0)
# class TestEncryptDecrypt(unittest.TestCase):
# def setUp(self):
# self.public_key, self.secret_key = GenerateKeyPair()
# self.message = b"Hello, world!"
# def test_encrypt_decrypt(self):
# encrypted_message = Encrypt(self.public_key, self.message)
# # 使用 SHA-256 哈希函数确保密钥为 16 字节
# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest()
# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = Decrypt(secret_key_int, encrypted_message)
# self.assertEqual(decrypted_message, self.message)
class TestGenerateReKey(unittest.TestCase):
def test_generate_rekey(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
self.assertIsInstance(rekey, list)
self.assertEqual(len(rekey), 5)
class TestEncapsulate(unittest.TestCase):
def test_encapsulate(self):
pk_A, _ = GenerateKeyPair()
K, capsule = Encapsulate(pk_A)
self.assertIsInstance(K, int)
self.assertIsInstance(capsule, tuple)
self.assertEqual(len(capsule), 3)
class TestReEncrypt(unittest.TestCase):
def test_reencrypt(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
pk_A, _ = GenerateKeyPair()
message = b"Hello, world!"
encrypted_message = Encrypt(pk_A, message)
reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
self.assertIsInstance(reencrypted_message, tuple)
self.assertEqual(len(reencrypted_message), 2)
# class TestDecryptFrags(unittest.TestCase):
# def test_decrypt_frags(self):
# sk_A = random.randint(0, sm2p256v1.N - 1)
# pk_B, sk_B = GenerateKeyPair()
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
# pk_A, _ = GenerateKeyPair()
# message = b"Hello, world!"
# encrypted_message = Encrypt(pk_A, message)
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
# cfrags = [reencrypted_message]
# merged_cfrags = MergeCFrag(cfrags)
# self.assertIsNotNone(merged_cfrags)
# sk_B_hash = hashlib.sha256(sk_B.to_bytes((sk_B.bit_length() + 7) // 8, 'big')).digest()
# sk_B_int = int.from_bytes(sk_B_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = DecryptFrags(sk_B_int, pk_B, pk_A, merged_cfrags)
# self.assertEqual(decrypted_message, message)
if __name__ == "__main__":
unittest.main()

27
todo.md
View File

@@ -1,27 +0,0 @@
# todolist
## todo
- [] 编写单元测试代码
## finished
- [x] 测试单核和多核性能
- 这个算法在获得足够的CPU资源即接近或等于1个完整核心时表现最佳。
- 过低的CPU资源分配会严重影响性能而适度的分配如0.1核心)则能提供更合理的性能。
- 单核和多核性能差异不大
- [x] 测试不同cpu架构性能的差异
- 测试了12th Gen Intel(R) Core(TM) i5-12490F 和 rk3399两颗cpu的性能
- [x] 测试不同消息长度的时间
- 测试了10M文本的加密速度在1s内可以完成全部算法内容
- [x] 测试极限1s时的节点数
- 12th Gen i5 CPU大概是90多个节点时达到1s的时间上限
- [x] 非docker部署需要获取本机ip
- 添加了通过网卡获取ip的方法
- [x] 复习预备知识
- [x] 准备圆场话术