# Copyright (c) 2012, 2023, Oracle and/or its affiliates.
|
#
|
# This program is free software; you can redistribute it and/or modify
|
# it under the terms of the GNU General Public License, version 2.0, as
|
# published by the Free Software Foundation.
|
#
|
# This program is also distributed with certain software (including
|
# but not limited to OpenSSL) that is licensed under separate terms,
|
# as designated in a particular file or component or in included license
|
# documentation. The authors of MySQL hereby grant you an
|
# additional permission to link the program and your derivative works
|
# with the separately licensed software that they have included with
|
# MySQL.
|
#
|
# Without limiting anything contained in the foregoing, this file,
|
# which is part of MySQL Connector/Python, is also subject to the
|
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
# http://oss.oracle.com/licenses/universal-foss-exception.
|
#
|
# This program is distributed in the hope that it will be useful, but
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
# See the GNU General Public License, version 2.0, for more details.
|
#
|
# You should have received a copy of the GNU General Public License
|
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
|
# mypy: disable-error-code="attr-defined"
|
|
"""Module implementing low-level socket communication with MySQL servers.
|
"""
|
|
import os
|
import socket
|
import struct
|
import warnings
|
import zlib
|
|
from collections import deque
|
|
try:
|
import ssl
|
|
TLS_VERSIONS = {
|
"TLSv1": ssl.PROTOCOL_TLSv1,
|
"TLSv1.1": ssl.PROTOCOL_TLSv1_1,
|
"TLSv1.2": ssl.PROTOCOL_TLSv1_2,
|
}
|
# TLSv1.3 included in PROTOCOL_TLS, but PROTOCOL_TLS is not included on 3.4
|
TLS_VERSIONS["TLSv1.3"] = (
|
ssl.PROTOCOL_TLS
|
if hasattr(ssl, "PROTOCOL_TLS")
|
else ssl.PROTOCOL_SSLv23 # Alias of PROTOCOL_TLS
|
)
|
TLS_V1_3_SUPPORTED = hasattr(ssl, "HAS_TLSv1_3") and ssl.HAS_TLSv1_3
|
except ImportError:
|
# If import fails, we don't have SSL support.
|
TLS_V1_3_SUPPORTED = False
|
|
from typing import Any, Deque, List, Optional, Tuple, Union
|
|
from .constants import MAX_PACKET_LENGTH
|
from .errors import InterfaceError, NotSupportedError, OperationalError
|
from .types import StrOrBytesPath
|
from .utils import init_bytearray
|
|
|
def _strioerror(err: IOError) -> str:
|
"""Reformat the IOError error message
|
|
This function reformats the IOError error message.
|
"""
|
if not err.errno:
|
return str(err)
|
return f"{err.errno} {err.strerror}"
|
|
|
def _prepare_packets(buf: bytes, pktnr: int) -> List[bytes]:
|
"""Prepare a packet for sending to the MySQL server"""
|
pkts = []
|
pllen = len(buf)
|
maxpktlen = MAX_PACKET_LENGTH
|
while pllen > maxpktlen:
|
pkts.append(b"\xff\xff\xff" + struct.pack("<B", pktnr) + buf[:maxpktlen])
|
buf = buf[maxpktlen:]
|
pllen = len(buf)
|
pktnr = pktnr + 1
|
pkts.append(struct.pack("<I", pllen)[0:3] + struct.pack("<B", pktnr) + buf)
|
return pkts
|
|
|
class BaseMySQLSocket:
|
"""Base class for MySQL socket communication
|
|
This class should not be used directly but overloaded, changing the
|
at least the open_connection()-method. Examples of subclasses are
|
mysql.connector.network.MySQLTCPSocket
|
mysql.connector.network.MySQLUnixSocket
|
"""
|
|
def __init__(self) -> None:
|
# holds the socket connection
|
self.sock: Optional[socket.socket] = None
|
self._connection_timeout: Optional[int] = None
|
self._packet_number: int = -1
|
self._compressed_packet_number: int = -1
|
self._packet_queue: Deque[bytearray] = deque()
|
self.server_host: Optional[str] = None
|
self.recvsize: int = 8192
|
|
def next_packet_number(self) -> int:
|
"""Increments the packet number"""
|
self._packet_number = self._packet_number + 1
|
if self._packet_number > 255:
|
self._packet_number = 0
|
return self._packet_number
|
|
def next_compressed_packet_number(self) -> int:
|
"""Increments the compressed packet number"""
|
self._compressed_packet_number = self._compressed_packet_number + 1
|
if self._compressed_packet_number > 255:
|
self._compressed_packet_number = 0
|
return self._compressed_packet_number
|
|
def open_connection(self) -> Any:
|
"""Open the socket"""
|
raise NotImplementedError
|
|
def get_address(self) -> Any:
|
"""Get the location of the socket"""
|
raise NotImplementedError
|
|
def shutdown(self) -> None:
|
"""Shut down the socket before closing it"""
|
try:
|
self.sock.shutdown(socket.SHUT_RDWR)
|
self.sock.close()
|
del self._packet_queue
|
except (AttributeError, OSError):
|
pass
|
|
def close_connection(self) -> None:
|
"""Close the socket"""
|
try:
|
self.sock.close()
|
del self._packet_queue
|
except (AttributeError, OSError):
|
pass
|
|
def __del__(self) -> None:
|
self.shutdown()
|
|
def send_plain(
|
self,
|
buf: bytes,
|
packet_number: Optional[int] = None,
|
compressed_packet_number: Optional[int] = None,
|
) -> None:
|
"""Send packets to the MySQL server"""
|
# Keep 'compressed_packet_number' for API backward compatibility
|
_ = compressed_packet_number
|
if packet_number is None:
|
self.next_packet_number()
|
else:
|
self._packet_number = packet_number
|
packets = _prepare_packets(buf, self._packet_number)
|
for packet in packets:
|
try:
|
self.sock.sendall(packet)
|
except IOError as err:
|
raise OperationalError(
|
errno=2055, values=(self.get_address(), _strioerror(err))
|
) from err
|
except AttributeError as err:
|
raise OperationalError(errno=2006) from err
|
|
send = send_plain
|
|
def send_compressed(
|
self,
|
buf: bytes,
|
packet_number: Optional[int] = None,
|
compressed_packet_number: Optional[int] = None,
|
) -> None:
|
"""Send compressed packets to the MySQL server"""
|
if packet_number is None:
|
self.next_packet_number()
|
else:
|
self._packet_number = packet_number
|
if compressed_packet_number is None:
|
self.next_compressed_packet_number()
|
else:
|
self._compressed_packet_number = compressed_packet_number
|
|
pktnr = self._packet_number
|
pllen = len(buf)
|
zpkts = []
|
maxpktlen = MAX_PACKET_LENGTH
|
if pllen > maxpktlen:
|
pkts = _prepare_packets(buf, pktnr)
|
tmpbuf = b"".join(pkts)
|
del pkts
|
zbuf = zlib.compress(tmpbuf[:16384])
|
header = (
|
struct.pack("<I", len(zbuf))[0:3]
|
+ struct.pack("<B", self._compressed_packet_number)
|
+ b"\x00\x40\x00"
|
)
|
zpkts.append(header + zbuf)
|
tmpbuf = tmpbuf[16384:]
|
pllen = len(tmpbuf)
|
self.next_compressed_packet_number()
|
while pllen > maxpktlen:
|
zbuf = zlib.compress(tmpbuf[:maxpktlen])
|
header = (
|
struct.pack("<I", len(zbuf))[0:3]
|
+ struct.pack("<B", self._compressed_packet_number)
|
+ b"\xff\xff\xff"
|
)
|
zpkts.append(header + zbuf)
|
tmpbuf = tmpbuf[maxpktlen:]
|
pllen = len(tmpbuf)
|
self.next_compressed_packet_number()
|
if tmpbuf:
|
zbuf = zlib.compress(tmpbuf)
|
header = (
|
struct.pack("<I", len(zbuf))[0:3]
|
+ struct.pack("<B", self._compressed_packet_number)
|
+ struct.pack("<I", pllen)[0:3]
|
)
|
zpkts.append(header + zbuf)
|
del tmpbuf
|
else:
|
pkt = struct.pack("<I", pllen)[0:3] + struct.pack("<B", pktnr) + buf
|
pllen = len(pkt)
|
if pllen > 50:
|
zbuf = zlib.compress(pkt)
|
zpkts.append(
|
struct.pack("<I", len(zbuf))[0:3]
|
+ struct.pack("<B", self._compressed_packet_number)
|
+ struct.pack("<I", pllen)[0:3]
|
+ zbuf
|
)
|
else:
|
header = (
|
struct.pack("<I", pllen)[0:3]
|
+ struct.pack("<B", self._compressed_packet_number)
|
+ struct.pack("<I", 0)[0:3]
|
)
|
zpkts.append(header + pkt)
|
|
for zip_packet in zpkts:
|
try:
|
self.sock.sendall(zip_packet)
|
except IOError as err:
|
raise OperationalError(
|
errno=2055, values=(self.get_address(), _strioerror(err))
|
) from err
|
except AttributeError as err:
|
raise OperationalError(errno=2006) from err
|
|
def recv_plain(self) -> bytearray:
|
"""Receive packets from the MySQL server"""
|
try:
|
# Read the header of the MySQL packet, 4 bytes
|
packet = bytearray(b"")
|
packet_len = 0
|
while packet_len < 4:
|
chunk = self.sock.recv(4 - packet_len)
|
if not chunk:
|
raise InterfaceError(errno=2013)
|
packet += chunk
|
packet_len = len(packet)
|
|
# Save the packet number and payload length
|
self._packet_number = packet[3]
|
payload_len = struct.unpack("<I", packet[0:3] + b"\x00")[0]
|
|
# Read the payload
|
rest = payload_len
|
packet.extend(bytearray(payload_len))
|
packet_view = memoryview(packet)
|
packet_view = packet_view[4:]
|
while rest:
|
read = self.sock.recv_into(packet_view, rest)
|
if read == 0 and rest > 0:
|
raise InterfaceError(errno=2013)
|
packet_view = packet_view[read:]
|
rest -= read
|
return packet
|
except IOError as err:
|
raise OperationalError(
|
errno=2055, values=(self.get_address(), _strioerror(err))
|
) from err
|
|
recv = recv_plain
|
|
def _split_zipped_payload(self, packet_bunch: bytearray) -> None:
|
"""Split compressed payload"""
|
while packet_bunch:
|
payload_length = struct.unpack("<I", packet_bunch[0:3] + b"\x00")[0]
|
self._packet_queue.append(packet_bunch[0 : payload_length + 4])
|
packet_bunch = packet_bunch[payload_length + 4 :]
|
|
def recv_compressed(self) -> Optional[bytearray]:
|
"""Receive compressed packets from the MySQL server"""
|
try:
|
pkt = self._packet_queue.popleft()
|
self._packet_number = pkt[3]
|
return pkt
|
except IndexError:
|
pass
|
|
header = bytearray(b"")
|
packets = []
|
try:
|
abyte = self.sock.recv(1)
|
while abyte and len(header) < 7:
|
header += abyte
|
abyte = self.sock.recv(1)
|
while header:
|
if len(header) < 7:
|
raise InterfaceError(errno=2013)
|
|
# Get length of compressed packet
|
zip_payload_length = struct.unpack("<I", header[0:3] + b"\x00")[0]
|
self._compressed_packet_number = header[3]
|
|
# Get payload length before compression
|
payload_length = struct.unpack("<I", header[4:7] + b"\x00")[0]
|
|
zip_payload = init_bytearray(abyte)
|
while len(zip_payload) < zip_payload_length:
|
chunk = self.sock.recv(zip_payload_length - len(zip_payload))
|
if not chunk:
|
raise InterfaceError(errno=2013)
|
zip_payload = zip_payload + chunk
|
|
# Payload was not compressed
|
if payload_length == 0:
|
self._split_zipped_payload(zip_payload)
|
pkt = self._packet_queue.popleft()
|
self._packet_number = pkt[3]
|
return pkt
|
|
packets.append((payload_length, zip_payload))
|
|
if zip_payload_length <= 16384:
|
# We received the full compressed packet
|
break
|
|
# Get next compressed packet
|
header = init_bytearray(b"")
|
abyte = self.sock.recv(1)
|
while abyte and len(header) < 7:
|
header += abyte
|
abyte = self.sock.recv(1)
|
|
except IOError as err:
|
raise OperationalError(
|
errno=2055, values=(self.get_address(), _strioerror(err))
|
) from err
|
|
# Compressed packet can contain more than 1 MySQL packets
|
# We decompress and make one so we can split it up
|
tmp = init_bytearray(b"")
|
for payload_length, payload in packets:
|
# payload_length can not be 0; this was previously handled
|
tmp += zlib.decompress(payload)
|
self._split_zipped_payload(tmp)
|
del tmp
|
|
try:
|
pkt = self._packet_queue.popleft()
|
self._packet_number = pkt[3]
|
return pkt
|
except IndexError:
|
pass
|
return None
|
|
def set_connection_timeout(self, timeout: Optional[int]) -> None:
|
"""Set the connection timeout"""
|
self._connection_timeout = timeout
|
if self.sock:
|
self.sock.settimeout(timeout)
|
|
def switch_to_ssl(
|
self,
|
ca: StrOrBytesPath,
|
cert: StrOrBytesPath,
|
key: StrOrBytesPath,
|
verify_cert: bool = False,
|
verify_identity: bool = False,
|
cipher_suites: Optional[str] = None,
|
tls_versions: Optional[List[str]] = None,
|
) -> None:
|
"""Switch the socket to use SSL"""
|
if not self.sock:
|
raise InterfaceError(errno=2048)
|
|
try:
|
if verify_cert:
|
cert_reqs = ssl.CERT_REQUIRED
|
elif verify_identity:
|
cert_reqs = ssl.CERT_OPTIONAL
|
else:
|
cert_reqs = ssl.CERT_NONE
|
|
if tls_versions is None or not tls_versions:
|
context = ssl.create_default_context()
|
if not verify_identity:
|
context.check_hostname = False
|
else:
|
tls_versions.sort(reverse=True)
|
|
tls_version = tls_versions[0]
|
if (
|
not TLS_V1_3_SUPPORTED
|
and tls_version == "TLSv1.3"
|
and len(tls_versions) > 1
|
):
|
tls_version = tls_versions[1]
|
ssl_protocol = TLS_VERSIONS[tls_version]
|
context = ssl.SSLContext(ssl_protocol)
|
|
if tls_version == "TLSv1.3":
|
if "TLSv1.2" not in tls_versions:
|
context.options |= ssl.OP_NO_TLSv1_2
|
if "TLSv1.1" not in tls_versions:
|
context.options |= ssl.OP_NO_TLSv1_1
|
if "TLSv1" not in tls_versions:
|
context.options |= ssl.OP_NO_TLSv1
|
|
context.check_hostname = False
|
context.verify_mode = cert_reqs
|
context.load_default_certs()
|
|
if ca:
|
try:
|
context.load_verify_locations(ca)
|
except (IOError, ssl.SSLError) as err:
|
self.sock.close()
|
raise InterfaceError(f"Invalid CA Certificate: {err}") from err
|
if cert:
|
try:
|
context.load_cert_chain(cert, key)
|
except (IOError, ssl.SSLError) as err:
|
self.sock.close()
|
raise InterfaceError(f"Invalid Certificate/Key: {err}") from err
|
if cipher_suites:
|
context.set_ciphers(cipher_suites)
|
|
if hasattr(self, "server_host"):
|
self.sock = context.wrap_socket(
|
self.sock, server_hostname=self.server_host
|
)
|
else:
|
self.sock = context.wrap_socket(self.sock)
|
|
if verify_identity:
|
context.check_hostname = True
|
hostnames: List[str] = [self.server_host] if self.server_host else []
|
if os.name == "nt" and self.server_host == "localhost":
|
hostnames = ["localhost", "127.0.0.1"]
|
aliases = socket.gethostbyaddr(self.server_host)
|
hostnames.extend([aliases[0]] + aliases[1])
|
match_found = False
|
errs = []
|
for hostname in hostnames:
|
try:
|
# Deprecated in Python 3.7 without a replacement and
|
# should be removed in the future, since OpenSSL now
|
# performs hostname matching
|
# pylint: disable=deprecated-method
|
ssl.match_hostname(self.sock.getpeercert(), hostname)
|
# pylint: enable=deprecated-method
|
except ssl.CertificateError as err:
|
errs.append(str(err))
|
else:
|
match_found = True
|
break
|
if not match_found:
|
self.sock.close()
|
raise InterfaceError(
|
f"Unable to verify server identity: {', '.join(errs)}"
|
)
|
except NameError as err:
|
raise NotSupportedError("Python installation has no SSL support") from err
|
except (ssl.SSLError, IOError) as err:
|
raise InterfaceError(
|
errno=2055, values=(self.get_address(), _strioerror(err))
|
) from err
|
except ssl.CertificateError as err:
|
raise InterfaceError(str(err)) from err
|
except NotImplementedError as err:
|
raise InterfaceError(str(err)) from err
|
|
|
class MySQLUnixSocket(BaseMySQLSocket):
|
"""MySQL socket class using UNIX sockets
|
|
Opens a connection through the UNIX socket of the MySQL Server.
|
"""
|
|
def __init__(self, unix_socket: str = "/tmp/mysql.sock") -> None:
|
super().__init__()
|
self.unix_socket: str = unix_socket
|
|
def get_address(self) -> str:
|
return self.unix_socket
|
|
def open_connection(self) -> None:
|
try:
|
self.sock = socket.socket(
|
socket.AF_UNIX, socket.SOCK_STREAM # pylint: disable=no-member
|
)
|
self.sock.settimeout(self._connection_timeout)
|
self.sock.connect(self.unix_socket)
|
except IOError as err:
|
raise InterfaceError(
|
errno=2002, values=(self.get_address(), _strioerror(err))
|
) from err
|
except Exception as err:
|
raise InterfaceError(str(err)) from err
|
|
def switch_to_ssl(
|
self, *args: Any, **kwargs: Any # pylint: disable=unused-argument
|
) -> None:
|
"""Switch the socket to use SSL."""
|
warnings.warn(
|
"SSL is disabled when using unix socket connections",
|
Warning,
|
)
|
|
|
class MySQLTCPSocket(BaseMySQLSocket):
|
"""MySQL socket class using TCP/IP
|
|
Opens a TCP/IP connection to the MySQL Server.
|
"""
|
|
def __init__(
|
self, host: str = "127.0.0.1", port: int = 3306, force_ipv6: bool = False
|
) -> None:
|
super().__init__()
|
self.server_host: str = host
|
self.server_port: int = port
|
self.force_ipv6: bool = force_ipv6
|
self._family: int = 0
|
|
def get_address(self) -> str:
|
return f"{self.server_host}:{self.server_port}"
|
|
def open_connection(self) -> None:
|
"""Open the TCP/IP connection to the MySQL server"""
|
# pylint: disable=no-member
|
# Get address information
|
addrinfo: Union[
|
Tuple[None, None, None, None, None],
|
Tuple[
|
socket.AddressFamily,
|
socket.SocketKind,
|
int,
|
str,
|
Union[Tuple[str, int], Tuple[str, int, int, int]],
|
],
|
] = (None, None, None, None, None)
|
try:
|
addrinfos = socket.getaddrinfo(
|
self.server_host,
|
self.server_port,
|
0,
|
socket.SOCK_STREAM,
|
socket.SOL_TCP,
|
)
|
# If multiple results we favor IPv4, unless IPv6 was forced.
|
for info in addrinfos:
|
if self.force_ipv6 and info[0] == socket.AF_INET6:
|
addrinfo = info
|
break
|
if info[0] == socket.AF_INET:
|
addrinfo = info
|
break
|
if self.force_ipv6 and addrinfo[0] is None:
|
raise InterfaceError(f"No IPv6 address found for {self.server_host}")
|
if addrinfo[0] is None:
|
addrinfo = addrinfos[0]
|
except IOError as err:
|
raise InterfaceError(
|
errno=2003, values=(self.get_address(), _strioerror(err))
|
) from err
|
|
(self._family, socktype, proto, _, sockaddr) = addrinfo
|
|
# Instanciate the socket and connect
|
try:
|
self.sock = socket.socket(self._family, socktype, proto)
|
self.sock.settimeout(self._connection_timeout)
|
self.sock.connect(sockaddr)
|
except IOError as err:
|
raise InterfaceError(
|
errno=2003,
|
values=(
|
self.server_host,
|
self.server_port,
|
_strioerror(err),
|
),
|
) from err
|
except Exception as err:
|
raise OperationalError(str(err)) from err
|