forked from sangge/tpre-python
		
	Compare commits
	
		
			5 Commits
		
	
	
		
			c2698455de
			...
			1.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| af81cb6c00 | |||
| c09c326bde | |||
| b5a88fa9b5 | |||
| 6cd279e88c | |||
| 647c92c8dc | 
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | database/dns.db  | ||||||
|  | .vscode | ||||||
							
								
								
									
										6
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | { | ||||||
|  |     "[python]": { | ||||||
|  |         "editor.defaultFormatter": "ms-python.autopep8" | ||||||
|  |     }, | ||||||
|  |     "python.formatting.provider": "none" | ||||||
|  | } | ||||||
							
								
								
									
										64
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										64
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,27 +1,67 @@ | |||||||
| # my-tor-core | # xiaomian DNS | ||||||
|  |  | ||||||
| 本项目是计算机网络的课程设计项目,该项目是一个类 Tor 的网络通信协议,旨在保护用户的隐私和匿名性。 | 本项目是计算机网络的课程设计项目,该项目是一个私有DNS的简单实现。 | ||||||
|  |  | ||||||
| ## 项目原理 | ## 项目原理 | ||||||
|  |  | ||||||
| 本项目目标是实现三层代理 + dns 解析自定义域名.xiaomian + 搭建匿名网站。 | 在server/xiaomiandns.py中实现了DNSserver和APIserver两个类。通过server/main.py启动实例化的server。配置文件在server/serverconf.yaml   | ||||||
| server路径下包含一个目录服务器,用于创建目录服务器,记录加入节点。 |  | ||||||
| client路径下是客户端程序,客户端程序通过访问目录服务器获取当前的路由,并通过随机路由算法选择代理节点。 |  | ||||||
| 本项目选择sqlite作为数据库,存储节点信息等数据。 | 本项目选择sqlite作为数据库,存储节点信息等数据。 | ||||||
|  |  | ||||||
| ## 环境依赖 | ## 环境依赖 | ||||||
|  |  | ||||||
| 该项目依赖以下软件:   | 该项目依赖以下软件:   | ||||||
| python 3.11 | python 3 | ||||||
| sqlite3 |  | ||||||
|  |  | ||||||
| ## 安装步骤 | ## 安装步骤 | ||||||
|  | ```console | ||||||
|  | # 安装依赖 | ||||||
|  | pip install -r requirements.txt | ||||||
|  |  | ||||||
|  | #初始化数据库 | ||||||
|  | python3 database/initdb.py | ||||||
|  | ``` | ||||||
|  |  | ||||||
| ## 使用说明 | ## 使用说明 | ||||||
| 本项目中包含三种角色,client, node和server。每种角色运行所需要的代码在相应的项目文件夹下面。 | ```python | ||||||
| client: 即客户端。可以通过用户端连接小面网络、创建小面网站、访问别人创建的小面网站。 | python3 server/main.py | ||||||
| node: 即代理节点。运行此程序可以将计算机加入小面网络中,代理连接流量 | ``` | ||||||
| server: 即DNS服务器和小面网站目录服务器。运行此程序可以作为server接受请求。 | dns默认端口为53,域名api默认端口为81 | ||||||
|  | 需要添加数据可以通过post方法向API提交数据 | ||||||
|  | usage: use POST method | ||||||
|  |            /add | ||||||
|  |            data: domian=xxxx&ip=xx.xx.xx.xx | ||||||
|  |            /delete | ||||||
|  |            data: domian=xxxx&ip=xx.xx.xx.xx | ||||||
|  | 注意域名只能以xiaomian作为根域名 | ||||||
|  |  | ||||||
|  | # 测试DNS功能: | ||||||
|  | ```console | ||||||
|  | # Linux | ||||||
|  | sudo apt install dnsutils | ||||||
|  | dig @DNS_SERVER -p PORT DOMAIN  | ||||||
|  |  | ||||||
|  | # Windows | ||||||
|  | nslookup DOMAIN DNS_SERVER | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 测试API功能 | ||||||
|  | ```console | ||||||
|  | # Linux | ||||||
|  | # 添加解析 | ||||||
|  | curl -d "domain=qqqwwweee.xiaomian&ip=123.12.23.34" -X POST http://10.20.117.208:81/add -i | ||||||
|  |  | ||||||
|  | # 删除解析 | ||||||
|  | curl -d "domain=qqqwwweee.xiaomian&ip=123.12.23.34" -X POST http://10.20.117.208:81/delete -i | ||||||
|  |  | ||||||
|  | # Windows | ||||||
|  | Invoke-WebRequest工具一直收不到post的body,不知道问题出在哪里 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## 未实现的功能 | ||||||
|  | 目前只能把解析记录作为a记录返回,还未实现添加其他解析记录。 | ||||||
|  |  | ||||||
| ## 许可证 | ## 许可证 | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										
											BIN
										
									
								
								api_add.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								api_add.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 85 KiB | 
							
								
								
									
										
											BIN
										
									
								
								api_delete.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								api_delete.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 72 KiB | 
| @@ -1,24 +0,0 @@ | |||||||
| import dns.resolver |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def resolver(domain): |  | ||||||
|     # 构造 DNS 查询请求 |  | ||||||
|     qtype = 'A' |  | ||||||
|  |  | ||||||
|     # 发送 DNS 查询请求 |  | ||||||
|     resolver = dns.resolver.Resolver() |  | ||||||
|     resolver.nameservers = ["127.0.0.1"] |  | ||||||
|  |  | ||||||
|     try: |  | ||||||
|         ip = resolver.resolve(domain, qtype)[0] |  | ||||||
|         return ip |  | ||||||
|     except dns.resolver.NXDOMAIN: |  | ||||||
|         print("can't find IP") |  | ||||||
|          |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     domain = 'mamahaha.work' |  | ||||||
|     ip = resolver(domain) |  | ||||||
|     print(ip) |  | ||||||
| @@ -1,66 +0,0 @@ | |||||||
| from cryptography.hazmat.primitives.asymmetric import rsa, padding |  | ||||||
| from cryptography.hazmat.primitives import serialization, hashes |  | ||||||
| import base64 |  | ||||||
| import random |  | ||||||
|  |  | ||||||
| # 生产随机域名 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_domain() -> str: |  | ||||||
|     domain = random.getrandbits(64) |  | ||||||
|     domain = hex(domain)[2:] |  | ||||||
|     return domain + ".xiaomian" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_key(): |  | ||||||
|  |  | ||||||
|     # Generate a new RSA key pair |  | ||||||
|     private_key = rsa.generate_private_key( |  | ||||||
|         public_exponent=65537, |  | ||||||
|         key_size=2048 |  | ||||||
|     ) |  | ||||||
|     public_key = private_key.public_key() |  | ||||||
|  |  | ||||||
|     # Convert keys to bytes |  | ||||||
|     private_key_bytes = private_key.private_bytes( |  | ||||||
|         encoding=serialization.Encoding.PEM, |  | ||||||
|         format=serialization.PrivateFormat.PKCS8, |  | ||||||
|         encryption_algorithm=serialization.NoEncryption() |  | ||||||
|     ) |  | ||||||
|     public_key_bytes = public_key.public_bytes( |  | ||||||
|         encoding=serialization.Encoding.PEM, |  | ||||||
|         format=serialization.PublicFormat.SubjectPublicKeyInfo |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     # Encode bytes as base64 |  | ||||||
|     private_key_base64 = base64.b64encode(private_key_bytes).decode('utf-8') |  | ||||||
|     public_key_base64 = base64.b64encode(public_key_bytes).decode('utf-8') |  | ||||||
|  |  | ||||||
|     return private_key_base64, public_key_base64 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # # Encrypt a message using the public key |  | ||||||
| # message = b"Hello World" |  | ||||||
| # encrypted_message = public_key.encrypt( |  | ||||||
| #     message, |  | ||||||
| #     padding.OAEP( |  | ||||||
| #         mgf=padding.MGF1(algorithm=hashes.SHA256()), |  | ||||||
| #         algorithm=hashes.SHA256(), |  | ||||||
| #         label=None |  | ||||||
| #     ) |  | ||||||
| # ) |  | ||||||
|  |  | ||||||
| # # Decrypt the message using the private key |  | ||||||
| # decrypted_message = private_key.decrypt( |  | ||||||
| #     encrypted_message, |  | ||||||
| #     padding.OAEP( |  | ||||||
| #         mgf=padding.MGF1(algorithm=hashes.SHA256()), |  | ||||||
| #         algorithm=hashes.SHA256(), |  | ||||||
| #         label=None |  | ||||||
| #     ) |  | ||||||
| # ) |  | ||||||
| # print(decrypted_message) |  | ||||||
| if __name__ == '__main__': |  | ||||||
|     print("Welcome to my xiaomiao tor network") |  | ||||||
|     domain = generate_domain() |  | ||||||
|     private_key_base64, public_key_base64 = generate_key() |  | ||||||
| @@ -1,15 +1,33 @@ | |||||||
| import sqlite3 | import sqlite3 | ||||||
|  | import argparse | ||||||
|  |  | ||||||
|  | # 用于创建 | ||||||
| db_file = 'database/dns.db' | db_file = 'database/dns.db' | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     conn = sqlite3.connect(db_file) |     conn = sqlite3.connect(db_file) | ||||||
|     cursor = conn.cursor() |     cursor = conn.cursor() | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument('--add', action='store_true', help='add test data') | ||||||
|  |     args = parser.parse_args() | ||||||
|     try: |     try: | ||||||
|         cursor.execute( |         cursor.execute( | ||||||
|             '''CREATE TABLE xiaomiandns(domain TEXT PRIMARY KEY, ip TEXT, pubkey TEXT, nodetype TEXT,timestamp DATETIME)''') |             '''CREATE TABLE xiaomiandns(domain TEXT PRIMARY KEY, ip TEXT,timestamp DATETIME)''') | ||||||
|         # node type contain 3 types: client, node, server |  | ||||||
|     except sqlite3.OperationalError: |     except sqlite3.OperationalError: | ||||||
|         print("table xiaomiandns already exists") |         print("table xiaomiandns already exists") | ||||||
|     conn.commit() |     conn.commit() | ||||||
|  |     if args.add: | ||||||
|  |      | ||||||
|  |         test_data = [ | ||||||
|  |         ('example.xiaomian', '192.168.1.1'), | ||||||
|  |         ('google.xiaomian', '8.8.8.8'), | ||||||
|  |         ('yahoo.xiaomian', '98.138.219.231') | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         for data in test_data: | ||||||
|  |             domain, ip = data | ||||||
|  |             cursor.execute("INSERT INTO xiaomiandns (domain, ip, timestamp) VALUES (?, ?, DATETIME('now'))", (domain, ip)) | ||||||
|  |  | ||||||
|  |     # 提交更改到数据库 | ||||||
|  |     conn.commit() | ||||||
|     cursor.close() |     cursor.close() | ||||||
|     conn.close() |     conn.close() | ||||||
|   | |||||||
| @@ -1,23 +0,0 @@ | |||||||
| import sqlite3 |  | ||||||
|  |  | ||||||
| db_file = 'database/dns.db' |  | ||||||
| if __name__ == '__main__': |  | ||||||
|     conn = sqlite3.connect(db_file) |  | ||||||
|     cursor = conn.cursor() |  | ||||||
|     domain = 'mamahaha.wor12' |  | ||||||
|     ip = "1.1.1.11" |  | ||||||
|     pubkey = "asdfasdfadfsdf" |  | ||||||
|     cursor.execute("SELECT * FROM xiaomiandns WHERE domain = ? OR ip = ? OR pubkey = ?", |  | ||||||
|               (domain, ip, pubkey)) |  | ||||||
|     existing_data = cursor.fetchall() |  | ||||||
|     if existing_data: |  | ||||||
|         print("qqqqqq") |  | ||||||
|     else: |  | ||||||
|         # Insert the new data |  | ||||||
|         cursor.execute( |  | ||||||
|             "INSERT INTO xiaomiandns (domain, ip, pubkey) VALUES (?, ?, ?)", (domain, ip, pubkey)) |  | ||||||
|         print("Data inserted successfully") |  | ||||||
|          |  | ||||||
|     conn.commit() |  | ||||||
|     cursor.close() |  | ||||||
|     conn.close() |  | ||||||
							
								
								
									
										
											BIN
										
									
								
								database_add.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								database_add.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 11 KiB | 
							
								
								
									
										
											BIN
										
									
								
								database_delete.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								database_delete.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 8.7 KiB | 
| @@ -1,2 +0,0 @@ | |||||||
| import yaml |  | ||||||
|  |  | ||||||
| @@ -1,55 +0,0 @@ | |||||||
| import socket |  | ||||||
| import socketserver |  | ||||||
| import struct |  | ||||||
| import select |  | ||||||
|  |  | ||||||
| class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): |  | ||||||
|     pass |  | ||||||
|  |  | ||||||
| class Socks5Handler(socketserver.BaseRequestHandler): |  | ||||||
|     VERSION = 5 |  | ||||||
|  |  | ||||||
|     def handle(self): |  | ||||||
|         # 客户端发送版本和方法 |  | ||||||
|         version, nmethods = struct.unpack('!BB', self.request.recv(2)) |  | ||||||
|         self.request.recv(nmethods) |  | ||||||
|  |  | ||||||
|         # 发送版本和方法响应 |  | ||||||
|         self.request.sendall(struct.pack('!BB', self.VERSION, 0)) |  | ||||||
|  |  | ||||||
|         # 获取请求详情 |  | ||||||
|         version, cmd, _, address_type = struct.unpack('!BBBB', self.request.recv(4)) |  | ||||||
|         if address_type == 1:  # IPv4 |  | ||||||
|             address = socket.inet_ntoa(self.request.recv(4)) |  | ||||||
|         else: |  | ||||||
|             raise NotImplementedError('Only IPv4 is supported.') |  | ||||||
|         port = struct.unpack('!H', self.request.recv(2))[0] |  | ||||||
|  |  | ||||||
|         # 发送响应 |  | ||||||
|         self.request.sendall(struct.pack('!BBBBIH', self.VERSION, 0, 0, 1, |  | ||||||
|                                          int(socket.inet_aton('0.0.0.0').hex(), 16), 0)) |  | ||||||
|  |  | ||||||
|         # 建立连接 |  | ||||||
|         if cmd == 1:  # CONNECT |  | ||||||
|             remote = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |  | ||||||
|             remote.connect((address, port)) |  | ||||||
|             self.exchange_loop(self.request, remote) |  | ||||||
|         else: |  | ||||||
|             raise NotImplementedError('Only CONNECT is supported.') |  | ||||||
|  |  | ||||||
|     def exchange_loop(self, client, remote): |  | ||||||
|         while True: |  | ||||||
|             # Simple data exchange between client and remote |  | ||||||
|             rlist, _, _ = select.select([client, remote], [], []) |  | ||||||
|             if client in rlist: |  | ||||||
|                 data = client.recv(4096) |  | ||||||
|                 if remote.send(data) <= 0: |  | ||||||
|                     break |  | ||||||
|             if remote in rlist: |  | ||||||
|                 data = remote.recv(4096) |  | ||||||
|                 if client.send(data) <= 0: |  | ||||||
|                     break |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': |  | ||||||
|     with ThreadingTCPServer(('0.0.0.0', 1080), Socks5Handler) as server: |  | ||||||
|         server.serve_forever() |  | ||||||
							
								
								
									
										2
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | dnspython==2.3.0 | ||||||
|  | PyYAML==6.0 | ||||||
| @@ -3,7 +3,7 @@ import yaml | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     with open('serverconf.yaml', 'r') as f: |     with open('server/serverconf.yaml', 'r') as f: | ||||||
|         config = yaml.safe_load(f) |         config = yaml.safe_load(f) | ||||||
|     db_file = config['database']['db_file'] |     db_file = config['database']['db_file'] | ||||||
|     DNS_port = config['DNS']['port'] |     DNS_port = config['DNS']['port'] | ||||||
| @@ -11,5 +11,10 @@ if __name__ == '__main__': | |||||||
|     API_port = config['API']['port'] |     API_port = config['API']['port'] | ||||||
|     API_listen_host = config['API']['listen_host'] |     API_listen_host = config['API']['listen_host'] | ||||||
|  |  | ||||||
|     DNSServer = xiaomiandns.DNSServer(DNS_listen_host, DNS_port, db_file) |     # start dns server | ||||||
|     DNSServer.run() |     server = xiaomiandns.DNSServer(DNS_listen_host, DNS_port, db_file) | ||||||
|  |     server.run() | ||||||
|  |  | ||||||
|  |     # start dns api server | ||||||
|  |     APIserver = xiaomiandns.DNSAPI(API_listen_host, API_port, db_file) | ||||||
|  |     APIserver.run() | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| database: | database: | ||||||
|   db_file : '../database/dns.db' |   db_file : 'database/dns.db' | ||||||
| DNS: | DNS: | ||||||
|   port : 53 |   port : 53 | ||||||
|   listen_host : "0.0.0.0" |   listen_host : "0.0.0.0" | ||||||
|   | |||||||
| @@ -7,16 +7,14 @@ import dns.rdatatype | |||||||
| import dns.flags | import dns.flags | ||||||
| import dns.rcode | import dns.rcode | ||||||
| import dns.rrset | import dns.rrset | ||||||
| import time |  | ||||||
| import sqlite3 | import sqlite3 | ||||||
| import re | import re | ||||||
| import base64 | import yaml | ||||||
| from cryptography.hazmat.primitives.asymmetric import rsa, padding |  | ||||||
| from cryptography.hazmat.primitives import serialization, hashes |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DNSServer: | class DNSServer: | ||||||
|  |      | ||||||
|     def __init__(self, hostname, port, db_file): |     def __init__(self, hostname, port, db_file): | ||||||
|         self.hostname = hostname |         self.hostname = hostname | ||||||
|         self.port = port |         self.port = port | ||||||
| @@ -72,7 +70,7 @@ class DNSServer: | |||||||
|         name = question.question[0].name |         name = question.question[0].name | ||||||
|         # search domain in database |         # search domain in database | ||||||
|         dbcursor.execute( |         dbcursor.execute( | ||||||
|             "SELECT ip FROM xiaomiandns WHERE domain = ? AND nodetype = client", (str(name)[:-1],)) |             "SELECT ip FROM xiaomiandns WHERE domain = ?", (str(name)[:-1],)) | ||||||
|         result = dbcursor.fetchone() |         result = dbcursor.fetchone() | ||||||
|  |  | ||||||
|         # Create a new RRset for the answer |         # Create a new RRset for the answer | ||||||
| @@ -90,12 +88,14 @@ class DNSServer: | |||||||
|  |  | ||||||
|  |  | ||||||
| class DNSAPI: | class DNSAPI: | ||||||
|     # usage: use POST method |     """ | ||||||
|     #        /add |     usage: use POST method | ||||||
|     #        data: domian=xxxx&ip=xx.xx.xx.xx&pubkey=xxxxx&nodetype=xxxx |            /add | ||||||
|     #        /delete |            data: domian=xxxx&ip=xx.xx.xx.xx | ||||||
|     #        data: domian=xxxx&ip=xx.xx.xx.xx&prikey=xxxxx&nodetype=xxxx |            /delete | ||||||
|  |            data: domian=xxxx&ip=xx.xx.xx.xx | ||||||
|  |     """ | ||||||
|  |      | ||||||
|     def __init__(self, hostname, port, db_file): |     def __init__(self, hostname, port, db_file): | ||||||
|         self.hostname = hostname |         self.hostname = hostname | ||||||
|         self.port = port |         self.port = port | ||||||
| @@ -136,40 +136,32 @@ class DNSAPI: | |||||||
|         return response |         return response | ||||||
|  |  | ||||||
|     def handle_get_request(self, url): |     def handle_get_request(self, url): | ||||||
|  |         status_code = 400 | ||||||
|         # check url start with /add |         reason_phrase = 'unsupport method, please use POST method' | ||||||
|         # if re.match(r'^/add\?', url): |  | ||||||
|         #     status_code = self.add_data(url[5:]) |  | ||||||
|         #     if status_code == 200: |  | ||||||
|         #         reason_phrase = 'Add data successful' |  | ||||||
|         #     else: |  | ||||||
|         #         reason_phrase = 'Add data unsuccessful' |  | ||||||
|         # # check url start with /delete |  | ||||||
|         # elif re.match(r'^/delete\?', url): |  | ||||||
|         #     status_code = self.delete_data(url[9:]) |  | ||||||
|         #     if status_code == 200: |  | ||||||
|         #         reason_phrase = 'Delete data successful' |  | ||||||
|         #     else: |  | ||||||
|         #         reason_phrase = 'Delete data unsuccessful' |  | ||||||
|         # else: |  | ||||||
|         #     status_code = 400 |  | ||||||
|         #     reason_phrase = 'unsupport api' |  | ||||||
|  |  | ||||||
|         response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase) |         response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase) | ||||||
|         return response.encode("utf-8") |         return response.encode("utf-8") | ||||||
|  |  | ||||||
|     def handle_post_request(self, url, data): |     def handle_post_request(self, url:str, data:str)->str: | ||||||
|         # 处理 POST 请求,data 是 POST 方法提交的数据 |         """处理 POST 请求 | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             url (str): url前缀 | ||||||
|  |             data (str): POST 方法提交的数据 | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             str: http response | ||||||
|  |         """ | ||||||
|  |          | ||||||
|         # check url start with /add |         # check url start with /add | ||||||
|         if re.match(r'^/add\?', url): |         if re.match(r'^/add', url): | ||||||
|             status_code = self.add_data(data) |             status_code = self.add_data(data) | ||||||
|             if status_code == 200: |             if status_code == 200: | ||||||
|                 reason_phrase = 'Add data successful' |                 reason_phrase = 'Add data successful' | ||||||
|             else: |             else: | ||||||
|                 reason_phrase = 'Add data unsuccessful' |                 reason_phrase = 'Add data unsuccessful' | ||||||
|  |                  | ||||||
|         # check url start with /delete |         # check url start with /delete | ||||||
|         elif re.match(r'^/delete\?', url): |         elif re.match(r'^/delete', url): | ||||||
|             status_code = self.delete_data(data) |             status_code = self.delete_data(data) | ||||||
|             if status_code == 200: |             if status_code == 200: | ||||||
|                 reason_phrase = 'Delete data successful' |                 reason_phrase = 'Delete data successful' | ||||||
| @@ -182,7 +174,7 @@ class DNSAPI: | |||||||
|         response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase) |         response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase) | ||||||
|         return response.encode("utf-8") |         return response.encode("utf-8") | ||||||
|  |  | ||||||
|     def handle_error_request(self, request): |     def handle_error_request(self): | ||||||
|         status_code = 400 |         status_code = 400 | ||||||
|         reason_phrase = "unsupport method" |         reason_phrase = "unsupport method" | ||||||
|         headers = { |         headers = { | ||||||
| @@ -190,129 +182,125 @@ class DNSAPI: | |||||||
|             'Connection': 'close', |             'Connection': 'close', | ||||||
|         } |         } | ||||||
|         response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase) |         response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase) | ||||||
|  |         for key, value in headers.items(): | ||||||
|  |             response += '{}: {}\r\n'.format(key, value) | ||||||
|  |         response += '\r\n' | ||||||
|         return response.encode("utf-8") |         return response.encode("utf-8") | ||||||
|  |  | ||||||
|     def add_data(self, data): |     def add_data(self, data:str)->int: | ||||||
|  |         """add data to database | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             data (str): domain and ip | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             int: status code | ||||||
|  |         """ | ||||||
|         # parse and check validation |         # parse and check validation | ||||||
|         domain, ip, pubkey, nodetype = self.parse_data(data) |         domain, ip = self.parse_data(data) | ||||||
|  |  | ||||||
|         if not self.check_data(domain,ip,nodetype): |         if not self.check_data(domain,ip): | ||||||
|             return 400 |             return 400 | ||||||
|  |  | ||||||
|         # connect db |         # connect db | ||||||
|         conn = sqlite3.connect(self.db_file) |         conn = sqlite3.connect(self.db_file) | ||||||
|         c = conn.cursor() |         c = conn.cursor() | ||||||
|  |  | ||||||
|         # Check if the data already exists |         # Check if the domain already exists | ||||||
|         c.execute( |         c.execute( | ||||||
|             "SELECT * FROM xiaomiandns WHERE domain = ? OR ip = ? OR pubkey = ? OR nodetype = ?", (domain, ip, pubkey, nodetype)) |             "SELECT * FROM xiaomiandns WHERE domain = ?", [domain]) | ||||||
|         existing_data = c.fetchall() |         existing_domain = c.fetchall() | ||||||
|  |          | ||||||
|  |  | ||||||
|         cursor.close() |         if existing_domain: | ||||||
|         conn.close() |             c.execute("UPDATE xiaomiandns SET ip = ?, timestamp = DATETIME('now')  WHERE domain = ?",(ip,domain)) | ||||||
|  |  | ||||||
|         if existing_data: |  | ||||||
|             return 400 |  | ||||||
|         else: |         else: | ||||||
|             # Insert the new data |             # Insert the new data | ||||||
|             c.execute( |             c.execute( | ||||||
|                 "INSERT INTO xiaomiandns (domain,ip,pubkey,nodetype,timestamp) VALUES (?,?,?,?,DATETIME('now'))", (domain, ip, pubkey, nodetype)) |                 "INSERT INTO xiaomiandns (domain,ip,timestamp) VALUES (?,?,DATETIME('now'))", (domain, ip)) | ||||||
|             return 200 |         conn.commit() | ||||||
|  |         c.close() | ||||||
|  |         conn.close() | ||||||
|  |         status_code = 200 | ||||||
|  |         return status_code | ||||||
|  |  | ||||||
|     def delete_data(self, data): |     def delete_data(self, data:str) -> int: | ||||||
|  |         """delete record in database | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             data (str): domain and ip | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             int: status code | ||||||
|  |         """ | ||||||
|         # parse and check validation |         # parse and check validation | ||||||
|         domain, ip, private_key_base64, nodetype = self.parse_data(data) |         domain, ip = self.parse_data(data) | ||||||
|          |          | ||||||
|         if not self.check_data(domain, ip ,nodetype): |         if not self.check_data(domain, ip): | ||||||
|             return 400 |             return 400 | ||||||
|  |  | ||||||
|         # connect db |         # connect db | ||||||
|         conn = sqlite3.connect(self.db_file) |         conn = sqlite3.connect(self.db_file) | ||||||
|         c = conn.cursor() |         c = conn.cursor() | ||||||
|         c.execute( |         c.execute( | ||||||
|             "SELECT pubkey FROM xiaomiandns WHERE domain = ? AND ip = ? AND nodetype = ?", (domain, ip, nodetype)) |             "DELETE FROM xiaomiandns WHERE domain = ? ", [domain]) | ||||||
|         public_key_base64 = c.fetchone() |         deleted_rows = c.rowcount | ||||||
|         cursor.close() |  | ||||||
|  |         if deleted_rows == 0: | ||||||
|  |             # unmatch | ||||||
|  |             status_code = 400 | ||||||
|  |         else: | ||||||
|  |             # deleted | ||||||
|  |             status_code = 200 | ||||||
|  |         conn.commit() | ||||||
|  |         c.close() | ||||||
|         conn.close() |         conn.close() | ||||||
|  |  | ||||||
|         if public_key_base64 != None: |  | ||||||
|             public_key_base64 = public_key_base64[0] |  | ||||||
|         else: |  | ||||||
|             return 400 |  | ||||||
|  |  | ||||||
|         private_key_bytes = base64.b64decode( |  | ||||||
|             private_key_base64).decode("utf-8") |  | ||||||
|  |  | ||||||
|         private_key = serialization.load_pem_private_key( |  | ||||||
|             private_key_bytes, |  | ||||||
|             password=None |  | ||||||
|         ) |  | ||||||
|          |          | ||||||
|         gen_public_key = private_key.public_key() |         return status_code | ||||||
|         gen_public_key_bytes = gen_public_key.public_bytes( |  | ||||||
|             encoding=serialization.Encoding.PEM, |  | ||||||
|             format=serialization.PublicFormat.SubjectPublicKeyInfo |  | ||||||
|         ) |  | ||||||
|         gen_public_key_base64 = base64.b64encode(gen_public_key_bytes).decode('utf-8') |  | ||||||
|          |  | ||||||
|         if gen_public_key_base64 == public_key_base64: |  | ||||||
|             conn = sqlite3.connect(self.db_file) |  | ||||||
|             c = conn.cursor() |  | ||||||
|             c.execute( |  | ||||||
|                 "DELETE FROM xiaomiandns WHERE domain = ? AND ip = ? AND nodetype = ?", (domain, ip, nodetype)) |  | ||||||
|             cursor.close() |  | ||||||
|             conn.close() |  | ||||||
|             return 200 |  | ||||||
|         else: |  | ||||||
|             return 400 |  | ||||||
|  |  | ||||||
|     def parse_data(self, data): |     def parse_data(self, data:str) -> str : | ||||||
|  |         """parse data form post data | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             data (str): post data | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             domain: domain name | ||||||
|  |             ip: ip | ||||||
|  |         """ | ||||||
|         domain = re.search(r'domain=([^&]+)', data) |         domain = re.search(r'domain=([^&]+)', data) | ||||||
|         ip = re.search(r'ip=([^&]+)', data) |         ip = re.search(r'ip=([^&]+)', data) | ||||||
|         pubkey = re.search(r'pubkey=([^&]+)', data) |  | ||||||
|         privkey = re.search(r'privkey=([^&]+)', data) |  | ||||||
|         nodetype = re.search(r'nodetype=([^]+)', data) |  | ||||||
|  |  | ||||||
|         if domain and ip and nodetype: |         if domain and ip: | ||||||
|             domain = domain.group(1) |             domain = domain.group(1) | ||||||
|             ip = ip.group(1) |             ip = ip.group(1) | ||||||
|             nodetype = nodetype.group(1) |         return domain, ip | ||||||
|             if bool(pubkey) != bool(privkey): |  | ||||||
|                 if pubkey: |  | ||||||
|                     key = pubkey.group(1) |  | ||||||
|                 else: |  | ||||||
|                     key = privkey.group(1) |  | ||||||
|         return domain, ip, key, nodetype |  | ||||||
|  |  | ||||||
|     def check_data(self, domain, ip, nodetype): |     def check_data(self, domain, ip): | ||||||
|  |         """_summary_ | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             domain (_type_): _description_ | ||||||
|  |             ip (_type_): _description_ | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             _type_: _description_ | ||||||
|  |         """ | ||||||
|         # check domain |         # check domain | ||||||
|         pattern = r'^[a-z0-9]{16}\.xiaomian$' |         domain_pattern = r'^[a-z0-9]+\.xiaomian$' | ||||||
|  |  | ||||||
|         if re.match(pattern, domain): |  | ||||||
|             return True |  | ||||||
|         else: |  | ||||||
|             return False |  | ||||||
|  |  | ||||||
|         # check ip |         # check ip | ||||||
|         pattern = r'^(\d{1,3}\.){3}\d{1,3}$' |         ip_pattern = r'^(\d{1,3}\.){3}\d{1,3}$' | ||||||
|         if re.match(pattern, ip): |         if re.match(domain_pattern, domain) and re.match(ip_pattern, ip) : | ||||||
|             octets = ip.split('.') |             octets = ip.split('.') | ||||||
|             if all(int(octet) < 256 for octet in octets): |             if all(int(octet) < 256 for octet in octets): | ||||||
|                 return True |                 return True | ||||||
|         return False |             else: | ||||||
|  |                 return False | ||||||
|  |  | ||||||
|         # check nodetype |  | ||||||
|         if nodetype in {"server", "client", "node"}: |  | ||||||
|             return True |  | ||||||
|         else: |  | ||||||
|             return False |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  |      | ||||||
|     with open('serverconf.yaml', 'r') as f: |     with open('serverconf.yaml', 'r') as f: | ||||||
|         config = yaml.safe_load(f) |         config = yaml.safe_load(f) | ||||||
|     db_file = config['database']['db_file'] |     db_file = config['database']['db_file'] | ||||||
| @@ -321,16 +309,6 @@ if __name__ == '__main__': | |||||||
|     API_port = config['API']['port'] |     API_port = config['API']['port'] | ||||||
|     API_listen_host = config['API']['listen_host'] |     API_listen_host = config['API']['listen_host'] | ||||||
|  |  | ||||||
|      |  | ||||||
|          |  | ||||||
|          |  | ||||||
|          |  | ||||||
|          |  | ||||||
|          |  | ||||||
|          |  | ||||||
|          |  | ||||||
|          |  | ||||||
|          |  | ||||||
|     # start dns server |     # start dns server | ||||||
|     server = DNSServer(API_listen_host, DNS_port, db_file) |     server = DNSServer(API_listen_host, DNS_port, db_file) | ||||||
|     server.run() |     server.run() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user