Compare commits
46 Commits
68d87c8cb7
...
c64ccd57e3
Author | SHA1 | Date | |
---|---|---|---|
c64ccd57e3 | |||
835c908ca7 | |||
fffa6728bc | |||
7fa5bb460c | |||
bf6e25b722 | |||
24a5e91382 | |||
2308b01cee | |||
04d44afac0 | |||
ff9a763cec | |||
0bb058df6e | |||
983ac91f74 | |||
81e34965e0 | |||
90a173dfa1 | |||
f69eecf58b | |||
bba12099cd | |||
a9a4e28d9a | |||
d071296126 | |||
2839dd5500 | |||
cb76380242 | |||
f5493ed8a1 | |||
0830a6733f | |||
ea8f17920a | |||
e79b24f909 | |||
15e35405f0 | |||
b26bd92328 | |||
5b85db2427 | |||
061bd5d2bf | |||
aa0eb6bfbc | |||
98e0e1122b | |||
7f1e201c33 | |||
2f7f55fd3a | |||
05de02f2a5 | |||
acbc9ecb1a | |||
1c18726066 | |||
e8e7c59579 | |||
6c240b1237 | |||
2d12e8c99c | |||
9654d8504b | |||
53928b7f9e | |||
6f19620490 | |||
ca253dbb77 | |||
68ed843777 | |||
d0a916afb2 | |||
3b2a728e14 | |||
82309c5505 | |||
8b89e1b722 |
@ -3,7 +3,7 @@ name: Test CI
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'src/**'
|
||||
- "src/**"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@ -14,10 +14,10 @@ jobs:
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Checkout repository
|
||||
uses: https://git.mamahaha.work/actions/checkout@v3
|
||||
|
||||
- name: Run script in Docker container
|
||||
run: |
|
||||
ls $PWD/src
|
||||
docker run --rm -v .:/app git.mamahaha.work/sangge/tpre:base ls
|
||||
# - name: Run script in Docker container
|
||||
# run: |
|
||||
# ls $PWD/src
|
||||
# docker run --rm -v .:/app git.mamahaha.work/sangge/tpre:base ls
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -14,3 +14,7 @@ build
|
||||
src/tpre.cpython-311-x86_64-linux-gnu.so
|
||||
.vscode
|
||||
|
||||
venv
|
||||
lib
|
||||
include
|
||||
/frontend/node_modules/
|
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,3 +1,6 @@
|
||||
[submodule "gmssl"]
|
||||
path = gmssl
|
||||
url = https://github.com/guanzhi/GmSSL.git
|
||||
[submodule "ecc_rs"]
|
||||
path = ecc_rs
|
||||
url = https://git.mamahaha.work/sangge/ecc_rs.git
|
||||
|
1
FQA.md
1
FQA.md
@ -59,3 +59,4 @@
|
||||
15. **Q:国密SM2 SM3 SM4** 如题
|
||||
|
||||
**A:**
|
||||
----------------------------------------------------------------------
|
||||
|
20
README.md
20
README.md
@ -30,9 +30,9 @@
|
||||
- Windows (需要自行安装gmssl的共享库)
|
||||
|
||||
该项目依赖以下软件:
|
||||
python 3.11
|
||||
gmssl
|
||||
gmssl-python
|
||||
python 3.12
|
||||
gmssl v3.1.1
|
||||
gmssl-python 2.2.2
|
||||
|
||||
### Docker 版本安装
|
||||
|
||||
@ -46,9 +46,9 @@ chmod +x install_docker.sh
|
||||
|
||||
docker 版本:
|
||||
|
||||
- 版本: 24.0.5
|
||||
- API 版本: 1.43
|
||||
- Go 版本: go1.20.6
|
||||
- 版本: 24.0.5
|
||||
- API 版本: 1.43
|
||||
- Go 版本: go1.20.6
|
||||
|
||||
## 安装步骤
|
||||
|
||||
@ -56,6 +56,14 @@ docker 版本:
|
||||
|
||||
本项目依赖gmssl,所以请提前安装好。访问 [GmSSL](https://github.com/guanzhi/GmSSL) 可以看到如何安装。
|
||||
|
||||
本项目也提供了submodule的方式,可以直接使用
|
||||
|
||||
```bash
|
||||
git clone --recurse-submodules https://git.mamahaha.work/sangge/tpre-python.git
|
||||
chmod +x install_gmssl.sh
|
||||
./install_docker.sh
|
||||
```
|
||||
|
||||
然后安装必要的python库
|
||||
|
||||
```bash
|
||||
|
12
README_en.md
12
README_en.md
@ -30,9 +30,9 @@ System requirements:
|
||||
|
||||
The project relies on the following software:
|
||||
|
||||
- Python 3.11
|
||||
- gmssl
|
||||
- gmssl-python
|
||||
- Python 3.12
|
||||
- gmssl v3.1.1
|
||||
- gmssl-python 2.2.2
|
||||
|
||||
### Docker installer
|
||||
|
||||
@ -46,9 +46,9 @@ chmod +x install_docker.sh
|
||||
|
||||
docker version:
|
||||
|
||||
- Version: 24.0.5
|
||||
- API version: 1.43
|
||||
- Go version: go1.20.6
|
||||
- Version: 24.0.5
|
||||
- API version: 1.43
|
||||
- Go version: go1.20.6
|
||||
|
||||
## Installation Steps
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
FROM python:3.11
|
||||
FROM python:3.12-slim
|
||||
|
||||
COPY requirements.txt /app/
|
||||
|
||||
@ -8,9 +8,12 @@ COPY requirements.txt /app/
|
||||
# 根据目标平台复制相应架构的库文件
|
||||
#COPY lib/${TARGETPLATFORM}/* /lib/
|
||||
|
||||
COPY lib/* /lib/
|
||||
COPY lib/* /usr/local/lib/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN pip install --index-url https://git.mamahaha.work/api/packages/sangge/pypi/simple/ ecc-rs
|
||||
|
||||
RUN ldconfig
|
@ -3,16 +3,21 @@
|
||||
## Run docker
|
||||
|
||||
```bash
|
||||
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=60.204.193.58 git.mamahaha.work/sangge/tpre:base bash
|
||||
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=119.3.125.234 git.mamahaha.work/sangge/tpre:base bash
|
||||
docker run -it -p 8000-8002:8000-8002 -v ~/mimajingsai/src:/app -e HOST_IP=124.70.165.73 git.mamahaha.work/sangge/tpre:base bash
|
||||
docker run -it -p 8000-8002:8000-8002 -v ~/tpre-python/src:/app -e HOST_IP=192.168.8.57 -e server_address=192.168.8.57:8000 git.mamahaha.work/sangge/tpre:base bash
|
||||
docker run -it -p 8000-8002:8000-8002 -v ~/tpre-python/src:/app -e HOST_IP=119.3.125.234 git.mamahaha.work/sangge/tpre:base bash
|
||||
docker run -it -p 8000-8002:8000-8002 -v ~/tpre-python/src:/app -e HOST_IP=124.70.165.73 git.mamahaha.work/sangge/tpre:base bash
|
||||
```
|
||||
|
||||
```bash
|
||||
tpre3: docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=60.204.233.103 git.mamahaha.work/sangge/tpre:base bash
|
||||
```
|
||||
## Deploy contract
|
||||
You should deploy the contract yourself in src/logger.sol using remix or any CLI-tools and replace the contract address in src/node.py with your actual address.
|
||||
|
||||
[Deployment document](https://remix-ide.readthedocs.io/zh-cn/latest/create_deploy.html)
|
||||
|
||||
## Start application
|
||||
You should replace the wallet address/privateKey in src/node.py with your own wallet address/privateKey.
|
||||
|
||||
```bash
|
||||
nohup python server.py &
|
||||
|
31
docker-compose.yml
Normal file
31
docker-compose.yml
Normal file
@ -0,0 +1,31 @@
|
||||
version: "3"
|
||||
services:
|
||||
server:
|
||||
image: git.mamahaha.work/sangge/tpre:base
|
||||
volumes:
|
||||
- ./src:/app
|
||||
environment:
|
||||
- server_address=http://server:8000
|
||||
entrypoint:
|
||||
- python
|
||||
- server.py
|
||||
|
||||
node:
|
||||
image: git.mamahaha.work/sangge/tpre:base
|
||||
volumes:
|
||||
- ./src:/app
|
||||
environment:
|
||||
- server_address=http://server:8000
|
||||
entrypoint:
|
||||
- python
|
||||
- node.py
|
||||
|
||||
client:
|
||||
image: git.mamahaha.work/sangge/tpre:base
|
||||
volumes:
|
||||
- ./src:/app
|
||||
environment:
|
||||
- server_address=http://server:8000
|
||||
entrypoint:
|
||||
- python
|
||||
- client.py
|
1
ecc_rs
Submodule
1
ecc_rs
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 880c34ce031158d3f27116b3e6d3a0a3748310db
|
23
frontend/.gitignore
vendored
Normal file
23
frontend/.gitignore
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.js
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
70
frontend/README.md
Normal file
70
frontend/README.md
Normal file
@ -0,0 +1,70 @@
|
||||
# Getting Started with Create React App
|
||||
|
||||
This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).
|
||||
|
||||
## Available Scripts
|
||||
|
||||
In the project directory, you can run:
|
||||
|
||||
### `npm start`
|
||||
|
||||
Runs the app in the development mode.\
|
||||
Open [http://localhost:3000](http://localhost:3000) to view it in your browser.
|
||||
|
||||
The page will reload when you make changes.\
|
||||
You may also see any lint errors in the console.
|
||||
|
||||
### `npm test`
|
||||
|
||||
Launches the test runner in the interactive watch mode.\
|
||||
See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information.
|
||||
|
||||
### `npm run build`
|
||||
|
||||
Builds the app for production to the `build` folder.\
|
||||
It correctly bundles React in production mode and optimizes the build for the best performance.
|
||||
|
||||
The build is minified and the filenames include the hashes.\
|
||||
Your app is ready to be deployed!
|
||||
|
||||
See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information.
|
||||
|
||||
### `npm run eject`
|
||||
|
||||
**Note: this is a one-way operation. Once you `eject`, you can't go back!**
|
||||
|
||||
If you aren't satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project.
|
||||
|
||||
Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you're on your own.
|
||||
|
||||
You don't have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn't feel obligated to use this feature. However we understand that this tool wouldn't be useful if you couldn't customize it when you are ready for it.
|
||||
|
||||
## Learn More
|
||||
|
||||
You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started).
|
||||
|
||||
To learn React, check out the [React documentation](https://reactjs.org/).
|
||||
|
||||
### Code Splitting
|
||||
|
||||
This section has moved here: [https://facebook.github.io/create-react-app/docs/code-splitting](https://facebook.github.io/create-react-app/docs/code-splitting)
|
||||
|
||||
### Analyzing the Bundle Size
|
||||
|
||||
This section has moved here: [https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size](https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size)
|
||||
|
||||
### Making a Progressive Web App
|
||||
|
||||
This section has moved here: [https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app](https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app)
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
This section has moved here: [https://facebook.github.io/create-react-app/docs/advanced-configuration](https://facebook.github.io/create-react-app/docs/advanced-configuration)
|
||||
|
||||
### Deployment
|
||||
|
||||
This section has moved here: [https://facebook.github.io/create-react-app/docs/deployment](https://facebook.github.io/create-react-app/docs/deployment)
|
||||
|
||||
### `npm run build` fails to minify
|
||||
|
||||
This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify)
|
19780
frontend/package-lock.json
generated
Normal file
19780
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
39
frontend/package.json
Normal file
39
frontend/package.json
Normal file
@ -0,0 +1,39 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"dependencies": {
|
||||
"@testing-library/jest-dom": "^5.17.0",
|
||||
"@testing-library/react": "^13.4.0",
|
||||
"@testing-library/user-event": "^13.5.0",
|
||||
"axios": "^1.7.7",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-scripts": "5.0.1",
|
||||
"web-vitals": "^2.1.4"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
"build": "react-scripts build",
|
||||
"test": "react-scripts test",
|
||||
"eject": "react-scripts eject"
|
||||
},
|
||||
"eslintConfig": {
|
||||
"extends": [
|
||||
"react-app",
|
||||
"react-app/jest"
|
||||
]
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
">0.2%",
|
||||
"not dead",
|
||||
"not op_mini all"
|
||||
],
|
||||
"development": [
|
||||
"last 1 chrome version",
|
||||
"last 1 firefox version",
|
||||
"last 1 safari version"
|
||||
]
|
||||
}
|
||||
}
|
BIN
frontend/public/favicon.ico
Normal file
BIN
frontend/public/favicon.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.8 KiB |
13
frontend/public/index.html
Normal file
13
frontend/public/index.html
Normal file
@ -0,0 +1,13 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>中央服务器路由</title>
|
||||
<link rel="stylesheet" href="./index.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script src="/frontend/static/js/bundle.js"></script> <!-- 通过构建工具引入打包后的JS文件 -->
|
||||
</body>
|
||||
</html>
|
25
frontend/public/manifest.json
Normal file
25
frontend/public/manifest.json
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
"short_name": "React App",
|
||||
"name": "Create React App Sample",
|
||||
"icons": [
|
||||
{
|
||||
"src": "favicon.ico",
|
||||
"sizes": "64x64 32x32 24x24 16x16",
|
||||
"type": "image/x-icon"
|
||||
},
|
||||
{
|
||||
"src": "logo192.png",
|
||||
"type": "image/png",
|
||||
"sizes": "192x192"
|
||||
},
|
||||
{
|
||||
"src": "logo512.png",
|
||||
"type": "image/png",
|
||||
"sizes": "512x512"
|
||||
}
|
||||
],
|
||||
"start_url": ".",
|
||||
"display": "standalone",
|
||||
"theme_color": "#000000",
|
||||
"background_color": "#ffffff"
|
||||
}
|
3
frontend/public/robots.txt
Normal file
3
frontend/public/robots.txt
Normal file
@ -0,0 +1,3 @@
|
||||
# https://www.robotstxt.org/robotstxt.html
|
||||
User-agent: *
|
||||
Disallow:
|
114
frontend/src/App.css
Normal file
114
frontend/src/App.css
Normal file
@ -0,0 +1,114 @@
|
||||
@keyframes glow-border {
|
||||
0% {
|
||||
box-shadow: inset 0 0 0 2px #00e6e6;
|
||||
}
|
||||
25% {
|
||||
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6;
|
||||
}
|
||||
50% {
|
||||
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6, 0 2px 0 0 #00e6e6;
|
||||
}
|
||||
75% {
|
||||
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6, 0 2px 0 0 #00e6e6, -2px 0 0 0 #00e6e6;
|
||||
}
|
||||
100% {
|
||||
box-shadow: inset 0 0 0 2px #00e6e6, 2px 0 0 0 #00e6e6, 0 2px 0 0 #00e6e6, -2px 0 0 0 #00e6e6, 0 -2px 0 0 #00e6e6;
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
background-color: #1e1e1e;
|
||||
color: #ffffff;
|
||||
font-family: Arial, sans-serif;
|
||||
}
|
||||
|
||||
.App-header {
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #eff4f0;
|
||||
font-size: 4em;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
h2 {
|
||||
color: #2196f3;
|
||||
font-size: 2em;
|
||||
margin-top: 20px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
h3 {
|
||||
color: #ff9800;
|
||||
font-size: 1.75em;
|
||||
margin-top: 15px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
h4 {
|
||||
color: #f44336;
|
||||
font-size: 1.5em;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
input, button {
|
||||
margin: 10px 0;
|
||||
padding: 10px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
input {
|
||||
width: 80%;
|
||||
max-width: 300px;
|
||||
}
|
||||
|
||||
button {
|
||||
background-color: #2196f3;
|
||||
color: #ffffff;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
background-color: #1976d2;
|
||||
}
|
||||
|
||||
section {
|
||||
background-color: #333333;
|
||||
padding: 20px;
|
||||
margin: 20px 0;
|
||||
border-radius: 10px;
|
||||
animation: glow-border 4s infinite;
|
||||
}
|
||||
|
||||
.log-info {
|
||||
background-color: #444444;
|
||||
color: #ffffff;
|
||||
padding: 10px;
|
||||
border-radius: 5px;
|
||||
animation: glow-border 4s infinite;
|
||||
}
|
||||
|
||||
ul {
|
||||
list-style-type: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
li {
|
||||
background-color: #444444;
|
||||
margin: 5px 0;
|
||||
padding: 10px;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.container {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.left-panel, .right-panel {
|
||||
width: 48%;
|
||||
}
|
118
frontend/src/App.js
Normal file
118
frontend/src/App.js
Normal file
@ -0,0 +1,118 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import axios from 'axios';
|
||||
import WebSocketComponent from './WebSocketComponent';
|
||||
import './App.css';
|
||||
|
||||
function App() {
|
||||
const [node, setNode] = useState(null);
|
||||
const [heartbeat, setHeartbeat] = useState(null);
|
||||
const [nodesList, setNodesList] = useState([]);
|
||||
const [ip, setIp] = useState('');
|
||||
const [count, setCount] = useState('');
|
||||
|
||||
const fetchNodes = async () => {
|
||||
try {
|
||||
const response = await axios.get('/server/show_nodes');
|
||||
setNodesList(response.data);
|
||||
} catch (error) {
|
||||
console.error('Error fetching nodes:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchNode = async (ip) => {
|
||||
try {
|
||||
const response = await axios.get('/server/get_node', { params: { ip } });
|
||||
setNode(response.data);
|
||||
} catch (error) {
|
||||
console.error('Error fetching node:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchHeartbeat = async (ip) => {
|
||||
try {
|
||||
const response = await axios.get('/server/heartbeat', { params: { ip } });
|
||||
setHeartbeat(response.data);
|
||||
} catch (error) {
|
||||
console.error('Error fetching heartbeat:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchNodesList = async (count) => {
|
||||
try {
|
||||
const response = await axios.get('/server/send_nodes_list', { params: { count } });
|
||||
setNodesList(response.data);
|
||||
} catch (error) {
|
||||
console.error('Error fetching nodes list:', error);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchNodes(); // 获取所有节点
|
||||
}, []);
|
||||
|
||||
const handleFetchNode = () => {
|
||||
fetchNode(ip); // 根据输入的 IP 获取单个节点
|
||||
};
|
||||
|
||||
const handleFetchHeartbeat = () => {
|
||||
fetchHeartbeat(ip); // 根据输入的 IP 获取心跳信息
|
||||
};
|
||||
|
||||
const handleFetchNodesList = () => {
|
||||
fetchNodesList(count); // 根据输入的数量获取节点列表
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="App">
|
||||
<header className="App-header">
|
||||
<h1 className="glow">The server</h1>
|
||||
<div className="container">
|
||||
<div className="left-panel">
|
||||
<section>
|
||||
<h2 className="glow">get node</h2>
|
||||
<input
|
||||
type="text"
|
||||
value={ip}
|
||||
onChange={(e) => setIp(e.target.value)}
|
||||
placeholder="Enter node ip"
|
||||
/>
|
||||
<button onClick={handleFetchNode}>send</button>
|
||||
{node ? <p>{JSON.stringify(node)}</p> : <p>here is nothing!</p>}
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h2 className="glow">heartbeat</h2>
|
||||
<button onClick={handleFetchHeartbeat}>Get heartbeat</button>
|
||||
{heartbeat ? <p>{JSON.stringify(heartbeat)}</p> : <p>here is nothing!</p>}
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h2 className="glow">nodes list</h2>
|
||||
<input
|
||||
type="number"
|
||||
value={count}
|
||||
onChange={(e) => setCount(e.target.value)}
|
||||
placeholder="Enter the number of nodes"
|
||||
/>
|
||||
<button onClick={handleFetchNodesList}>Get node list</button>
|
||||
{nodesList.length > 0 ? (
|
||||
<ul>
|
||||
{nodesList.map((node, index) => (
|
||||
<li key={index}>{JSON.stringify(node)}</li>
|
||||
))}
|
||||
</ul>
|
||||
) : (
|
||||
<p>here is nothing!</p>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
<div className="right-panel">
|
||||
<WebSocketComponent />
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
66
frontend/src/WebSocketComponent.js
Normal file
66
frontend/src/WebSocketComponent.js
Normal file
@ -0,0 +1,66 @@
|
||||
import React, { useEffect, useState, useRef, useCallback } from 'react';
|
||||
|
||||
const WebSocketComponent = () => {
|
||||
const [logs, setLogs] = useState([]);
|
||||
const wsRef = useRef(null);
|
||||
|
||||
const connectWebSocket = useCallback(() => {
|
||||
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
|
||||
return; // 如果连接已经打开,不再重新连接
|
||||
}
|
||||
|
||||
wsRef.current = new WebSocket('ws://localhost:8000/ws/logs');
|
||||
|
||||
wsRef.current.onopen = () => {
|
||||
console.log('WebSocket 连接成功');
|
||||
};
|
||||
|
||||
wsRef.current.onmessage = (event) => {
|
||||
setLogs((prevLogs) => [...prevLogs, event.data]); // 直接加入收到的消息
|
||||
};
|
||||
|
||||
wsRef.current.onerror = (error) => {
|
||||
console.error('WebSocket 错误: ', error);
|
||||
};
|
||||
|
||||
wsRef.current.onclose = () => {
|
||||
console.log('WebSocket 连接关闭,尝试重新连接...');
|
||||
// 确保 WebSocket 连接在关闭后再进行重连
|
||||
setTimeout(() => {
|
||||
connectWebSocket();
|
||||
}, 5000); // 延迟 5 秒再重连
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
connectWebSocket();
|
||||
return () => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
}
|
||||
};
|
||||
}, [connectWebSocket]);
|
||||
|
||||
useEffect(() => {
|
||||
const logContainer = document.querySelector('.log-info');
|
||||
if (logContainer) {
|
||||
logContainer.scrollTop = logContainer.scrollHeight;
|
||||
}
|
||||
}, [logs]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>The logs</h2>
|
||||
<div
|
||||
className="log-info"
|
||||
style={{ height: '550px', overflowY: 'scroll', backgroundColor: 'rgb(32, 28, 28)', padding: '10px' }}
|
||||
>
|
||||
{logs.map((log, index) => (
|
||||
<p key={index} style={{ margin: '5px 0' }}>{log}</p>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default WebSocketComponent;
|
11
frontend/src/index.css
Normal file
11
frontend/src/index.css
Normal file
@ -0,0 +1,11 @@
|
||||
/* src/index.css */
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
}
|
||||
|
||||
code {
|
||||
font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace;
|
||||
}
|
13
frontend/src/index.js
Normal file
13
frontend/src/index.js
Normal file
@ -0,0 +1,13 @@
|
||||
import React from 'react';
|
||||
import { createRoot } from 'react-dom/client';
|
||||
import App from './App';
|
||||
import './index.css';
|
||||
|
||||
const container = document.getElementById('root');
|
||||
const root = createRoot(container);
|
||||
|
||||
root.render(
|
||||
<React.StrictMode>
|
||||
<App />
|
||||
</React.StrictMode>
|
||||
);
|
10
frontend/src/setupProxy.js
Normal file
10
frontend/src/setupProxy.js
Normal file
@ -0,0 +1,10 @@
|
||||
const { createProxyMiddleware } = require('http-proxy-middleware');
|
||||
|
||||
module.exports = function(app) {
|
||||
app.use(
|
||||
createProxyMiddleware('/server', {
|
||||
target: 'http://localhost:8000',
|
||||
changeOrigin: true,
|
||||
})
|
||||
);
|
||||
};
|
2
gmssl
2
gmssl
@ -1 +1 @@
|
||||
Subproject commit 31efcb5d87f99fc50b448c181d0a0e59d7edea03
|
||||
Subproject commit d655c06b3a6b0fe8cff900f293bf0e5aac6eb0a2
|
11
install_gmssl.sh
Normal file → Executable file
11
install_gmssl.sh
Normal file → Executable file
@ -3,10 +3,13 @@
|
||||
mkdir lib
|
||||
mkdir include
|
||||
|
||||
cp gmssl/include include
|
||||
cp -r gmssl/include include
|
||||
|
||||
mkdir gmssl/build
|
||||
cd gmssl/build
|
||||
cd gmssl/build || exit
|
||||
cmake ..
|
||||
make
|
||||
cp bin/lib* ../../lib
|
||||
cd bin || exit
|
||||
cp libgmssl.so libgmssl.so.3 libgmssl.so.3.1 ../../../lib
|
||||
cp libsdf_dummy.so libsdf_dummy.so.3 libsdf_dummy.so.3.1 ../../../lib
|
||||
cp libskf_dummy.so libskf_dummy.so.3 libskf_dummy.so.3.1 ../../../lib
|
||||
sudo make install
|
||||
|
6
package-lock.json
generated
Normal file
6
package-lock.json
generated
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"name": "tpre-python",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
gmssl-python
|
||||
gmssl-python>=2.2.2,<3.0.0
|
||||
fastapi
|
||||
uvicorn
|
||||
requests
|
||||
web3
|
@ -2,14 +2,19 @@ from fastapi import FastAPI, HTTPException
|
||||
import requests
|
||||
import os
|
||||
from typing import Tuple
|
||||
from tpre import *
|
||||
from tpre import (
|
||||
GenerateKeyPair,
|
||||
Encrypt,
|
||||
DecryptFrags,
|
||||
GenerateReKey,
|
||||
MergeCFrag,
|
||||
point,
|
||||
)
|
||||
import sqlite3
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel
|
||||
import socket
|
||||
import random
|
||||
import time
|
||||
import base64
|
||||
import json
|
||||
import pickle
|
||||
from fastapi.responses import JSONResponse
|
||||
@ -18,12 +23,12 @@ import asyncio
|
||||
# 测试文本
|
||||
test_msessgaes = {
|
||||
"name": b"proxy re-encryption",
|
||||
"environment": b"distributed environment"
|
||||
"environment": b"distributed environment",
|
||||
}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(_: FastAPI):
|
||||
init()
|
||||
yield
|
||||
clean_env()
|
||||
@ -84,13 +89,8 @@ def init_db():
|
||||
|
||||
# load config from config file
|
||||
def init_config():
|
||||
import configparser
|
||||
|
||||
global server_address
|
||||
config = configparser.ConfigParser()
|
||||
config.read("client.ini")
|
||||
|
||||
server_address = config["settings"]["server_address"]
|
||||
server_address = os.environ.get("server_address")
|
||||
|
||||
|
||||
# execute on exit
|
||||
@ -200,7 +200,7 @@ def check_merge(ct: int, ip: str):
|
||||
try:
|
||||
pkx, pky = result[0] # result[0] = (pkx, pky)
|
||||
pk_sender = (int(pkx), int(pky))
|
||||
except:
|
||||
except IndexError:
|
||||
pk_sender, T = 0, -1
|
||||
|
||||
T = 2
|
||||
@ -212,7 +212,7 @@ def check_merge(ct: int, ip: str):
|
||||
byte_length = (ct.bit_length() + 7) // 8
|
||||
temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(byte_length)))
|
||||
|
||||
cfrags = mergecfrag(temp_cfrag_cts)
|
||||
cfrags = MergeCFrag(temp_cfrag_cts)
|
||||
|
||||
print("sk:", type(sk))
|
||||
print("pk:", type(pk))
|
||||
@ -371,7 +371,7 @@ async def receive_request(i_m: IP_Message):
|
||||
try:
|
||||
message = test_msessgaes[i_m.message_name]
|
||||
|
||||
except:
|
||||
except IndexError:
|
||||
message = b"hello world" + random.randbytes(8)
|
||||
print(f"Message to be send: {message}")
|
||||
|
||||
@ -391,7 +391,7 @@ def get_own_ip() -> str:
|
||||
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
|
||||
ip = s.getsockname()[0]
|
||||
s.close()
|
||||
except:
|
||||
except IndexError:
|
||||
raise ValueError("Unable to get IP")
|
||||
return str(ip)
|
||||
|
||||
@ -464,7 +464,7 @@ async def recieve_pk(pk: pk_model):
|
||||
|
||||
pk = (0, 0)
|
||||
sk = 0
|
||||
server_address = str
|
||||
server_address = os.environ.get("server_address")
|
||||
node_response = False
|
||||
message = bytes
|
||||
local_ip = get_own_ip()
|
||||
|
@ -4,14 +4,14 @@ import json
|
||||
|
||||
|
||||
def send_post_request(ip_addr, message_name):
|
||||
url = f"http://localhost:8002/request_message"
|
||||
url = "http://localhost:8002/request_message"
|
||||
data = {"dest_ip": ip_addr, "message_name": message_name}
|
||||
response = requests.post(url, json=data)
|
||||
return response.text
|
||||
|
||||
|
||||
def get_pk(ip_addr):
|
||||
url = f"http://" + ip_addr + ":8002/get_pk"
|
||||
url = "http://" + ip_addr + ":8002/get_pk"
|
||||
response = requests.get(url, timeout=1)
|
||||
print(response.text)
|
||||
json_pk = json.loads(response.text)
|
||||
@ -21,7 +21,6 @@ def get_pk(ip_addr):
|
||||
return response.text
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Send POST request to a specified IP.")
|
||||
parser.add_argument("ip_addr", help="IP address to send request to.")
|
||||
|
@ -64,7 +64,7 @@ for N in range(4, 21, 2):
|
||||
|
||||
# 9
|
||||
start_time = time.time()
|
||||
cfrags = mergecfrag(cfrag_cts)
|
||||
cfrags = MergeCFrag(cfrag_cts)
|
||||
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
|
||||
end_time = time.time()
|
||||
elapsed_time_dec = end_time - start_time
|
||||
|
63
src/eth_logger.py
Normal file
63
src/eth_logger.py
Normal file
@ -0,0 +1,63 @@
|
||||
from web3 import Web3
|
||||
import json
|
||||
|
||||
rpc_url = "https://ethereum-holesky-rpc.publicnode.com"
|
||||
chain = Web3(Web3.HTTPProvider(rpc_url))
|
||||
contract_address = "0x642C23F91bf8339044A00251BC09d1D98110C433"
|
||||
contract_abi = json.loads(
|
||||
"""[
|
||||
{
|
||||
"anonymous": false,
|
||||
"inputs": [
|
||||
{
|
||||
"indexed": false,
|
||||
"internalType": "string",
|
||||
"name": "msg",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"name": "messageLog",
|
||||
"type": "event"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "string",
|
||||
"name": "text",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"name": "logmessage",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
}
|
||||
]"""
|
||||
)
|
||||
contract = chain.eth.contract(address=contract_address, abi=contract_abi)
|
||||
wallet_address = "0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6"
|
||||
pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
|
||||
|
||||
|
||||
def call_eth_logger(wallet_address, pk, message: str):
|
||||
transaction = contract.functions.logmessage(message).build_transaction(
|
||||
{
|
||||
"chainId": 17000,
|
||||
"gas": 30000,
|
||||
"gasPrice": chain.to_wei("10", "gwei"),
|
||||
"nonce": chain.eth.get_transaction_count(wallet_address, "pending"),
|
||||
}
|
||||
)
|
||||
signed_tx = chain.eth.account.sign_transaction(transaction, private_key=pk)
|
||||
tx_hash = chain.eth.send_raw_transaction(signed_tx.raw_transaction)
|
||||
print(tx_hash)
|
||||
receipt = chain.eth.wait_for_transaction_receipt(tx_hash)
|
||||
transfer_event = contract.events.messageLog().process_receipt(receipt)
|
||||
for event in transfer_event:
|
||||
print(event["args"]["msg"])
|
8
src/eth_logger_test.py
Normal file
8
src/eth_logger_test.py
Normal file
@ -0,0 +1,8 @@
|
||||
from eth_logger import call_eth_logger
|
||||
|
||||
wallet_address = (
|
||||
"0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6" # 修改成要使用的钱包地址/私钥
|
||||
)
|
||||
wallet_pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
|
||||
|
||||
call_eth_logger(wallet_address, wallet_pk, "hello World")
|
10
src/logger.sol
Normal file
10
src/logger.sol
Normal file
@ -0,0 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
pragma solidity ^0.8.0;
|
||||
|
||||
contract logger{
|
||||
event messageLog(string msg);
|
||||
function logmessage(string memory text) public returns (bool){
|
||||
emit messageLog(text);
|
||||
return true;
|
||||
}
|
||||
}
|
92
src/node.py
92
src/node.py
@ -1,25 +1,31 @@
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
import requests
|
||||
from contextlib import asynccontextmanager
|
||||
import socket
|
||||
import asyncio
|
||||
from pydantic import BaseModel
|
||||
from tpre import *
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import requests
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from eth_logger import call_eth_logger
|
||||
from tpre import ReEncrypt, capsule
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(_: FastAPI):
|
||||
init()
|
||||
yield
|
||||
clear()
|
||||
|
||||
|
||||
message_list = []
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
server_address = "http://60.204.236.38:8000/server"
|
||||
server_address = os.environ.get("server_address")
|
||||
id = 0
|
||||
ip = ""
|
||||
client_ip_src = "" # 发送信息用户的ip
|
||||
@ -36,17 +42,21 @@ logger = logging.getLogger("uvicorn")
|
||||
|
||||
# 向中心服务器发送自己的IP地址,并获取自己的id
|
||||
def send_ip():
|
||||
url = server_address + "/get_node?ip=" + ip # type: ignore
|
||||
url = f"http://{server_address}/server/get_node?ip={ip}" # 添加 http:// 协议
|
||||
# ip = get_local_ip() # type: ignore
|
||||
global id
|
||||
id = requests.get(url, timeout=3)
|
||||
logger.info(f"中心服务器返回节点ID为: {id}")
|
||||
print("中心服务器返回节点ID为: ", id)
|
||||
try:
|
||||
response = requests.get(url, timeout=3)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
data = response.json() # 将响应内容解析为 JSON 格式
|
||||
global id
|
||||
id = data.get("id") # 假设返回的 JSON 包含 id 字段
|
||||
logger.info(f"中心服务器返回节点ID为: {id}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"请求中心服务器失败: {e}")
|
||||
|
||||
|
||||
# 用环境变量获取本机ip
|
||||
def get_local_ip():
|
||||
global ip
|
||||
def get_local_ip() -> str | None:
|
||||
ip = os.environ.get("HOST_IP")
|
||||
if not ip: # 如果环境变量中没有IP
|
||||
try:
|
||||
@ -55,12 +65,16 @@ def get_local_ip():
|
||||
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
|
||||
ip = str(s.getsockname()[0])
|
||||
s.close()
|
||||
except:
|
||||
return ip
|
||||
except IndexError:
|
||||
raise ValueError("Unable to get IP")
|
||||
else:
|
||||
return ip
|
||||
|
||||
|
||||
def init():
|
||||
get_local_ip()
|
||||
global ip
|
||||
ip = get_local_ip()
|
||||
send_ip()
|
||||
asyncio.create_task(send_heartbeat_internal())
|
||||
print("Finish init")
|
||||
@ -76,12 +90,12 @@ def clear():
|
||||
async def send_heartbeat_internal() -> None:
|
||||
timeout = 30
|
||||
global ip
|
||||
url = server_address + "/heartbeat?ip=" + ip # type: ignore
|
||||
url = f"http://{server_address}/server/heartbeat?ip={ip}" # 添加 http:// 协议
|
||||
while True:
|
||||
# print('successful send my_heart')
|
||||
try:
|
||||
requests.get(url, timeout=3)
|
||||
except:
|
||||
except requests.exceptions.RequestException:
|
||||
logger.error("Central server error")
|
||||
print("Central server error")
|
||||
|
||||
@ -119,6 +133,17 @@ async def user_src(message: Req):
|
||||
capsule = message.capsule
|
||||
ct = message.ct
|
||||
|
||||
payload = {
|
||||
"source_ip": source_ip,
|
||||
"dest_ip": dest_ip,
|
||||
"capsule": capsule,
|
||||
"ct": ct,
|
||||
"rk": message.rk,
|
||||
}
|
||||
# 将消息详情记录到区块链
|
||||
global message_list
|
||||
message_list.append(payload)
|
||||
|
||||
byte_length = (ct.bit_length() + 7) // 8
|
||||
capsule_ct = (capsule, ct.to_bytes(byte_length))
|
||||
|
||||
@ -139,7 +164,9 @@ async def user_src(message: Req):
|
||||
return HTTPException(status_code=200, detail="message recieved")
|
||||
|
||||
|
||||
async def send_user_des_message(source_ip: str, dest_ip: str, re_message): # 发送消息给用户2
|
||||
async def send_user_des_message(
|
||||
source_ip: str, dest_ip: str, re_message
|
||||
): # 发送消息给用户2
|
||||
data = {"Tuple": re_message, "ip": source_ip} # 类型不匹配
|
||||
|
||||
# 发送 HTTP POST 请求
|
||||
@ -151,7 +178,24 @@ async def send_user_des_message(source_ip: str, dest_ip: str, re_message): #
|
||||
print("send stauts:", response.text)
|
||||
|
||||
|
||||
def log_message():
|
||||
while True:
|
||||
global message_list
|
||||
payload = json.dumps(message_list)
|
||||
message_list = []
|
||||
call_eth_logger(wallet_address, wallet_pk, payload)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
wallet_address = (
|
||||
"0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6" # 修改成要使用的钱包地址/私钥
|
||||
)
|
||||
wallet_pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn # pylint: disable=e0401
|
||||
import uvicorn
|
||||
|
||||
# threading.Thread(target=log_message).start()
|
||||
|
||||
uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True, log_level="debug")
|
||||
|
2
src/pytest.ini
Normal file
2
src/pytest.ini
Normal file
@ -0,0 +1,2 @@
|
||||
[pytest]
|
||||
asyncio_default_fixture_loop_scope = function
|
207
src/server.py
207
src/server.py
@ -1,27 +1,74 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.websockets import WebSocketState
|
||||
from fastapi.responses import JSONResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import sqlite3
|
||||
import asyncio
|
||||
import time
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = [
|
||||
"http://localhost:3000",
|
||||
]
|
||||
|
||||
# 配置 CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # 允许所有方法
|
||||
allow_headers=["*"], # 允许所有头
|
||||
)
|
||||
|
||||
# 配置日志文件路径
|
||||
log_dir = "logs"
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
log_file = os.path.join(log_dir, "server_logs.log")
|
||||
|
||||
# 全局日志配置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # 设置全局日志级别
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # 日志格式
|
||||
handlers=[
|
||||
logging.FileHandler(log_file, encoding="utf-8"), # 输出到日志文件
|
||||
logging.StreamHandler(), # 输出到控制台
|
||||
],
|
||||
)
|
||||
|
||||
# 获取日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(_: FastAPI):
|
||||
init()
|
||||
yield
|
||||
clean_env()
|
||||
|
||||
|
||||
# 获取当前文件所在的目录
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 定义 frontend 的绝对路径
|
||||
frontend_dir = os.path.join(current_dir, "..", "frontend")
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.mount("/frontend", StaticFiles(directory="frontend/build"), name="frontend")
|
||||
|
||||
|
||||
def init():
|
||||
asyncio.create_task(receive_heartbeat_internal())
|
||||
init_db()
|
||||
|
||||
|
||||
def init_db():
|
||||
conn = sqlite3.connect("server.db")
|
||||
cursor = conn.cursor()
|
||||
@ -39,6 +86,9 @@ def clean_env():
|
||||
clear_database()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home():
|
||||
return {"message": "Hello, World!"}
|
||||
@ -54,10 +104,30 @@ async def show_nodes() -> list:
|
||||
|
||||
for row in rows:
|
||||
nodes_list.append(row)
|
||||
# TODO: use JSONResponse
|
||||
logger.info("节点信息已成功获取")
|
||||
return nodes_list
|
||||
|
||||
|
||||
@app.get("/nodes", response_class=HTMLResponse)
|
||||
async def get_nodes_page():
|
||||
with open("frontend/public/index.html") as f:
|
||||
return HTMLResponse(content=f.read(), status_code=200)
|
||||
|
||||
|
||||
def validate_ip(ip: str) -> bool:
|
||||
"""
|
||||
Validate an IP address.
|
||||
|
||||
This function checks if the provided string is a valid IP address.
|
||||
Both IPv4 and IPv6 are considered valid.
|
||||
|
||||
Args:
|
||||
ip (str): The IP address to validate.
|
||||
|
||||
Returns:
|
||||
bool: True if the IP address is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
@ -66,7 +136,7 @@ def validate_ip(ip: str) -> bool:
|
||||
|
||||
|
||||
@app.get("/server/get_node")
|
||||
async def get_node(ip: str) -> int:
|
||||
async def get_node(ip: str) -> JSONResponse:
|
||||
"""
|
||||
中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点
|
||||
params:
|
||||
@ -82,11 +152,12 @@ async def get_node(ip: str) -> int:
|
||||
ip_int = 0
|
||||
for i in range(4):
|
||||
ip_int += int(ip_parts[i]) << (24 - (8 * i))
|
||||
print("IP", ip, "对应的ID为", ip_int)
|
||||
|
||||
logger.info(f"IP {ip} 对应的ID为 {ip_int}")
|
||||
|
||||
# 获取当前时间
|
||||
current_time = int(time.time())
|
||||
print("当前时间: ", current_time)
|
||||
logger.info(f"当前时间: {current_time}")
|
||||
|
||||
with sqlite3.connect("server.db") as db:
|
||||
# 插入数据
|
||||
@ -96,17 +167,25 @@ async def get_node(ip: str) -> int:
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return ip_int
|
||||
# 使用 JSONResponse 返回节点ID和当前时间
|
||||
logger.info(f"节点 {ip} 已成功添加到数据库")
|
||||
content = {"id": ip_int, "current_time": current_time}
|
||||
return JSONResponse(content, status_code=200)
|
||||
|
||||
|
||||
# TODO: try to use @app.delete("/node")
|
||||
@app.get("/server/delete_node")
|
||||
async def delete_node(ip: str) -> None:
|
||||
async def delete_node(ip: str):
|
||||
"""
|
||||
param:
|
||||
ip: 待删除节点的ip地址
|
||||
return:
|
||||
None
|
||||
Delete a node by ip.
|
||||
|
||||
Args:
|
||||
ip (str): The ip of the node to be deleted.
|
||||
|
||||
"""
|
||||
if not validate_ip(ip):
|
||||
logger.warning(f"收到无效 IP 格式的删除请求: {ip}")
|
||||
raise HTTPException(status_code=400, detail="Invalid IP format")
|
||||
|
||||
with sqlite3.connect("server.db") as db:
|
||||
# 查询要删除的节点
|
||||
@ -117,28 +196,44 @@ async def delete_node(ip: str) -> None:
|
||||
# 执行删除操作
|
||||
db.execute("DELETE FROM nodes WHERE ip=?", (ip,))
|
||||
db.commit()
|
||||
print(f"Node with IP {ip} deleted successfully.")
|
||||
|
||||
logger.info(f"节点 {ip} 已成功删除")
|
||||
return {"message": f"Node with IP {ip} deleted successfully."}
|
||||
else:
|
||||
print(f"Node with IP {ip} not found.")
|
||||
logger.warning(f"节点 {ip} 未找到")
|
||||
raise HTTPException(status_code=404, detail=f"Node with IP {ip} not found.")
|
||||
|
||||
|
||||
# 接收节点心跳包
|
||||
@app.get("/server/heartbeat")
|
||||
async def receive_heartbeat(ip: str):
|
||||
"""
|
||||
Receive a heartbeat from a node.
|
||||
|
||||
Args:
|
||||
ip (str): The IP address of the node.
|
||||
|
||||
Returns:
|
||||
JSONResponse: A message indicating the result of the operation.
|
||||
"""
|
||||
if not validate_ip(ip):
|
||||
content = {"message": "invalid ip "}
|
||||
content = {"message": "invalid ip format"}
|
||||
logger.warning(f"收到无效 IP 格式的心跳包: {ip}")
|
||||
return JSONResponse(content, status_code=400)
|
||||
print("收到来自", ip, "的心跳包")
|
||||
logger.info(f"收到来自 {ip} 的心跳包")
|
||||
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute(
|
||||
"UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)
|
||||
)
|
||||
return {"status": "received"}
|
||||
logger.info(f"成功更新节点 {ip} 的心跳时间")
|
||||
content = {"status": "received"}
|
||||
return JSONResponse(content, status_code=200)
|
||||
|
||||
|
||||
async def receive_heartbeat_internal():
|
||||
timeout = 70
|
||||
while 1:
|
||||
timeout = 70
|
||||
with sqlite3.connect("server.db") as db:
|
||||
# 删除超时的节点
|
||||
db.execute(
|
||||
@ -165,22 +260,86 @@ async def send_nodes_list(count: int) -> list:
|
||||
rows = cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
id, ip, last_heartbeat = row
|
||||
# id, ip, last_heartbeat = row
|
||||
_, ip, _ = row
|
||||
nodes_list.append(ip)
|
||||
|
||||
print("收到来自客户端的节点列表请求...")
|
||||
print(nodes_list)
|
||||
logger.info(f"已成功发送 {count} 个节点信息")
|
||||
return nodes_list
|
||||
|
||||
|
||||
# @app.get("/server/clear_database")
|
||||
def clear_database() -> None:
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute("DELETE FROM nodes")
|
||||
db.commit()
|
||||
logger.info("数据库已清空")
|
||||
|
||||
|
||||
# WebSocket连接池
|
||||
connected_clients = []
|
||||
log_queue = queue.Queue() # 用于存储日志的队列
|
||||
|
||||
|
||||
@app.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
connected_clients.append(websocket) # 添加 WebSocket 客户端
|
||||
try:
|
||||
# 发送历史日志
|
||||
while not log_queue.empty():
|
||||
log_message = log_queue.get()
|
||||
await websocket.send_json({"type": "log", "message": log_message})
|
||||
|
||||
# 实时日志发送
|
||||
while True:
|
||||
await asyncio.sleep(5) # 保证 WebSocket 持续连接
|
||||
except Exception as e:
|
||||
print(f"WebSocket connection closed with error: {e}")
|
||||
finally:
|
||||
if websocket in connected_clients:
|
||||
connected_clients.remove(websocket)
|
||||
await websocket.close()
|
||||
|
||||
|
||||
class WebSocketLogHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
log_entry = self.format(record)
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created))
|
||||
log_message = f"{timestamp} - {log_entry}"
|
||||
log_queue.put(log_message) # 将日志消息放入队列
|
||||
for client in connected_clients:
|
||||
if client.application_state == WebSocketState.CONNECTED:
|
||||
# 改为异步线程安全地发送日志
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.safe_send_log(client, log_message), asyncio.get_event_loop()
|
||||
)
|
||||
|
||||
async def safe_send_log(self, client, log_message):
|
||||
try:
|
||||
await client.send_json({"type": "log", "message": log_message})
|
||||
except RuntimeError as e:
|
||||
print(f"Error while sending log to {client.application_state}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
finally:
|
||||
if client in connected_clients:
|
||||
connected_clients.remove(client)
|
||||
|
||||
|
||||
# 捕获 FastAPI 和 Uvicorn 的日志
|
||||
uvicorn_logger = logging.getLogger("uvicorn")
|
||||
uvicorn_logger.setLevel(logging.INFO)
|
||||
|
||||
# 捕获 FastAPI 的日志
|
||||
fastapi_logger = logging.getLogger("fastapi")
|
||||
fastapi_logger.setLevel(logging.INFO)
|
||||
|
||||
# 将日志输出到 WebSocket
|
||||
uvicorn_logger.addHandler(WebSocketLogHandler())
|
||||
fastapi_logger.addHandler(WebSocketLogHandler())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn # pylint: disable=e0401
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
|
||||
|
@ -1,29 +0,0 @@
|
||||
import unittest
|
||||
import sqlite3
|
||||
import os
|
||||
from server import *
|
||||
|
||||
|
||||
class TestServer(unittest.TestCase):
|
||||
def test_init_creates_table(self):
|
||||
# 执行初始化函数
|
||||
init_db()
|
||||
|
||||
conn = sqlite3.connect("server.db")
|
||||
cursor = conn.cursor()
|
||||
# 检查表是否被正确创建
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='nodes'"
|
||||
)
|
||||
tables = cursor.fetchall()
|
||||
|
||||
self.assertTrue(any("nodes" in table for table in tables))
|
||||
# 关闭数据库连接
|
||||
conn.close()
|
||||
os.remove("server.db")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
14
src/setup.py
14
src/setup.py
@ -1,14 +0,0 @@
|
||||
from setuptools import setup, Extension
|
||||
|
||||
# 定义您的扩展
|
||||
ext = Extension(
|
||||
"tpreECC",
|
||||
sources=["tpreECC.c"],
|
||||
)
|
||||
|
||||
setup(
|
||||
name="tpreECC",
|
||||
version="1.0",
|
||||
description="basic ECC written in C",
|
||||
ext_modules=[ext],
|
||||
)
|
60
src/tpre.py
60
src/tpre.py
@ -1,7 +1,7 @@
|
||||
from gmssl import * # pylint: disable = e0401
|
||||
from typing import Tuple, Callable
|
||||
from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT
|
||||
from typing import Tuple
|
||||
import random
|
||||
import traceback
|
||||
import ecc_rs
|
||||
|
||||
point = Tuple[int, int]
|
||||
capsule = Tuple[point, point, int]
|
||||
@ -29,23 +29,31 @@ sm2p256v1 = CurveFp(
|
||||
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
|
||||
)
|
||||
|
||||
point = Tuple[int, int]
|
||||
|
||||
# 生成元
|
||||
g = (sm2p256v1.Gx, sm2p256v1.Gy)
|
||||
|
||||
|
||||
def multiply(a: point, n: int) -> point:
|
||||
N = sm2p256v1.N
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
|
||||
def multiply(a: point, n: int, flag: int = 0) -> point:
|
||||
if flag == 1:
|
||||
|
||||
result = ecc_rs.multiply(a, n)
|
||||
return result
|
||||
else:
|
||||
N = sm2p256v1.N
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
|
||||
|
||||
|
||||
def add(a: point, b: point) -> point:
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
|
||||
def add(a: point, b: point, flag: int = 0) -> point:
|
||||
if flag == 1:
|
||||
result = ecc_rs.add(a, b)
|
||||
return result
|
||||
else:
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
|
||||
|
||||
|
||||
def inv(a: int, n: int) -> int:
|
||||
@ -209,12 +217,13 @@ def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]:
|
||||
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602
|
||||
enc_Data = sm4_enc.update(m)
|
||||
enc_Data += sm4_enc.finish()
|
||||
enc_message = (capsule, enc_Data)
|
||||
enc_message = (capsule, bytes(enc_Data))
|
||||
return enc_message
|
||||
|
||||
|
||||
def Decapsulate(ska: int, capsule: capsule) -> int:
|
||||
E, V, s = capsule
|
||||
# E, V, s = capsule
|
||||
E, V, _ = capsule
|
||||
EVa = multiply(add(E, V), ska) # (E*V)^ska
|
||||
K = KDF(EVa)
|
||||
|
||||
@ -233,7 +242,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
|
||||
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
|
||||
dec_Data = sm4_dec.update(enc_Data)
|
||||
dec_Data += sm4_dec.finish()
|
||||
return dec_Data
|
||||
return bytes(dec_Data)
|
||||
|
||||
|
||||
# GenerateRekey
|
||||
@ -305,8 +314,9 @@ def GenerateReKey(
|
||||
# 计算KF
|
||||
KF = []
|
||||
for i in range(N):
|
||||
y = random.randint(0, sm2p256v1.N - 1)
|
||||
Y = multiply(g, y)
|
||||
# seems unused?
|
||||
# y = random.randint(0, sm2p256v1.N - 1)
|
||||
# Y = multiply(g, y)
|
||||
s_x = hash5(id_tuple[i], D) # id需要设置
|
||||
r_k = f(s_x, f_modulus, T)
|
||||
U1 = multiply(U, r_k)
|
||||
@ -344,8 +354,10 @@ def Checkcapsule(capsule: capsule) -> bool: # 验证胶囊的有效性
|
||||
|
||||
|
||||
def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, point]:
|
||||
id, rk, Xa, U1 = kFrag
|
||||
E, V, s = capsule
|
||||
# id, rk, Xa, U1 = kFrag
|
||||
id, rk, Xa, _ = kFrag
|
||||
# E, V, s = capsule
|
||||
E, V, _ = capsule
|
||||
if not Checkcapsule(capsule):
|
||||
raise ValueError("Invalid capsule")
|
||||
E1 = multiply(E, rk)
|
||||
@ -369,7 +381,7 @@ def ReEncrypt(
|
||||
|
||||
|
||||
# 将加密节点加密后产生的t个(capsule,ct)合并在一起,产生cfrags = {{capsule1,capsule2,...},ct}
|
||||
def mergecfrag(cfrag_cts: list) -> list:
|
||||
def MergeCFrag(cfrag_cts: list) -> list:
|
||||
ct_list = []
|
||||
cfrags_list = []
|
||||
cfrags = []
|
||||
@ -421,7 +433,9 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
|
||||
Vk = multiply(Vlist[k], bis[k])
|
||||
E2 = add(Ek, E2)
|
||||
V2 = add(Vk, V2)
|
||||
X_Ab = multiply(X_Alist[0], sk_B) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
|
||||
X_Ab = multiply(
|
||||
X_Alist[0], sk_B
|
||||
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
|
||||
d = hash3((X_Alist[0], pk_B, X_Ab))
|
||||
EV = add(E2, V2) # E2 + V2
|
||||
EVd = multiply(EV, d) # (E2 + V2)^d
|
||||
@ -447,4 +461,4 @@ def DecryptFrags(sk_B: int, pk_B: point, pk_A: point, cfrags: list) -> bytes:
|
||||
print(e)
|
||||
print("key error")
|
||||
dec_Data = b""
|
||||
return dec_Data
|
||||
return bytes(dec_Data)
|
||||
|
284
src/tpreECC.c
284
src/tpreECC.c
@ -1,284 +0,0 @@
|
||||
#include <Python.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// define TPRE Big Number
|
||||
typedef uint64_t TPRE_BN[8]
|
||||
|
||||
// GF(p)
|
||||
typedef TPRE_BN SM2_Fp;
|
||||
|
||||
// GF(n)
|
||||
typedef TPRE_BN SM2_Fn;
|
||||
|
||||
// 定义一个结构体来表示雅各比坐标系的点
|
||||
typedef struct
|
||||
{
|
||||
TPRE_BN X;
|
||||
TPRE_BN Y;
|
||||
TPRE_BN Z;
|
||||
} JACOBIAN_POINT;
|
||||
|
||||
// 定义一个结构体来表示点
|
||||
typedef struct
|
||||
{
|
||||
uint8_t x[32];
|
||||
uint8_t y[32];
|
||||
} TPRE_POINT;
|
||||
|
||||
const TPRE_BN SM2_P = {
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0x00000000,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xfffffffe,
|
||||
};
|
||||
|
||||
const TPRE_BN SM2_A = {
|
||||
0xfffffffc,
|
||||
0xffffffff,
|
||||
0x00000000,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xfffffffe,
|
||||
};
|
||||
|
||||
const TPRE_BN SM2_B = {
|
||||
0x4d940e93,
|
||||
0xddbcbd41,
|
||||
0x15ab8f92,
|
||||
0xf39789f5,
|
||||
0xcf6509a7,
|
||||
0x4d5a9e4b,
|
||||
0x9d9f5e34,
|
||||
0x28e9fa9e,
|
||||
};
|
||||
|
||||
// 生成元GX, GY
|
||||
const SM2_JACOBIAN_POINT _SM2_G = {
|
||||
{
|
||||
0x334c74c7,
|
||||
0x715a4589,
|
||||
0xf2660be1,
|
||||
0x8fe30bbf,
|
||||
0x6a39c994,
|
||||
0x5f990446,
|
||||
0x1f198119,
|
||||
0x32c4ae2c,
|
||||
},
|
||||
{
|
||||
0x2139f0a0,
|
||||
0x02df32e5,
|
||||
0xc62a4740,
|
||||
0xd0a9877c,
|
||||
0x6b692153,
|
||||
0x59bdcee3,
|
||||
0xf4f6779c,
|
||||
0xbc3736a2,
|
||||
},
|
||||
{
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
},
|
||||
};
|
||||
|
||||
const SM2_JACOBIAN_POINT *SM2_G = &_SM2_G;
|
||||
|
||||
const TPRE_BN SM2_N = {
|
||||
0x39d54123,
|
||||
0x53bbf409,
|
||||
0x21c6052b,
|
||||
0x7203df6b,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xfffffffe,
|
||||
};
|
||||
|
||||
// u = (p - 1)/4, u + 1 = (p + 1)/4
|
||||
const TPRE_BN SM2_U_PLUS_ONE = {
|
||||
0x00000000,
|
||||
0x40000000,
|
||||
0xc0000000,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xffffffff,
|
||||
0xbfffffff,
|
||||
0x3fffffff,
|
||||
};
|
||||
|
||||
const TPRE_BN SM2_ONE = {1, 0, 0, 0, 0, 0, 0, 0};
|
||||
const TPRE_BN SM2_TWO = {2, 0, 0, 0, 0, 0, 0, 0};
|
||||
const TPRE_BN SM2_THREE = {3, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
#define GETU32(p) \
|
||||
((uint32_t)(p)[0] << 24 | \
|
||||
(uint32_t)(p)[1] << 16 | \
|
||||
(uint32_t)(p)[2] << 8 | \
|
||||
(uint32_t)(p)[3])
|
||||
|
||||
// 点乘
|
||||
static void multiply(TPRE_POINT r, const TPRE_POINT a, int64_t n)
|
||||
{
|
||||
Point result;
|
||||
// ...实现乘法逻辑...
|
||||
return result;
|
||||
}
|
||||
|
||||
// 点加
|
||||
static void add(TPRE_POINT *R, TPRE_POINT *P, TPRE_POINT *Q)
|
||||
{
|
||||
JACOBIAN_POINT P_;
|
||||
JACOBIAN_POINT Q_;
|
||||
|
||||
jacobianPoint_from_bytes(&P_, (uint8_t *)P)
|
||||
jacobianPoint_from_bytes(&Q_, (uint8_t *)Q)
|
||||
jacobianPoint_add(&P_, &P_, &Q_);
|
||||
jacobianPoint_to_bytes(&P_, (uint8_t *)R);
|
||||
}
|
||||
|
||||
// 求逆
|
||||
static void inv()
|
||||
{
|
||||
}
|
||||
|
||||
// jacobianPoint点加
|
||||
static void jacobianPoint_add(JACOBIAN_POINT *R, const JACOBIAN_POINT *P, const JACOBIAN_POINT *Q)
|
||||
{
|
||||
const uint64_t *X1 = P->X;
|
||||
const uint64_t *Y1 = P->Y;
|
||||
const uint64_t *Z1 = P->Z;
|
||||
const uint64_t *x2 = Q->X;
|
||||
const uint64_t *y2 = Q->Y;
|
||||
SM2_BN T1;
|
||||
SM2_BN T2;
|
||||
SM2_BN T3;
|
||||
SM2_BN T4;
|
||||
SM2_BN X3;
|
||||
SM2_BN Y3;
|
||||
SM2_BN Z3;
|
||||
if (sm2_jacobian_point_is_at_infinity(Q))
|
||||
{
|
||||
sm2_jacobian_point_copy(R, P);
|
||||
return;
|
||||
}
|
||||
|
||||
if (sm2_jacobian_point_is_at_infinity(P))
|
||||
{
|
||||
sm2_jacobian_point_copy(R, Q);
|
||||
return;
|
||||
}
|
||||
|
||||
assert(sm2_bn_is_one(Q->Z));
|
||||
|
||||
sm2_fp_sqr(T1, Z1);
|
||||
sm2_fp_mul(T2, T1, Z1);
|
||||
sm2_fp_mul(T1, T1, x2);
|
||||
sm2_fp_mul(T2, T2, y2);
|
||||
sm2_fp_sub(T1, T1, X1);
|
||||
sm2_fp_sub(T2, T2, Y1);
|
||||
if (sm2_bn_is_zero(T1))
|
||||
{
|
||||
if (sm2_bn_is_zero(T2))
|
||||
{
|
||||
SM2_JACOBIAN_POINT _Q, *Q = &_Q;
|
||||
sm2_jacobian_point_set_xy(Q, x2, y2);
|
||||
|
||||
sm2_jacobian_point_dbl(R, Q);
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
sm2_jacobian_point_set_infinity(R);
|
||||
return;
|
||||
}
|
||||
}
|
||||
sm2_fp_mul(Z3, Z1, T1);
|
||||
sm2_fp_sqr(T3, T1);
|
||||
sm2_fp_mul(T4, T3, T1);
|
||||
sm2_fp_mul(T3, T3, X1);
|
||||
sm2_fp_dbl(T1, T3);
|
||||
sm2_fp_sqr(X3, T2);
|
||||
sm2_fp_sub(X3, X3, T1);
|
||||
sm2_fp_sub(X3, X3, T4);
|
||||
sm2_fp_sub(T3, T3, X3);
|
||||
sm2_fp_mul(T3, T3, T2);
|
||||
sm2_fp_mul(T4, T4, Y1);
|
||||
sm2_fp_sub(Y3, T3, T4);
|
||||
|
||||
sm2_bn_copy(R->X, X3);
|
||||
sm2_bn_copy(R->Y, Y3);
|
||||
sm2_bn_copy(R->Z, Z3);
|
||||
}
|
||||
|
||||
// bytes转jacobianPoint
|
||||
static void jacobianPoint_from_bytes(JACOBIAN_POINT *P, const uint8_t in[64])
|
||||
{
|
||||
}
|
||||
|
||||
// jacobianPoint转bytes
|
||||
static void jacobianPoint_to_bytes(JACOBIAN_POINT *P, const uint8_t in[64])
|
||||
{
|
||||
}
|
||||
|
||||
static void BN_from_bytes(TPRE_BN *r, const uint8_t in[32])
|
||||
{
|
||||
int i;
|
||||
for (i = 7; i >= 0; i--)
|
||||
{
|
||||
r[i] = GETU32(in);
|
||||
in += sizeof(uint32_t);
|
||||
}
|
||||
}
|
||||
|
||||
// 点乘的Python接口函数
|
||||
static PyObject *py_multiply(PyObject *self, PyObject *args)
|
||||
{
|
||||
return
|
||||
}
|
||||
|
||||
// 点加的Python接口函数
|
||||
static PyObject *py_add(PyObject *self, PyObject *args)
|
||||
{
|
||||
return
|
||||
}
|
||||
|
||||
// 求逆的Python接口函数
|
||||
static PyObject *py_inv(PyObject *self, PyObject *args)
|
||||
{
|
||||
return
|
||||
}
|
||||
|
||||
// 模块方法定义
|
||||
static PyMethodDef MyMethods[] = {
|
||||
{"multiply", py_multiply, METH_VARARGS, "Multiply a point on the sm2p256v1 curve"},
|
||||
{"add", py_add, METH_VARARGS, "Add a point on thesm2p256v1 curve"},
|
||||
{"inv", py_inv, METH_VARARGS, "Calculate an inv of a number"},
|
||||
{NULL, NULL, 0, NULL} // 哨兵
|
||||
};
|
||||
|
||||
// 模块定义
|
||||
static struct PyModuleDef tpreECC = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"tpreECC",
|
||||
NULL, // 模块文档
|
||||
-1,
|
||||
MyMethods};
|
||||
|
||||
// 初始化模块
|
||||
PyMODINIT_FUNC PyInit_tpre(void)
|
||||
{
|
||||
return PyModule_Create(&tpreECC);
|
||||
}
|
@ -1,57 +0,0 @@
|
||||
from tpre import *
|
||||
import unittest
|
||||
|
||||
|
||||
class TestHash2(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.double_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash3(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash4(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
self.Zp = random.randint(0, sm2p256v1.N - 1)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
37
tests/ecc_speed_test.py
Normal file
37
tests/ecc_speed_test.py
Normal file
@ -0,0 +1,37 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import add, multiply, sm2p256v1
|
||||
import time
|
||||
|
||||
# 生成元
|
||||
g = (sm2p256v1.Gx, sm2p256v1.Gy)
|
||||
|
||||
start_time = time.time() # 获取开始时间
|
||||
for i in range(10):
|
||||
result = multiply(g, 10000, 1) # 执行函数
|
||||
end_time = time.time() # 获取结束时间
|
||||
elapsed_time = end_time - start_time # 计算执行时间
|
||||
print(f"rust multiply 执行时间: {elapsed_time:.6f} 秒")
|
||||
|
||||
start_time = time.time() # 获取开始时间
|
||||
for i in range(10):
|
||||
result = multiply(g, 10000, 0) # 执行函数
|
||||
end_time = time.time() # 获取结束时间
|
||||
elapsed_time = end_time - start_time # 计算执行时间
|
||||
print(f"python multiply 执行时间: {elapsed_time:.6f} 秒")
|
||||
|
||||
start_time = time.time() # 获取开始时间
|
||||
for i in range(10):
|
||||
result = add(g, g, 1) # 执行函数
|
||||
end_time = time.time() # 获取结束时间
|
||||
elapsed_time = end_time - start_time # 计算执行时间
|
||||
print(f"rust add 执行时间: {elapsed_time:.6f} 秒")
|
||||
|
||||
start_time = time.time() # 获取开始时间
|
||||
for i in range(10):
|
||||
result = add(g, g, 0) # 执行函数
|
||||
end_time = time.time() # 获取结束时间
|
||||
elapsed_time = end_time - start_time # 计算执行时间
|
||||
print(f"python add 执行时间: {elapsed_time:.6f} 秒")
|
@ -1,4 +1,15 @@
|
||||
from tpre import *
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import (
|
||||
GenerateKeyPair,
|
||||
Encrypt,
|
||||
GenerateReKey,
|
||||
ReEncrypt,
|
||||
MergeCFrag,
|
||||
DecryptFrags,
|
||||
)
|
||||
import time
|
||||
|
||||
N = 20
|
||||
@ -52,7 +63,7 @@ for i in range(1, 10):
|
||||
|
||||
# 9
|
||||
start_time = time.time()
|
||||
cfrags = mergecfrag(cfrag_cts)
|
||||
cfrags = MergeCFrag(cfrag_cts)
|
||||
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
@ -1,3 +1,7 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import *
|
||||
import time
|
||||
|
||||
@ -54,7 +58,7 @@ while total_time < 1:
|
||||
|
||||
# 9
|
||||
start_time = time.time()
|
||||
cfrags = mergecfrag(cfrag_cts)
|
||||
cfrags = MergeCFrag(cfrag_cts)
|
||||
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
88
tests/node_test.py
Normal file
88
tests/node_test.py
Normal file
@ -0,0 +1,88 @@
|
||||
# 测试 node.py 中的函数
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
import node
|
||||
|
||||
|
||||
class TestGetLocalIP(unittest.TestCase):
|
||||
|
||||
@patch.dict("os.environ", {"HOST_IP": "60.204.193.58"}) # 模拟设置 HOST_IP 环境变量
|
||||
def test_get_ip_from_env(self):
|
||||
# 调用被测函数
|
||||
node.get_local_ip()
|
||||
|
||||
# 检查函数是否正确获取到 HOST_IP
|
||||
self.assertEqual(node.ip, "60.204.193.58")
|
||||
|
||||
@patch("socket.socket") # Mock socket 连接行为
|
||||
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
|
||||
def test_get_ip_from_socket(self, mock_socket):
|
||||
# 模拟 socket 返回的 IP 地址
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
|
||||
|
||||
# 调用被测函数
|
||||
node.get_local_ip()
|
||||
|
||||
# 确认 socket 被调用过
|
||||
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
# 检查是否通过 socket 获取到正确的 IP 地址
|
||||
self.assertEqual(node.ip, "110.41.155.96")
|
||||
|
||||
|
||||
class TestSendIP(unittest.TestCase):
|
||||
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
|
||||
@patch("requests.get") # Mock requests.get 调用
|
||||
def test_send_ip(self, mock_get):
|
||||
# 设置模拟返回的 HTTP 响应
|
||||
mock_response = Mock()
|
||||
mock_response.text = "node123" # 模拟返回的节点ID
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = (
|
||||
mock_response # 设置 requests.get() 的返回值为 mock_response
|
||||
)
|
||||
|
||||
# 保存原始的全局 id 值
|
||||
original_id = node.id
|
||||
|
||||
# 调用待测函数
|
||||
node.send_ip()
|
||||
|
||||
# 确保 requests.get 被正确调用
|
||||
expected_url = f"{node.server_address}/get_node?ip={node.ip}"
|
||||
mock_get.assert_called_once_with(expected_url, timeout=3)
|
||||
|
||||
# 检查 id 是否被正确更新
|
||||
self.assertIs(node.id, mock_response) # 检查 id 是否被修改
|
||||
self.assertEqual(
|
||||
node.id.text, "node123"
|
||||
) # 检查更新后的 id 是否与 mock_response.text 匹配
|
||||
|
||||
|
||||
class TestNode(unittest.TestCase):
|
||||
|
||||
@patch("node.send_ip")
|
||||
@patch("node.get_local_ip")
|
||||
@patch("node.asyncio.create_task")
|
||||
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
|
||||
# 调用 init 函数
|
||||
node.init()
|
||||
|
||||
# 验证 get_local_ip 和 send_ip 被调用
|
||||
mock_get_local_ip.assert_called_once()
|
||||
mock_send_ip.assert_called_once()
|
||||
|
||||
# 确保 create_task 被调用来启动心跳包
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
194
tests/node_test5.py
Normal file
194
tests/node_test5.py
Normal file
@ -0,0 +1,194 @@
|
||||
# node_test剩下部分(有问题)
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, Mock, AsyncMock
|
||||
import requests
|
||||
import asyncio
|
||||
import httpx
|
||||
import respx
|
||||
from fastapi.testclient import TestClient
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from node import (
|
||||
app,
|
||||
send_heartbeat_internal,
|
||||
Req,
|
||||
send_ip,
|
||||
get_local_ip,
|
||||
init,
|
||||
clear,
|
||||
send_user_des_message,
|
||||
id,
|
||||
)
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
server_address = "http://60.204.236.38:8000/server"
|
||||
# ip = None # 初始化全局变量 ip
|
||||
# id = None # 初始化全局变量 id
|
||||
|
||||
|
||||
class TestGetLocalIP(unittest.TestCase):
|
||||
|
||||
def test_get_ip_from_env(self):
|
||||
os.environ["HOST_IP"] = "60.204.193.58" # 模拟设置 HOST_IP 环境变量
|
||||
ip = get_local_ip()
|
||||
# 检查函数是否正确获取到 HOST_IP
|
||||
self.assertEqual(ip, "60.204.193.58")
|
||||
|
||||
@patch("socket.socket") # Mock socket 连接行为
|
||||
def test_get_ip_from_socket(self, mock_socket):
|
||||
os.environ.pop("HOST_IP", None)
|
||||
|
||||
# 模拟 socket 返回的 IP 地址
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
|
||||
|
||||
# 调用被测函数
|
||||
ip = get_local_ip()
|
||||
|
||||
# 确认 socket 被调用过
|
||||
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
# 检查是否通过 socket 获取到正确的 IP 地址
|
||||
self.assertEqual(ip, "110.41.155.96")
|
||||
|
||||
|
||||
class TestSendIP(unittest.TestCase):
|
||||
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
|
||||
@respx.mock
|
||||
def test_send_ip(self):
|
||||
global ip, id
|
||||
ip = "60.204.193.58"
|
||||
mock_url = f"{server_address}/get_node?ip={ip}"
|
||||
respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123"))
|
||||
|
||||
# 调用待测函数
|
||||
send_ip()
|
||||
|
||||
# 确保 requests.get 被正确调用
|
||||
self.assertEqual(
|
||||
id, "node123"
|
||||
) # 检查更新后的 id 是否与 mock_response.text 匹配
|
||||
|
||||
|
||||
class TestNode(unittest.TestCase):
|
||||
|
||||
@patch("node.send_ip")
|
||||
@patch("node.get_local_ip")
|
||||
@patch("node.asyncio.create_task")
|
||||
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
|
||||
# 调用 init 函数
|
||||
init()
|
||||
|
||||
# 验证 get_local_ip 和 send_ip 被调用
|
||||
mock_get_local_ip.assert_called_once()
|
||||
mock_send_ip.assert_called_once()
|
||||
|
||||
# 确保 create_task 被调用来启动心跳包
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
def test_clear(self):
|
||||
# 调用 clear 函数
|
||||
clear()
|
||||
# 检查输出
|
||||
self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_heartbeat_internal_success():
|
||||
global ip
|
||||
ip = "60.204.193.58"
|
||||
# 模拟心跳请求
|
||||
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
|
||||
return_value=httpx.Response(200)
|
||||
)
|
||||
|
||||
# 模拟 requests.get 以避免实际请求
|
||||
with patch("requests.get", return_value=httpx.Response(200)) as mock_get:
|
||||
# 模拟 asyncio.sleep 以避免实际延迟
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
task = asyncio.create_task(send_heartbeat_internal())
|
||||
await asyncio.sleep(0.1) # 允许任务运行一段时间
|
||||
task.cancel() # 取消任务以停止无限循环
|
||||
try:
|
||||
await task # 确保任务被等待
|
||||
except asyncio.CancelledError:
|
||||
pass # 捕获取消错误
|
||||
|
||||
assert mock_get.called
|
||||
assert mock_get.call_count > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_heartbeat_internal_failure():
|
||||
global ip
|
||||
ip = "60.204.193.58"
|
||||
# 模拟心跳请求以引发异常
|
||||
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
|
||||
side_effect=httpx.RequestError("Central server error")
|
||||
)
|
||||
|
||||
# 模拟 requests.get 以避免实际请求
|
||||
with patch(
|
||||
"requests.get", side_effect=httpx.RequestError("Central server error")
|
||||
) as mock_get:
|
||||
# 模拟 asyncio.sleep 以避免实际延迟
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
task = asyncio.create_task(send_heartbeat_internal())
|
||||
await asyncio.sleep(0.1) # 允许任务运行一段时间
|
||||
task.cancel() # 取消任务以停止无限循环
|
||||
try:
|
||||
await task # 确保任务被等待
|
||||
except asyncio.CancelledError:
|
||||
pass # 捕获取消错误
|
||||
|
||||
assert mock_get.called
|
||||
assert mock_get.call_count > 0
|
||||
|
||||
|
||||
def test_user_src():
|
||||
# 模拟 ReEncrypt 函数
|
||||
with patch(
|
||||
"node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")
|
||||
):
|
||||
# 模拟 send_user_des_message 函数
|
||||
with patch(
|
||||
"node.send_user_des_message", new_callable=AsyncMock
|
||||
) as mock_send_user_des_message:
|
||||
message = {
|
||||
"source_ip": "60.204.193.58",
|
||||
"dest_ip": "60.204.193.59",
|
||||
"capsule": (("x1", "y1"), ("x2", "y2"), 123),
|
||||
"ct": 456,
|
||||
"rk": ["rk1", "rk2"],
|
||||
}
|
||||
response = client.post("/user_src", json=message)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"detail": "message received"}
|
||||
mock_send_user_des_message.assert_called_once()
|
||||
|
||||
|
||||
def test_send_user_des_message():
|
||||
with respx.mock:
|
||||
dest_ip = "60.204.193.59"
|
||||
re_message = (("a", "b", "c", "d"), 123)
|
||||
respx.post(f"http://{dest_ip}:8002/receive_messages").mock(
|
||||
return_value=httpx.Response(200, json={"status": "success"})
|
||||
)
|
||||
response = requests.post(
|
||||
f"http://{dest_ip}:8002/receive_messages",
|
||||
json={"Tuple": re_message, "ip": "60.204.193.58"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "success"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
140
tests/server_test.py
Normal file
140
tests/server_test.py
Normal file
@ -0,0 +1,140 @@
|
||||
import sqlite3
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from server import app, validate_ip
|
||||
|
||||
# 创建 TestClient 实例
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# 准备测试数据库数据
|
||||
def setup_db():
|
||||
# 创建数据库并插入测试数据
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS nodes (
|
||||
id INTEGER PRIMARY KEY,
|
||||
ip TEXT NOT NULL,
|
||||
last_heartbeat INTEGER NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)"
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
# 清空数据库
|
||||
def clear_db():
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表
|
||||
db.commit()
|
||||
|
||||
|
||||
# 测试 IP 验证功能
|
||||
def test_validate_ip():
|
||||
assert validate_ip("192.168.0.1") is True
|
||||
assert validate_ip("256.256.256.256") is False
|
||||
assert validate_ip("::1") is True
|
||||
assert validate_ip("invalid_ip") is False
|
||||
|
||||
|
||||
# 测试首页路由
|
||||
def test_home():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello, World!"}
|
||||
|
||||
|
||||
# 测试 show_nodes 路由
|
||||
def test_show_nodes():
|
||||
setup_db()
|
||||
|
||||
response = client.get("/server/show_nodes")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0][1] == "192.168.0.1"
|
||||
assert data[1][1] == "192.168.0.2"
|
||||
|
||||
|
||||
# 测试 get_node 路由
|
||||
def test_get_node():
|
||||
# 确保数据库和表的存在
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.3"
|
||||
invalid_ip = "256.256.256.256"
|
||||
|
||||
# 测试有效的 IP 地址
|
||||
response = client.get(f"/server/get_node?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# 测试无效的 IP 地址
|
||||
response = client.get(f"/server/get_node?ip={invalid_ip}")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# 测试 delete_node 路由
|
||||
def test_delete_node():
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.1"
|
||||
invalid_ip = "192.168.0.255"
|
||||
|
||||
response = client.get(f"/server/delete_node?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
assert "Node with IP 192.168.0.1 deleted successfully." in response.text
|
||||
|
||||
response = client.get(f"/server/delete_node?ip={invalid_ip}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# 测试 heartbeat 路由
|
||||
def test_receive_heartbeat():
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.2"
|
||||
invalid_ip = "256.256.256.256"
|
||||
|
||||
response = client.get(f"/server/heartbeat?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "received"}
|
||||
|
||||
response = client.get(f"/server/heartbeat?ip={invalid_ip}")
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"message": "invalid ip format"}
|
||||
|
||||
|
||||
# 测试 send_nodes_list 路由
|
||||
def test_send_nodes_list():
|
||||
setup_db()
|
||||
|
||||
response = client.get("/server/send_nodes_list?count=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0] == "192.168.0.1"
|
||||
|
||||
response = client.get("/server/send_nodes_list?count=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
||||
|
||||
# 运行完测试后清理数据库
|
||||
@pytest.fixture(autouse=True)
|
||||
def run_around_tests():
|
||||
clear_db()
|
||||
yield
|
||||
clear_db()
|
@ -1,4 +1,15 @@
|
||||
from tpre import *
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import (
|
||||
GenerateKeyPair,
|
||||
Encrypt,
|
||||
GenerateReKey,
|
||||
ReEncrypt,
|
||||
MergeCFrag,
|
||||
DecryptFrags,
|
||||
)
|
||||
import time
|
||||
|
||||
N = 20
|
||||
@ -50,7 +61,7 @@ print(f"重加密算法运行时间:{elapsed_time}秒")
|
||||
|
||||
# 9
|
||||
start_time = time.time()
|
||||
cfrags = mergecfrag(cfrag_cts)
|
||||
cfrags = MergeCFrag(cfrag_cts)
|
||||
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
76
tests/test_client.py
Normal file
76
tests/test_client.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from client import app, init_db, clean_env, get_own_ip
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_and_teardown():
|
||||
# 设置测试环境
|
||||
init_db()
|
||||
yield
|
||||
# 清理测试环境
|
||||
clean_env()
|
||||
|
||||
|
||||
def test_read_root():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello, World!"}
|
||||
|
||||
|
||||
def test_receive_messages():
|
||||
message = {"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), "ip": "127.0.0.1"}
|
||||
response = client.post("/receive_messages", json=message)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("detail") == "Message received"
|
||||
|
||||
|
||||
# @respx.mock
|
||||
# def test_request_message():
|
||||
# request_message = {
|
||||
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
|
||||
# "message_name": "name"
|
||||
# }
|
||||
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
|
||||
# response = client.post("/request_message", json=request_message)
|
||||
# assert response.status_code == 200
|
||||
# assert "threshold" in response.json()
|
||||
# assert "public_key" in response.json()
|
||||
|
||||
# @respx.mock
|
||||
# def test_receive_request():
|
||||
# ip_message = {
|
||||
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
|
||||
# "message_name": "name",
|
||||
# "source_ip": "124.70.165.73", # 使用不同的 IP 地址
|
||||
# "pk": (123, 456)
|
||||
# }
|
||||
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
|
||||
# response = client.post("/receive_request", json=ip_message)
|
||||
# assert response.status_code == 200
|
||||
# assert "threshold" in response.json()
|
||||
# assert "public_key" in response.json()
|
||||
|
||||
|
||||
def test_get_pk():
|
||||
response = client.get("/get_pk")
|
||||
assert response.status_code == 200
|
||||
assert "pkx" in response.json()
|
||||
assert "pky" in response.json()
|
||||
|
||||
|
||||
def test_recieve_pk():
|
||||
pk_data = {"pkx": "123", "pky": "456", "ip": "127.0.0.1"}
|
||||
response = client.post("/recieve_pk", json=pk_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "save pk in database"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
158
tests/tpre_test.py
Normal file
158
tests/tpre_test.py
Normal file
@ -0,0 +1,158 @@
|
||||
import sys
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import (
|
||||
hash2,
|
||||
hash3,
|
||||
hash4,
|
||||
multiply,
|
||||
g,
|
||||
sm2p256v1,
|
||||
GenerateKeyPair,
|
||||
Encrypt,
|
||||
Decrypt,
|
||||
GenerateReKey,
|
||||
Encapsulate,
|
||||
ReEncrypt,
|
||||
DecryptFrags,
|
||||
MergeCFrag,
|
||||
)
|
||||
import random
|
||||
import unittest
|
||||
|
||||
|
||||
class TestHash2(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.double_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash3(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash4(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
self.Zp = random.randint(0, sm2p256v1.N - 1)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestGenerateKeyPair(unittest.TestCase):
|
||||
def test_key_pair(self):
|
||||
public_key, secret_key = GenerateKeyPair()
|
||||
self.assertIsInstance(public_key, tuple)
|
||||
self.assertEqual(len(public_key), 2)
|
||||
self.assertIsInstance(secret_key, int)
|
||||
self.assertLess(secret_key, sm2p256v1.N)
|
||||
self.assertGreater(secret_key, 0)
|
||||
|
||||
|
||||
# class TestEncryptDecrypt(unittest.TestCase):
|
||||
# def setUp(self):
|
||||
# self.public_key, self.secret_key = GenerateKeyPair()
|
||||
# self.message = b"Hello, world!"
|
||||
|
||||
# def test_encrypt_decrypt(self):
|
||||
# encrypted_message = Encrypt(self.public_key, self.message)
|
||||
# # 使用 SHA-256 哈希函数确保密钥为 16 字节
|
||||
# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest()
|
||||
# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数
|
||||
|
||||
# decrypted_message = Decrypt(secret_key_int, encrypted_message)
|
||||
# self.assertEqual(decrypted_message, self.message)
|
||||
|
||||
|
||||
class TestGenerateReKey(unittest.TestCase):
|
||||
def test_generate_rekey(self):
|
||||
sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
pk_B, _ = GenerateKeyPair()
|
||||
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
self.assertIsInstance(rekey, list)
|
||||
self.assertEqual(len(rekey), 5)
|
||||
|
||||
|
||||
class TestEncapsulate(unittest.TestCase):
|
||||
def test_encapsulate(self):
|
||||
pk_A, _ = GenerateKeyPair()
|
||||
K, capsule = Encapsulate(pk_A)
|
||||
self.assertIsInstance(K, int)
|
||||
self.assertIsInstance(capsule, tuple)
|
||||
self.assertEqual(len(capsule), 3)
|
||||
|
||||
|
||||
class TestReEncrypt(unittest.TestCase):
|
||||
def test_reencrypt(self):
|
||||
sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
pk_B, _ = GenerateKeyPair()
|
||||
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
pk_A, _ = GenerateKeyPair()
|
||||
message = b"Hello, world!"
|
||||
encrypted_message = Encrypt(pk_A, message)
|
||||
reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
|
||||
self.assertIsInstance(reencrypted_message, tuple)
|
||||
self.assertEqual(len(reencrypted_message), 2)
|
||||
|
||||
|
||||
# class TestDecryptFrags(unittest.TestCase):
|
||||
# def test_decrypt_frags(self):
|
||||
# sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
# pk_B, sk_B = GenerateKeyPair()
|
||||
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
# pk_A, _ = GenerateKeyPair()
|
||||
# message = b"Hello, world!"
|
||||
# encrypted_message = Encrypt(pk_A, message)
|
||||
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
|
||||
# cfrags = [reencrypted_message]
|
||||
# merged_cfrags = MergeCFrag(cfrags)
|
||||
|
||||
# self.assertIsNotNone(merged_cfrags)
|
||||
|
||||
# sk_B_hash = hashlib.sha256(sk_B.to_bytes((sk_B.bit_length() + 7) // 8, 'big')).digest()
|
||||
# sk_B_int = int.from_bytes(sk_B_hash[:16], 'big') # 取前 16 字节并转换为整数
|
||||
|
||||
# decrypted_message = DecryptFrags(sk_B_int, pk_B, pk_A, merged_cfrags)
|
||||
# self.assertEqual(decrypted_message, message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user