Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ my_bitflags! {
UnknownMariadbCapabilityFlags,
u32,

/// Mariadb client capability flags
/// MariaDB client capability flags
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
pub struct MariadbCapabilities: u32 {
/// Permits feedback during long-running operations
Expand Down Expand Up @@ -431,6 +431,20 @@ my_bitflags! {
}
}

my_bitflags! {
StmtBulkExecuteParamsFlags,
#[error("Unknown flags in the raw value of StmtBulkExecuteParamsFlags (raw={0:b})")]
UnknownStmtBulkExecuteParamsFlags,
u16,

/// MySql stmt execute params flags.
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
pub struct StmtBulkExecuteParamsFlags: u16 {
const SEND_UNIT_RESULTS = 64_u16;
const SEND_TYPES_TO_SERVER = 128_u16;
}
}

my_bitflags! {
ColumnFlags,
#[error("Unknown flags in the raw value of ColumnFlags (raw={0:b})")]
Expand Down Expand Up @@ -528,6 +542,22 @@ pub enum Command {
COM_BINLOG_DUMP_GTID,
COM_RESET_CONNECTION,
COM_END,
COM_STMT_BULK_EXECUTE = 0xfa_u8,
}

/// MariaDB bulk execute parameter value indicators
#[allow(non_camel_case_types)]
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
#[repr(u8)]
pub enum MariadbBulkIndicator {
/// No special indicator, normal value
BULK_INDICATOR_NONE = 0x00_u8,
/// NULL value
BULK_INDICATOR_NULL = 0x01_u8,
/// For INSERT/UPDATE, value is default. Not used
BULK_INDICATOR_DEFAULT = 0x02_u8,
/// Value is default for insert, Is ignored for update. Not used.
BULK_INDICATOR_IGNORE = 0x03_u8,
}

/// Type of state change information (part of MySql's Ok packet).
Expand Down
191 changes: 186 additions & 5 deletions src/packets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ use std::{
};

use crate::collations::CollationId;
use crate::constants::StmtBulkExecuteParamsFlags;
use crate::scramble::create_response_for_ed25519;
use crate::{
constants::{
CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN,
MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags,
StmtExecuteParamsFlags,
MariadbBulkIndicator, MariadbCapabilities, SessionStateType, StatusFlags,
StmtExecuteParamFlags, StmtExecuteParamsFlags,
},
io::{BufMutExt, ParseBuf},
misc::{
Expand Down Expand Up @@ -2762,6 +2763,182 @@ impl MySerialize for ComStmtClose {
}
}

/// Sends array of parameters to the server for the bulk execution of a prepared statement with
/// COM_STMT_BULK_EXECUTE command. This command is MariaDB only and may not be used for queries w/out
/// parameters and with empty parameter sets.
#[derive(Debug, Clone, PartialEq)]
pub struct ComStmtBulkExecuteRequestBuilder {
pub stmt_id: u32,
pub with_types: bool,
pub paramset: Vec<Vec<Value>>,
pub payload_len: usize,
pub max_payload_len: usize, /* max_allowed_packet(if known) - 4 */
}

impl ComStmtBulkExecuteRequestBuilder {
pub fn new(stmt_id: u32, max_payload: usize) -> Self {
Self {
stmt_id,
with_types: true,
paramset: Vec::new(),
payload_len: 0,
max_payload_len: max_payload,
}
}

// Resets the builder to start building a new bulk execute request. In particular - without types.
// If it's called - means that there is row to be added that did not fit previous packet. So, it should
// be always followed by add_row(). That is something it can do on its own.
pub fn next(&mut self, params: &Vec<Value>) -> () {
self.with_types = false;
self.paramset.clear();
self.payload_len = 0;
self.add_row(params);
}

// Adds a new row of parameters to the bulk execute request.
// Returns true if adding this row would exceed the max allowed packet size.
pub fn add_row(&mut self, params: &Vec<Value>) -> bool {
if self.with_types && self.payload_len == 0 {
self.payload_len = params.len() * 2;
}
let mut data_len = 0;
for p in params {
// bin_len() includes lenght encoding bytes
match p.bin_len() as usize {
0 => data_len += 1, // NULLs take 1 byte for the indicator
x => data_len += x + 1, // non-NULLs take their length + 1 byte for the indicator
}
}
// 7 = 1(command id) + 4 (stmt_id) + 2 (flags). If it's 1st row - we take it to return error
// later(when the packet is sent). In this way we can avoid eternal loops of trying to add this row.
if 7 + self.payload_len + data_len > self.max_payload_len && !self.paramset.is_empty() {
return true;
}
self.paramset.push(params.to_vec());
self.payload_len += data_len;
false
}

pub fn has_rows(&self) -> bool {
!self.paramset.is_empty()
}

pub fn build(&self) -> ComStmtBulkExecuteRequest<'_> {
ComStmtBulkExecuteRequest {
com_stmt_bulk_execute: ConstU8::new(),
stmt_id: RawInt::new(self.stmt_id),
bulk_flags: if self.with_types {
Const::new(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER)
} else {
Const::new(StmtBulkExecuteParamsFlags::empty())
},
params: &self.paramset,
}
}
}

define_header!(
ComStmtBulkExecuteHeader,
COM_STMT_BULK_EXECUTE,
InvalidComStmtBulkExecuteHeader
);

#[derive(Debug, Clone, PartialEq)]
pub struct ComStmtBulkExecuteRequest<'a> {
com_stmt_bulk_execute: ComStmtBulkExecuteHeader,
stmt_id: RawInt<LeU32>,
bulk_flags: Const<StmtBulkExecuteParamsFlags, LeU16>,
// max params / bits per byte = 8192
params: &'a Vec<Vec<Value>>,
}

impl<'a> ComStmtBulkExecuteRequest<'a> {
pub fn stmt_id(&self) -> u32 {
self.stmt_id.0
}

pub fn bulk_flags(&self) -> StmtBulkExecuteParamsFlags {
self.bulk_flags.0
}

pub fn params(&self) -> &[Vec<Value>] {
self.params.as_ref()
}
}

impl MySerialize for ComStmtBulkExecuteRequest<'_> {
fn serialize(&self, buf: &mut Vec<u8>) {
self.com_stmt_bulk_execute.serialize(&mut *buf);
self.stmt_id.serialize(&mut *buf);
self.bulk_flags.serialize(&mut *buf);

if self
.bulk_flags
.0
.contains(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER)
&& !self.params.is_empty()
{
for param in &self.params[0] {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. This will panic if SEND_TYPES_TO_SERVER is set and the params is empty that seems reachable from ComStmtBulkExecuteRequestBuilder 🤔

let (column_type, flags) = match param {
Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
Value::Bytes(_) => (
ColumnType::MYSQL_TYPE_VAR_STRING,
StmtExecuteParamFlags::empty(),
),
Value::Int(_) => (
ColumnType::MYSQL_TYPE_LONGLONG,
StmtExecuteParamFlags::empty(),
),
Value::UInt(_) => (
ColumnType::MYSQL_TYPE_LONGLONG,
StmtExecuteParamFlags::UNSIGNED,
),
Value::Float(_) => {
(ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty())
}
Value::Double(_) => (
ColumnType::MYSQL_TYPE_DOUBLE,
StmtExecuteParamFlags::empty(),
),
Value::Date(..) => (
ColumnType::MYSQL_TYPE_DATETIME,
StmtExecuteParamFlags::empty(),
),
Value::Time(..) => {
(ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty())
}
};
buf.put_slice(&[column_type as u8, flags.bits()]);
}
}

for row in self.params {
for param in row {
match param {
Value::Int(_)
| Value::UInt(_)
| Value::Float(_)
| Value::Double(_)
| Value::Date(..)
| Value::Time(..) => {
buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL
param.serialize(buf);
}
Value::Bytes(_) => {
buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL
param.serialize(buf);
}
Value::NULL => {
buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NULL as u8); // NULL indicator
}
}
}
}
}
}
// ------------------------------------------------------------------------------

define_header!(
ComRegisterSlaveHeader,
COM_REGISTER_SLAVE,
Expand Down Expand Up @@ -4129,7 +4306,7 @@ mod test {
fn should_parse_handshake_packet_with_mariadb_ext_capabilities() {
const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\
\x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x00\x00\x2a\x34\x64\
\x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";

let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
Expand All @@ -4150,6 +4327,7 @@ mod test {
assert_eq!(
hsp.mariadb_ext_capabilities(),
MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
| MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS
);
let mut output = Vec::new();
hsp.serialize(&mut output);
Expand All @@ -4169,7 +4347,10 @@ mod test {
None,
1_u32.to_be(),
)
.with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA);
.with_mariadb_ext_capabilities(
MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
| MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS,
);
let mut actual = Vec::new();
response.serialize(&mut actual);

Expand All @@ -4179,7 +4360,7 @@ mod test {
0x2d, // charset
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, // reserved
0x10, 0x00, 0x00, 0x00, // mariadb capabilities
0x14, 0x00, 0x00, 0x00, // mariadb capabilities
0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
0x00, // blank scramble
0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
Expand Down