Compare commits

...

46 Commits

Author SHA1 Message Date
c64ccd57e3 Merge pull request 'merge into main' (#36) from fix-update into main
All checks were successful
Test CI / test speed (push) Successful in 18s
Reviewed-on: #36
2024-10-24 23:35:42 +08:00
835c908ca7 Merge remote-tracking branch 'refs/remotes/origin/fix-update' 2024-10-15 20:19:46 +08:00
fffa6728bc fix: remove duplicate import 2024-10-15 20:19:26 +08:00
7fa5bb460c Add css style 2024-10-15 19:59:20 +08:00
bf6e25b722 Add front-end functionality and improved server.py and node.py
All checks were successful
Test CI / test speed (push) Successful in 17s
2024-10-15 17:05:46 +08:00
24a5e91382 忽略 node_modules 目录 2024-10-15 17:05:46 +08:00
2308b01cee fix: fix test_get_ip_from_env and send_ip
Some checks failed
Test CI / test speed (push) Failing after 1m8s
2024-10-15 16:29:10 +08:00
04d44afac0 Update docker-compose.yml and requirements.txt 2024-10-08 23:33:40 +08:00
ff9a763cec change base image to python3.12-slim 2024-10-08 23:02:46 +08:00
0bb058df6e update install script 2024-10-08 22:56:33 +08:00
983ac91f74 Merge branch 'fix-update' of https://git.mamahaha.work/sangge/tpre-python into fix-update 2024-10-08 21:31:11 +08:00
81e34965e0 Merge branch 'fix-update' of https://git.mamahaha.work/sangge/tpre-python into fix-update 2024-10-08 21:30:50 +08:00
90a173dfa1 feat:添加docker-compose 2024-10-08 21:30:50 +08:00
f69eecf58b 更新 src/node.py 2024-10-08 21:30:50 +08:00
bba12099cd fix:serverIP从环境变量获取 2024-10-08 21:30:50 +08:00
a9a4e28d9a Improve test code files 2024-10-08 21:30:50 +08:00
d071296126 Remove unnecessary files and configurations 2024-10-08 21:30:50 +08:00
2839dd5500 Merge branch 'fix-update' of https://git.mamahaha.work/sangge/tpre-python into fix-update 2024-10-08 21:28:48 +08:00
cb76380242 feat:添加docker-compose 2024-10-08 21:28:29 +08:00
f5493ed8a1 更新 src/node.py 2024-10-08 21:28:29 +08:00
0830a6733f fix:serverIP从环境变量获取 2024-10-08 21:28:29 +08:00
ea8f17920a Improve test code files 2024-10-08 21:28:29 +08:00
e79b24f909 Remove unnecessary files and configurations 2024-10-08 21:28:22 +08:00
15e35405f0 feat:添加docker-compose 2024-10-08 20:30:35 +08:00
b26bd92328 更新 src/node.py
All checks were successful
Test CI / test speed (push) Successful in 14s
2024-10-08 20:25:36 +08:00
5b85db2427 fix:serverIP从环境变量获取
All checks were successful
Test CI / test speed (push) Successful in 14s
2024-10-08 20:25:04 +08:00
061bd5d2bf Improve test code files
All checks were successful
Test CI / test speed (push) Successful in 10s
2024-10-04 22:29:21 +08:00
aa0eb6bfbc Remove unnecessary files and configurations 2024-10-04 22:29:21 +08:00
98e0e1122b doc: 添加关于智能合约部署部分文档 2024-10-04 15:03:35 +08:00
7f1e201c33 Merge remote-tracking branch 'refs/remotes/origin/fix-update'
All checks were successful
Test CI / test speed (push) Successful in 15s
2024-10-01 14:55:00 +08:00
2f7f55fd3a test: add unit test 2024-10-01 14:55:00 +08:00
05de02f2a5 test: Add unit tests 2024-10-01 14:55:00 +08:00
acbc9ecb1a test(Add unit tests): 2024-10-01 14:54:49 +08:00
1c18726066 update eth 2024-09-30 21:51:49 +08:00
e8e7c59579 feat: add basic blockchain
All checks were successful
Test CI / test speed (push) Successful in 1m48s
2024-09-22 20:28:37 +08:00
6c240b1237 finish optimize 2024-09-06 10:48:07 +08:00
2d12e8c99c add rust implementation
All checks were successful
Test CI / test speed (push) Successful in 15s
2024-09-05 10:32:43 +08:00
9654d8504b refactor server 2024-09-05 10:32:14 +08:00
53928b7f9e update submodule 2024-09-05 10:31:44 +08:00
6f19620490 update ci 2024-09-04 10:23:42 +08:00
ca253dbb77 update doc and requirement 2024-09-04 10:23:14 +08:00
68ed843777 add submodule ecc_rs 2024-09-04 10:22:40 +08:00
d0a916afb2 revert to v3.1.1 2024-09-02 12:28:31 +08:00
3b2a728e14 update gitignore
Some checks failed
Test CI / test speed (push) Failing after 7m23s
2024-09-02 11:21:53 +08:00
82309c5505 refactor: refactor some code 2024-09-02 11:21:39 +08:00
8b89e1b722 Update submodule to the latest commit 2024-08-30 07:17:14 +08:00
53 changed files with 21518 additions and 533 deletions

View File

@ -3,7 +3,7 @@ name: Test CI
on:
push:
paths:
- 'src/**'
- "src/**"
jobs:
test:
@ -14,10 +14,10 @@ jobs:
image: catthehacker/ubuntu:act-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Checkout repository
uses: https://git.mamahaha.work/actions/checkout@v3
- name: Run script in Docker container
run: |
ls $PWD/src
docker run --rm -v .:/app git.mamahaha.work/sangge/tpre:base ls
# - name: Run script in Docker container
# run: |
# ls $PWD/src
# docker run --rm -v .:/app git.mamahaha.work/sangge/tpre:base ls

4
.gitignore vendored
View File

@ -14,3 +14,7 @@ build
src/tpre.cpython-311-x86_64-linux-gnu.so
.vscode
venv
lib
include
/frontend/node_modules/

3
.gitmodules vendored
View File

@ -1,3 +1,6 @@
[submodule "gmssl"]
path = gmssl
url = https://github.com/guanzhi/GmSSL.git
[submodule "ecc_rs"]
path = ecc_rs
url = https://git.mamahaha.work/sangge/ecc_rs.git

1
FQA.md
View File

@ -59,3 +59,4 @@
15. **Q:国密SM2 SM3 SM4** 如题
**A:**
----------------------------------------------------------------------

View File

@ -30,9 +30,9 @@
- Windows (需要自行安装gmssl的共享库)
该项目依赖以下软件:
python 3.11
gmssl
gmssl-python
python 3.12
gmssl v3.1.1
gmssl-python 2.2.2
### Docker 版本安装
@ -46,9 +46,9 @@ chmod +x install_docker.sh
docker 版本:
- 版本: 24.0.5
- API 版本: 1.43
- Go 版本: go1.20.6
- 版本: 24.0.5
- API 版本: 1.43
- Go 版本: go1.20.6
## 安装步骤
@ -56,6 +56,14 @@ docker 版本:
本项目依赖gmssl所以请提前安装好。访问 [GmSSL](https://github.com/guanzhi/GmSSL) 可以看到如何安装。
本项目也提供了submodule的方式可以直接使用
```bash
git clone --recurse-submodules https://git.mamahaha.work/sangge/tpre-python.git
chmod +x install_gmssl.sh
./install_docker.sh
```
然后安装必要的python库
```bash

View File

@ -30,9 +30,9 @@ System requirements:
The project relies on the following software:
- Python 3.11
- gmssl
- gmssl-python
- Python 3.12
- gmssl v3.1.1
- gmssl-python 2.2.2
### Docker installer
@ -46,9 +46,9 @@ chmod +x install_docker.sh
docker version:
- Version: 24.0.5
- API version: 1.43
- Go version: go1.20.6
- Version: 24.0.5
- API version: 1.43
- Go version: go1.20.6
## Installation Steps

View File

@ -1,4 +1,4 @@
FROM python:3.11
FROM python:3.12-slim
COPY requirements.txt /app/
@ -8,9 +8,12 @@ COPY requirements.txt /app/
# 根据目标平台复制相应架构的库文件
#COPY lib/${TARGETPLATFORM}/* /lib/
COPY lib/* /lib/
COPY lib/* /usr/local/lib/
WORKDIR /app
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
RUN pip install --index-url https://git.mamahaha.work/api/packages/sangge/pypi/simple/ ecc-rs
RUN ldconfig

View File

@ -3,16 +3,21 @@
## Run docker
```bash
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=60.204.193.58 git.mamahaha.work/sangge/tpre:base bash
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=119.3.125.234 git.mamahaha.work/sangge/tpre:base bash
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=124.70.165.73 git.mamahaha.work/sangge/tpre:base bash
docker run -it -p 8000-8002:8000-8002 -v ~/tpre-python/src:/app -e HOST_IP=192.168.8.57 -e server_address=192.168.8.57:8000 git.mamahaha.work/sangge/tpre:base bash
docker run -it -p 8000-8002:8000-8002 -v ~/tpre-python/src:/app -e HOST_IP=119.3.125.234 git.mamahaha.work/sangge/tpre:base bash
docker run -it -p 8000-8002:8000-8002 -v ~/tpre-python/src:/app -e HOST_IP=124.70.165.73 git.mamahaha.work/sangge/tpre:base bash
```
```bash
tpre3: docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=60.204.233.103 git.mamahaha.work/sangge/tpre:base bash
```
## Deploy contract
You should deploy the contract yourself in src/logger.sol using remix or any CLI-tools and replace the contract address in src/node.py with your actual address.
[Deployment document](https://remix-ide.readthedocs.io/zh-cn/latest/create_deploy.html)
## Start application
You should replace the wallet address/privateKey in src/node.py with your own wallet address/privateKey.
```bash
nohup python server.py &

31
docker-compose.yml Normal file
View File

@ -0,0 +1,31 @@
version: "3"
services:
server:
image: git.mamahaha.work/sangge/tpre:base
volumes:
- ./src:/app
environment:
- server_address=http://server:8000
entrypoint:
- python
- server.py
node:
image: git.mamahaha.work/sangge/tpre:base
volumes:
- ./src:/app
environment:
- server_address=http://server:8000
entrypoint:
- python
- node.py
client:
image: git.mamahaha.work/sangge/tpre:base
volumes:
- ./src:/app
environment:
- server_address=http://server:8000
entrypoint:
- python
- client.py

1
ecc_rs Submodule

@ -0,0 +1 @@
Subproject commit 880c34ce031158d3f27116b3e6d3a0a3748310db

23
frontend/.gitignore vendored Normal file
View File

@ -0,0 +1,23 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.js
# testing
/coverage
# production
/build
# misc
.DS_Store
.env.local
.env.development.local
.env.test.local
.env.production.local
npm-debug.log*
yarn-debug.log*
yarn-error.log*

70
frontend/README.md Normal file
View File

@ -0,0 +1,70 @@
# Getting Started with Create React App
This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).
## Available Scripts
In the project directory, you can run:
### `npm start`
Runs the app in the development mode.\
Open [http://localhost:3000](http://localhost:3000) to view it in your browser.
The page will reload when you make changes.\
You may also see any lint errors in the console.
### `npm test`
Launches the test runner in the interactive watch mode.\
See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information.
### `npm run build`
Builds the app for production to the `build` folder.\
It correctly bundles React in production mode and optimizes the build for the best performance.
The build is minified and the filenames include the hashes.\
Your app is ready to be deployed!
See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information.
### `npm run eject`
**Note: this is a one-way operation. Once you `eject`, you can't go back!**
If you aren't satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project.
Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you're on your own.
You don't have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn't feel obligated to use this feature. However we understand that this tool wouldn't be useful if you couldn't customize it when you are ready for it.
## Learn More
You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started).
To learn React, check out the [React documentation](https://reactjs.org/).
### Code Splitting
This section has moved here: [https://facebook.github.io/create-react-app/docs/code-splitting](https://facebook.github.io/create-react-app/docs/code-splitting)
### Analyzing the Bundle Size
This section has moved here: [https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size](https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size)
### Making a Progressive Web App
This section has moved here: [https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app](https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app)
### Advanced Configuration
This section has moved here: [https://facebook.github.io/create-react-app/docs/advanced-configuration](https://facebook.github.io/create-react-app/docs/advanced-configuration)
### Deployment
This section has moved here: [https://facebook.github.io/create-react-app/docs/deployment](https://facebook.github.io/create-react-app/docs/deployment)
### `npm run build` fails to minify
This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify)

19780
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

39
frontend/package.json Normal file
View File

@ -0,0 +1,39 @@
{
"name": "frontend",
"version": "0.1.0",
"private": true,
"dependencies": {
"@testing-library/jest-dom": "^5.17.0",
"@testing-library/react": "^13.4.0",
"@testing-library/user-event": "^13.5.0",
"axios": "^1.7.7",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-scripts": "5.0.1",
"web-vitals": "^2.1.4"
},
"scripts": {
"start": "react-scripts start",
"build": "react-scripts build",
"test": "react-scripts test",
"eject": "react-scripts eject"
},
"eslintConfig": {
"extends": [
"react-app",
"react-app/jest"
]
},
"browserslist": {
"production": [
">0.2%",
"not dead",
"not op_mini all"
],
"development": [
"last 1 chrome version",
"last 1 firefox version",
"last 1 safari version"
]
}
}

BIN
frontend/public/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

View File

@ -0,0 +1,13 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>中央服务器路由</title>
<link rel="stylesheet" href="./index.css">
</head>
<body>
<div id="root"></div>
<script src="/frontend/static/js/bundle.js"></script> <!-- 通过构建工具引入打包后的JS文件 -->
</body>
</html>

View File

@ -0,0 +1,25 @@
{
"short_name": "React App",
"name": "Create React App Sample",
"icons": [
{
"src": "favicon.ico",
"sizes": "64x64 32x32 24x24 16x16",
"type": "image/x-icon"
},
{
"src": "logo192.png",
"type": "image/png",
"sizes": "192x192"
},
{
"src": "logo512.png",
"type": "image/png",
"sizes": "512x512"
}
],
"start_url": ".",
"display": "standalone",
"theme_color": "#000000",
"background_color": "#ffffff"
}

View File

@ -0,0 +1,3 @@
# https://www.robotstxt.org/robotstxt.html
User-agent: *
Disallow:

114
frontend/src/App.css Normal file
View File

@ -0,0 +1,114 @@
@keyframes glow-border {
0% {
box-shadow: inset 0 0 0 2px #00e6e6;
}
25% {
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6;
}
50% {
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6, 0 2px 0 0 #00e6e6;
}
75% {
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6, 0 2px 0 0 #00e6e6, -2px 0 0 0 #00e6e6;
}
100% {
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6, 0 2px 0 0 #00e6e6, -2px 0 0 0 #00e6e6, 0 -2px 0 0 #00e6e6;
}
}
body {
background-color: #1e1e1e;
color: #ffffff;
font-family: Arial, sans-serif;
}
.App-header {
text-align: center;
padding: 20px;
}
h1 {
color: #eff4f0;
font-size: 4em;
margin-bottom: 20px;
}
h2 {
color: #2196f3;
font-size: 2em;
margin-top: 20px;
margin-bottom: 10px;
}
h3 {
color: #ff9800;
font-size: 1.75em;
margin-top: 15px;
margin-bottom: 10px;
}
h4 {
color: #f44336;
font-size: 1.5em;
margin-top: 10px;
margin-bottom: 5px;
}
input, button {
margin: 10px 0;
padding: 10px;
border: none;
border-radius: 5px;
}
input {
width: 80%;
max-width: 300px;
}
button {
background-color: #2196f3;
color: #ffffff;
cursor: pointer;
}
button:hover {
background-color: #1976d2;
}
section {
background-color: #333333;
padding: 20px;
margin: 20px 0;
border-radius: 10px;
animation: glow-border 4s infinite;
}
.log-info {
background-color: #444444;
color: #ffffff;
padding: 10px;
border-radius: 5px;
animation: glow-border 4s infinite;
}
ul {
list-style-type: none;
padding: 0;
}
li {
background-color: #444444;
margin: 5px 0;
padding: 10px;
border-radius: 5px;
}
.container {
display: flex;
justify-content: space-between;
}
.left-panel, .right-panel {
width: 48%;
}

118
frontend/src/App.js Normal file
View File

@ -0,0 +1,118 @@
import React, { useEffect, useState } from 'react';
import axios from 'axios';
import WebSocketComponent from './WebSocketComponent';
import './App.css';
function App() {
const [node, setNode] = useState(null);
const [heartbeat, setHeartbeat] = useState(null);
const [nodesList, setNodesList] = useState([]);
const [ip, setIp] = useState('');
const [count, setCount] = useState('');
const fetchNodes = async () => {
try {
const response = await axios.get('/server/show_nodes');
setNodesList(response.data);
} catch (error) {
console.error('Error fetching nodes:', error);
}
};
const fetchNode = async (ip) => {
try {
const response = await axios.get('/server/get_node', { params: { ip } });
setNode(response.data);
} catch (error) {
console.error('Error fetching node:', error);
}
};
const fetchHeartbeat = async (ip) => {
try {
const response = await axios.get('/server/heartbeat', { params: { ip } });
setHeartbeat(response.data);
} catch (error) {
console.error('Error fetching heartbeat:', error);
}
};
const fetchNodesList = async (count) => {
try {
const response = await axios.get('/server/send_nodes_list', { params: { count } });
setNodesList(response.data);
} catch (error) {
console.error('Error fetching nodes list:', error);
}
};
useEffect(() => {
fetchNodes(); // 获取所有节点
}, []);
const handleFetchNode = () => {
fetchNode(ip); // 根据输入的 IP 获取单个节点
};
const handleFetchHeartbeat = () => {
fetchHeartbeat(ip); // 根据输入的 IP 获取心跳信息
};
const handleFetchNodesList = () => {
fetchNodesList(count); // 根据输入的数量获取节点列表
};
return (
<div className="App">
<header className="App-header">
<h1 className="glow">The server</h1>
<div className="container">
<div className="left-panel">
<section>
<h2 className="glow">get node</h2>
<input
type="text"
value={ip}
onChange={(e) => setIp(e.target.value)}
placeholder="Enter node ip"
/>
<button onClick={handleFetchNode}>send</button>
{node ? <p>{JSON.stringify(node)}</p> : <p>here is nothing!</p>}
</section>
<section>
<h2 className="glow">heartbeat</h2>
<button onClick={handleFetchHeartbeat}>Get heartbeat</button>
{heartbeat ? <p>{JSON.stringify(heartbeat)}</p> : <p>here is nothing!</p>}
</section>
<section>
<h2 className="glow">nodes list</h2>
<input
type="number"
value={count}
onChange={(e) => setCount(e.target.value)}
placeholder="Enter the number of nodes"
/>
<button onClick={handleFetchNodesList}>Get node list</button>
{nodesList.length > 0 ? (
<ul>
{nodesList.map((node, index) => (
<li key={index}>{JSON.stringify(node)}</li>
))}
</ul>
) : (
<p>here is nothing!</p>
)}
</section>
</div>
<div className="right-panel">
<WebSocketComponent />
</div>
</div>
</header>
</div>
);
}
export default App;

View File

@ -0,0 +1,66 @@
import React, { useEffect, useState, useRef, useCallback } from 'react';
const WebSocketComponent = () => {
const [logs, setLogs] = useState([]);
const wsRef = useRef(null);
const connectWebSocket = useCallback(() => {
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
return; // 如果连接已经打开,不再重新连接
}
wsRef.current = new WebSocket('ws://localhost:8000/ws/logs');
wsRef.current.onopen = () => {
console.log('WebSocket 连接成功');
};
wsRef.current.onmessage = (event) => {
setLogs((prevLogs) => [...prevLogs, event.data]); // 直接加入收到的消息
};
wsRef.current.onerror = (error) => {
console.error('WebSocket 错误: ', error);
};
wsRef.current.onclose = () => {
console.log('WebSocket 连接关闭,尝试重新连接...');
// 确保 WebSocket 连接在关闭后再进行重连
setTimeout(() => {
connectWebSocket();
}, 5000); // 延迟 5 秒再重连
};
}, []);
useEffect(() => {
connectWebSocket();
return () => {
if (wsRef.current) {
wsRef.current.close();
}
};
}, [connectWebSocket]);
useEffect(() => {
const logContainer = document.querySelector('.log-info');
if (logContainer) {
logContainer.scrollTop = logContainer.scrollHeight;
}
}, [logs]);
return (
<div>
<h2>The logs</h2>
<div
className="log-info"
style={{ height: '550px', overflowY: 'scroll', backgroundColor: 'rgb(32, 28, 28)', padding: '10px' }}
>
{logs.map((log, index) => (
<p key={index} style={{ margin: '5px 0' }}>{log}</p>
))}
</div>
</div>
);
};
export default WebSocketComponent;

11
frontend/src/index.css Normal file
View File

@ -0,0 +1,11 @@
/* src/index.css */
body {
margin: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
code {
font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace;
}

13
frontend/src/index.js Normal file
View File

@ -0,0 +1,13 @@
import React from 'react';
import { createRoot } from 'react-dom/client';
import App from './App';
import './index.css';
const container = document.getElementById('root');
const root = createRoot(container);
root.render(
<React.StrictMode>
<App />
</React.StrictMode>
);

View File

@ -0,0 +1,10 @@
const { createProxyMiddleware } = require('http-proxy-middleware');
module.exports = function(app) {
app.use(
createProxyMiddleware('/server', {
target: 'http://localhost:8000',
changeOrigin: true,
})
);
};

2
gmssl

@ -1 +1 @@
Subproject commit 31efcb5d87f99fc50b448c181d0a0e59d7edea03
Subproject commit d655c06b3a6b0fe8cff900f293bf0e5aac6eb0a2

11
install_gmssl.sh Normal file → Executable file
View File

@ -3,10 +3,13 @@
mkdir lib
mkdir include
cp gmssl/include include
cp -r gmssl/include include
mkdir gmssl/build
cd gmssl/build
cd gmssl/build || exit
cmake ..
make
cp bin/lib* ../../lib
cd bin || exit
cp libgmssl.so libgmssl.so.3 libgmssl.so.3.1 ../../../lib
cp libsdf_dummy.so libsdf_dummy.so.3 libsdf_dummy.so.3.1 ../../../lib
cp libskf_dummy.so libskf_dummy.so.3 libskf_dummy.so.3.1 ../../../lib
sudo make install

6
package-lock.json generated Normal file
View File

@ -0,0 +1,6 @@
{
"name": "tpre-python",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

View File

@ -1,4 +1,5 @@
gmssl-python
gmssl-python>=2.2.2,<3.0.0
fastapi
uvicorn
requests
web3

View File

@ -2,14 +2,19 @@ from fastapi import FastAPI, HTTPException
import requests
import os
from typing import Tuple
from tpre import *
from tpre import (
GenerateKeyPair,
Encrypt,
DecryptFrags,
GenerateReKey,
MergeCFrag,
point,
)
import sqlite3
from contextlib import asynccontextmanager
from pydantic import BaseModel
import socket
import random
import time
import base64
import json
import pickle
from fastapi.responses import JSONResponse
@ -18,12 +23,12 @@ import asyncio
# 测试文本
test_msessgaes = {
"name": b"proxy re-encryption",
"environment": b"distributed environment"
"environment": b"distributed environment",
}
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(_: FastAPI):
init()
yield
clean_env()
@ -84,13 +89,8 @@ def init_db():
# load config from config file
def init_config():
import configparser
global server_address
config = configparser.ConfigParser()
config.read("client.ini")
server_address = config["settings"]["server_address"]
server_address = os.environ.get("server_address")
# execute on exit
@ -200,7 +200,7 @@ def check_merge(ct: int, ip: str):
try:
pkx, pky = result[0] # result[0] = (pkx, pky)
pk_sender = (int(pkx), int(pky))
except:
except IndexError:
pk_sender, T = 0, -1
T = 2
@ -212,7 +212,7 @@ def check_merge(ct: int, ip: str):
byte_length = (ct.bit_length() + 7) // 8
temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(byte_length)))
cfrags = mergecfrag(temp_cfrag_cts)
cfrags = MergeCFrag(temp_cfrag_cts)
print("sk:", type(sk))
print("pk:", type(pk))
@ -371,7 +371,7 @@ async def receive_request(i_m: IP_Message):
try:
message = test_msessgaes[i_m.message_name]
except:
except IndexError:
message = b"hello world" + random.randbytes(8)
print(f"Message to be send: {message}")
@ -391,7 +391,7 @@ def get_own_ip() -> str:
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
ip = s.getsockname()[0]
s.close()
except:
except IndexError:
raise ValueError("Unable to get IP")
return str(ip)
@ -464,7 +464,7 @@ async def recieve_pk(pk: pk_model):
pk = (0, 0)
sk = 0
server_address = str
server_address = os.environ.get("server_address")
node_response = False
message = bytes
local_ip = get_own_ip()

View File

@ -4,14 +4,14 @@ import json
def send_post_request(ip_addr, message_name):
url = f"http://localhost:8002/request_message"
url = "http://localhost:8002/request_message"
data = {"dest_ip": ip_addr, "message_name": message_name}
response = requests.post(url, json=data)
return response.text
def get_pk(ip_addr):
url = f"http://" + ip_addr + ":8002/get_pk"
url = "http://" + ip_addr + ":8002/get_pk"
response = requests.get(url, timeout=1)
print(response.text)
json_pk = json.loads(response.text)
@ -21,7 +21,6 @@ def get_pk(ip_addr):
return response.text
def main():
parser = argparse.ArgumentParser(description="Send POST request to a specified IP.")
parser.add_argument("ip_addr", help="IP address to send request to.")

View File

View File

@ -64,7 +64,7 @@ for N in range(4, 21, 2):
# 9
start_time = time.time()
cfrags = mergecfrag(cfrag_cts)
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time_dec = end_time - start_time

63
src/eth_logger.py Normal file
View File

@ -0,0 +1,63 @@
from web3 import Web3
import json
rpc_url = "https://ethereum-holesky-rpc.publicnode.com"
chain = Web3(Web3.HTTPProvider(rpc_url))
contract_address = "0x642C23F91bf8339044A00251BC09d1D98110C433"
contract_abi = json.loads(
"""[
{
"anonymous": false,
"inputs": [
{
"indexed": false,
"internalType": "string",
"name": "msg",
"type": "string"
}
],
"name": "messageLog",
"type": "event"
},
{
"inputs": [
{
"internalType": "string",
"name": "text",
"type": "string"
}
],
"name": "logmessage",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "nonpayable",
"type": "function"
}
]"""
)
contract = chain.eth.contract(address=contract_address, abi=contract_abi)
wallet_address = "0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6"
pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
def call_eth_logger(wallet_address, pk, message: str):
transaction = contract.functions.logmessage(message).build_transaction(
{
"chainId": 17000,
"gas": 30000,
"gasPrice": chain.to_wei("10", "gwei"),
"nonce": chain.eth.get_transaction_count(wallet_address, "pending"),
}
)
signed_tx = chain.eth.account.sign_transaction(transaction, private_key=pk)
tx_hash = chain.eth.send_raw_transaction(signed_tx.raw_transaction)
print(tx_hash)
receipt = chain.eth.wait_for_transaction_receipt(tx_hash)
transfer_event = contract.events.messageLog().process_receipt(receipt)
for event in transfer_event:
print(event["args"]["msg"])

8
src/eth_logger_test.py Normal file
View File

@ -0,0 +1,8 @@
from eth_logger import call_eth_logger
wallet_address = (
"0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6" # 修改成要使用的钱包地址/私钥
)
wallet_pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
call_eth_logger(wallet_address, wallet_pk, "hello World")

10
src/logger.sol Normal file
View File

@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract logger{
event messageLog(string msg);
function logmessage(string memory text) public returns (bool){
emit messageLog(text);
return true;
}
}

View File

@ -1,25 +1,31 @@
from fastapi import FastAPI, Request, HTTPException
import requests
from contextlib import asynccontextmanager
import socket
import asyncio
from pydantic import BaseModel
from tpre import *
import os
from typing import Any, Tuple
import base64
import json
import logging
import os
import socket
import threading
import time
from contextlib import asynccontextmanager
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from eth_logger import call_eth_logger
from tpre import ReEncrypt, capsule
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(_: FastAPI):
init()
yield
clear()
message_list = []
app = FastAPI(lifespan=lifespan)
server_address = "http://60.204.236.38:8000/server"
server_address = os.environ.get("server_address")
id = 0
ip = ""
client_ip_src = "" # 发送信息用户的ip
@ -36,17 +42,21 @@ logger = logging.getLogger("uvicorn")
# 向中心服务器发送自己的IP地址,并获取自己的id
def send_ip():
url = server_address + "/get_node?ip=" + ip # type: ignore
url = f"http://{server_address}/server/get_node?ip={ip}" # 添加 http:// 协议
# ip = get_local_ip() # type: ignore
global id
id = requests.get(url, timeout=3)
logger.info(f"中心服务器返回节点ID为: {id}")
print("中心服务器返回节点ID为: ", id)
try:
response = requests.get(url, timeout=3)
response.raise_for_status() # 检查请求是否成功
data = response.json() # 将响应内容解析为 JSON 格式
global id
id = data.get("id") # 假设返回的 JSON 包含 id 字段
logger.info(f"中心服务器返回节点ID为: {id}")
except requests.exceptions.RequestException as e:
logger.error(f"请求中心服务器失败: {e}")
# 用环境变量获取本机ip
def get_local_ip():
global ip
def get_local_ip() -> str | None:
ip = os.environ.get("HOST_IP")
if not ip: # 如果环境变量中没有IP
try:
@ -55,12 +65,16 @@ def get_local_ip():
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
ip = str(s.getsockname()[0])
s.close()
except:
return ip
except IndexError:
raise ValueError("Unable to get IP")
else:
return ip
def init():
get_local_ip()
global ip
ip = get_local_ip()
send_ip()
asyncio.create_task(send_heartbeat_internal())
print("Finish init")
@ -76,12 +90,12 @@ def clear():
async def send_heartbeat_internal() -> None:
timeout = 30
global ip
url = server_address + "/heartbeat?ip=" + ip # type: ignore
url = f"http://{server_address}/server/heartbeat?ip={ip}" # 添加 http:// 协议
while True:
# print('successful send my_heart')
try:
requests.get(url, timeout=3)
except:
except requests.exceptions.RequestException:
logger.error("Central server error")
print("Central server error")
@ -119,6 +133,17 @@ async def user_src(message: Req):
capsule = message.capsule
ct = message.ct
payload = {
"source_ip": source_ip,
"dest_ip": dest_ip,
"capsule": capsule,
"ct": ct,
"rk": message.rk,
}
# 将消息详情记录到区块链
global message_list
message_list.append(payload)
byte_length = (ct.bit_length() + 7) // 8
capsule_ct = (capsule, ct.to_bytes(byte_length))
@ -139,7 +164,9 @@ async def user_src(message: Req):
return HTTPException(status_code=200, detail="message recieved")
async def send_user_des_message(source_ip: str, dest_ip: str, re_message): # 发送消息给用户2
async def send_user_des_message(
source_ip: str, dest_ip: str, re_message
): # 发送消息给用户2
data = {"Tuple": re_message, "ip": source_ip} # 类型不匹配
# 发送 HTTP POST 请求
@ -151,7 +178,24 @@ async def send_user_des_message(source_ip: str, dest_ip: str, re_message): #
print("send stauts:", response.text)
def log_message():
while True:
global message_list
payload = json.dumps(message_list)
message_list = []
call_eth_logger(wallet_address, wallet_pk, payload)
time.sleep(2)
wallet_address = (
"0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6" # 修改成要使用的钱包地址/私钥
)
wallet_pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
if __name__ == "__main__":
import uvicorn # pylint: disable=e0401
import uvicorn
# threading.Thread(target=log_message).start()
uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True, log_level="debug")

View File

2
src/pytest.ini Normal file
View File

@ -0,0 +1,2 @@
[pytest]
asyncio_default_fixture_loop_scope = function

View File

@ -1,27 +1,74 @@
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.websockets import WebSocketState
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import sqlite3
import asyncio
import time
import ipaddress
import logging
import os
import queue
app = FastAPI()
origins = [
"http://localhost:3000",
]
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
)
# 配置日志文件路径
log_dir = "logs"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_file = os.path.join(log_dir, "server_logs.log")
# 全局日志配置
logging.basicConfig(
level=logging.INFO, # 设置全局日志级别
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # 日志格式
handlers=[
logging.FileHandler(log_file, encoding="utf-8"), # 输出到日志文件
logging.StreamHandler(), # 输出到控制台
],
)
# 获取日志记录器
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(_: FastAPI):
init()
yield
clean_env()
# 获取当前文件所在的目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 定义 frontend 的绝对路径
frontend_dir = os.path.join(current_dir, "..", "frontend")
app = FastAPI(lifespan=lifespan)
app.mount("/frontend", StaticFiles(directory="frontend/build"), name="frontend")
def init():
asyncio.create_task(receive_heartbeat_internal())
init_db()
def init_db():
conn = sqlite3.connect("server.db")
cursor = conn.cursor()
@ -39,6 +86,9 @@ def clean_env():
clear_database()
# -----------------------------------------------------------------------------------------------
@app.get("/")
async def home():
return {"message": "Hello, World!"}
@ -54,10 +104,30 @@ async def show_nodes() -> list:
for row in rows:
nodes_list.append(row)
# TODO: use JSONResponse
logger.info("节点信息已成功获取")
return nodes_list
@app.get("/nodes", response_class=HTMLResponse)
async def get_nodes_page():
with open("frontend/public/index.html") as f:
return HTMLResponse(content=f.read(), status_code=200)
def validate_ip(ip: str) -> bool:
"""
Validate an IP address.
This function checks if the provided string is a valid IP address.
Both IPv4 and IPv6 are considered valid.
Args:
ip (str): The IP address to validate.
Returns:
bool: True if the IP address is valid, False otherwise.
"""
try:
ipaddress.ip_address(ip)
return True
@ -66,7 +136,7 @@ def validate_ip(ip: str) -> bool:
@app.get("/server/get_node")
async def get_node(ip: str) -> int:
async def get_node(ip: str) -> JSONResponse:
"""
中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点
params:
@ -82,11 +152,12 @@ async def get_node(ip: str) -> int:
ip_int = 0
for i in range(4):
ip_int += int(ip_parts[i]) << (24 - (8 * i))
print("IP", ip, "对应的ID为", ip_int)
logger.info(f"IP {ip} 对应的ID为 {ip_int}")
# 获取当前时间
current_time = int(time.time())
print("当前时间: ", current_time)
logger.info(f"当前时间: {current_time}")
with sqlite3.connect("server.db") as db:
# 插入数据
@ -96,17 +167,25 @@ async def get_node(ip: str) -> int:
)
db.commit()
return ip_int
# 使用 JSONResponse 返回节点ID和当前时间
logger.info(f"节点 {ip} 已成功添加到数据库")
content = {"id": ip_int, "current_time": current_time}
return JSONResponse(content, status_code=200)
# TODO: try to use @app.delete("/node")
@app.get("/server/delete_node")
async def delete_node(ip: str) -> None:
async def delete_node(ip: str):
"""
param:
ip: 待删除节点的ip地址
return:
None
Delete a node by ip.
Args:
ip (str): The ip of the node to be deleted.
"""
if not validate_ip(ip):
logger.warning(f"收到无效 IP 格式的删除请求: {ip}")
raise HTTPException(status_code=400, detail="Invalid IP format")
with sqlite3.connect("server.db") as db:
# 查询要删除的节点
@ -117,28 +196,44 @@ async def delete_node(ip: str) -> None:
# 执行删除操作
db.execute("DELETE FROM nodes WHERE ip=?", (ip,))
db.commit()
print(f"Node with IP {ip} deleted successfully.")
logger.info(f"节点 {ip} 已成功删除")
return {"message": f"Node with IP {ip} deleted successfully."}
else:
print(f"Node with IP {ip} not found.")
logger.warning(f"节点 {ip} 未找到")
raise HTTPException(status_code=404, detail=f"Node with IP {ip} not found.")
# 接收节点心跳包
@app.get("/server/heartbeat")
async def receive_heartbeat(ip: str):
"""
Receive a heartbeat from a node.
Args:
ip (str): The IP address of the node.
Returns:
JSONResponse: A message indicating the result of the operation.
"""
if not validate_ip(ip):
content = {"message": "invalid ip "}
content = {"message": "invalid ip format"}
logger.warning(f"收到无效 IP 格式的心跳包: {ip}")
return JSONResponse(content, status_code=400)
print("收到来自", ip, "的心跳包")
logger.info(f"收到来自 {ip} 的心跳包")
with sqlite3.connect("server.db") as db:
db.execute(
"UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)
)
return {"status": "received"}
logger.info(f"成功更新节点 {ip} 的心跳时间")
content = {"status": "received"}
return JSONResponse(content, status_code=200)
async def receive_heartbeat_internal():
timeout = 70
while 1:
timeout = 70
with sqlite3.connect("server.db") as db:
# 删除超时的节点
db.execute(
@ -165,22 +260,86 @@ async def send_nodes_list(count: int) -> list:
rows = cursor.fetchall()
for row in rows:
id, ip, last_heartbeat = row
# id, ip, last_heartbeat = row
_, ip, _ = row
nodes_list.append(ip)
print("收到来自客户端的节点列表请求...")
print(nodes_list)
logger.info(f"已成功发送 {count} 个节点信息")
return nodes_list
# @app.get("/server/clear_database")
def clear_database() -> None:
with sqlite3.connect("server.db") as db:
db.execute("DELETE FROM nodes")
db.commit()
logger.info("数据库已清空")
# WebSocket连接池
connected_clients = []
log_queue = queue.Queue() # 用于存储日志的队列
@app.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
await websocket.accept()
connected_clients.append(websocket) # 添加 WebSocket 客户端
try:
# 发送历史日志
while not log_queue.empty():
log_message = log_queue.get()
await websocket.send_json({"type": "log", "message": log_message})
# 实时日志发送
while True:
await asyncio.sleep(5) # 保证 WebSocket 持续连接
except Exception as e:
print(f"WebSocket connection closed with error: {e}")
finally:
if websocket in connected_clients:
connected_clients.remove(websocket)
await websocket.close()
class WebSocketLogHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created))
log_message = f"{timestamp} - {log_entry}"
log_queue.put(log_message) # 将日志消息放入队列
for client in connected_clients:
if client.application_state == WebSocketState.CONNECTED:
# 改为异步线程安全地发送日志
asyncio.run_coroutine_threadsafe(
self.safe_send_log(client, log_message), asyncio.get_event_loop()
)
async def safe_send_log(self, client, log_message):
try:
await client.send_json({"type": "log", "message": log_message})
except RuntimeError as e:
print(f"Error while sending log to {client.application_state}: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
finally:
if client in connected_clients:
connected_clients.remove(client)
# 捕获 FastAPI 和 Uvicorn 的日志
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.setLevel(logging.INFO)
# 捕获 FastAPI 的日志
fastapi_logger = logging.getLogger("fastapi")
fastapi_logger.setLevel(logging.INFO)
# 将日志输出到 WebSocket
uvicorn_logger.addHandler(WebSocketLogHandler())
fastapi_logger.addHandler(WebSocketLogHandler())
if __name__ == "__main__":
import uvicorn # pylint: disable=e0401
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)

View File

@ -1,29 +0,0 @@
import unittest
import sqlite3
import os
from server import *
class TestServer(unittest.TestCase):
def test_init_creates_table(self):
# 执行初始化函数
init_db()
conn = sqlite3.connect("server.db")
cursor = conn.cursor()
# 检查表是否被正确创建
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='nodes'"
)
tables = cursor.fetchall()
self.assertTrue(any("nodes" in table for table in tables))
# 关闭数据库连接
conn.close()
os.remove("server.db")
if __name__ == "__main__":
unittest.main()

View File

@ -1,14 +0,0 @@
from setuptools import setup, Extension
# 定义您的扩展
ext = Extension(
"tpreECC",
sources=["tpreECC.c"],
)
setup(
name="tpreECC",
version="1.0",
description="basic ECC written in C",
ext_modules=[ext],
)

View File

@ -1,7 +1,7 @@
from gmssl import * # pylint: disable = e0401
from typing import Tuple, Callable
from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT
from typing import Tuple
import random
import traceback
import ecc_rs
point = Tuple[int, int]
capsule = Tuple[point, point, int]
@ -29,23 +29,31 @@ sm2p256v1 = CurveFp(
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
point = Tuple[int, int]
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
def multiply(a: point, n: int) -> point:
N = sm2p256v1.N
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
def multiply(a: point, n: int, flag: int = 0) -> point:
if flag == 1:
result = ecc_rs.multiply(a, n)
return result
else:
N = sm2p256v1.N
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
def add(a: point, b: point) -> point:
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
def add(a: point, b: point, flag: int = 0) -> point:
if flag == 1:
result = ecc_rs.add(a, b)
return result
else:
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
def inv(a: int, n: int) -> int:
@ -209,12 +217,13 @@ def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]:
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602
enc_Data = sm4_enc.update(m)
enc_Data += sm4_enc.finish()
enc_message = (capsule, enc_Data)
enc_message = (capsule, bytes(enc_Data))
return enc_message
def Decapsulate(ska: int, capsule: capsule) -> int:
E, V, s = capsule
# E, V, s = capsule
E, V, _ = capsule
EVa = multiply(add(E, V), ska) # (E*V)^ska
K = KDF(EVa)
@ -233,7 +242,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
dec_Data = sm4_dec.update(enc_Data)
dec_Data += sm4_dec.finish()
return dec_Data
return bytes(dec_Data)
# GenerateRekey
@ -305,8 +314,9 @@ def GenerateReKey(
# 计算KF
KF = []
for i in range(N):
y = random.randint(0, sm2p256v1.N - 1)
Y = multiply(g, y)
# seems unused?
# y = random.randint(0, sm2p256v1.N - 1)
# Y = multiply(g, y)
s_x = hash5(id_tuple[i], D) # id需要设置
r_k = f(s_x, f_modulus, T)
U1 = multiply(U, r_k)
@ -344,8 +354,10 @@ def Checkcapsule(capsule: capsule) -> bool: # 验证胶囊的有效性
def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, point]:
id, rk, Xa, U1 = kFrag
E, V, s = capsule
# id, rk, Xa, U1 = kFrag
id, rk, Xa, _ = kFrag
# E, V, s = capsule
E, V, _ = capsule
if not Checkcapsule(capsule):
raise ValueError("Invalid capsule")
E1 = multiply(E, rk)
@ -369,7 +381,7 @@ def ReEncrypt(
# 将加密节点加密后产生的t个capsule,ct合并在一起,产生cfrags = {{capsule1,capsule2,...},ct}
def mergecfrag(cfrag_cts: list) -> list:
def MergeCFrag(cfrag_cts: list) -> list:
ct_list = []
cfrags_list = []
cfrags = []
@ -421,7 +433,9 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
Vk = multiply(Vlist[k], bis[k])
E2 = add(Ek, E2)
V2 = add(Vk, V2)
X_Ab = multiply(X_Alist[0], sk_B) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
X_Ab = multiply(
X_Alist[0], sk_B
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
d = hash3((X_Alist[0], pk_B, X_Ab))
EV = add(E2, V2) # E2 + V2
EVd = multiply(EV, d) # (E2 + V2)^d
@ -447,4 +461,4 @@ def DecryptFrags(sk_B: int, pk_B: point, pk_A: point, cfrags: list) -> bytes:
print(e)
print("key error")
dec_Data = b""
return dec_Data
return bytes(dec_Data)

View File

@ -1,284 +0,0 @@
#include <Python.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
// define TPRE Big Number
typedef uint64_t TPRE_BN[8]
// GF(p)
typedef TPRE_BN SM2_Fp;
// GF(n)
typedef TPRE_BN SM2_Fn;
// 定义一个结构体来表示雅各比坐标系的点
typedef struct
{
TPRE_BN X;
TPRE_BN Y;
TPRE_BN Z;
} JACOBIAN_POINT;
// 定义一个结构体来表示点
typedef struct
{
uint8_t x[32];
uint8_t y[32];
} TPRE_POINT;
const TPRE_BN SM2_P = {
0xffffffff,
0xffffffff,
0x00000000,
0xffffffff,
0xffffffff,
0xffffffff,
0xffffffff,
0xfffffffe,
};
const TPRE_BN SM2_A = {
0xfffffffc,
0xffffffff,
0x00000000,
0xffffffff,
0xffffffff,
0xffffffff,
0xffffffff,
0xfffffffe,
};
const TPRE_BN SM2_B = {
0x4d940e93,
0xddbcbd41,
0x15ab8f92,
0xf39789f5,
0xcf6509a7,
0x4d5a9e4b,
0x9d9f5e34,
0x28e9fa9e,
};
// 生成元GX, GY
const SM2_JACOBIAN_POINT _SM2_G = {
{
0x334c74c7,
0x715a4589,
0xf2660be1,
0x8fe30bbf,
0x6a39c994,
0x5f990446,
0x1f198119,
0x32c4ae2c,
},
{
0x2139f0a0,
0x02df32e5,
0xc62a4740,
0xd0a9877c,
0x6b692153,
0x59bdcee3,
0xf4f6779c,
0xbc3736a2,
},
{
1,
0,
0,
0,
0,
0,
0,
0,
},
};
const SM2_JACOBIAN_POINT *SM2_G = &_SM2_G;
const TPRE_BN SM2_N = {
0x39d54123,
0x53bbf409,
0x21c6052b,
0x7203df6b,
0xffffffff,
0xffffffff,
0xffffffff,
0xfffffffe,
};
// u = (p - 1)/4, u + 1 = (p + 1)/4
const TPRE_BN SM2_U_PLUS_ONE = {
0x00000000,
0x40000000,
0xc0000000,
0xffffffff,
0xffffffff,
0xffffffff,
0xbfffffff,
0x3fffffff,
};
const TPRE_BN SM2_ONE = {1, 0, 0, 0, 0, 0, 0, 0};
const TPRE_BN SM2_TWO = {2, 0, 0, 0, 0, 0, 0, 0};
const TPRE_BN SM2_THREE = {3, 0, 0, 0, 0, 0, 0, 0};
#define GETU32(p) \
((uint32_t)(p)[0] << 24 | \
(uint32_t)(p)[1] << 16 | \
(uint32_t)(p)[2] << 8 | \
(uint32_t)(p)[3])
// 点乘
static void multiply(TPRE_POINT r, const TPRE_POINT a, int64_t n)
{
Point result;
// ...实现乘法逻辑...
return result;
}
// 点加
static void add(TPRE_POINT *R, TPRE_POINT *P, TPRE_POINT *Q)
{
JACOBIAN_POINT P_;
JACOBIAN_POINT Q_;
jacobianPoint_from_bytes(&P_, (uint8_t *)P)
jacobianPoint_from_bytes(&Q_, (uint8_t *)Q)
jacobianPoint_add(&P_, &P_, &Q_);
jacobianPoint_to_bytes(&P_, (uint8_t *)R);
}
// 求逆
static void inv()
{
}
// jacobianPoint点加
static void jacobianPoint_add(JACOBIAN_POINT *R, const JACOBIAN_POINT *P, const JACOBIAN_POINT *Q)
{
const uint64_t *X1 = P->X;
const uint64_t *Y1 = P->Y;
const uint64_t *Z1 = P->Z;
const uint64_t *x2 = Q->X;
const uint64_t *y2 = Q->Y;
SM2_BN T1;
SM2_BN T2;
SM2_BN T3;
SM2_BN T4;
SM2_BN X3;
SM2_BN Y3;
SM2_BN Z3;
if (sm2_jacobian_point_is_at_infinity(Q))
{
sm2_jacobian_point_copy(R, P);
return;
}
if (sm2_jacobian_point_is_at_infinity(P))
{
sm2_jacobian_point_copy(R, Q);
return;
}
assert(sm2_bn_is_one(Q->Z));
sm2_fp_sqr(T1, Z1);
sm2_fp_mul(T2, T1, Z1);
sm2_fp_mul(T1, T1, x2);
sm2_fp_mul(T2, T2, y2);
sm2_fp_sub(T1, T1, X1);
sm2_fp_sub(T2, T2, Y1);
if (sm2_bn_is_zero(T1))
{
if (sm2_bn_is_zero(T2))
{
SM2_JACOBIAN_POINT _Q, *Q = &_Q;
sm2_jacobian_point_set_xy(Q, x2, y2);
sm2_jacobian_point_dbl(R, Q);
return;
}
else
{
sm2_jacobian_point_set_infinity(R);
return;
}
}
sm2_fp_mul(Z3, Z1, T1);
sm2_fp_sqr(T3, T1);
sm2_fp_mul(T4, T3, T1);
sm2_fp_mul(T3, T3, X1);
sm2_fp_dbl(T1, T3);
sm2_fp_sqr(X3, T2);
sm2_fp_sub(X3, X3, T1);
sm2_fp_sub(X3, X3, T4);
sm2_fp_sub(T3, T3, X3);
sm2_fp_mul(T3, T3, T2);
sm2_fp_mul(T4, T4, Y1);
sm2_fp_sub(Y3, T3, T4);
sm2_bn_copy(R->X, X3);
sm2_bn_copy(R->Y, Y3);
sm2_bn_copy(R->Z, Z3);
}
// bytes转jacobianPoint
static void jacobianPoint_from_bytes(JACOBIAN_POINT *P, const uint8_t in[64])
{
}
// jacobianPoint转bytes
static void jacobianPoint_to_bytes(JACOBIAN_POINT *P, const uint8_t in[64])
{
}
static void BN_from_bytes(TPRE_BN *r, const uint8_t in[32])
{
int i;
for (i = 7; i >= 0; i--)
{
r[i] = GETU32(in);
in += sizeof(uint32_t);
}
}
// 点乘的Python接口函数
static PyObject *py_multiply(PyObject *self, PyObject *args)
{
return
}
// 点加的Python接口函数
static PyObject *py_add(PyObject *self, PyObject *args)
{
return
}
// 求逆的Python接口函数
static PyObject *py_inv(PyObject *self, PyObject *args)
{
return
}
// 模块方法定义
static PyMethodDef MyMethods[] = {
{"multiply", py_multiply, METH_VARARGS, "Multiply a point on the sm2p256v1 curve"},
{"add", py_add, METH_VARARGS, "Add a point on thesm2p256v1 curve"},
{"inv", py_inv, METH_VARARGS, "Calculate an inv of a number"},
{NULL, NULL, 0, NULL} // 哨兵
};
// 模块定义
static struct PyModuleDef tpreECC = {
PyModuleDef_HEAD_INIT,
"tpreECC",
NULL, // 模块文档
-1,
MyMethods};
// 初始化模块
PyMODINIT_FUNC PyInit_tpre(void)
{
return PyModule_Create(&tpreECC);
}

View File

@ -1,57 +0,0 @@
from tpre import *
import unittest
class TestHash2(unittest.TestCase):
def setUp(self):
self.double_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash2(self.double_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash2(self.double_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash3(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash3(self.triple_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash3(self.triple_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash4(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
self.Zp = random.randint(0, sm2p256v1.N - 1)
def test_digest_type(self):
digest = hash4(self.triple_G, self.Zp)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash4(self.triple_G, self.Zp)
self.assertLess(digest, sm2p256v1.N)
if __name__ == "__main__":
unittest.main()

37
tests/ecc_speed_test.py Normal file
View File

@ -0,0 +1,37 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import add, multiply, sm2p256v1
import time
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust add 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python add 执行时间: {elapsed_time:.6f}")

View File

@ -1,4 +1,15 @@
from tpre import *
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time
N = 20
@ -52,7 +63,7 @@ for i in range(1, 10):
# 9
start_time = time.time()
cfrags = mergecfrag(cfrag_cts)
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time

View File

@ -1,3 +1,7 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import *
import time
@ -54,7 +58,7 @@ while total_time < 1:
# 9
start_time = time.time()
cfrags = mergecfrag(cfrag_cts)
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time

88
tests/node_test.py Normal file
View File

@ -0,0 +1,88 @@
# 测试 node.py 中的函数
import os
import unittest
from unittest.mock import patch, MagicMock, Mock
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
import node
class TestGetLocalIP(unittest.TestCase):
@patch.dict("os.environ", {"HOST_IP": "60.204.193.58"}) # 模拟设置 HOST_IP 环境变量
def test_get_ip_from_env(self):
# 调用被测函数
node.get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(node.ip, "60.204.193.58")
@patch("socket.socket") # Mock socket 连接行为
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
def test_get_ip_from_socket(self, mock_socket):
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
node.get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(node.ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@patch("requests.get") # Mock requests.get 调用
def test_send_ip(self, mock_get):
# 设置模拟返回的 HTTP 响应
mock_response = Mock()
mock_response.text = "node123" # 模拟返回的节点ID
mock_response.status_code = 200
mock_get.return_value = (
mock_response # 设置 requests.get() 的返回值为 mock_response
)
# 保存原始的全局 id 值
original_id = node.id
# 调用待测函数
node.send_ip()
# 确保 requests.get 被正确调用
expected_url = f"{node.server_address}/get_node?ip={node.ip}"
mock_get.assert_called_once_with(expected_url, timeout=3)
# 检查 id 是否被正确更新
self.assertIs(node.id, mock_response) # 检查 id 是否被修改
self.assertEqual(
node.id.text, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
node.init()
# 验证 get_local_ip 和 send_ip 被调用
mock_get_local_ip.assert_called_once()
mock_send_ip.assert_called_once()
# 确保 create_task 被调用来启动心跳包
mock_create_task.assert_called_once()
if __name__ == "__main__":
unittest.main()

194
tests/node_test5.py Normal file
View File

@ -0,0 +1,194 @@
# node_test剩下部分(有问题)
import os
import unittest
import pytest
from unittest.mock import patch, MagicMock, Mock, AsyncMock
import requests
import asyncio
import httpx
import respx
from fastapi.testclient import TestClient
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from node import (
app,
send_heartbeat_internal,
Req,
send_ip,
get_local_ip,
init,
clear,
send_user_des_message,
id,
)
client = TestClient(app)
server_address = "http://60.204.236.38:8000/server"
# ip = None # 初始化全局变量 ip
# id = None # 初始化全局变量 id
class TestGetLocalIP(unittest.TestCase):
def test_get_ip_from_env(self):
os.environ["HOST_IP"] = "60.204.193.58" # 模拟设置 HOST_IP 环境变量
ip = get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(ip, "60.204.193.58")
@patch("socket.socket") # Mock socket 连接行为
def test_get_ip_from_socket(self, mock_socket):
os.environ.pop("HOST_IP", None)
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
ip = get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@respx.mock
def test_send_ip(self):
global ip, id
ip = "60.204.193.58"
mock_url = f"{server_address}/get_node?ip={ip}"
respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123"))
# 调用待测函数
send_ip()
# 确保 requests.get 被正确调用
self.assertEqual(
id, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
init()
# 验证 get_local_ip 和 send_ip 被调用
mock_get_local_ip.assert_called_once()
mock_send_ip.assert_called_once()
# 确保 create_task 被调用来启动心跳包
mock_create_task.assert_called_once()
def test_clear(self):
# 调用 clear 函数
clear()
# 检查输出
self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_success():
global ip
ip = "60.204.193.58"
# 模拟心跳请求
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
return_value=httpx.Response(200)
)
# 模拟 requests.get 以避免实际请求
with patch("requests.get", return_value=httpx.Response(200)) as mock_get:
# 模拟 asyncio.sleep 以避免实际延迟
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
task = asyncio.create_task(send_heartbeat_internal())
await asyncio.sleep(0.1) # 允许任务运行一段时间
task.cancel() # 取消任务以停止无限循环
try:
await task # 确保任务被等待
except asyncio.CancelledError:
pass # 捕获取消错误
assert mock_get.called
assert mock_get.call_count > 0
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_failure():
global ip
ip = "60.204.193.58"
# 模拟心跳请求以引发异常
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
side_effect=httpx.RequestError("Central server error")
)
# 模拟 requests.get 以避免实际请求
with patch(
"requests.get", side_effect=httpx.RequestError("Central server error")
) as mock_get:
# 模拟 asyncio.sleep 以避免实际延迟
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
task = asyncio.create_task(send_heartbeat_internal())
await asyncio.sleep(0.1) # 允许任务运行一段时间
task.cancel() # 取消任务以停止无限循环
try:
await task # 确保任务被等待
except asyncio.CancelledError:
pass # 捕获取消错误
assert mock_get.called
assert mock_get.call_count > 0
def test_user_src():
# 模拟 ReEncrypt 函数
with patch(
"node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")
):
# 模拟 send_user_des_message 函数
with patch(
"node.send_user_des_message", new_callable=AsyncMock
) as mock_send_user_des_message:
message = {
"source_ip": "60.204.193.58",
"dest_ip": "60.204.193.59",
"capsule": (("x1", "y1"), ("x2", "y2"), 123),
"ct": 456,
"rk": ["rk1", "rk2"],
}
response = client.post("/user_src", json=message)
assert response.status_code == 200
assert response.json() == {"detail": "message received"}
mock_send_user_des_message.assert_called_once()
def test_send_user_des_message():
with respx.mock:
dest_ip = "60.204.193.59"
re_message = (("a", "b", "c", "d"), 123)
respx.post(f"http://{dest_ip}:8002/receive_messages").mock(
return_value=httpx.Response(200, json={"status": "success"})
)
response = requests.post(
f"http://{dest_ip}:8002/receive_messages",
json={"Tuple": re_message, "ip": "60.204.193.58"},
)
assert response.status_code == 200
assert response.json() == {"status": "success"}
if __name__ == "__main__":
unittest.main()

140
tests/server_test.py Normal file
View File

@ -0,0 +1,140 @@
import sqlite3
import pytest
from fastapi.testclient import TestClient
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from server import app, validate_ip
# 创建 TestClient 实例
client = TestClient(app)
# 准备测试数据库数据
def setup_db():
# 创建数据库并插入测试数据
with sqlite3.connect("server.db") as db:
db.execute(
"""
CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY,
ip TEXT NOT NULL,
last_heartbeat INTEGER NOT NULL
)
"""
)
db.execute(
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)"
)
db.execute(
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)"
)
db.commit()
# 清空数据库
def clear_db():
with sqlite3.connect("server.db") as db:
db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表
db.commit()
# 测试 IP 验证功能
def test_validate_ip():
assert validate_ip("192.168.0.1") is True
assert validate_ip("256.256.256.256") is False
assert validate_ip("::1") is True
assert validate_ip("invalid_ip") is False
# 测试首页路由
def test_home():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
# 测试 show_nodes 路由
def test_show_nodes():
setup_db()
response = client.get("/server/show_nodes")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0][1] == "192.168.0.1"
assert data[1][1] == "192.168.0.2"
# 测试 get_node 路由
def test_get_node():
# 确保数据库和表的存在
setup_db()
valid_ip = "192.168.0.3"
invalid_ip = "256.256.256.256"
# 测试有效的 IP 地址
response = client.get(f"/server/get_node?ip={valid_ip}")
assert response.status_code == 200
# 测试无效的 IP 地址
response = client.get(f"/server/get_node?ip={invalid_ip}")
assert response.status_code == 400
# 测试 delete_node 路由
def test_delete_node():
setup_db()
valid_ip = "192.168.0.1"
invalid_ip = "192.168.0.255"
response = client.get(f"/server/delete_node?ip={valid_ip}")
assert response.status_code == 200
assert "Node with IP 192.168.0.1 deleted successfully." in response.text
response = client.get(f"/server/delete_node?ip={invalid_ip}")
assert response.status_code == 404
# 测试 heartbeat 路由
def test_receive_heartbeat():
setup_db()
valid_ip = "192.168.0.2"
invalid_ip = "256.256.256.256"
response = client.get(f"/server/heartbeat?ip={valid_ip}")
assert response.status_code == 200
assert response.json() == {"status": "received"}
response = client.get(f"/server/heartbeat?ip={invalid_ip}")
assert response.status_code == 400
assert response.json() == {"message": "invalid ip format"}
# 测试 send_nodes_list 路由
def test_send_nodes_list():
setup_db()
response = client.get("/server/send_nodes_list?count=1")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0] == "192.168.0.1"
response = client.get("/server/send_nodes_list?count=2")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
# 运行完测试后清理数据库
@pytest.fixture(autouse=True)
def run_around_tests():
clear_db()
yield
clear_db()

View File

@ -1,4 +1,15 @@
from tpre import *
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time
N = 20
@ -50,7 +61,7 @@ print(f"重加密算法运行时间:{elapsed_time}秒")
# 9
start_time = time.time()
cfrags = mergecfrag(cfrag_cts)
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time

76
tests/test_client.py Normal file
View File

@ -0,0 +1,76 @@
import os
import pytest
from fastapi.testclient import TestClient
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from client import app, init_db, clean_env, get_own_ip
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
# 设置测试环境
init_db()
yield
# 清理测试环境
clean_env()
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
def test_receive_messages():
message = {"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), "ip": "127.0.0.1"}
response = client.post("/receive_messages", json=message)
assert response.status_code == 200
assert response.json().get("detail") == "Message received"
# @respx.mock
# def test_request_message():
# request_message = {
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
# "message_name": "name"
# }
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
# response = client.post("/request_message", json=request_message)
# assert response.status_code == 200
# assert "threshold" in response.json()
# assert "public_key" in response.json()
# @respx.mock
# def test_receive_request():
# ip_message = {
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
# "message_name": "name",
# "source_ip": "124.70.165.73", # 使用不同的 IP 地址
# "pk": (123, 456)
# }
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
# response = client.post("/receive_request", json=ip_message)
# assert response.status_code == 200
# assert "threshold" in response.json()
# assert "public_key" in response.json()
def test_get_pk():
response = client.get("/get_pk")
assert response.status_code == 200
assert "pkx" in response.json()
assert "pky" in response.json()
def test_recieve_pk():
pk_data = {"pkx": "123", "pky": "456", "ip": "127.0.0.1"}
response = client.post("/recieve_pk", json=pk_data)
assert response.status_code == 200
assert response.json() == {"message": "save pk in database"}
if __name__ == "__main__":
pytest.main()

158
tests/tpre_test.py Normal file
View File

@ -0,0 +1,158 @@
import sys
import os
import hashlib
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
hash2,
hash3,
hash4,
multiply,
g,
sm2p256v1,
GenerateKeyPair,
Encrypt,
Decrypt,
GenerateReKey,
Encapsulate,
ReEncrypt,
DecryptFrags,
MergeCFrag,
)
import random
import unittest
class TestHash2(unittest.TestCase):
def setUp(self):
self.double_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash2(self.double_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash2(self.double_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash3(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash3(self.triple_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash3(self.triple_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash4(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
self.Zp = random.randint(0, sm2p256v1.N - 1)
def test_digest_type(self):
digest = hash4(self.triple_G, self.Zp)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash4(self.triple_G, self.Zp)
self.assertLess(digest, sm2p256v1.N)
class TestGenerateKeyPair(unittest.TestCase):
def test_key_pair(self):
public_key, secret_key = GenerateKeyPair()
self.assertIsInstance(public_key, tuple)
self.assertEqual(len(public_key), 2)
self.assertIsInstance(secret_key, int)
self.assertLess(secret_key, sm2p256v1.N)
self.assertGreater(secret_key, 0)
# class TestEncryptDecrypt(unittest.TestCase):
# def setUp(self):
# self.public_key, self.secret_key = GenerateKeyPair()
# self.message = b"Hello, world!"
# def test_encrypt_decrypt(self):
# encrypted_message = Encrypt(self.public_key, self.message)
# # 使用 SHA-256 哈希函数确保密钥为 16 字节
# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest()
# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = Decrypt(secret_key_int, encrypted_message)
# self.assertEqual(decrypted_message, self.message)
class TestGenerateReKey(unittest.TestCase):
def test_generate_rekey(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
self.assertIsInstance(rekey, list)
self.assertEqual(len(rekey), 5)
class TestEncapsulate(unittest.TestCase):
def test_encapsulate(self):
pk_A, _ = GenerateKeyPair()
K, capsule = Encapsulate(pk_A)
self.assertIsInstance(K, int)
self.assertIsInstance(capsule, tuple)
self.assertEqual(len(capsule), 3)
class TestReEncrypt(unittest.TestCase):
def test_reencrypt(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
pk_A, _ = GenerateKeyPair()
message = b"Hello, world!"
encrypted_message = Encrypt(pk_A, message)
reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
self.assertIsInstance(reencrypted_message, tuple)
self.assertEqual(len(reencrypted_message), 2)
# class TestDecryptFrags(unittest.TestCase):
# def test_decrypt_frags(self):
# sk_A = random.randint(0, sm2p256v1.N - 1)
# pk_B, sk_B = GenerateKeyPair()
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
# pk_A, _ = GenerateKeyPair()
# message = b"Hello, world!"
# encrypted_message = Encrypt(pk_A, message)
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
# cfrags = [reencrypted_message]
# merged_cfrags = MergeCFrag(cfrags)
# self.assertIsNotNone(merged_cfrags)
# sk_B_hash = hashlib.sha256(sk_B.to_bytes((sk_B.bit_length() + 7) // 8, 'big')).digest()
# sk_B_int = int.from_bytes(sk_B_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = DecryptFrags(sk_B_int, pk_B, pk_A, merged_cfrags)
# self.assertEqual(decrypted_message, message)
if __name__ == "__main__":
unittest.main()