# Copyright (c) 2016, 2022, Oracle and/or its affiliates.
|
#
|
# This program is free software; you can redistribute it and/or modify
|
# it under the terms of the GNU General Public License, version 2.0, as
|
# published by the Free Software Foundation.
|
#
|
# This program is also distributed with certain software (including
|
# but not limited to OpenSSL) that is licensed under separate terms,
|
# as designated in a particular file or component or in included license
|
# documentation. The authors of MySQL hereby grant you an
|
# additional permission to link the program and your derivative works
|
# with the separately licensed software that they have included with
|
# MySQL.
|
#
|
# Without limiting anything contained in the foregoing, this file,
|
# which is part of MySQL Connector/Python, is also subject to the
|
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
# http://oss.oracle.com/licenses/universal-foss-exception.
|
#
|
# This program is distributed in the hope that it will be useful, but
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
# See the GNU General Public License, version 2.0, for more details.
|
#
|
# You should have received a copy of the GNU General Public License
|
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
|
# mypy: disable-error-code="return-value"
|
|
"""Implementation of Statements."""
|
|
from __future__ import annotations
|
|
import copy
|
import json
|
import warnings
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
from .constants import LockContention
|
from .dbdoc import DbDoc
|
from .errors import NotSupportedError, ProgrammingError
|
from .expr import ExprParser
|
from .helpers import deprecated
|
from .protobuf import mysqlxpb_enum
|
from .result import DocResult, Result, RowResult, SqlResult
|
from .types import (
|
ConnectionType,
|
DatabaseTargetType,
|
MessageType,
|
ProtobufMessageCextType,
|
ProtobufMessageType,
|
SchemaType,
|
)
|
|
ERR_INVALID_INDEX_NAME = 'The given index name "{}" is not valid'
|
|
|
class Expr:
|
"""Expression wrapper."""
|
|
def __init__(self, expr: Any) -> None:
|
self.expr: Any = expr
|
|
|
def flexible_params(*values: Any) -> Union[List, Tuple]:
|
"""Parse flexible parameters."""
|
if len(values) == 1 and isinstance(values[0], (list, tuple)):
|
return values[0]
|
return values
|
|
|
def is_quoted_identifier(identifier: str, sql_mode: str = "") -> bool:
|
"""Check if the given identifier is quoted.
|
|
Args:
|
identifier (string): Identifier to check.
|
sql_mode (Optional[string]): SQL mode.
|
|
Returns:
|
`True` if the identifier has backtick quotes, and False otherwise.
|
"""
|
if "ANSI_QUOTES" in sql_mode:
|
return (identifier[0] == "`" and identifier[-1] == "`") or (
|
identifier[0] == '"' and identifier[-1] == '"'
|
)
|
return identifier[0] == "`" and identifier[-1] == "`"
|
|
|
def quote_identifier(identifier: str, sql_mode: str = "") -> str:
|
"""Quote the given identifier with backticks, converting backticks (`) in
|
the identifier name with the correct escape sequence (``).
|
|
Args:
|
identifier (string): Identifier to quote.
|
sql_mode (Optional[string]): SQL mode.
|
|
Returns:
|
A string with the identifier quoted with backticks.
|
"""
|
if len(identifier) == 0:
|
return "``"
|
if "ANSI_QUOTES" in sql_mode:
|
quoted = identifier.replace('"', '""')
|
return f'"{quoted}"'
|
quoted = identifier.replace("`", "``")
|
return f"`{quoted}`"
|
|
|
def quote_multipart_identifier(identifiers: Iterable[str], sql_mode: str = "") -> str:
|
"""Quote the given multi-part identifier with backticks.
|
|
Args:
|
identifiers (iterable): List of identifiers to quote.
|
sql_mode (Optional[string]): SQL mode.
|
|
Returns:
|
A string with the multi-part identifier quoted with backticks.
|
"""
|
return ".".join(
|
[quote_identifier(identifier, sql_mode) for identifier in identifiers]
|
)
|
|
|
def parse_table_name(
|
default_schema: str, table_name: str, sql_mode: str = ""
|
) -> Tuple[str, str]:
|
"""Parse table name.
|
|
Args:
|
default_schema (str): The default schema.
|
table_name (str): The table name.
|
sql_mode(Optional[str]): The SQL mode.
|
|
Returns:
|
str: The parsed table name.
|
"""
|
quote = '"' if "ANSI_QUOTES" in sql_mode else "`"
|
delimiter = f".{quote}" if quote in table_name else "."
|
temp = table_name.split(delimiter, 1)
|
return (
|
default_schema if len(temp) == 1 else temp[0].strip(quote),
|
temp[-1].strip(quote),
|
)
|
|
|
class Statement:
|
"""Provides base functionality for statement objects.
|
|
Args:
|
target (object): The target database object, it can be
|
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
|
doc_based (bool): `True` if it is document based.
|
"""
|
|
def __init__(self, target: DatabaseTargetType, doc_based: bool = True) -> None:
|
self._target: DatabaseTargetType = target
|
self._doc_based: bool = doc_based
|
self._connection: Optional[ConnectionType] = (
|
target.get_connection() if target else None
|
)
|
self._stmt_id: Optional[int] = None
|
self._exec_counter: int = 0
|
self._changed: bool = True
|
self._prepared: bool = False
|
self._deallocate_prepare_execute: bool = False
|
|
@property
|
def target(self) -> DatabaseTargetType:
|
"""object: The database object target."""
|
return self._target
|
|
@property
|
def schema(self) -> SchemaType:
|
""":class:`mysqlx.Schema`: The Schema object."""
|
return self._target.schema
|
|
@property
|
def stmt_id(self) -> int:
|
"""Returns this statement ID.
|
|
Returns:
|
int: The statement ID.
|
"""
|
return self._stmt_id
|
|
@stmt_id.setter
|
def stmt_id(self, value: int) -> None:
|
self._stmt_id = value
|
|
@property
|
def exec_counter(self) -> int:
|
"""int: The number of times this statement was executed."""
|
return self._exec_counter
|
|
@property
|
def changed(self) -> bool:
|
"""bool: `True` if this statement has changes."""
|
return self._changed
|
|
@changed.setter
|
def changed(self, value: bool) -> None:
|
self._changed = value
|
|
@property
|
def prepared(self) -> bool:
|
"""bool: `True` if this statement has been prepared."""
|
return self._prepared
|
|
@prepared.setter
|
def prepared(self, value: bool) -> None:
|
self._prepared = value
|
|
@property
|
def repeated(self) -> bool:
|
"""bool: `True` if this statement was executed more than once."""
|
return self._exec_counter > 1
|
|
@property
|
def deallocate_prepare_execute(self) -> bool:
|
"""bool: `True` to deallocate + prepare + execute statement."""
|
return self._deallocate_prepare_execute
|
|
@deallocate_prepare_execute.setter
|
def deallocate_prepare_execute(self, value: bool) -> None:
|
self._deallocate_prepare_execute = value
|
|
def is_doc_based(self) -> bool:
|
"""Check if it is document based.
|
|
Returns:
|
bool: `True` if it is document based.
|
"""
|
return self._doc_based
|
|
def increment_exec_counter(self) -> None:
|
"""Increments the number of times this statement has been executed."""
|
self._exec_counter += 1
|
|
def reset_exec_counter(self) -> None:
|
"""Resets the number of times this statement has been executed."""
|
self._exec_counter = 0
|
|
def execute(self) -> Any:
|
"""Execute the statement.
|
|
Raises:
|
NotImplementedError: This method must be implemented.
|
"""
|
raise NotImplementedError
|
|
|
class FilterableStatement(Statement):
|
"""A statement to be used with filterable statements.
|
|
Args:
|
target (object): The target database object, it can be
|
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
|
doc_based (Optional[bool]): `True` if it is document based
|
(default: `True`).
|
condition (Optional[str]): Sets the search condition to filter
|
documents or records.
|
"""
|
|
def __init__(
|
self,
|
target: DatabaseTargetType,
|
doc_based: bool = True,
|
condition: Optional[str] = None,
|
) -> None:
|
super().__init__(target=target, doc_based=doc_based)
|
self._binding_map: Dict[str, Any] = {}
|
self._bindings: Union[Dict[str, Any], List] = {}
|
self._having: Optional[MessageType] = None
|
self._grouping_str: str = ""
|
self._grouping: Optional[
|
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
|
] = None
|
self._limit_offset: int = 0
|
self._limit_row_count: int = None
|
self._projection_str: str = ""
|
self._projection_expr: Optional[
|
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
|
] = None
|
self._sort_str: str = ""
|
self._sort_expr: Optional[
|
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
|
] = None
|
self._where_str: str = ""
|
self._where_expr: MessageType = None
|
self.has_bindings: bool = False
|
self.has_limit: bool = False
|
self.has_group_by: bool = False
|
self.has_having: bool = False
|
self.has_projection: bool = False
|
self.has_sort: bool = False
|
self.has_where: bool = False
|
if condition:
|
self._set_where(condition)
|
|
def _bind_single(self, obj: Union[DbDoc, Dict[str, Any], str]) -> None:
|
"""Bind single object.
|
|
Args:
|
obj (:class:`mysqlx.DbDoc` or str): DbDoc or JSON string object.
|
|
Raises:
|
:class:`mysqlx.ProgrammingError`: If invalid JSON string to bind.
|
ValueError: If JSON loaded is not a dictionary.
|
"""
|
if isinstance(obj, dict):
|
self.bind(DbDoc(obj).as_str())
|
elif isinstance(obj, DbDoc):
|
self.bind(obj.as_str())
|
elif isinstance(obj, str):
|
try:
|
res = json.loads(obj)
|
if not isinstance(res, dict):
|
raise ValueError
|
except ValueError as err:
|
raise ProgrammingError("Invalid JSON string to bind") from err
|
for key in res.keys():
|
self.bind(key, res[key])
|
else:
|
raise ProgrammingError("Invalid JSON string or object to bind")
|
|
def _sort(self, *clauses: str) -> FilterableStatement:
|
"""Sets the sorting criteria.
|
|
Args:
|
*clauses: The expression strings defining the sort criteria.
|
|
Returns:
|
mysqlx.FilterableStatement: FilterableStatement object.
|
"""
|
self.has_sort = True
|
self._sort_str = ",".join(flexible_params(*clauses))
|
self._sort_expr = ExprParser(
|
self._sort_str, not self._doc_based
|
).parse_order_spec()
|
self._changed = True
|
return self
|
|
def _set_where(self, condition: str) -> FilterableStatement:
|
"""Sets the search condition to filter.
|
|
Args:
|
condition (str): Sets the search condition to filter documents or
|
records.
|
|
Returns:
|
mysqlx.FilterableStatement: FilterableStatement object.
|
"""
|
self.has_where = True
|
self._where_str = condition
|
try:
|
expr = ExprParser(condition, not self._doc_based)
|
self._where_expr = expr.expr()
|
except ValueError as err:
|
raise ProgrammingError("Invalid condition") from err
|
self._binding_map = expr.placeholder_name_to_position
|
self._changed = True
|
return self
|
|
def _set_group_by(self, *fields: str) -> None:
|
"""Set group by.
|
|
Args:
|
*fields: List of fields.
|
"""
|
fields = flexible_params(*fields)
|
self.has_group_by = True
|
self._grouping_str = ",".join(fields)
|
self._grouping = ExprParser(
|
self._grouping_str, not self._doc_based
|
).parse_expr_list()
|
self._changed = True
|
|
def _set_having(self, condition: str) -> None:
|
"""Set having.
|
|
Args:
|
condition (str): The condition.
|
"""
|
self.has_having = True
|
self._having = ExprParser(condition, not self._doc_based).expr()
|
self._changed = True
|
|
def _set_projection(self, *fields: str) -> FilterableStatement:
|
"""Set the projection.
|
|
Args:
|
*fields: List of fields.
|
|
Returns:
|
:class:`mysqlx.FilterableStatement`: Returns self.
|
"""
|
fields = flexible_params(*fields)
|
self.has_projection = True
|
self._projection_str = ",".join(fields)
|
self._projection_expr = ExprParser(
|
self._projection_str, not self._doc_based
|
).parse_table_select_projection()
|
self._changed = True
|
return self
|
|
def get_binding_map(self) -> Dict[str, Any]:
|
"""Returns the binding map dictionary.
|
|
Returns:
|
dict: The binding map dictionary.
|
"""
|
return self._binding_map
|
|
def get_bindings(self) -> Union[Dict[str, Any], List]:
|
"""Returns the bindings list.
|
|
Returns:
|
`list`: The bindings list.
|
"""
|
return self._bindings
|
|
def get_grouping(self) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
|
"""Returns the grouping expression list.
|
|
Returns:
|
`list`: The grouping expression list.
|
"""
|
return self._grouping
|
|
def get_having(self) -> MessageType:
|
"""Returns the having expression.
|
|
Returns:
|
object: The having expression.
|
"""
|
return self._having
|
|
def get_limit_row_count(self) -> int:
|
"""Returns the limit row count.
|
|
Returns:
|
int: The limit row count.
|
"""
|
return self._limit_row_count
|
|
def get_limit_offset(self) -> int:
|
"""Returns the limit offset.
|
|
Returns:
|
int: The limit offset.
|
"""
|
return self._limit_offset
|
|
def get_where_expr(self) -> MessageType:
|
"""Returns the where expression.
|
|
Returns:
|
object: The where expression.
|
"""
|
return self._where_expr
|
|
def get_projection_expr(
|
self,
|
) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
|
"""Returns the projection expression.
|
|
Returns:
|
object: The projection expression.
|
"""
|
return self._projection_expr
|
|
def get_sort_expr(
|
self,
|
) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
|
"""Returns the sort expression.
|
|
Returns:
|
object: The sort expression.
|
"""
|
return self._sort_expr
|
|
@deprecated("8.0.12")
|
def where(self, condition: str) -> FilterableStatement:
|
"""Sets the search condition to filter.
|
|
Args:
|
condition (str): Sets the search condition to filter documents or
|
records.
|
|
Returns:
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
.. deprecated:: 8.0.12
|
"""
|
return self._set_where(condition)
|
|
@deprecated("8.0.12")
|
def sort(self, *clauses: str) -> FilterableStatement:
|
"""Sets the sorting criteria.
|
|
Args:
|
*clauses: The expression strings defining the sort criteria.
|
|
Returns:
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
.. deprecated:: 8.0.12
|
"""
|
return self._sort(*clauses)
|
|
def limit(
|
self, row_count: int, offset: Optional[int] = None
|
) -> FilterableStatement:
|
"""Sets the maximum number of items to be returned.
|
|
Args:
|
row_count (int): The maximum number of items.
|
|
Returns:
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
Raises:
|
ValueError: If ``row_count`` is not a positive integer.
|
|
.. versionchanged:: 8.0.12
|
The usage of ``offset`` was deprecated.
|
"""
|
if not isinstance(row_count, int) or row_count < 0:
|
raise ValueError("The 'row_count' value must be a positive integer")
|
if not self.has_limit:
|
self._changed = bool(self._exec_counter == 0)
|
self._deallocate_prepare_execute = bool(not self._exec_counter == 0)
|
|
self._limit_row_count = row_count
|
self.has_limit = True
|
if offset:
|
self.offset(offset)
|
warnings.warn(
|
"'limit(row_count, offset)' is deprecated, please "
|
"use 'offset(offset)' to set the number of items to "
|
"skip",
|
category=DeprecationWarning,
|
)
|
return self
|
|
def offset(self, offset: int) -> FilterableStatement:
|
"""Sets the number of items to skip.
|
|
Args:
|
offset (int): The number of items to skip.
|
|
Returns:
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
Raises:
|
ValueError: If ``offset`` is not a positive integer.
|
|
.. versionadded:: 8.0.12
|
"""
|
if not isinstance(offset, int) or offset < 0:
|
raise ValueError("The 'offset' value must be a positive integer")
|
self._limit_offset = offset
|
return self
|
|
def bind(self, *args: Any) -> FilterableStatement:
|
"""Binds value(s) to a specific placeholder(s).
|
|
Args:
|
*args: The name of the placeholder and the value to bind.
|
A :class:`mysqlx.DbDoc` object or a JSON string
|
representation can be used.
|
|
Returns:
|
mysqlx.FilterableStatement: FilterableStatement object.
|
|
Raises:
|
ProgrammingError: If the number of arguments is invalid.
|
"""
|
self.has_bindings = True
|
count = len(args)
|
if count == 1:
|
self._bind_single(args[0])
|
elif count == 2:
|
self._bindings[args[0]] = args[1]
|
else:
|
raise ProgrammingError("Invalid number of arguments to bind")
|
return self
|
|
def execute(self) -> Any:
|
"""Execute the statement.
|
|
Raises:
|
NotImplementedError: This method must be implemented.
|
"""
|
raise NotImplementedError
|
|
|
class SqlStatement(Statement):
|
"""A statement for SQL execution.
|
|
Args:
|
connection (mysqlx.connection.Connection): Connection object.
|
sql (string): The sql statement to be executed.
|
"""
|
|
def __init__(self, connection: ConnectionType, sql: str) -> None:
|
super().__init__(target=None, doc_based=False)
|
self._connection: ConnectionType = connection
|
self._sql: str = sql
|
self._binding_map: Optional[Dict[str, Any]] = None
|
self._bindings: Union[List, Tuple] = []
|
self.has_bindings: bool = False
|
self.has_limit: bool = False
|
|
@property
|
def sql(self) -> str:
|
"""string: The SQL text statement."""
|
return self._sql
|
|
def get_binding_map(self) -> Dict[str, Any]:
|
"""Returns the binding map dictionary.
|
|
Returns:
|
dict: The binding map dictionary.
|
"""
|
return self._binding_map
|
|
def get_bindings(self) -> Union[Tuple, List]:
|
"""Returns the bindings list.
|
|
Returns:
|
`list`: The bindings list.
|
"""
|
return self._bindings
|
|
def bind(self, *args: Any) -> SqlStatement:
|
"""Binds value(s) to a specific placeholder(s).
|
|
Args:
|
*args: The value(s) to bind.
|
|
Returns:
|
mysqlx.SqlStatement: SqlStatement object.
|
"""
|
if len(args) == 0:
|
raise ProgrammingError("Invalid number of arguments to bind")
|
self.has_bindings = True
|
bindings = flexible_params(*args)
|
if isinstance(bindings, (list, tuple)):
|
self._bindings = bindings
|
else:
|
self._bindings.append(bindings)
|
return self
|
|
def execute(self) -> SqlResult:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.SqlResult: SqlResult object.
|
"""
|
return self._connection.send_sql(self)
|
|
|
class WriteStatement(Statement):
|
"""Provide common write operation attributes."""
|
|
def __init__(self, target: DatabaseTargetType, doc_based: bool) -> None:
|
super().__init__(target, doc_based)
|
self._values: List[
|
Union[
|
int,
|
str,
|
DbDoc,
|
Dict[str, Any],
|
List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]],
|
]
|
] = []
|
|
def get_values(
|
self,
|
) -> List[
|
Union[
|
int,
|
str,
|
DbDoc,
|
Dict[str, Any],
|
List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]],
|
]
|
]:
|
"""Returns the list of values.
|
|
Returns:
|
`list`: The list of values.
|
"""
|
return self._values
|
|
def execute(self) -> Any:
|
"""Execute the statement.
|
|
Raises:
|
NotImplementedError: This method must be implemented.
|
"""
|
raise NotImplementedError
|
|
|
class AddStatement(WriteStatement):
|
"""A statement for document addition on a collection.
|
|
Args:
|
collection (mysqlx.Collection): The Collection object.
|
"""
|
|
def __init__(self, collection: DatabaseTargetType) -> None:
|
super().__init__(collection, True)
|
self._upsert: bool = False
|
self.ids: List = []
|
|
def is_upsert(self) -> bool:
|
"""Returns `True` if it's an upsert.
|
|
Returns:
|
bool: `True` if it's an upsert.
|
"""
|
return self._upsert
|
|
def upsert(self, value: bool = True) -> AddStatement:
|
"""Sets the upset flag to the boolean of the value provided.
|
Setting of this flag allows updating of the matched rows/documents
|
with the provided value.
|
|
Args:
|
value (optional[bool]): Set or unset the upsert flag.
|
"""
|
self._upsert = value
|
return self
|
|
def add(self, *values: DbDoc) -> AddStatement:
|
"""Adds a list of documents into a collection.
|
|
Args:
|
*values: The documents to be added into the collection.
|
|
Returns:
|
mysqlx.AddStatement: AddStatement object.
|
"""
|
for val in flexible_params(*values):
|
if isinstance(val, DbDoc):
|
self._values.append(val)
|
else:
|
self._values.append(DbDoc(val))
|
return self
|
|
def execute(self) -> Result:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object.
|
"""
|
if len(self._values) == 0:
|
return Result()
|
|
return self._connection.send_insert(self)
|
|
|
class UpdateSpec:
|
"""Update specification class implementation.
|
|
Args:
|
update_type (int): The update type.
|
source (str): The source.
|
value (Optional[str]): The value.
|
|
Raises:
|
ProgrammingError: If `source` is invalid.
|
"""
|
|
def __init__(self, update_type: int, source: str, value: Any = None) -> None:
|
if update_type == mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"):
|
self._table_set(source, value)
|
else:
|
self.update_type: int = update_type
|
try:
|
self.source: Any = ExprParser(source, False).document_field().identifier
|
except ValueError as err:
|
raise ProgrammingError(f"{err}") from err
|
self.value: Any = value
|
|
def _table_set(self, source: str, value: Any) -> None:
|
"""Table set.
|
|
Args:
|
source (str): The source.
|
value (str): The value.
|
"""
|
self.update_type = mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET")
|
self.source = ExprParser(source, True).parse_table_update_field()
|
self.value = value
|
|
|
class ModifyStatement(FilterableStatement):
|
"""A statement for document update operations on a Collection.
|
|
Args:
|
collection (mysqlx.Collection): The Collection object.
|
condition (str): Sets the search condition to identify the documents
|
to be modified.
|
|
.. versionchanged:: 8.0.12
|
The ``condition`` parameter is now mandatory.
|
"""
|
|
def __init__(self, collection: DatabaseTargetType, condition: str) -> None:
|
super().__init__(target=collection, condition=condition)
|
self._update_ops: Dict[str, Any] = {}
|
|
def sort(self, *clauses: str) -> ModifyStatement:
|
"""Sets the sorting criteria.
|
|
Args:
|
*clauses: The expression strings defining the sort criteria.
|
|
Returns:
|
mysqlx.ModifyStatement: ModifyStatement object.
|
"""
|
return self._sort(*clauses)
|
|
def get_update_ops(self) -> Dict[str, Any]:
|
"""Returns the list of update operations.
|
|
Returns:
|
`list`: The list of update operations.
|
"""
|
return self._update_ops
|
|
def set(self, doc_path: str, value: Any) -> ModifyStatement:
|
"""Sets or updates attributes on documents in a collection.
|
|
Args:
|
doc_path (string): The document path of the item to be set.
|
value (string): The value to be set on the specified attribute.
|
|
Returns:
|
mysqlx.ModifyStatement: ModifyStatement object.
|
"""
|
self._update_ops[doc_path] = UpdateSpec(
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_SET"),
|
doc_path,
|
value,
|
)
|
self._changed = True
|
return self
|
|
@deprecated("8.0.12")
|
def change(self, doc_path: str, value: Any) -> ModifyStatement:
|
"""Add an update to the statement setting the field, if it exists at
|
the document path, to the given value.
|
|
Args:
|
doc_path (string): The document path of the item to be set.
|
value (object): The value to be set on the specified attribute.
|
|
Returns:
|
mysqlx.ModifyStatement: ModifyStatement object.
|
|
.. deprecated:: 8.0.12
|
"""
|
self._update_ops[doc_path] = UpdateSpec(
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REPLACE"),
|
doc_path,
|
value,
|
)
|
self._changed = True
|
return self
|
|
def unset(self, *doc_paths: str) -> ModifyStatement:
|
"""Removes attributes from documents in a collection.
|
|
Args:
|
doc_paths (list): The list of document paths of the attributes to be
|
removed.
|
|
Returns:
|
mysqlx.ModifyStatement: ModifyStatement object.
|
"""
|
for item in flexible_params(*doc_paths):
|
self._update_ops[item] = UpdateSpec(
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REMOVE"),
|
item,
|
)
|
self._changed = True
|
return self
|
|
def array_insert(self, field: str, value: Any) -> ModifyStatement:
|
"""Insert a value into the specified array in documents of a
|
collection.
|
|
Args:
|
field (string): A document path that identifies the array attribute
|
and position where the value will be inserted.
|
value (object): The value to be inserted.
|
|
Returns:
|
mysqlx.ModifyStatement: ModifyStatement object.
|
"""
|
self._update_ops[field] = UpdateSpec(
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_INSERT"),
|
field,
|
value,
|
)
|
self._changed = True
|
return self
|
|
def array_append(self, doc_path: str, value: Any) -> ModifyStatement:
|
"""Inserts a value into a specific position in an array attribute in
|
documents of a collection.
|
|
Args:
|
doc_path (string): A document path that identifies the array
|
attribute and position where the value will be
|
inserted.
|
value (object): The value to be inserted.
|
|
Returns:
|
mysqlx.ModifyStatement: ModifyStatement object.
|
"""
|
self._update_ops[doc_path] = UpdateSpec(
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_APPEND"),
|
doc_path,
|
value,
|
)
|
self._changed = True
|
return self
|
|
def patch(self, doc: Union[Dict, DbDoc, ExprParser, str]) -> ModifyStatement:
|
"""Takes a :class:`mysqlx.DbDoc`, string JSON format or a dict with the
|
changes and applies it on all matching documents.
|
|
Args:
|
doc (object): A generic document (DbDoc), string in JSON format or
|
dict, with the changes to apply to the matching
|
documents.
|
|
Returns:
|
mysqlx.ModifyStatement: ModifyStatement object.
|
"""
|
if doc is None:
|
doc = ""
|
if not isinstance(doc, (ExprParser, dict, DbDoc, str)):
|
raise ProgrammingError(
|
"Invalid data for update operation on document collection table"
|
)
|
self._update_ops["patch"] = UpdateSpec(
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.MERGE_PATCH"),
|
"$",
|
doc.expr() if isinstance(doc, ExprParser) else doc,
|
)
|
self._changed = True
|
return self
|
|
def execute(self) -> Result:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object.
|
|
Raises:
|
ProgrammingError: If condition was not set.
|
"""
|
if not self.has_where:
|
raise ProgrammingError("No condition was found for modify")
|
return self._connection.send_update(self)
|
|
|
class ReadStatement(FilterableStatement):
|
"""Provide base functionality for Read operations
|
|
Args:
|
target (object): The target database object, it can be
|
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
|
doc_based (Optional[bool]): `True` if it is document based
|
(default: `True`).
|
condition (Optional[str]): Sets the search condition to filter
|
documents or records.
|
"""
|
|
def __init__(
|
self,
|
target: DatabaseTargetType,
|
doc_based: bool = True,
|
condition: Optional[str] = None,
|
) -> None:
|
super().__init__(target, doc_based, condition)
|
self._lock_exclusive: bool = False
|
self._lock_shared: bool = False
|
self._lock_contention: LockContention = LockContention.DEFAULT
|
|
@property
|
def lock_contention(self) -> LockContention:
|
""":class:`mysqlx.LockContention`: The lock contention value."""
|
return self._lock_contention
|
|
def _set_lock_contention(self, lock_contention: LockContention) -> None:
|
"""Set the lock contention.
|
|
Args:
|
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
|
|
Raises:
|
ProgrammingError: If is an invalid lock contention value.
|
"""
|
try:
|
# Check if is a valid lock contention value
|
_ = LockContention(lock_contention.value)
|
except ValueError as err:
|
raise ProgrammingError(
|
"Invalid lock contention mode. Use 'NOWAIT' or 'SKIP_LOCKED'"
|
) from err
|
self._lock_contention = lock_contention
|
|
def is_lock_exclusive(self) -> bool:
|
"""Returns `True` if is `EXCLUSIVE LOCK`.
|
|
Returns:
|
bool: `True` if is `EXCLUSIVE LOCK`.
|
"""
|
return self._lock_exclusive
|
|
def is_lock_shared(self) -> bool:
|
"""Returns `True` if is `SHARED LOCK`.
|
|
Returns:
|
bool: `True` if is `SHARED LOCK`.
|
"""
|
return self._lock_shared
|
|
def lock_shared(
|
self, lock_contention: LockContention = LockContention.DEFAULT
|
) -> ReadStatement:
|
"""Execute a read operation with `SHARED LOCK`. Only one lock can be
|
active at a time.
|
|
Args:
|
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
|
"""
|
self._lock_exclusive = False
|
self._lock_shared = True
|
self._set_lock_contention(lock_contention)
|
return self
|
|
def lock_exclusive(
|
self, lock_contention: LockContention = LockContention.DEFAULT
|
) -> ReadStatement:
|
"""Execute a read operation with `EXCLUSIVE LOCK`. Only one lock can be
|
active at a time.
|
|
Args:
|
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
|
"""
|
self._lock_exclusive = True
|
self._lock_shared = False
|
self._set_lock_contention(lock_contention)
|
return self
|
|
def group_by(self, *fields: str) -> ReadStatement:
|
"""Sets a grouping criteria for the resultset.
|
|
Args:
|
*fields: The string expressions identifying the grouping criteria.
|
|
Returns:
|
mysqlx.ReadStatement: ReadStatement object.
|
"""
|
self._set_group_by(*fields)
|
return self
|
|
def having(self, condition: str) -> ReadStatement:
|
"""Sets a condition for records to be considered in agregate function
|
operations.
|
|
Args:
|
condition (string): A condition on the agregate functions used on
|
the grouping criteria.
|
|
Returns:
|
mysqlx.ReadStatement: ReadStatement object.
|
"""
|
self._set_having(condition)
|
return self
|
|
def execute(self) -> Union[DocResult, RowResult]:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object.
|
"""
|
return self._connection.send_find(self)
|
|
|
class FindStatement(ReadStatement):
|
"""A statement document selection on a Collection.
|
|
Args:
|
collection (mysqlx.Collection): The Collection object.
|
condition (Optional[str]): An optional expression to identify the
|
documents to be retrieved. If not specified
|
all the documents will be included on the
|
result unless a limit is set.
|
"""
|
|
def __init__(
|
self, collection: DatabaseTargetType, condition: Optional[str] = None
|
) -> None:
|
super().__init__(collection, True, condition)
|
|
def fields(self, *fields: str) -> FindStatement:
|
"""Sets a document field filter.
|
|
Args:
|
*fields: The string expressions identifying the fields to be
|
extracted.
|
|
Returns:
|
mysqlx.FindStatement: FindStatement object.
|
"""
|
return self._set_projection(*fields)
|
|
def sort(self, *clauses: str) -> FindStatement:
|
"""Sets the sorting criteria.
|
|
Args:
|
*clauses: The expression strings defining the sort criteria.
|
|
Returns:
|
mysqlx.FindStatement: FindStatement object.
|
"""
|
return self._sort(*clauses)
|
|
|
class SelectStatement(ReadStatement):
|
"""A statement for record retrieval operations on a Table.
|
|
Args:
|
table (mysqlx.Table): The Table object.
|
*fields: The fields to be retrieved.
|
"""
|
|
def __init__(self, table: DatabaseTargetType, *fields: str) -> None:
|
super().__init__(table, False)
|
self._set_projection(*fields)
|
|
def where(self, condition: str) -> SelectStatement:
|
"""Sets the search condition to filter.
|
|
Args:
|
condition (str): Sets the search condition to filter records.
|
|
Returns:
|
mysqlx.SelectStatement: SelectStatement object.
|
"""
|
return self._set_where(condition)
|
|
def order_by(self, *clauses: str) -> SelectStatement:
|
"""Sets the order by criteria.
|
|
Args:
|
*clauses: The expression strings defining the order by criteria.
|
|
Returns:
|
mysqlx.SelectStatement: SelectStatement object.
|
"""
|
return self._sort(*clauses)
|
|
def get_sql(self) -> str:
|
"""Returns the generated SQL.
|
|
Returns:
|
str: The generated SQL.
|
"""
|
where = f" WHERE {self._where_str}" if self.has_where else ""
|
group_by = f" GROUP BY {self._grouping_str}" if self.has_group_by else ""
|
having = f" HAVING {self._having}" if self.has_having else ""
|
order_by = f" ORDER BY {self._sort_str}" if self.has_sort else ""
|
limit = (
|
f" LIMIT {self._limit_row_count} OFFSET {self._limit_offset}"
|
if self.has_limit
|
else ""
|
)
|
stmt = (
|
f"SELECT {self._projection_str or '*'} "
|
f"FROM {self.schema.name}.{self.target.name}"
|
f"{where}{group_by}{having}{order_by}{limit}"
|
)
|
return stmt
|
|
|
class InsertStatement(WriteStatement):
|
"""A statement for insert operations on Table.
|
|
Args:
|
table (mysqlx.Table): The Table object.
|
*fields: The fields to be inserted.
|
"""
|
|
def __init__(self, table: DatabaseTargetType, *fields: Any) -> None:
|
super().__init__(table, False)
|
self._fields: Union[List, Tuple] = flexible_params(*fields)
|
|
def values(self, *values: Any) -> InsertStatement:
|
"""Set the values to be inserted.
|
|
Args:
|
*values: The values of the columns to be inserted.
|
|
Returns:
|
mysqlx.InsertStatement: InsertStatement object.
|
"""
|
self._values.append(list(flexible_params(*values)))
|
return self
|
|
def execute(self) -> Result:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object.
|
"""
|
return self._connection.send_insert(self)
|
|
|
class UpdateStatement(FilterableStatement):
|
"""A statement for record update operations on a Table.
|
|
Args:
|
table (mysqlx.Table): The Table object.
|
|
.. versionchanged:: 8.0.12
|
The ``fields`` parameters were removed.
|
"""
|
|
def __init__(self, table: DatabaseTargetType) -> None:
|
super().__init__(target=table, doc_based=False)
|
self._update_ops: Dict[str, Any] = {}
|
|
def where(self, condition: str) -> UpdateStatement:
|
"""Sets the search condition to filter.
|
|
Args:
|
condition (str): Sets the search condition to filter records.
|
|
Returns:
|
mysqlx.UpdateStatement: UpdateStatement object.
|
"""
|
return self._set_where(condition)
|
|
def order_by(self, *clauses: str) -> UpdateStatement:
|
"""Sets the order by criteria.
|
|
Args:
|
*clauses: The expression strings defining the order by criteria.
|
|
Returns:
|
mysqlx.UpdateStatement: UpdateStatement object.
|
"""
|
return self._sort(*clauses)
|
|
def get_update_ops(self) -> Dict[str, Any]:
|
"""Returns the list of update operations.
|
|
Returns:
|
`list`: The list of update operations.
|
"""
|
return self._update_ops
|
|
def set(self, field: str, value: Any) -> UpdateStatement:
|
"""Updates the column value on records in a table.
|
|
Args:
|
field (string): The column name to be updated.
|
value (object): The value to be set on the specified column.
|
|
Returns:
|
mysqlx.UpdateStatement: UpdateStatement object.
|
"""
|
self._update_ops[field] = UpdateSpec(
|
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"),
|
field,
|
value,
|
)
|
self._changed = True
|
return self
|
|
def execute(self) -> Result:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object
|
|
Raises:
|
ProgrammingError: If condition was not set.
|
"""
|
if not self.has_where:
|
raise ProgrammingError("No condition was found for update")
|
return self._connection.send_update(self)
|
|
|
class RemoveStatement(FilterableStatement):
|
"""A statement for document removal from a collection.
|
|
Args:
|
collection (mysqlx.Collection): The Collection object.
|
condition (str): Sets the search condition to identify the documents
|
to be removed.
|
|
.. versionchanged:: 8.0.12
|
The ``condition`` parameter was added.
|
"""
|
|
def __init__(self, collection: DatabaseTargetType, condition: str) -> None:
|
super().__init__(target=collection, condition=condition)
|
|
def sort(self, *clauses: str) -> RemoveStatement:
|
"""Sets the sorting criteria.
|
|
Args:
|
*clauses: The expression strings defining the sort criteria.
|
|
Returns:
|
mysqlx.FindStatement: FindStatement object.
|
"""
|
return self._sort(*clauses)
|
|
def execute(self) -> Result:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object.
|
|
Raises:
|
ProgrammingError: If condition was not set.
|
"""
|
if not self.has_where:
|
raise ProgrammingError("No condition was found for remove")
|
return self._connection.send_delete(self)
|
|
|
class DeleteStatement(FilterableStatement):
|
"""A statement that drops a table.
|
|
Args:
|
table (mysqlx.Table): The Table object.
|
|
.. versionchanged:: 8.0.12
|
The ``condition`` parameter was removed.
|
"""
|
|
def __init__(self, table: DatabaseTargetType) -> None:
|
super().__init__(target=table, doc_based=False)
|
|
def where(self, condition: str) -> DeleteStatement:
|
"""Sets the search condition to filter.
|
|
Args:
|
condition (str): Sets the search condition to filter records.
|
|
Returns:
|
mysqlx.DeleteStatement: DeleteStatement object.
|
"""
|
return self._set_where(condition)
|
|
def order_by(self, *clauses: str) -> DeleteStatement:
|
"""Sets the order by criteria.
|
|
Args:
|
*clauses: The expression strings defining the order by criteria.
|
|
Returns:
|
mysqlx.DeleteStatement: DeleteStatement object.
|
"""
|
return self._sort(*clauses)
|
|
def execute(self) -> Result:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object.
|
|
Raises:
|
ProgrammingError: If condition was not set.
|
"""
|
if not self.has_where:
|
raise ProgrammingError("No condition was found for delete")
|
return self._connection.send_delete(self)
|
|
|
class CreateCollectionIndexStatement(Statement):
|
"""A statement that creates an index on a collection.
|
|
Args:
|
collection (mysqlx.Collection): Collection.
|
index_name (string): Index name.
|
index_desc (dict): A dictionary containing the fields members that
|
constraints the index to be created. It must have
|
the form as shown in the following::
|
|
{"fields": [{"field": member_path,
|
"type": member_type,
|
"required": member_required,
|
"collation": collation,
|
"options": options,
|
"srid": srid},
|
# {... more members,
|
# repeated as many times
|
# as needed}
|
],
|
"type": type}
|
"""
|
|
def __init__(
|
self,
|
collection: DatabaseTargetType,
|
index_name: str,
|
index_desc: Dict[str, Any],
|
) -> None:
|
super().__init__(target=collection)
|
self._index_desc: Dict[str, Any] = copy.deepcopy(index_desc)
|
self._index_name: str = index_name
|
self._fields_desc: List[Dict[str, Any]] = self._index_desc.pop("fields", [])
|
|
def execute(self) -> Result:
|
"""Execute the statement.
|
|
Returns:
|
mysqlx.Result: Result object.
|
"""
|
# Validate index name is a valid identifier
|
if self._index_name is None:
|
raise ProgrammingError(ERR_INVALID_INDEX_NAME.format(self._index_name))
|
try:
|
parsed_ident = ExprParser(self._index_name).expr().get_message()
|
|
# The message is type dict when the Protobuf cext is used
|
if isinstance(parsed_ident, dict):
|
if parsed_ident["type"] != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"):
|
raise ProgrammingError(
|
ERR_INVALID_INDEX_NAME.format(self._index_name)
|
)
|
else:
|
if parsed_ident.type != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"):
|
raise ProgrammingError(
|
ERR_INVALID_INDEX_NAME.format(self._index_name)
|
)
|
|
except (ValueError, AttributeError) as err:
|
raise ProgrammingError(
|
ERR_INVALID_INDEX_NAME.format(self._index_name)
|
) from err
|
|
# Validate members that constraint the index
|
if not self._fields_desc:
|
raise ProgrammingError(
|
"Required member 'fields' not found in the given index "
|
f"description: {self._index_desc}"
|
)
|
|
if not isinstance(self._fields_desc, list):
|
raise ProgrammingError("Required member 'fields' must contain a list")
|
args: Dict[str, Any] = {}
|
args["name"] = self._index_name
|
args["collection"] = self._target.name
|
args["schema"] = self._target.schema.name
|
if "type" in self._index_desc:
|
args["type"] = self._index_desc.pop("type")
|
else:
|
args["type"] = "INDEX"
|
args["unique"] = self._index_desc.pop("unique", False)
|
# Currently unique indexes are not supported:
|
if args["unique"]:
|
raise NotSupportedError("Unique indexes are not supported.")
|
args["constraint"] = []
|
|
if self._index_desc:
|
raise ProgrammingError(f"Unidentified fields: {self._index_desc}")
|
|
try:
|
for field_desc in self._fields_desc:
|
constraint = {}
|
constraint["member"] = field_desc.pop("field")
|
constraint["type"] = field_desc.pop("type")
|
constraint["required"] = field_desc.pop("required", False)
|
constraint["array"] = field_desc.pop("array", False)
|
if not isinstance(constraint["required"], bool):
|
raise TypeError("Field member 'required' must be Boolean")
|
if not isinstance(constraint["array"], bool):
|
raise TypeError("Field member 'array' must be Boolean")
|
if args["type"].upper() == "SPATIAL" and not constraint["required"]:
|
raise ProgrammingError(
|
"Field member 'required' must be set to 'True' when "
|
"index type is set to 'SPATIAL'"
|
)
|
if args["type"].upper() == "INDEX" and constraint["type"] == "GEOJSON":
|
raise ProgrammingError(
|
"Index 'type' must be set to 'SPATIAL' when field "
|
"type is set to 'GEOJSON'"
|
)
|
if "collation" in field_desc:
|
if not constraint["type"].upper().startswith("TEXT"):
|
raise ProgrammingError(
|
"The 'collation' member can only be used when "
|
"field type is set to "
|
f"'{constraint['type'].upper()}'"
|
)
|
constraint["collation"] = field_desc.pop("collation")
|
# "options" and "srid" fields in IndexField can be
|
# present only if "type" is set to "GEOJSON"
|
if "options" in field_desc:
|
if constraint["type"].upper() != "GEOJSON":
|
raise ProgrammingError(
|
"The 'options' member can only be used when "
|
"index type is set to 'GEOJSON'"
|
)
|
constraint["options"] = field_desc.pop("options")
|
if "srid" in field_desc:
|
if constraint["type"].upper() != "GEOJSON":
|
raise ProgrammingError(
|
"The 'srid' member can only be used when index "
|
"type is set to 'GEOJSON'"
|
)
|
constraint["srid"] = field_desc.pop("srid")
|
args["constraint"].append(constraint)
|
except KeyError as err:
|
raise ProgrammingError(
|
f"Required inner member {err} not found in constraint: {field_desc}"
|
) from err
|
|
for field_desc in self._fields_desc:
|
if field_desc:
|
raise ProgrammingError(f"Unidentified inner fields: {field_desc}")
|
|
return self._connection.execute_nonquery(
|
"mysqlx", "create_collection_index", True, args
|
)
|