algo_present/CoreAlgorithm.py
2024-04-09 14:34:58 +08:00

517 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
from collections.abc import Mapping
from encoding import EncodedNumber
from util import invert, powmod, getprimeover
DEFAULT_KEYSIZE = 1024 # 定义默认的二进制数长度
def generate_paillier_keypair(
private_keyring=None, n_length=DEFAULT_KEYSIZE
): # 生成公钥和私钥的函数
# 生成 Paillier 密钥对函数
p = q = n = None # 初始化素数 p, q 和计算结果 n
length = 0 # 初始化计算结果 n 的长度 (即用二进制表示 n 所需要的二进制位数)
while length != n_length: # 循环直至计算结果 n 的长度达到指定长度 n_length
p = getprimeover(n_length // 2) # 随机生成一个 (n_length//2) 长的素数 p
q = p
while q == p:
# 确保 q 与 p 不相等
q = getprimeover(n_length // 2) # 随机生成一个 (n_length//2) 长的素数 q
n = p * q # 计算 n,即两个素数乘积
length = n.bit_length() # 计算 n 的二进制长度
# 创建公钥对象
public_key = PaillierPublicKey(n)
# 创建私钥对象
private_key = PaillierPrivateKey(public_key, p, q)
if private_keyring is not None: # 如果传入了私钥环对象,则将私钥添加到私钥环中
private_keyring.add(private_key)
return public_key, private_key # 返回公钥和私钥
class PaillierPublicKey(object): # 定义公钥类
def __init__(self, n):
self.g = n + 1
self.n = n # 公钥的模数
self.nsquare = n * n # n的平方
self.max_int = n // 3 - 1 # 公钥的一个属性(限制可加密/解密的最大整数值)
def __repr__(self): # 用于打印出该类的对象
public_key_hash = hex(hash(self))[2:]
return "<PaillierPublicKey {}>".format(public_key_hash[:10]) # 返回表示对象的字符串
def __eq__(self, other): # 用于比较两个对象是否相等,并返回比较结果
return self.n == other.n
def __hash__(self): # 用于返回n的Hash值
return hash(self.n)
def get_n_and_g(self): # 获取该公钥的 n 和 g 的值
return self.n, self.g
def raw_encrypt(self, plaintext, r_value=None): # 用于返回加密后的密文,其中r_value可给随机数赋值
if not isinstance(plaintext, int): # 判断plaintext是否是整数
raise TypeError("明文不是整数,而是: %s" % type(plaintext))
if self.n - self.max_int <= plaintext < self.n: # 对于非常大的明文,使用特殊的计算方法进行加密:
neg_plaintext = self.n - plaintext # = abs(plaintext - nsquare)
neg_ciphertext = (self.n * neg_plaintext + 1) % self.nsquare
nude_ciphertext = invert(neg_ciphertext, self.nsquare)
else: # 如果不是非常大的明文:
nude_ciphertext = (
self.n * plaintext + 1
) % self.nsquare # (n + 1)^plaintext = n * plaintext + 1 mod n^2
# 生成一个随机数,其值为r_value。如果r_value没有值,则r随机:
r = r_value or self.get_random_lt_n()
obfuscator = powmod(r, self.n, self.nsquare) # (r ^ n) mod n^2
return (nude_ciphertext * obfuscator) % self.nsquare # 返回加密后的密文
def get_random_lt_n(self): # 返回一个1——n间的随机整数
return random.SystemRandom().randrange(1, self.n)
def encrypt(
self, value, precision=None, r_value=None
): # value表示要加密的值,precision是加密精度,r_value是随机数
# 判断value是否是EncodedNumber类型,如果是则直接赋值给encoding如果不是,则对value进行编码
if isinstance(value, EncodedNumber):
encoding = value
else:
encoding = EncodedNumber.encode(self, value, precision)
return self.encrypt_encoded(encoding, r_value)
def encrypt_encoded(self, encoding, r_value): # 将已编码的数值对象转换为加密后的数值对象,并可以选择进行混淆处理
obfuscator = r_value or 1 # 为随机数r_value,没有则默认为1
ciphertext = self.raw_encrypt(encoding.encoding, r_value=obfuscator)
encrypted_number = EncryptedNumber(self, ciphertext, encoding.exponent)
"""
PS:默认生成情况下不输入随机数r_value的情况下:
encrypt中的随机数r_value为:None
raw_encrypt中的随机数为:1
encrypt_encoded中的随机数为:None
"""
if r_value is None: # 结合上述注释,可知:密文混淆函数是会默认执行的
encrypted_number.obfuscate() # 如果encrypt_encoded没有随机数r_value,则进行密文混淆处理obfuscate()
return encrypted_number
class PaillierPrivateKey(object): # 私钥
def __init__(self, public_key, p, q):
if not p * q == public_key.n: # 如果p * q 不等于 公钥的n,则说明出错
raise ValueError("所给公钥与p,q不匹配")
if p == q: # p,q相同
raise ValueError("p,q不能相同")
self.public_key = public_key
# 给self的p q赋值:
if q < p: # 默认是p 大于等于 q
self.p = q
self.q = p
else:
self.p = p
self.q = q
self.psquare = self.p * self.p
self.qsquare = self.q * self.q
self.p_inverse = invert(self.p, self.q) # 计算p mod q 的乘法逆元
self.hp = self.h_function(self.p, self.psquare) # p mod p方
self.hq = self.h_function(self.q, self.qsquare) # q mod q方
def __repr__(self): # 用于打印出该类的对象
pub_repr = repr(self.public_key)
return "<PaillierPrivateKey for {}>".format(pub_repr)
def decrypt(self, encrypted_number): # 解密密文,并返回明文
# 执行下面这个语句前的类型为EncryptedNumber,执行完毕后类型为EncodedNumber中间会变为int型的ciphertext:
encoded = self.decrypt_encoded(encrypted_number)
return encoded.decode()
def decrypt_encoded(
self, encrypted_number, Encoding=None
): # 用于解密密文并返回解密后的EncodedNumber类型
# 检查输入信息是否是EncryptedNumber参数,如果不是:
if not isinstance(encrypted_number, EncryptedNumber):
raise TypeError(
"参数应该是EncryptedNumber," " 参数不能为: %s" % type(encrypted_number)
)
if self.public_key != encrypted_number.public_key: # 如果公钥与加密数字的公钥不一致
raise ValueError("加密信息不能被不同的公钥进行加密!")
if Encoding is None: # 将Encoding设置为未赋值的EncodedNumber变量
Encoding = EncodedNumber
"""提取 encrypted_number 中的 ciphertext
这里是禁用安全模式,
所以是直接提取ciphertext,
随后调用raw_decrypt函数对ciphertext进行处理:"""
encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False))
return Encoding(self.public_key, encoded, encrypted_number.exponent)
def raw_decrypt(self, ciphertext): # 对密文进行原始解密
if not isinstance(ciphertext, int): # 如果所给的密文不是int型
raise TypeError("密文应该是int型, 而不是: %s" % type(ciphertext))
# 将解密结果存放在p和q中,并将p q进行合并:
decrypt_to_p = (
self.l_function(powmod(ciphertext, self.p - 1, self.psquare), self.p)
* self.hp
% self.p
)
decrypt_to_q = (
self.l_function(powmod(ciphertext, self.q - 1, self.qsquare), self.q)
* self.hq
% self.q
)
return self.crt(decrypt_to_p, decrypt_to_q)
def h_function(self, x, xsquare): # 计算并返回h函数值[用于中国剩余定理]
return invert(self.l_function(powmod(self.public_key.g, x - 1, xsquare), x), x)
def l_function(self, mju, p): # 计算并返回l值算L(μ)
return (mju - 1) // p
def crt(self, mp, mq): # 实现中国剩余定理(Chinese remainder theorem)
u = (mq - mp) * self.p_inverse % self.q
return mp + (u * self.p)
def __eq__(self, other): # 判断两个对象的 q 与 p 是否相等
return self.p == other.p and self.q == other.q
def __hash__(self): # 计算 p 与 q 元组的哈希值
return hash((self.p, self.q))
class PaillierPrivateKeyring(Mapping): # 私钥环类,并继承了Mapping类
def __init__(self, private_keys=None): # 初始化私钥环对象(私钥环列表)
if private_keys is None:
private_keys = []
# 将私钥和公钥进行组合,并存储在私钥环中:
public_keys = [k.public_key for k in private_keys]
self.__keyring = dict(zip(public_keys, private_keys))
def __getitem__(self, key): # 通过公钥,来查找私钥环中对应的私钥
return self.__keyring[key]
def __len__(self): # 存储的私钥数量
return len(self.__keyring)
def __iter__(self): # 遍历私钥环中的公钥
return iter(self.__keyring)
def __delitem__(self, public_key): # 删除与公钥对应的私钥
del self.__keyring[public_key]
def add(self, private_key): # 向私钥环中添加私钥
if not isinstance(private_key, PaillierPrivateKey): # 对要添加的私钥进行判断
raise TypeError("私钥应该是PaillierPrivateKey类型, " "而不是 %s" % type(private_key))
self.__keyring[private_key.public_key] = private_key # 将该公钥和对用的私钥一块儿加入到私钥环中
def decrypt(self, encrypted_number): # 对密文进行解密
relevant_private_key = self.__keyring[
encrypted_number.public_key
] # 在私钥环中获取对应的私钥
return relevant_private_key.decrypt(encrypted_number) # 返回加密结果
class EncryptedNumber(object): # 浮点数或整数的Pailier加密
"""
1. D(E(a) * E(b)) = a + b
2. D(E(a)**b) = a * b
"""
def __init__(self, public_key, ciphertext, exponent=0):
self.public_key = public_key
self.__ciphertext = ciphertext # 密文
self.exponent = exponent # 用于表示指数
self.__is_obfuscated = False # 用于表示数据是否被混淆
if isinstance(self.ciphertext, EncryptedNumber): # 如果密文是EncryptedNumber
raise TypeError("密文必须是int型")
if not isinstance(
self.public_key, PaillierPublicKey
): # 如果公钥不是PaillierPublicKey
raise TypeError("公钥必须是PaillierPublicKey")
def __add__(self, other): # 运算符重载,重载为EncryptedNumber与(EncryptedNumber/整数/浮点数)的加法
if isinstance(other, EncryptedNumber):
return self._add_encrypted(other)
elif isinstance(other, EncodedNumber):
return self._add_encoded(other)
else:
return self._add_scalar(other)
def __radd__(self, other): # 反加,处理整数/浮点数与EncryptedNumber之间的加法
return self.__add__(other)
def __mul__(self, other): # 运算符重载,重载为EncryptedNumber与(整数/浮点数)的乘法
# 判断other对象是否是EncryptedNumber,如果是:
if isinstance(other, EncryptedNumber):
raise NotImplementedError("EncryptedNumber 与 EncryptedNumber 之间不能相乘!")
if isinstance(other, EncodedNumber):
encoding = other
else:
encoding = EncodedNumber.encode(self.public_key, other)
product = self._raw_mul(encoding.encoding) # 重新更新乘积
exponent = self.exponent + encoding.exponent # 重新更新指数
return EncryptedNumber(self.public_key, product, exponent)
def __rmul__(self, other): # 反乘,处理整数/浮点数与EncryptedNumber之间的乘法
return self.__mul__(other)
def __sub__(self, other): # 运算符重载,重载为EncryptedNumber与(EncryptedNumber/整数/浮点数)的减法
return self + (other * -1)
def __rsub__(self, other): # 处理整数/浮点数与EncryptedNumber之间的减法
return other + (self * -1)
def __truediv__(
self, scalar
): # 运算符重载,重载为EncryptedNumber与(EncryptedNumber/整数/浮点数)的除法
return self.__mul__(1 / scalar)
def __invert__(self): # 运算符重载~(对 数 的取反)
return self * (-1)
# def __pow__(self, exponent): # 运算符重载 ** (对密文的幂函数)
# if not isinstance(exponent, int): # 如果输入有问题
# print("指数应输入 整数 标量!")
# else:
# result = self
# for i in [1, exponent]:
# result *= self
# return result
# # 原本的幂运算 ** return self.value ** exponent
def ciphertext(self, be_secure=True): # 用于混淆密文,并返回混淆后的密文
"""
EncryptedNumber类的一个方法ciphertext,用于返回该对象的密文。
在Paillier加密中,为了提高计算性能,加法和乘法操作进行了简化,
避免对每个加法和乘法结果进行随机数的加密操作。
这样会使得内部计算快速,但会暴露一部分信息。
此外,为了保证安全,如果需要与其他人共享密文,应该使用be_secure=True。
这样,如果密文还没有被混淆,会调用obfuscate方法对其进行混淆操作。
"""
if be_secure and not self.__is_obfuscated: # 如果密文没有被混淆,则进行混淆操作
self.obfuscate()
return self.__ciphertext
def decrease_exponent_to(
self, new_exp
): # 返回一个指数较低但大小相同的数(即返回一个同值的,但指数较低的EncryptedNumber
if new_exp > self.exponent:
raise ValueError("新指数值 %i 应比原指数 %i 小! " % (new_exp, self.exponent))
multiplied = self * pow(EncodedNumber.BASE, self.exponent - new_exp) # 降指数后的乘积
multiplied.exponent = new_exp # 降指数后的新指数
return multiplied
def obfuscate(self): # 混淆密文
r = self.public_key.get_random_lt_n() # 生成一个(1——n)间的随机数r,不 r
r_pow_n = powmod(
r, self.public_key.n, self.public_key.nsquare
) # (r ^ n) mod n^2
self.__ciphertext = (
self.__ciphertext * r_pow_n % self.public_key.nsquare
) # 对原密文进行处理
self.__is_obfuscated = True # 用于判断密文是否被混淆
def _add_scalar(self, scalar): # 执行EncodedNumber与标量(整型/浮点型)相加的操作
encoded = EncodedNumber.encode(
self.public_key, scalar, max_exponent=self.exponent
)
return self._add_encoded(encoded)
def _add_encoded(self, encoded): # 对EncodedNumber与标量encoded加法编码
# 返回 E(a + b)
if self.public_key != encoded.public_key: # 如果公钥与编码公钥不相同
raise ValueError("不能使用不同的公钥,对数字进行编码!")
a, b = self, encoded
# 对指数处理(使指数相同):
if a.exponent > b.exponent:
a = self.decrease_exponent_to(b.exponent)
elif a.exponent < b.exponent:
b = b.decrease_exponent_to(a.exponent)
encrypted_scalar = a.public_key.raw_encrypt(
b.encoding, 1
) # 用公钥加密b.encoding后的标量
sum_ciphertext = a._raw_add(a.ciphertext(False), encrypted_scalar) # 进行相加操作
return EncryptedNumber(a.public_key, sum_ciphertext, a.exponent)
def _add_encrypted(self, other): # 对EncodedNumber与EncodedNumber加法加密
if self.public_key != other.public_key:
raise ValueError("不能使用不同的公钥,对数字进行加密!")
# 对指数处理(使指数相同):
a, b = self, other
if a.exponent > b.exponent:
a = self.decrease_exponent_to(b.exponent)
elif a.exponent < b.exponent:
b = b.decrease_exponent_to(a.exponent)
sum_ciphertext = a._raw_add(a.ciphertext(False), b.ciphertext(False))
return EncryptedNumber(a.public_key, sum_ciphertext, a.exponent)
def _raw_add(self, e_a, e_b): # 对加密后的a,b直接进行相加,并返回未加密的结果
return e_a * e_b % self.public_key.nsquare
def _raw_mul(self, plaintext): # 对密文进行乘法运算,并返回未加密的结果
# 检查乘数是否为int型:
if not isinstance(plaintext, int):
raise TypeError("期望密文应该是int型, 而不是 %s" % type(plaintext))
# 如果乘数是负数,或乘数比公钥的模n大:
if plaintext < 0 or plaintext >= self.public_key.n:
raise ValueError("超出可计算范围: %i" % plaintext)
if self.public_key.n - self.public_key.max_int <= plaintext:
# 如果数据很大,则先反置一下再进行运算:
neg_c = invert(self.ciphertext(False), self.public_key.nsquare)
neg_scalar = self.public_key.n - plaintext
return powmod(neg_c, neg_scalar, self.public_key.nsquare)
else:
return powmod(self.ciphertext(False), plaintext, self.public_key.nsquare)
def increment(self): # 定义自增运算
return self + 1
def decrement(self): # 定义自减运算
return self + 1
def cal_sum(self, *args):
result = 0 # 将初始值设置为0
for i in args:
if not isinstance(i, (int, float, EncryptedNumber)):
raise TypeError("期望密文应该是int/float/EncryptedNumber型, 而不是 %s" % type(i))
if isinstance(i, int or float): # 如果是 int 或 float 明文型,则先将明文加密在进行运算
result += self.public_key.encrypt(i)
else:
result += i # 第一次循环:标量与密文相加;后面的循环,密文与密文相加
return result
def average(self, *args): # 定义求平均值
total_sum = self.cal_sum(
*args
) # 计算总和total是<__main__.EncryptedNumber object at 0x000002AB74FB9850>
# # 如果总数超过了可计算范围
# if total_sum > 91000:
# raise ValueError('超出可计算范围: %i' % total_sum)
count = 0 # 定义count,用来统计参数的个数
for _ in args:
count += 1 # count++
return total_sum / count
def weighted_average(self, *args): # 定义加权平均 def weighted_average(*args):
"""PS:
args[0]: <__main__.EncryptedNumber object at 0x000001F7C1B6A610>
args[1]: 第一个参数
args[2]: 给第一个参数设置的权值
。。。。。。
"""
total_weight = sum(args[2::2]) # 计算权值的总和(使用切片操作从参数列表中取出索引为参数权值的元素)
if total_weight != 1:
raise TypeError("加权平均算法的权值设置错误!请重新设置!")
else:
# 计算加权和,其中: for i in range(0, len(args), 2) 表示以2为步长,从0递增,直到 i >= len(args)时:
result = sum(args[i] * args[i + 1] for i in range(1, len(args), 2))
return result
def reset(self): # 定义复位置0运算
zero = self.public_key.encrypt(0) # 用公钥对0进行加密
return zero
def calculate_variance(self, *args): # 定义求方差
mean = self.average(*args) # 均值
count = 0 # 定义count,用来统计参数的个数
for _ in args:
count += 1 # count++
variance = sum((x - mean) ** 2 for x in args) / (count - 1)
return variance
# def IsZero(self): # 判断该数是否为0
# ZERO = self
# zero = ZERO.public_key.encrypt(0) # 用公钥对0进行加密
# flag = False # 用于判断该数是否为0(默认不为0)
#
# if self == zero:
# flag = True
# return flag
# def POW(self, num): # 定义幂运算
# if not isinstance(num, int): # 如果输入有问题
# print("指数应输入 整数 标量!")
# else:
# result = self
# print(num)
# for i in [1, num]:
# result *= self
# return result
# def get_certificate(public_key):
# # 获得公钥的PEM编码的二进制形式
# public_bytes = public_key.public_bytes(
# encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
#
# # 获得数字证书
# cert = (public_bytes, hashlib.sha256(public_bytes).hexdigest()) # 元祖类型
# return cert
if __name__ == "__main__": # 主函数
Public_Key, Private_Key = generate_paillier_keypair() # 随机生成1024长的公钥和私钥
x = 90000.23
y = 90
z = 0.5
x_encrypted = Public_Key.encrypt(x) # 加密后的x
y_encrypted = Public_Key.encrypt(y) # 加密后的y
z_encrypted = Public_Key.encrypt(z) # 加密后的z
t_encrypted = x_encrypted + y_encrypted * 0.5 # 在x,y保密的情况下计算t,得到加密后的t(t_encrypted)
# x_encrypted = x_encrypted.increment() # 自增
# y_encrypted = y_encrypted.decrement() # 自减
# print(x_encrypted != y_encrypted) # 不相等
# print(x_encrypted == y_encrypted) # 相等
# print(Private_Key.decrypt(~x_encrypted) ) # 取反
# total = x_encrypted.cal_sum(x_encrypted, y_encrypted, 0.5) # 求和函数
# print("密文之和为:", Private_Key.decrypt(total))
# avg = x_encrypted.average(y_encrypted, z_encrypted, z_encrypted) # 求平均值函数
# print("密文的平均值为:", Private_Key.decrypt(avg) ) # 只能对0~90090.73的数进行除法运算(除不尽)
# weight_average = x_encrypted.weighted_average(x_encrypted, 0.1, y_encrypted, 0.3, z_encrypted, 0.6) # 加权平均函数
# print("加权平均结果为:", Private_Key.decrypt(weight_average))
# variance = x_encrypted.calculate_variance(x_encrypted, y_encrypted) #求方差
# print("方差为:", Private_Key.decrypt(variance))
# z_encrypted = z_encrypted.reset() # 复位函数
# print("z复位后的结果为:", Private_Key.decrypt(z_encrypted) )
# print(x_encrypted ** x) # 相当于print(x_encrypted.POW(2) )
# print(x_encrypted > y_encrypted)
# print(type(Public_Key))
# print(Public_Key)
print(f"x + y * 0.5的结果是:{Private_Key.decrypt(t_encrypted)}") # 打印出t