_auth.py
7.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
Implements auth methods
"""
from ._compat import text_type, PY2
from .constants import CLIENT
from .err import OperationalError
from .util import byte2int, int2byte
try:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
_have_cryptography = True
except ImportError:
_have_cryptography = False
from functools import partial
import hashlib
import io
import struct
import warnings
DEBUG = False
SCRAMBLE_LENGTH = 20
sha1_new = partial(hashlib.new, 'sha1')
# mysql_native_password
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
def scramble_native_password(password, message):
"""Scramble used for mysql_native_password"""
if not password:
return b''
stage1 = sha1_new(password).digest()
stage2 = sha1_new(stage1).digest()
s = sha1_new()
s.update(message[:SCRAMBLE_LENGTH])
s.update(stage2)
result = s.digest()
return _my_crypt(result, stage1)
def _my_crypt(message1, message2):
result = bytearray(message1)
if PY2:
message2 = bytearray(message2)
for i in range(len(result)):
result[i] ^= message2[i]
return bytes(result)
# old_passwords support ported from libmysql/password.c
# https://dev.mysql.com/doc/internals/en/old-password-authentication.html
SCRAMBLE_LENGTH_323 = 8
class RandStruct_323(object):
def __init__(self, seed1, seed2):
self.max_value = 0x3FFFFFFF
self.seed1 = seed1 % self.max_value
self.seed2 = seed2 % self.max_value
def my_rnd(self):
self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value
self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value
return float(self.seed1) / float(self.max_value)
def scramble_old_password(password, message):
"""Scramble for old_password"""
warnings.warn("old password (for MySQL <4.1) is used. Upgrade your password with newer auth method.\n"
"old password support will be removed in future PyMySQL version")
hash_pass = _hash_password_323(password)
hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
hash_pass_n = struct.unpack(">LL", hash_pass)
hash_message_n = struct.unpack(">LL", hash_message)
rand_st = RandStruct_323(
hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1]
)
outbuf = io.BytesIO()
for _ in range(min(SCRAMBLE_LENGTH_323, len(message))):
outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
extra = int2byte(int(rand_st.my_rnd() * 31))
out = outbuf.getvalue()
outbuf = io.BytesIO()
for c in out:
outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
return outbuf.getvalue()
def _hash_password_323(password):
nr = 1345345333
add = 7
nr2 = 0x12345671
# x in py3 is numbers, p27 is chars
for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
add = (add + c) & 0xFFFFFFFF
r1 = nr & ((1 << 31) - 1) # kill sign bits
r2 = nr2 & ((1 << 31) - 1)
return struct.pack(">LL", r1, r2)
# sha256_password
def _roundtrip(conn, send_data):
conn.write_packet(send_data)
pkt = conn._read_packet()
pkt.check_error()
return pkt
def _xor_password(password, salt):
password_bytes = bytearray(password)
salt = bytearray(salt) # for PY2 compat.
salt_len = len(salt)
for i in range(len(password_bytes)):
password_bytes[i] ^= salt[i % salt_len]
return bytes(password_bytes)
def sha2_rsa_encrypt(password, salt, public_key):
"""Encrypt password with salt and public_key.
Used for sha256_password and caching_sha2_password.
"""
if not _have_cryptography:
raise RuntimeError("cryptography is required for sha256_password or caching_sha2_password")
message = _xor_password(password + b'\0', salt)
rsa_key = serialization.load_pem_public_key(public_key, default_backend())
return rsa_key.encrypt(
message,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None,
),
)
def sha256_password_auth(conn, pkt):
if conn._secure:
if DEBUG:
print("sha256: Sending plain password")
data = conn.password + b'\0'
return _roundtrip(conn, data)
if pkt.is_auth_switch_request():
conn.salt = pkt.read_all()
if not conn.server_public_key and conn.password:
# Request server public key
if DEBUG:
print("sha256: Requesting server public key")
pkt = _roundtrip(conn, b'\1')
if pkt.is_extra_auth_data():
conn.server_public_key = pkt._data[1:]
if DEBUG:
print("Received public key:\n", conn.server_public_key.decode('ascii'))
if conn.password:
if not conn.server_public_key:
raise OperationalError("Couldn't receive server's public key")
data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
else:
data = b''
return _roundtrip(conn, data)
def scramble_caching_sha2(password, nonce):
# (bytes, bytes) -> bytes
"""Scramble algorithm used in cached_sha2_password fast path.
XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce))
"""
if not password:
return b''
p1 = hashlib.sha256(password).digest()
p2 = hashlib.sha256(p1).digest()
p3 = hashlib.sha256(p2 + nonce).digest()
res = bytearray(p1)
if PY2:
p3 = bytearray(p3)
for i in range(len(p3)):
res[i] ^= p3[i]
return bytes(res)
def caching_sha2_password_auth(conn, pkt):
# No password fast path
if not conn.password:
return _roundtrip(conn, b'')
if pkt.is_auth_switch_request():
# Try from fast auth
if DEBUG:
print("caching sha2: Trying fast path")
conn.salt = pkt.read_all()
scrambled = scramble_caching_sha2(conn.password, conn.salt)
pkt = _roundtrip(conn, scrambled)
# else: fast auth is tried in initial handshake
if not pkt.is_extra_auth_data():
raise OperationalError(
"caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1]
)
# magic numbers:
# 2 - request public key
# 3 - fast auth succeeded
# 4 - need full auth
pkt.advance(1)
n = pkt.read_uint8()
if n == 3:
if DEBUG:
print("caching sha2: succeeded by fast path.")
pkt = conn._read_packet()
pkt.check_error() # pkt must be OK packet
return pkt
if n != 4:
raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n)
if DEBUG:
print("caching sha2: Trying full auth...")
if conn._secure:
if DEBUG:
print("caching sha2: Sending plain password via secure connection")
return _roundtrip(conn, conn.password + b'\0')
if not conn.server_public_key:
pkt = _roundtrip(conn, b'\x02') # Request public key
if not pkt.is_extra_auth_data():
raise OperationalError(
"caching sha2: Unknown packet for public key: %s" % pkt._data[:1]
)
conn.server_public_key = pkt._data[1:]
if DEBUG:
print(conn.server_public_key.decode('ascii'))
data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
pkt = _roundtrip(conn, data)