1+ import re
12import typing
23from collections import deque
34from itertools import count , islice
@@ -185,7 +186,7 @@ def executemany(self: "Cursor", operation, param_sets) -> "Cursor":
185186 self ._row_count = - 1 if - 1 in rowcounts else sum (rowcounts )
186187 return self
187188
188- def fetchone (self : "Cursor" ) -> typing .Optional ["Cursor" ]:
189+ def fetchone (self : "Cursor" ) -> typing .Optional [typing . List ]:
189190 """Fetch the next row of a query result set.
190191
191192 This method is part of the `DBAPI 2.0 specification
@@ -196,7 +197,7 @@ def fetchone(self: "Cursor") -> typing.Optional["Cursor"]:
196197 are available.
197198 """
198199 try :
199- return typing . cast ( "Cursor" , next (self ) )
200+ return next (self )
200201 except StopIteration :
201202 return None
202203 except TypeError :
@@ -271,7 +272,7 @@ def setoutputsize(self: "Cursor", size, column=None):
271272 """
272273 pass
273274
274- def __next__ (self : "Cursor" ):
275+ def __next__ (self : "Cursor" ) -> typing . List :
275276 try :
276277 return self ._cached_rows .popleft ()
277278 except IndexError :
@@ -311,16 +312,48 @@ def fetch_dataframe(self: "Cursor", num: typing.Optional[int] = None) -> typing.
311312 return None
312313 return pandas .DataFrame (result , columns = columns )
313314
315+ def __is_valid_table (self : "Cursor" , table : str ) -> bool :
316+ split_table_name : typing .List [str ] = table .split ("." )
317+
318+ if len (split_table_name ) > 2 :
319+ return False
320+
321+ q : str = "select 1 from information_schema.tables where table_name = ?"
322+
323+ temp = self .paramstyle
324+ self .paramstyle = "qmark"
325+
326+ try :
327+ if len (split_table_name ) == 2 :
328+ q += " and table_schema = ?"
329+ self .execute (q , (split_table_name [1 ], split_table_name [0 ]))
330+ else :
331+ self .execute (q , (split_table_name [0 ],))
332+ except :
333+ raise
334+ finally :
335+ # reset paramstyle to it's original value
336+ self .paramstyle = temp
337+
338+ result = self .fetchone ()
339+
340+ return result [0 ] == 1 if result is not None else False
341+
314342 def write_dataframe (self : "Cursor" , df : "pandas.DataFrame" , table : str ) -> None :
315343 """write same structure dataframe into Redshift database"""
316344 try :
317345 import pandas
318346 except ModuleNotFoundError :
319347 raise ModuleNotFoundError (MISSING_MODULE_ERROR_MSG .format (module = "pandas" ))
320348
349+ if not self .__is_valid_table (table ):
350+ raise InterfaceError ("Invalid table name passed to write_dataframe: {}" .format (table ))
351+ sanitized_table_name : str = self .__sanitize_str (table )
321352 arrays : "numpy.ndarray" = df .values
322353 placeholder : str = ", " .join (["%s" ] * len (arrays [0 ]))
323- sql : str = "insert into {table} values ({placeholder})" .format (table = table , placeholder = placeholder )
354+ sql : str = "insert into {table} values ({placeholder})" .format (
355+ table = sanitized_table_name , placeholder = placeholder
356+ )
324357 if len (arrays ) == 1 :
325358 self .execute (sql , arrays [0 ])
326359 elif len (arrays ) > 1 :
@@ -361,16 +394,33 @@ def get_procedures(
361394 " LEFT JOIN pg_catalog.pg_namespace pn ON (c.relnamespace=pn.oid AND pn.nspname='pg_catalog') "
362395 " WHERE p.pronamespace=n.oid "
363396 )
397+ query_args : typing .List [str ] = []
364398 if schema_pattern is not None :
365- sql += " AND n.nspname LIKE {schema}" .format (schema = self .__escape_quotes (schema_pattern ))
399+ sql += " AND n.nspname LIKE ?"
400+ query_args .append (self .__sanitize_str (schema_pattern ))
366401 else :
367402 sql += "and pg_function_is_visible(p.prooid)"
368403
369404 if procedure_name_pattern is not None :
370- sql += " AND p.proname LIKE {procedure}" .format (procedure = self .__escape_quotes (procedure_name_pattern ))
405+ sql += " AND p.proname LIKE ?"
406+ query_args .append (self .__sanitize_str (procedure_name_pattern ))
371407 sql += " ORDER BY PROCEDURE_SCHEM, PROCEDURE_NAME, p.prooid::text "
372408
373- self .execute (sql )
409+ if len (query_args ) > 0 :
410+ # temporarily use qmark paramstyle
411+ temp = self .paramstyle
412+ self .paramstyle = "qmark"
413+
414+ try :
415+ self .execute (sql , tuple (query_args ))
416+ except :
417+ raise
418+ finally :
419+ # reset the original value of paramstyle
420+ self .paramstyle = temp
421+ else :
422+ self .execute (sql )
423+
374424 procedures : tuple = self .fetchall ()
375425 return procedures
376426
@@ -383,11 +433,25 @@ def get_schemas(
383433 " OR nspname = (pg_catalog.current_schemas(true))[1]) AND (nspname !~ '^pg_toast_temp_' "
384434 " OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_')) "
385435 )
436+ query_args : typing .List [str ] = []
386437 if schema_pattern is not None :
387- sql += " AND nspname LIKE {schema}" .format (schema = self .__escape_quotes (schema_pattern ))
438+ sql += " AND nspname LIKE ?"
439+ query_args .append (self .__sanitize_str (schema_pattern ))
388440 sql += " ORDER BY TABLE_SCHEM"
389441
390- self .execute (sql )
442+ if len (query_args ) == 1 :
443+ # temporarily use qmark paramstyle
444+ temp = self .paramstyle
445+ self .paramstyle = "qmark"
446+ try :
447+ self .execute (sql , tuple (query_args ))
448+ except :
449+ raise
450+ finally :
451+ self .paramstyle = temp
452+ else :
453+ self .execute (sql )
454+
391455 schemas : tuple = self .fetchall ()
392456 return schemas
393457
@@ -418,13 +482,28 @@ def get_primary_keys(
418482 "i.indisprimary AND "
419483 "ct.relnamespace = n.oid "
420484 )
485+ query_args : typing .List [str ] = []
421486 if schema is not None :
422- sql += " AND n.nspname = {schema}" .format (schema = self .__escape_quotes (schema ))
487+ sql += " AND n.nspname = ?"
488+ query_args .append (self .__sanitize_str (schema ))
423489 if table is not None :
424- sql += " AND ct.relname = {table}" .format (table = self .__escape_quotes (table ))
490+ sql += " AND ct.relname = ?"
491+ query_args .append (self .__sanitize_str (table ))
425492
426493 sql += " ORDER BY table_name, pk_name, key_seq"
427- self .execute (sql )
494+
495+ if len (query_args ) > 0 :
496+ # temporarily use qmark paramstyle
497+ temp = self .paramstyle
498+ self .paramstyle = "qmark"
499+ try :
500+ self .execute (sql , tuple (query_args ))
501+ except :
502+ raise
503+ finally :
504+ self .paramstyle = temp
505+ else :
506+ self .execute (sql )
428507 keys : tuple = self .fetchall ()
429508 return keys
430509
@@ -437,15 +516,30 @@ def get_tables(
437516 ) -> tuple :
438517 """Returns the unique public tables which are user-defined within the system"""
439518 sql : str = ""
519+ sql_args : typing .Tuple [str , ...] = tuple ()
440520 schema_pattern_type : str = self .__schema_pattern_match (schema_pattern )
441521 if schema_pattern_type == "LOCAL_SCHEMA_QUERY" :
442- sql = self .__build_local_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
522+ sql , sql_args = self .__build_local_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
443523 elif schema_pattern_type == "NO_SCHEMA_UNIVERSAL_QUERY" :
444- sql = self .__build_universal_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
524+ sql , sql_args = self .__build_universal_schema_tables_query (
525+ catalog , schema_pattern , table_name_pattern , types
526+ )
445527 elif schema_pattern_type == "EXTERNAL_SCHEMA_QUERY" :
446- sql = self .__build_external_schema_tables_query (catalog , schema_pattern , table_name_pattern , types )
528+ sql , sql_args = self .__build_external_schema_tables_query (
529+ catalog , schema_pattern , table_name_pattern , types
530+ )
447531
448- self .execute (sql )
532+ if len (sql_args ) > 0 :
533+ temp = self .paramstyle
534+ self .paramstyle = "qmark"
535+ try :
536+ self .execute (sql , sql_args )
537+ except :
538+ raise
539+ finally :
540+ self .paramstyle = temp
541+ else :
542+ self .execute (sql )
449543 tables : tuple = self .fetchall ()
450544 return tables
451545
@@ -455,7 +549,7 @@ def __build_local_schema_tables_query(
455549 schema_pattern : typing .Optional [str ],
456550 table_name_pattern : typing .Optional [str ],
457551 types : list ,
458- ) -> str :
552+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
459553 sql : str = (
460554 "SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT, n.nspname AS TABLE_SCHEM, c.relname AS TABLE_NAME, "
461555 " CASE n.nspname ~ '^pg_' OR n.nspname = 'information_schema' "
@@ -502,32 +596,41 @@ def __build_local_schema_tables_query(
502596 " LEFT JOIN pg_catalog.pg_namespace dn ON (dn.oid=dc.relnamespace AND dn.nspname='pg_catalog') "
503597 " WHERE c.relnamespace = n.oid "
504598 )
505- filter_clause : str = self .__get_table_filter_clause (
599+ filter_clause , filter_args = self .__get_table_filter_clause (
506600 catalog , schema_pattern , table_name_pattern , types , "LOCAL_SCHEMA_QUERY"
507601 )
508602 orderby : str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
509603
510- return sql + filter_clause + orderby
604+ return sql + filter_clause + orderby , filter_args
511605
512606 def __get_table_filter_clause (
513607 self : "Cursor" ,
514608 catalog : typing .Optional [str ],
515609 schema_pattern : typing .Optional [str ],
516610 table_name_pattern : typing .Optional [str ],
517- types : list ,
611+ types : typing . List [ str ] ,
518612 schema_pattern_type : str ,
519- ) -> str :
613+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
520614 filter_clause : str = ""
521615 use_schemas : str = "SCHEMAS"
616+ query_args : typing .List [str ] = []
522617 if schema_pattern is not None :
523- filter_clause += " AND TABLE_SCHEM LIKE {schema}" .format (schema = self .__escape_quotes (schema_pattern ))
618+ filter_clause += " AND TABLE_SCHEM LIKE ?"
619+ query_args .append (self .__sanitize_str (schema_pattern ))
524620 if table_name_pattern is not None :
525- filter_clause += " AND TABLE_NAME LIKE {table}" .format (table = self .__escape_quotes (table_name_pattern ))
621+ filter_clause += " AND TABLE_NAME LIKE ?"
622+ query_args .append (self .__sanitize_str (table_name_pattern ))
526623 if len (types ) > 0 :
527624 if schema_pattern_type == "LOCAL_SCHEMA_QUERY" :
528625 filter_clause += " AND (false "
529626 orclause : str = ""
530627 for type in types :
628+ if type not in table_type_clauses .keys ():
629+ raise InterfaceError (
630+ "Invalid type: {} provided. types may only contain: {}" .format (
631+ type , table_type_clauses .keys ()
632+ )
633+ )
531634 clauses = table_type_clauses [type ]
532635 if len (clauses ) > 0 :
533636 cluase = clauses [use_schemas ]
@@ -538,21 +641,28 @@ def __get_table_filter_clause(
538641 filter_clause += " AND TABLE_TYPE IN ( "
539642 length = len (types )
540643 for type in types :
541- filter_clause += self .__escape_quotes (type )
644+ if type not in table_type_clauses .keys ():
645+ raise InterfaceError (
646+ "Invalid type: {} provided. types may only contain: {}" .format (
647+ type , table_type_clauses .keys ()
648+ )
649+ )
650+ filter_clause += "?"
651+ query_args .append (self .__sanitize_str (type ))
542652 length -= 1
543653 if length > 0 :
544654 filter_clause += ", "
545655 filter_clause += ") "
546656
547- return filter_clause
657+ return filter_clause , tuple ( query_args )
548658
549659 def __build_universal_schema_tables_query (
550660 self : "Cursor" ,
551661 catalog : typing .Optional [str ],
552662 schema_pattern : typing .Optional [str ],
553663 table_name_pattern : typing .Optional [str ],
554664 types : list ,
555- ) -> str :
665+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
556666 sql : str = (
557667 "SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
558668 " table_schema AS TABLE_SCHEM,"
@@ -583,20 +693,20 @@ def __build_universal_schema_tables_query(
583693 " FROM svv_tables)"
584694 " WHERE true "
585695 )
586- filter_clause : str = self .__get_table_filter_clause (
696+ filter_clause , filter_args = self .__get_table_filter_clause (
587697 catalog , schema_pattern , table_name_pattern , types , "NO_SCHEMA_UNIVERSAL_QUERY"
588698 )
589699 orderby : str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
590700 sql += filter_clause + orderby
591- return sql
701+ return sql , filter_args
592702
593703 def __build_external_schema_tables_query (
594704 self : "Cursor" ,
595705 catalog : typing .Optional [str ],
596706 schema_pattern : typing .Optional [str ],
597707 table_name_pattern : typing .Optional [str ],
598708 types : list ,
599- ) -> str :
709+ ) -> typing . Tuple [ str , typing . Tuple [ str , ...]] :
600710 sql : str = (
601711 "SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
602712 " schemaname AS table_schem,"
@@ -611,12 +721,12 @@ def __build_external_schema_tables_query(
611721 " FROM svv_external_tables)"
612722 " WHERE true "
613723 )
614- filter_clause : str = self .__get_table_filter_clause (
724+ filter_clause , filter_args = self .__get_table_filter_clause (
615725 catalog , schema_pattern , table_name_pattern , types , "EXTERNAL_SCHEMA_QUERY"
616726 )
617727 orderby : str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
618728 sql += filter_clause + orderby
619- return sql
729+ return sql , filter_args
620730
621731 def get_columns (
622732 self : "Cursor" ,
@@ -1477,5 +1587,8 @@ def __schema_pattern_match(self: "Cursor", schema_pattern: typing.Optional[str])
14771587 else :
14781588 return "NO_SCHEMA_UNIVERSAL_QUERY"
14791589
1590+ def __sanitize_str (self : "Cursor" , s : str ) -> str :
1591+ return re .sub (r"[-;/'\"\n\r ]" , "" , s )
1592+
14801593 def __escape_quotes (self : "Cursor" , s : str ) -> str :
1481- return "'{s}'" .format (s = s )
1594+ return "'{s}'" .format (s = self . __sanitize_str ( s ) )
0 commit comments