|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
| 14 | +import tarfile |
14 | 15 | import tempfile |
15 | 16 | import unittest |
| 17 | +import zipfile |
16 | 18 | from pathlib import Path |
17 | 19 | from urllib.error import ContentTooShortError, HTTPError |
18 | 20 |
|
@@ -66,5 +68,186 @@ def test_default(self, key, file_type): |
66 | 68 | ) |
67 | 69 |
|
68 | 70 |
|
| 71 | +class TestPathTraversalProtection(unittest.TestCase): |
| 72 | + """Test cases for path traversal attack protection in extractall function.""" |
| 73 | + |
| 74 | + def test_valid_zip_extraction(self): |
| 75 | + """Test that valid zip files extract successfully without raising exceptions.""" |
| 76 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 77 | + # Create a valid zip file |
| 78 | + zip_path = Path(tmp_dir) / "valid_test.zip" |
| 79 | + extract_dir = Path(tmp_dir) / "extract" |
| 80 | + extract_dir.mkdir() |
| 81 | + |
| 82 | + # Create zip with normal file structure |
| 83 | + with zipfile.ZipFile(zip_path, "w") as zf: |
| 84 | + zf.writestr("normal_file.txt", "This is a normal file") |
| 85 | + zf.writestr("subdir/nested_file.txt", "This is a nested file") |
| 86 | + zf.writestr("another_file.json", '{"key": "value"}') |
| 87 | + |
| 88 | + # This should not raise any exception |
| 89 | + try: |
| 90 | + extractall(str(zip_path), str(extract_dir)) |
| 91 | + |
| 92 | + # Verify files were extracted correctly |
| 93 | + self.assertTrue((extract_dir / "normal_file.txt").exists()) |
| 94 | + self.assertTrue((extract_dir / "subdir" / "nested_file.txt").exists()) |
| 95 | + self.assertTrue((extract_dir / "another_file.json").exists()) |
| 96 | + |
| 97 | + # Verify content |
| 98 | + with open(extract_dir / "normal_file.txt") as f: |
| 99 | + self.assertEqual(f.read(), "This is a normal file") |
| 100 | + |
| 101 | + except Exception as e: |
| 102 | + self.fail(f"Valid zip extraction should not raise exception: {e}") |
| 103 | + |
| 104 | + def test_malicious_zip_path_traversal(self): |
| 105 | + """Test that malicious zip files with path traversal attempts raise ValueError.""" |
| 106 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 107 | + # Create malicious zip file with path traversal |
| 108 | + zip_path = Path(tmp_dir) / "malicious_test.zip" |
| 109 | + extract_dir = Path(tmp_dir) / "extract" |
| 110 | + extract_dir.mkdir() |
| 111 | + |
| 112 | + # Create zip with malicious paths |
| 113 | + with zipfile.ZipFile(zip_path, "w") as zf: |
| 114 | + # Try to write outside extraction directory |
| 115 | + zf.writestr("../../../etc/malicious.txt", "malicious content") |
| 116 | + zf.writestr("normal_file.txt", "normal content") |
| 117 | + |
| 118 | + # This should raise ValueError due to path traversal detection |
| 119 | + with self.assertRaises(ValueError) as context: |
| 120 | + extractall(str(zip_path), str(extract_dir)) |
| 121 | + |
| 122 | + self.assertIn("unsafe path", str(context.exception).lower()) |
| 123 | + |
| 124 | + def test_valid_tar_extraction(self): |
| 125 | + """Test that valid tar files extract successfully without raising exceptions.""" |
| 126 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 127 | + # Create a valid tar file |
| 128 | + tar_path = Path(tmp_dir) / "valid_test.tar.gz" |
| 129 | + extract_dir = Path(tmp_dir) / "extract" |
| 130 | + extract_dir.mkdir() |
| 131 | + |
| 132 | + # Create tar with normal file structure |
| 133 | + with tarfile.open(tar_path, "w:gz") as tf: |
| 134 | + # Create temporary files to add to tar |
| 135 | + temp_file1 = Path(tmp_dir) / "temp1.txt" |
| 136 | + temp_file1.write_text("This is a normal file") |
| 137 | + tf.add(temp_file1, arcname="normal_file.txt") |
| 138 | + |
| 139 | + temp_file2 = Path(tmp_dir) / "temp2.txt" |
| 140 | + temp_file2.write_text("This is a nested file") |
| 141 | + tf.add(temp_file2, arcname="subdir/nested_file.txt") |
| 142 | + |
| 143 | + # This should not raise any exception |
| 144 | + try: |
| 145 | + extractall(str(tar_path), str(extract_dir)) |
| 146 | + |
| 147 | + # Verify files were extracted correctly |
| 148 | + self.assertTrue((extract_dir / "normal_file.txt").exists()) |
| 149 | + self.assertTrue((extract_dir / "subdir" / "nested_file.txt").exists()) |
| 150 | + |
| 151 | + # Verify content |
| 152 | + with open(extract_dir / "normal_file.txt") as f: |
| 153 | + self.assertEqual(f.read(), "This is a normal file") |
| 154 | + |
| 155 | + except Exception as e: |
| 156 | + self.fail(f"Valid tar extraction should not raise exception: {e}") |
| 157 | + |
| 158 | + def test_malicious_tar_path_traversal(self): |
| 159 | + """Test that malicious tar files with path traversal attempts raise ValueError.""" |
| 160 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 161 | + # Create malicious tar file with path traversal |
| 162 | + tar_path = Path(tmp_dir) / "malicious_test.tar.gz" |
| 163 | + extract_dir = Path(tmp_dir) / "extract" |
| 164 | + extract_dir.mkdir() |
| 165 | + |
| 166 | + # Create tar with malicious paths |
| 167 | + with tarfile.open(tar_path, "w:gz") as tf: |
| 168 | + # Create a temporary file |
| 169 | + temp_file = Path(tmp_dir) / "temp.txt" |
| 170 | + temp_file.write_text("malicious content") |
| 171 | + |
| 172 | + # Add with malicious path (trying to write outside extraction directory) |
| 173 | + tf.add(temp_file, arcname="../../../etc/malicious.txt") |
| 174 | + |
| 175 | + # This should raise ValueError due to path traversal detection |
| 176 | + with self.assertRaises(ValueError) as context: |
| 177 | + extractall(str(tar_path), str(extract_dir)) |
| 178 | + |
| 179 | + self.assertIn("unsafe path", str(context.exception).lower()) |
| 180 | + |
| 181 | + def test_absolute_path_protection(self): |
| 182 | + """Test protection against absolute paths in archives.""" |
| 183 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 184 | + # Create zip with absolute path |
| 185 | + zip_path = Path(tmp_dir) / "absolute_path_test.zip" |
| 186 | + extract_dir = Path(tmp_dir) / "extract" |
| 187 | + extract_dir.mkdir() |
| 188 | + |
| 189 | + with zipfile.ZipFile(zip_path, "w") as zf: |
| 190 | + # Try to use absolute path |
| 191 | + zf.writestr("/etc/passwd_bad", "malicious content") |
| 192 | + |
| 193 | + # This should raise ValueError due to absolute path detection |
| 194 | + with self.assertRaises(ValueError) as context: |
| 195 | + extractall(str(zip_path), str(extract_dir)) |
| 196 | + |
| 197 | + self.assertIn("unsafe path", str(context.exception).lower()) |
| 198 | + |
| 199 | + def test_malicious_symlink_protection(self): |
| 200 | + """Test protection against malicious symlinks in tar archives.""" |
| 201 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 202 | + # Create malicious tar file with symlink |
| 203 | + tar_path = Path(tmp_dir) / "malicious_symlink_test.tar.gz" |
| 204 | + extract_dir = Path(tmp_dir) / "extract" |
| 205 | + extract_dir.mkdir() |
| 206 | + |
| 207 | + # Create tar with malicious symlink |
| 208 | + with tarfile.open(tar_path, "w:gz") as tf: |
| 209 | + temp_file = Path(tmp_dir) / "normal.txt" |
| 210 | + temp_file.write_text("normal content") |
| 211 | + tf.add(temp_file, arcname="normal.txt") |
| 212 | + |
| 213 | + symlink_info = tarfile.TarInfo(name="malicious_symlink.txt") |
| 214 | + symlink_info.type = tarfile.SYMTYPE |
| 215 | + symlink_info.linkname = "../../../etc/passwd_bad" |
| 216 | + symlink_info.size = 0 |
| 217 | + tf.addfile(symlink_info) |
| 218 | + |
| 219 | + with self.assertRaises(ValueError) as context: |
| 220 | + extractall(str(tar_path), str(extract_dir)) |
| 221 | + |
| 222 | + error_msg = str(context.exception).lower() |
| 223 | + self.assertTrue("unsafe path" in error_msg or "symlink" in error_msg) |
| 224 | + |
| 225 | + def test_malicious_hardlink_protection(self): |
| 226 | + """Test protection against malicious hard links in tar archives.""" |
| 227 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 228 | + # Create malicious tar file with hard link |
| 229 | + tar_path = Path(tmp_dir) / "malicious_hardlink_test.tar.gz" |
| 230 | + extract_dir = Path(tmp_dir) / "extract" |
| 231 | + extract_dir.mkdir() |
| 232 | + |
| 233 | + # Create tar with malicious hard link |
| 234 | + with tarfile.open(tar_path, "w:gz") as tf: |
| 235 | + temp_file = Path(tmp_dir) / "normal.txt" |
| 236 | + temp_file.write_text("normal content") |
| 237 | + tf.add(temp_file, arcname="normal.txt") |
| 238 | + |
| 239 | + hardlink_info = tarfile.TarInfo(name="malicious_hardlink.txt") |
| 240 | + hardlink_info.type = tarfile.LNKTYPE |
| 241 | + hardlink_info.linkname = "/etc/passwd_bad" |
| 242 | + hardlink_info.size = 0 |
| 243 | + tf.addfile(hardlink_info) |
| 244 | + |
| 245 | + with self.assertRaises(ValueError) as context: |
| 246 | + extractall(str(tar_path), str(extract_dir)) |
| 247 | + |
| 248 | + error_msg = str(context.exception).lower() |
| 249 | + self.assertTrue("unsafe path" in error_msg or "hardlink" in error_msg) |
| 250 | + |
| 251 | + |
69 | 252 | if __name__ == "__main__": |
70 | 253 | unittest.main() |
0 commit comments