Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions MssqlMcp/dotnet/MssqlMcp.Tests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,122 @@ public async Task SqlInjection_NotExecuted_When_QueryFails()
Assert.NotNull(describeResult);
Assert.True(describeResult.Success);
}

[Fact]
public async Task ReadOnlyMode_ListTables_ReturnsSuccess_WhenReadOnlyIsTrue()
{
// Set READONLY environment variable
Environment.SetEnvironmentVariable("READONLY", "true");
try
{
var result = await _tools.ListTables() as DbOperationResult;
Assert.NotNull(result);
Assert.True(result.Success);
Assert.NotNull(result.Data);
}
finally
{
// Clean up environment variable
Environment.SetEnvironmentVariable("READONLY", null);
}
}

[Fact]
public async Task ReadOnlyMode_CreateTable_ReturnsError_WhenReadOnlyIsTrue()
{
// Set READONLY environment variable
Environment.SetEnvironmentVariable("READONLY", "true");
try
{
var sql = $"CREATE TABLE {_tableName} (Id INT PRIMARY KEY)";
var result = await _tools.CreateTable(sql) as DbOperationResult;
Assert.NotNull(result);
Assert.False(result.Success);
Assert.Contains("CREATE TABLE operation is not allowed in READONLY mode", result.Error ?? string.Empty);
}
finally
{
// Clean up environment variable
Environment.SetEnvironmentVariable("READONLY", null);
}
}

[Fact]
public async Task ReadOnlyMode_InsertData_ReturnsError_WhenReadOnlyIsTrue()
{
// Set READONLY environment variable
Environment.SetEnvironmentVariable("READONLY", "true");
try
{
var sql = $"INSERT INTO {_tableName} (Id) VALUES (1)";
var result = await _tools.InsertData(sql) as DbOperationResult;
Assert.NotNull(result);
Assert.False(result.Success);
Assert.Contains("INSERT operation is not allowed in READONLY mode", result.Error ?? string.Empty);
}
finally
{
// Clean up environment variable
Environment.SetEnvironmentVariable("READONLY", null);
}
}

[Fact]
public async Task ReadOnlyMode_UpdateData_ReturnsError_WhenReadOnlyIsTrue()
{
// Set READONLY environment variable
Environment.SetEnvironmentVariable("READONLY", "true");
try
{
var sql = $"UPDATE {_tableName} SET Id = 2 WHERE Id = 1";
var result = await _tools.UpdateData(sql) as DbOperationResult;
Assert.NotNull(result);
Assert.False(result.Success);
Assert.Contains("UPDATE operation is not allowed in READONLY mode", result.Error ?? string.Empty);
}
finally
{
// Clean up environment variable
Environment.SetEnvironmentVariable("READONLY", null);
}
}

[Fact]
public async Task ReadOnlyMode_DropTable_ReturnsError_WhenReadOnlyIsTrue()
{
// Set READONLY environment variable
Environment.SetEnvironmentVariable("READONLY", "true");
try
{
var sql = $"DROP TABLE IF EXISTS {_tableName}";
var result = await _tools.DropTable(sql) as DbOperationResult;
Assert.NotNull(result);
Assert.False(result.Success);
Assert.Contains("DROP TABLE operation is not allowed in READONLY mode", result.Error ?? string.Empty);
}
finally
{
// Clean up environment variable
Environment.SetEnvironmentVariable("READONLY", null);
}
}

[Fact]
public async Task TestConnection_ReturnsSuccess_WhenConnectionIsValid()
{
var result = await _tools.TestConnection() as DbOperationResult;
Assert.NotNull(result);
Assert.True(result.Success);
Assert.NotNull(result.Data);

var dict = result.Data as System.Collections.IDictionary;
Assert.NotNull(dict);
Assert.True(dict.Contains("ConnectionState"));
Assert.True(dict.Contains("Database"));
Assert.True(dict.Contains("ServerVersion"));
Assert.True(dict.Contains("DataSource"));
Assert.True(dict.Contains("ConnectionTimeout"));
Assert.Equal("Open", dict["ConnectionState"]?.ToString());
}
}
}
5 changes: 5 additions & 0 deletions MssqlMcp/dotnet/MssqlMcp/Tools/CreateTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ public partial class Tools
public async Task<DbOperationResult> CreateTable(
[Description("CREATE TABLE SQL statement")] string sql)
{
if (IsReadOnlyMode)
{
return new DbOperationResult(success: false, error: "CREATE TABLE operation is not allowed in READONLY mode");
}

var conn = await _connectionFactory.GetOpenConnectionAsync();
try
{
Expand Down
5 changes: 5 additions & 0 deletions MssqlMcp/dotnet/MssqlMcp/Tools/DropTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ public partial class Tools
public async Task<DbOperationResult> DropTable(
[Description("DROP TABLE SQL statement")] string sql)
{
if (IsReadOnlyMode)
{
return new DbOperationResult(success: false, error: "DROP TABLE operation is not allowed in READONLY mode");
}

var conn = await _connectionFactory.GetOpenConnectionAsync();
try
{
Expand Down
5 changes: 5 additions & 0 deletions MssqlMcp/dotnet/MssqlMcp/Tools/InsertData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ public partial class Tools
public async Task<DbOperationResult> InsertData(
[Description("INSERT SQL statement")] string sql)
{
if (IsReadOnlyMode)
{
return new DbOperationResult(success: false, error: "INSERT operation is not allowed in READONLY mode");
}

var conn = await _connectionFactory.GetOpenConnectionAsync();
try
{
Expand Down
197 changes: 192 additions & 5 deletions MssqlMcp/dotnet/MssqlMcp/Tools/ReadData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,197 @@
// Licensed under the MIT license.

using System.ComponentModel;
using System.Text.RegularExpressions;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Server;

namespace Mssql.McpServer;
public partial class Tools
{
// List of dangerous SQL keywords that should not be allowed
private static readonly string[] DangerousKeywords =
[
"DELETE", "DROP", "UPDATE", "INSERT", "ALTER", "CREATE",
"TRUNCATE", "EXEC", "EXECUTE", "MERGE", "REPLACE",
"GRANT", "REVOKE", "COMMIT", "ROLLBACK", "TRANSACTION",
"BEGIN", "DECLARE", "SET", "USE", "BACKUP",
"RESTORE", "KILL", "SHUTDOWN", "WAITFOR", "OPENROWSET",
"OPENDATASOURCE", "OPENQUERY", "OPENXML", "BULK"
];

// Regex patterns to detect common SQL injection techniques
private static readonly Regex[] DangerousPatterns =
[
// Semicolon followed by dangerous keywords
new(@";\s*(DELETE|DROP|UPDATE|INSERT|ALTER|CREATE|TRUNCATE|EXEC|EXECUTE|MERGE|REPLACE|GRANT|REVOKE)", RegexOptions.IgnoreCase),

// UNION injection attempts with dangerous keywords
new(@"UNION\s+(?:ALL\s+)?SELECT.*?(DELETE|DROP|UPDATE|INSERT|ALTER|CREATE|TRUNCATE|EXEC|EXECUTE)", RegexOptions.IgnoreCase),

// Comment-based injection attempts
new(@"--.*?(DELETE|DROP|UPDATE|INSERT|ALTER|CREATE|TRUNCATE|EXEC|EXECUTE)", RegexOptions.IgnoreCase),
new(@"/\*.*?(DELETE|DROP|UPDATE|INSERT|ALTER|CREATE|TRUNCATE|EXEC|EXECUTE).*?\*/", RegexOptions.IgnoreCase),

// Stored procedure execution patterns
new(@"EXEC\s*\(", RegexOptions.IgnoreCase),
new(@"EXECUTE\s*\(", RegexOptions.IgnoreCase),
new(@"sp_", RegexOptions.IgnoreCase),
new(@"xp_", RegexOptions.IgnoreCase),

// Bulk operations
new(@"BULK\s+INSERT", RegexOptions.IgnoreCase),
new(@"OPENROWSET", RegexOptions.IgnoreCase),
new(@"OPENDATASOURCE", RegexOptions.IgnoreCase),

// System functions that could be dangerous
new(@"@@", RegexOptions.None),
new(@"SYSTEM_USER", RegexOptions.IgnoreCase),
new(@"USER_NAME", RegexOptions.IgnoreCase),
new(@"DB_NAME", RegexOptions.IgnoreCase),
new(@"HOST_NAME", RegexOptions.IgnoreCase),

// Time delay attacks
new(@"WAITFOR\s+DELAY", RegexOptions.IgnoreCase),
new(@"WAITFOR\s+TIME", RegexOptions.IgnoreCase),

// Multiple statements (semicolon not at end)
new(@";\s*\w", RegexOptions.None),

// String concatenation that might hide malicious code
new(@"\+\s*CHAR\s*\(", RegexOptions.IgnoreCase),
new(@"\+\s*NCHAR\s*\(", RegexOptions.IgnoreCase),
new(@"\+\s*ASCII\s*\(", RegexOptions.IgnoreCase)
];

/// <summary>
/// Validates the SQL query for security issues
/// </summary>
/// <param name="query">The SQL query to validate</param>
/// <returns>Validation result with success flag and error message if invalid</returns>
private (bool IsValid, string? Error) ValidateQuery(string query)
{
if (string.IsNullOrWhiteSpace(query))
{
return (false, "Query must be a non-empty string");
}

// Remove comments and normalize whitespace for analysis
var cleanQuery = Regex.Replace(query, @"--.*$", "", RegexOptions.Multiline) // Remove line comments
.Replace("/*", "").Replace("*/", "") // Remove block comments (simple approach)
.Trim();

cleanQuery = Regex.Replace(cleanQuery, @"\s+", " "); // Normalize whitespace

if (string.IsNullOrWhiteSpace(cleanQuery))
{
return (false, "Query cannot be empty after removing comments");
}

var upperQuery = cleanQuery.ToUpperInvariant();

// Must start with SELECT
if (!upperQuery.StartsWith("SELECT"))
{
return (false, "Query must start with SELECT for security reasons");
}

// Check for dangerous keywords in the cleaned query using word boundaries
foreach (var keyword in DangerousKeywords)
{
// Use word boundary regex to match only complete keywords, not parts of words
var keywordRegex = new Regex($@"(^|\s|[^A-Za-z0-9_]){keyword}($|\s|[^A-Za-z0-9_])", RegexOptions.IgnoreCase);
if (keywordRegex.IsMatch(upperQuery))
{
return (false, $"Dangerous keyword '{keyword}' detected in query. Only SELECT operations are allowed.");
}
}

// Check for dangerous patterns using regex
foreach (var pattern in DangerousPatterns)
{
if (pattern.IsMatch(query))
{
return (false, "Potentially malicious SQL pattern detected. Only simple SELECT queries are allowed.");
}
}

// Additional validation: Check for multiple statements
var statements = cleanQuery.Split(';', StringSplitOptions.RemoveEmptyEntries);
if (statements.Length > 1)
{
return (false, "Multiple SQL statements are not allowed. Use only a single SELECT statement.");
}

// Check for suspicious string patterns that might indicate obfuscation
if (query.Contains("CHAR(") || query.Contains("NCHAR(") || query.Contains("ASCII("))
{
return (false, "Character conversion functions are not allowed as they may be used for obfuscation.");
}

// Limit query length to prevent potential DoS
if (query.Length > 10000)
{
return (false, "Query is too long. Maximum allowed length is 10,000 characters.");
}

return (true, null);
}

/// <summary>
/// Sanitizes the query result to prevent any potential security issues
/// </summary>
/// <param name="data">The query result data</param>
/// <returns>Sanitized data</returns>
private List<Dictionary<string, object?>> SanitizeResult(List<Dictionary<string, object?>> data)
{
// Limit the number of returned records to prevent memory issues
const int maxRecords = 10000;
if (data.Count > maxRecords)
{
_logger.LogWarning("Query returned {Count} records, limiting to {MaxRecords}", data.Count, maxRecords);
data = data.Take(maxRecords).ToList();
}

return data.Select(record =>
{
var sanitized = new Dictionary<string, object?>();
foreach (var (key, value) in record)
{
// Sanitize column names (remove any suspicious characters)
var sanitizedKey = Regex.Replace(key, @"[^\w\s\-_.]", "");
if (sanitizedKey != key)
{
_logger.LogWarning("Column name sanitized: {Original} -> {Sanitized}", key, sanitizedKey);
}
sanitized[sanitizedKey] = value;
}
return sanitized;
}).ToList();
}

[McpServerTool(
Title = "Read Data",
ReadOnly = true,
Idempotent = true,
Destructive = false),
Description("Executes SQL queries against SQL Database to read data")]
Description("Executes a SELECT query on an MSSQL Database table. The query must start with SELECT and cannot contain any destructive SQL operations for security reasons.")]
public async Task<DbOperationResult> ReadData(
[Description("SQL query to execute")] string sql)
[Description("SQL SELECT query to execute (must start with SELECT and cannot contain destructive operations). Example: SELECT * FROM movies WHERE genre = 'comedy'")] string sql)
{
// Validate the query for security issues
var (isValid, error) = ValidateQuery(sql);
if (!isValid)
{
_logger.LogWarning("Security validation failed for query: {QueryStart}...", sql.Length > 100 ? sql[..100] : sql);
return new DbOperationResult(success: false, error: $"Security validation failed: {error}");
}

// Log the query for audit purposes (in production, consider more secure logging)
_logger.LogInformation("Executing validated SELECT query: {QueryStart}{Truncated}",
sql.Length > 200 ? sql[..200] : sql,
sql.Length > 200 ? "..." : "");

var conn = await _connectionFactory.GetOpenConnectionAsync();
try
{
Expand All @@ -35,13 +210,25 @@ public async Task<DbOperationResult> ReadData(
}
results.Add(row);
}
return new DbOperationResult(success: true, data: results);

// Sanitize the result
var sanitizedResults = SanitizeResult(results);

return new DbOperationResult(
success: true,
data: sanitizedResults);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "ReadData failed: {Message}", ex.Message);
return new DbOperationResult(success: false, error: ex.Message);

// Don't expose internal error details to prevent information leakage
var safeErrorMessage = ex.Message.Contains("Invalid object name")
? ex.Message
: "Database query execution failed";

return new DbOperationResult(success: false, error: $"Failed to execute query: {safeErrorMessage}");
}
}
}
}
Loading