diff --git a/test/cluster/test_audit.py b/test/cluster/test_audit.py index 37dc3b3eb5..48cca7b363 100644 --- a/test/cluster/test_audit.py +++ b/test/cluster/test_audit.py @@ -29,7 +29,7 @@ import pytest from cassandra import AlreadyExists, AuthenticationFailed, ConsistencyLevel, InvalidRequest, Unauthorized, Unavailable, WriteFailure from cassandra.auth import PlainTextAuthProvider from cassandra.cluster import NoHostAvailable, Session, EXEC_PROFILE_DEFAULT -from cassandra.query import SimpleStatement, named_tuple_factory +from cassandra.query import BatchStatement, BatchType, SimpleStatement, named_tuple_factory from test.cluster.dtest.dtest_class import create_ks, wait_for from test.cluster.dtest.tools.assertions import assert_invalid @@ -1609,6 +1609,55 @@ class CQLAuditTester(AuditTester): with self.assert_entries_were_added(session, expected_entries, merge_duplicate_rows=False): session.execute(batch_query) + async def _test_batch_native_protocol(self, helper_class): + """ + Native protocol BATCH message (as opposed to CQL text batch). + + Reproducer for a bug where batches sent via the native + protocol BATCH message were not audited. The driver's BatchStatement + sends a native-protocol BATCH (opcode 0x0D) which is handled by + process_batch_internal in transport/server.cc — a different code path + from a textual BEGIN BATCH … APPLY BATCH sent as a QUERY message. + """ + with helper_class() as helper: + session = await self.prepare(helper=helper) + + session.execute( + """ + CREATE TABLE test_batch_native ( + pk int PRIMARY KEY, + v text + ) + """ + ) + + # Unprepared native-protocol batch (SimpleStatement inside BatchStatement) + batch = BatchStatement(batch_type=BatchType.UNLOGGED) + batch.add(SimpleStatement("INSERT INTO test_batch_native (pk, v) VALUES (%s, %s)"), (1, "val1")) + batch.add(SimpleStatement("INSERT INTO test_batch_native (pk, v) VALUES (%s, %s)"), (2, "val2")) + + expected_entries = [ + AuditEntry(category="DML", statement="INSERT INTO test_batch_native (pk, v) VALUES (1, 'val1')", table="test_batch_native", ks="ks", user="anonymous", cl="ONE", error=False), + AuditEntry(category="DML", statement="INSERT INTO test_batch_native (pk, v) VALUES (2, 'val2')", table="test_batch_native", ks="ks", user="anonymous", cl="ONE", error=False), + ] + + with self.assert_entries_were_added(session, expected_entries, merge_duplicate_rows=False): + session.execute(batch) + + # Prepared native-protocol batch + prepared = session.prepare("INSERT INTO test_batch_native (pk, v) VALUES (?, ?)") + batch_prepared = BatchStatement(batch_type=BatchType.UNLOGGED) + batch_prepared.add(prepared, (3, "val3")) + batch_prepared.add(prepared, (4, "val4")) + + expected_entries_prepared = [ + AuditEntry(category="DML", statement="INSERT INTO test_batch_native (pk, v) VALUES (?, ?)", table="test_batch_native", ks="ks", user="anonymous", cl="ONE", error=False), + AuditEntry(category="DML", statement="INSERT INTO test_batch_native (pk, v) VALUES (?, ?)", table="test_batch_native", ks="ks", user="anonymous", cl="ONE", error=False), + ] + + with self.assert_entries_were_added(session, expected_entries_prepared, merge_duplicate_rows=False): + session.execute(batch_prepared) + async def _test_service_level_statements(self): """ Test auditing service level statements - ones that use the ADMIN audit category. @@ -1833,6 +1882,7 @@ async def test_audit_table_noauth(manager: ManagerClient): await t._test_negative_audit_records_query() await t._test_prepare(AuditBackendTable) await t._test_batch(AuditBackendTable) + await t._test_batch_native_protocol(AuditBackendTable) # AuditBackendTable, auth (cassandra), rf=1 @@ -1916,6 +1966,7 @@ async def test_audit_syslog_noauth(manager: ManagerClient): await t._test_audit_categories_part3(Syslog) await t._test_prepare(Syslog) await t._test_batch(Syslog) + await t._test_batch_native_protocol(Syslog) # AuditBackendSyslog, auth, rf=1 @@ -1944,6 +1995,7 @@ async def test_audit_composite_noauth(manager: ManagerClient): await t._test_audit_categories_part3(Composite) await t._test_prepare(Composite) await t._test_batch(Composite) + await t._test_batch_native_protocol(Composite) # AuditBackendComposite, auth, rf=1