tpre-python/server/xiaomiandns.py
Smart-SangGe 6cd279e88c modified: README.md
new file:   api_add.png
	modified:   database/initdb.py
	new file:   database_add.png
	new file:   dns1.png
	new file:   dns2.png
	new file:   requirements.txt
	modified:   server/main.py
	modified:   server/serverconf.yaml
	modified:   server/xiaomiandns.py
2023-06-09 17:44:51 +08:00

320 lines
9.4 KiB
Python

import socket
import threading
import dns.resolver
import dns.message
import dns.rdataclass
import dns.rdatatype
import dns.flags
import dns.rcode
import dns.rrset
import sqlite3
import re
import yaml
class DNSServer:
def __init__(self, hostname, port, db_file):
self.hostname = hostname
self.port = port
self.db_file = db_file
def run(self):
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.udp_socket.bind((self.hostname, self.port))
self.tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.tcp_socket.bind((self.hostname, self.port))
self.tcp_socket.listen(1)
print(f"DNS server running on {self.hostname}:{self.port}")
for i in range(3):
udp_thread = threading.Thread(target=self.handle_udp_request)
udp_thread.start()
tcp_thread = threading.Thread(target=self.handle_tcp_request)
tcp_thread.start()
def handle_udp_request(self):
data, address = self.udp_socket.recvfrom(1024)
response = self.handle_request(data)
self.udp_socket.sendto(response, address)
udp_thread = threading.Thread(target=self.handle_udp_request)
udp_thread.start()
def handle_tcp_request(self):
connection, address = self.tcp_socket.accept()
data = connection.recv(1024)
response = self.handle_request(data)
connection.send(response)
connection.close()
tcp_thread = threading.Thread(target=self.handle_tcp_request)
tcp_thread.start()
def handle_request(self, data):
conn = sqlite3.connect(self.db_file)
cur = conn.cursor()
question = dns.message.from_wire(data)
response = self.build_response(question, cur)
return response
def build_response(self, question, dbcursor, rcode=dns.rcode.NOERROR, answer=None):
# Create a new DNS message object
response = dns.message.Message()
# Set the message header fields
response.id = question.id
response.flags = dns.flags.QR | dns.flags.RA
# Add the question to the message
response.question = question.question
name = question.question[0].name
# search domain in database
dbcursor.execute(
"SELECT ip FROM xiaomiandns WHERE domain = ?", (str(name)[:-1],))
result = dbcursor.fetchone()
# Create a new RRset for the answer
if result is not None:
answer = dns.rrset.RRset(name, dns.rdataclass.IN, dns.rdatatype.A)
rdata = dns.rdata.from_text(
dns.rdataclass.IN, dns.rdatatype.A, result[0])
answer.add(rdata)
response.answer.append(answer)
# Set the response code
response.set_rcode(rcode)
else:
response.set_rcode(dns.rcode.NXDOMAIN)
return response.to_wire()
class DNSAPI:
"""
usage: use POST method
/add
data: domian=xxxx&ip=xx.xx.xx.xx
/delete
data: domian=xxxx&ip=xx.xx.xx.xx
"""
def __init__(self, hostname, port, db_file):
self.hostname = hostname
self.port = port
self.db_file = db_file
def run(self):
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 绑定 IP 地址和端口号
server_socket.bind((self.hostname, self.port))
# 监听连接
server_socket.listen(5)
print(f"API server running on {self.hostname}:{self.port}")
while True:
# 接受连接
conn, addr = server_socket.accept()
# 处理请求
t = threading.Thread(target=self.handle_tcp_request, args=(conn,))
t.start()
def handle_tcp_request(self, conn):
request = conn.recv(1024).decode('utf-8')
response = self.handle_http_request(request)
conn.send(response)
conn.close()
def handle_http_request(self, request):
request_line, headers = request.split('\r\n\r\n', 2)
method, url, version = request_line.split(' ', 2)
if method == 'GET':
response = self.handle_get_request(url)
elif method == 'POST':
data = request.split('\r\n')[-1]
response = self.handle_post_request(url, data)
else:
response = self.handle_error_request()
return response
def handle_get_request(self, url):
status_code = 400
reason_phrase = 'unsupport method, please use POST method'
response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase)
return response.encode("utf-8")
def handle_post_request(self, url:str, data:str)->str:
"""处理 POST 请求
Args:
url (str): url前缀
data (str): POST 方法提交的数据
Returns:
str: http response
"""
# check url start with /add
if re.match(r'^/add', url):
status_code = self.add_data(data)
if status_code == 200:
reason_phrase = 'Add data successful'
else:
reason_phrase = 'Add data unsuccessful'
# check url start with /delete
elif re.match(r'^/delete', url):
status_code = self.delete_data(data)
if status_code == 200:
reason_phrase = 'Delete data successful'
else:
reason_phrase = 'Delete data unsuccessful'
else:
status_code = 400
reason_phrase = 'unsupport api'
response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase)
return response.encode("utf-8")
def handle_error_request(self):
status_code = 400
reason_phrase = "unsupport method"
headers = {
'Content-Type': 'text/html',
'Connection': 'close',
}
response = 'HTTP/1.1 {} {}\r\n'.format(status_code, reason_phrase)
for key, value in headers.items():
response += '{}: {}\r\n'.format(key, value)
response += '\r\n'
return response.encode("utf-8")
def add_data(self, data:str)->int:
"""add data to database
Args:
data (str): domain and ip
Returns:
int: status code
"""
# parse and check validation
domain, ip = self.parse_data(data)
if not self.check_data(domain,ip):
return 400
# connect db
conn = sqlite3.connect(self.db_file)
c = conn.cursor()
# Check if the domain already exists
c.execute(
"SELECT * FROM xiaomiandns WHERE domain = ?", [domain])
existing_domain = c.fetchall()
if existing_domain:
c.execute("UPDATE xiaomiandns SET ip = ?, timestamp = DATETIME('now') WHERE domain = ?",(ip,domain))
else:
# Insert the new data
c.execute(
"INSERT INTO xiaomiandns (domain,ip,timestamp) VALUES (?,?,DATETIME('now'))", (domain, ip))
conn.commit()
c.close()
conn.close()
status_code = 200
return status_code
def delete_data(self, data:str) -> int:
"""delete record in database
Args:
data (str): domain and ip
Returns:
int: status code
"""
# parse and check validation
domain, ip = self.parse_data(data)
if not self.check_data(domain, ip):
return 400
# connect db
conn = sqlite3.connect(self.db_file)
c = conn.cursor()
c.execute(
"DELETE FROM xiaomiandns WHERE domain = ? ", [domain])
deleted_rows = c.rowcount
if deleted_rows == 0:
# unmatch
status_code = 400
else:
# deleted
status_code = 200
conn.commit()
c.close()
conn.close()
return status_code
def parse_data(self, data:str) -> str :
"""parse data form post data
Args:
data (str): post data
Returns:
domain: domain name
ip: ip
"""
domain = re.search(r'domain=([^&]+)', data)
ip = re.search(r'ip=([^&]+)', data)
if domain and ip:
domain = domain.group(1)
ip = ip.group(1)
return domain, ip
def check_data(self, domain, ip):
"""_summary_
Args:
domain (_type_): _description_
ip (_type_): _description_
Returns:
_type_: _description_
"""
# check domain
domain_pattern = r'^[a-z0-9]+\.xiaomian$'
# check ip
ip_pattern = r'^(\d{1,3}\.){3}\d{1,3}$'
if re.match(domain_pattern, domain) and re.match(ip_pattern, ip) :
octets = ip.split('.')
if all(int(octet) < 256 for octet in octets):
return True
else:
return False
if __name__ == '__main__':
with open('serverconf.yaml', 'r') as f:
config = yaml.safe_load(f)
db_file = config['database']['db_file']
DNS_port = config['DNS']['port']
DNS_listen_host = config['DNS']['listen_host']
API_port = config['API']['port']
API_listen_host = config['API']['listen_host']
# start dns server
server = DNSServer(API_listen_host, DNS_port, db_file)
server.run()
# start dns api server
APIserver = DNSAPI(API_listen_host, API_port, db_file)
APIserver.run()