【MySQL】caching_sha2_password 加密原理和连接过程

概述:MySQL8 版本加密原理和连接过程说明

参考文章:MYSQL caching_sha2_password 加密原理和连接过程(FULL)-腾讯云开发者社区-腾讯云

使用Python连接MySQL

rsa加密使用的pymysql的. 因为不属于本文的内容

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import hashlib
import struct
import socket
import os

#来自pymysql
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding

def btoint(bdata,t='little'):
return int.from_bytes(bdata,t)

#来自pymysql
def _lenenc_int(i):
if i < 0:
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
elif i < 0xFB:
return bytes([i])
elif i < (1 << 16):
return b"\xfc" + struct.pack("<H", i)
elif i < (1 << 24):
return b"\xfd" + struct.pack("<I", i)[:3]
elif i < (1 << 64):
return b"\xfe" + struct.pack("<Q", i)
else:
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger"% (i, (1 << 64)))

#就是做个异或
#来自Pymysql
def _xor_password(password, salt):
salt = bytearray(salt[:20])
password = bytearray(password)
for i in range(len(password)):
password[i] ^= salt[i%len(salt)]
return bytes(password)

#来自pymysql
def sha2_rsa_encrypt(password, salt, public_key):
message = _xor_password(password + b"\0", salt)
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
return rsa_key.encrypt(
message,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None,
),
)

def native_password(password,salt):
stage1 = hashlib.sha1(password).digest()
stage2 = hashlib.sha1(stage1).digest()

rp = hashlib.sha1(salt)
rp.update(stage2)
result = bytearray(rp.digest())

for x in range(len(result)):
result[x] ^= stage1[x]
return result

def sha2_password(password,salt):
stage1 = hashlib.sha256(password).digest()
stage2 = hashlib.sha256(stage1).digest()
stage3 = hashlib.sha256(stage2+salt).digest()

result = bytearray(stage3)
for x in range(len(result)):
result[x] ^= stage1[x]
return result

def parse_handshake(bdata):
i = 0
protocol_version = bdata[:1]
server_end = bdata.find(b"\0", i)
i = server_end + 1
thread_id = btoint(bdata[i:i+4])
i += 4
salt = bdata[i:i+8]
i += 9
server_capabilities = btoint(bdata[i:i+2])
i += 2
server_charset = btoint(bdata[i:i+1])
i += 1
server_status = btoint(bdata[i:i+2])
i += 2
server_capabilities |= btoint(bdata[i:i+2]) << 16
i += 2
salt_length = struct.unpack('<B',bdata[i:i+1])[0]
salt_length = max(13,salt_length-8)
i += 11
salt += bdata[i:i+salt_length]
i += salt_length
server_plugname = bdata[i:]
return salt


class mysql(object):
def __init__(self):
self.host = '192.168.101.21'
self.port = 3314
self.user = 'u1'
self.password = '123456'

def read_pack(self,):
pack_header = self.rf.read(4)
btrl, btrh, packet_seq = struct.unpack("<HBB", pack_header)
pack_size = btrl + (btrh << 16)
self._next_seq_id = (self._next_seq_id + 1) % 256
bdata = self.rf.read(pack_size)
return bdata

def write_pack(self,data):
bdata = struct.pack("<I", len(data))[:3] + bytes([self._next_seq_id]) + data
self.sock.sendall(bdata)
self._next_seq_id = (self._next_seq_id + 1) % 256

def handshake(self,bdata):
i = 0 #已经读取的字节数, 解析binlog的时候也是这么用的.....
protocol_version = bdata[:1] #只解析10

server_end = bdata.find(b"\0", i)
self.server_version = bdata[i:server_end]
i = server_end + 1

self.thread_id = btoint(bdata[i:i+4])
i += 4

self.salt = bdata[i:i+8]
i += 9 #还有1字节的filter, 没啥意义,就不保存了

self.server_capabilities = btoint(bdata[i:i+2])
i += 2

self.server_charset = btoint(bdata[i:i+1])
i += 1

self.server_status = btoint(bdata[i:i+2])
i += 2

self.server_capabilities |= btoint(bdata[i:i+2]) << 16 #往左移16位 为啥不把capability_flags_1和capability_flags_2和一起呢
i += 2

salt_length = struct.unpack('<B',bdata[i:i+1])[0] #懒得去判断capabilities & CLIENT_PLUGIN_AUTH了
salt_length = max(13,salt_length-8) #前面已经有8字节了
i += 1

i += 10 #reserved

self.salt += bdata[i:i+salt_length]
i += salt_length

self.server_plugname = bdata[i:]

def HandshakeResponse41(self,):
#client_flag = 3842565 #不含DBname
client_flag = 33531525#不含DBname
#client_flag |= 1 << 3

charset_id = 45 #45:utf8mb4 33:utf8

#bdata = client_flag.to_bytes(4,'little') #其实应该最后在加, 毕竟还要判断很多参数, 可能还需要修改, 但是懒
bdata = struct.pack('<iIB23s',client_flag,2**24-1,charset_id,b'')

bdata += self.user.encode() + b'\0'

auth_password = native_password(self.password.encode(), self.salt[:20])
auth_response = _lenenc_int(len(auth_password)) + auth_password
bdata += auth_response

bdata += b"mysql_native_password" + b'\0'

#本文有设置连接属性, 主要是为了方便观察
attr = {'_client_name':'ddcw_for_pymysql', '_pid':str(os.getpid()), "_client_version":'0.0.1',}
#key长度+k+v长度+v
connect_attrs = b""
for k, v in attr.items():
k = k.encode()
connect_attrs += _lenenc_int(len(k)) + k
v = v.encode()
connect_attrs += _lenenc_int(len(v)) + v
bdata += _lenenc_int(len(connect_attrs)) + connect_attrs
self.write_pack(bdata)

#0xFE 交换认证 https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_old_auth_switch_request.html
#0x01 额外认证
#0x00 OK
#偷懒, 懒得去判断client_flag了
auth_pack = self.read_pack()
if auth_pack[:1] == b'\0':
print('OK',auth_pack)
elif auth_pack[:1] == b'\xfe':
#switch request
print('hava switch request')
if auth_pack.find(b'caching_sha2_password') < 0:
print('仅测试caching_sha2_password, 但当前是:',auth_pack[1:auth_pack.find(b'\x00')])
return False
scrambled = sha2_password(self.password.encode(),auth_pack[auth_pack.find(b'\x00')+1:]) #salt是剩下的部分
self.write_pack(scrambled)
auth_pack = self.read_pack()
print(auth_pack)
self.caching_sha2_password_auth(auth_pack)
elif auth_pack[:1] == b'\x01':
self.caching_sha2_password_auth(auth_pack)
else:
print('FAILED',auth_pack)


def caching_sha2_password_auth(self,auth_pack):
if auth_pack[1:2] == b'\x03': #fast
bdata = self.read_pack() #ok pack
print('fast auth success.',bdata)
elif auth_pack[1:2] == b'\x04': #full
#如果是SSL/socket/shard_mem就直接发送密码(不需要加密了) TODO
self.write_pack(b'\x02') #要公钥
bdata = self.read_pack() #server发来的公钥
pubk = bdata[1:] #第一字节是extra_auth 而且肯定是 0x01
#print('bdata',bdata)
self.pubk = pubk
password = sha2_rsa_encrypt(self.password.encode(), self.salt, pubk)
self.write_pack(password)
authpack = self.read_pack() #看看是否成功
print('full auth',authpack)
else:
print('???',auth_pack)


def query(self,sql):
"""不考虑SQL超过16MB情况"""
# payload_length:3 sequence_id:1 payload:N
# payload: com_query(0x03):1 sql:n
bdata = struct.pack('<IB',len(sql)+1,0x03) #I:每个com_query的seq_id都从0开始,第4字节固定为0, 所以直接用I, +1:com_query占用1字节, 0x03:com_query
bdata += sql.encode()
self.sock.sendall(bdata)
self._next_seq_id = 1 #下一个包seq_id = 1

def connect(self):
self._next_seq_id = 0
sock = socket.create_connection((self.host, self.port))
sock.settimeout(None)
self.sock = sock
self.rf = sock.makefile("rb")
bdata = self.read_pack()
self.handshake(bdata)

self.HandshakeResponse41()

【MySQL】caching_sha2_password 加密原理和连接过程
https://hodlyounger.github.io/C_OpenSource/MySQL/【MySQL】加密原理/
作者
mingming
发布于
2023年10月27日
许可协议