概述:MySQL8 版本加密原理和连接过程说明
参考文章:MYSQL caching_sha2_password 加密原理和连接过程(FULL)-腾讯云开发者社区-腾讯云
使用Python连接MySQL
rsa加密使用的pymysql的. 因为不属于本文的内容
import hashlib
import struct
import socket
import os
#来自pymysql
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
def btoint(bdata,t='little'):
return int.from_bytes(bdata,t)
#来自pymysql
def _lenenc_int(i):
if i < 0:
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
elif i < 0xFB:
return bytes([i])
elif i < (1 << 16):
return b"\xfc" + struct.pack("<H", i)
elif i < (1 << 24):
return b"\xfd" + struct.pack("<I", i)[:3]
elif i < (1 << 64):
return b"\xfe" + struct.pack("<Q", i)
else:
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger"% (i, (1 << 64)))
#就是做个异或
#来自Pymysql
def _xor_password(password, salt):
salt = bytearray(salt[:20])
password = bytearray(password)
for i in range(len(password)):
password[i] ^= salt[i%len(salt)]
return bytes(password)
#来自pymysql
def sha2_rsa_encrypt(password, salt, public_key):
message = _xor_password(password + b"\0", salt)
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
return rsa_key.encrypt(
message,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None,
),
)
def native_password(password,salt):
stage1 = hashlib.sha1(password).digest()
stage2 = hashlib.sha1(stage1).digest()
rp = hashlib.sha1(salt)
rp.update(stage2)
result = bytearray(rp.digest())
for x in range(len(result)):
result[x] ^= stage1[x]
return result
def sha2_password(password,salt):
stage1 = hashlib.sha256(password).digest()
stage2 = hashlib.sha256(stage1).digest()
stage3 = hashlib.sha256(stage2+salt).digest()
result = bytearray(stage3)
for x in range(len(result)):
result[x] ^= stage1[x]
return result
def parse_handshake(bdata):
i = 0
protocol_version = bdata[:1]
server_end = bdata.find(b"\0", i)
i = server_end + 1
thread_id = btoint(bdata[i:i+4])
i += 4
salt = bdata[i:i+8]
i += 9
server_capabilities = btoint(bdata[i:i+2])
i += 2
server_charset = btoint(bdata[i:i+1])
i += 1
server_status = btoint(bdata[i:i+2])
i += 2
server_capabilities |= btoint(bdata[i:i+2]) << 16
i += 2
salt_length = struct.unpack('<B',bdata[i:i+1])[0]
salt_length = max(13,salt_length-8)
i += 11
salt += bdata[i:i+salt_length]
i += salt_length
server_plugname = bdata[i:]
return salt
class mysql(object):
def __init__(self):
self.host = '192.168.101.21'
self.port = 3314
self.user = 'u1'
self.password = '123456'
def read_pack(self,):
pack_header = self.rf.read(4)
btrl, btrh, packet_seq = struct.unpack("<HBB", pack_header)
pack_size = btrl + (btrh << 16)
self._next_seq_id = (self._next_seq_id + 1) % 256
bdata = self.rf.read(pack_size)
return bdata
def write_pack(self,data):
bdata = struct.pack("<I", len(data))[:3] + bytes([self._next_seq_id]) + data
self.sock.sendall(bdata)
self._next_seq_id = (self._next_seq_id + 1) % 256
def handshake(self,bdata):
i = 0 #已经读取的字节数, 解析binlog的时候也是这么用的.....
protocol_version = bdata[:1] #只解析10
server_end = bdata.find(b"\0", i)
self.server_version = bdata[i:server_end]
i = server_end + 1
self.thread_id = btoint(bdata[i:i+4])
i += 4
self.salt = bdata[i:i+8]
i += 9 #还有1字节的filter, 没啥意义,就不保存了
self.server_capabilities = btoint(bdata[i:i+2])
i += 2
self.server_charset = btoint(bdata[i:i+1])
i += 1
self.server_status = btoint(bdata[i:i+2])
i += 2
self.server_capabilities |= btoint(bdata[i:i+2]) << 16 #往左移16位 为啥不把capability_flags_1和capability_flags_2和一起呢
i += 2
salt_length = struct.unpack('<B',bdata[i:i+1])[0] #懒得去判断capabilities & CLIENT_PLUGIN_AUTH了
salt_length = max(13,salt_length-8) #前面已经有8字节了
i += 1
i += 10 #reserved
self.salt += bdata[i:i+salt_length]
i += salt_length
self.server_plugname = bdata[i:]
def HandshakeResponse41(self,):
#client_flag = 3842565 #不含DBname
client_flag = 33531525#不含DBname
#client_flag |= 1 << 3
charset_id = 45 #45:utf8mb4 33:utf8
#bdata = client_flag.to_bytes(4,'little') #其实应该最后在加, 毕竟还要判断很多参数, 可能还需要修改, 但是懒
bdata = struct.pack('<iIB23s',client_flag,2**24-1,charset_id,b'')
bdata += self.user.encode() + b'\0'
auth_password = native_password(self.password.encode(), self.salt[:20])
auth_response = _lenenc_int(len(auth_password)) + auth_password
bdata += auth_response
bdata += b"mysql_native_password" + b'\0'
#本文有设置连接属性, 主要是为了方便观察
attr = {'_client_name':'ddcw_for_pymysql', '_pid':str(os.getpid()), "_client_version":'0.0.1',}
#key长度+k+v长度+v
connect_attrs = b""
for k, v in attr.items():
k = k.encode()
connect_attrs += _lenenc_int(len(k)) + k
v = v.encode()
connect_attrs += _lenenc_int(len(v)) + v
bdata += _lenenc_int(len(connect_attrs)) + connect_attrs
self.write_pack(bdata)
#0xFE 交换认证 https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_old_auth_switch_request.html
#0x01 额外认证
#0x00 OK
#偷懒, 懒得去判断client_flag了
auth_pack = self.read_pack()
if auth_pack[:1] == b'\0':
print('OK',auth_pack)
elif auth_pack[:1] == b'\xfe':
#switch request
print('hava switch request')
if auth_pack.find(b'caching_sha2_password') < 0:
print('仅测试caching_sha2_password, 但当前是:',auth_pack[1:auth_pack.find(b'\x00')])
return False
scrambled = sha2_password(self.password.encode(),auth_pack[auth_pack.find(b'\x00')+1:]) #salt是剩下的部分
self.write_pack(scrambled)
auth_pack = self.read_pack()
print(auth_pack)
self.caching_sha2_password_auth(auth_pack)
elif auth_pack[:1] == b'\x01':
self.caching_sha2_password_auth(auth_pack)
else:
print('FAILED',auth_pack)
def caching_sha2_password_auth(self,auth_pack):
if auth_pack[1:2] == b'\x03': #fast
bdata = self.read_pack() #ok pack
print('fast auth success.',bdata)
elif auth_pack[1:2] == b'\x04': #full
#如果是SSL/socket/shard_mem就直接发送密码(不需要加密了) TODO
self.write_pack(b'\x02') #要公钥
bdata = self.read_pack() #server发来的公钥
pubk = bdata[1:] #第一字节是extra_auth 而且肯定是 0x01
#print('bdata',bdata)
self.pubk = pubk
password = sha2_rsa_encrypt(self.password.encode(), self.salt, pubk)
self.write_pack(password)
authpack = self.read_pack() #看看是否成功
print('full auth',authpack)
else:
print('???',auth_pack)
def query(self,sql):
"""不考虑SQL超过16MB情况"""
# payload_length:3 sequence_id:1 payload:N
# payload: com_query(0x03):1 sql:n
bdata = struct.pack('<IB',len(sql)+1,0x03) #I:每个com_query的seq_id都从0开始,第4字节固定为0, 所以直接用I, +1:com_query占用1字节, 0x03:com_query
bdata += sql.encode()
self.sock.sendall(bdata)
self._next_seq_id = 1 #下一个包seq_id = 1
def connect(self):
self._next_seq_id = 0
sock = socket.create_connection((self.host, self.port))
sock.settimeout(None)
self.sock = sock
self.rf = sock.makefile("rb")
bdata = self.read_pack()
self.handshake(bdata)
self.HandshakeResponse41()