diff --git a/Ch14_Computer_Vision/.idea/.gitignore b/Ch14_Computer_Vision/.idea/.gitignore
new file mode 100644
index 00000000..35410cac
--- /dev/null
+++ b/Ch14_Computer_Vision/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/Ch14_Computer_Vision/.idea/Ch14_Computer_Vision.iml b/Ch14_Computer_Vision/.idea/Ch14_Computer_Vision.iml
new file mode 100644
index 00000000..f571432d
--- /dev/null
+++ b/Ch14_Computer_Vision/.idea/Ch14_Computer_Vision.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Ch14_Computer_Vision/.idea/deployment.xml b/Ch14_Computer_Vision/.idea/deployment.xml
new file mode 100644
index 00000000..aba64b06
--- /dev/null
+++ b/Ch14_Computer_Vision/.idea/deployment.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Ch14_Computer_Vision/.idea/inspectionProfiles/profiles_settings.xml b/Ch14_Computer_Vision/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 00000000..105ce2da
--- /dev/null
+++ b/Ch14_Computer_Vision/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Ch14_Computer_Vision/.idea/misc.xml b/Ch14_Computer_Vision/.idea/misc.xml
new file mode 100644
index 00000000..db8786c0
--- /dev/null
+++ b/Ch14_Computer_Vision/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Ch14_Computer_Vision/.idea/modules.xml b/Ch14_Computer_Vision/.idea/modules.xml
new file mode 100644
index 00000000..9b389ad5
--- /dev/null
+++ b/Ch14_Computer_Vision/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Ch14_Computer_Vision/.idea/vcs.xml b/Ch14_Computer_Vision/.idea/vcs.xml
new file mode 100644
index 00000000..6c0b8635
--- /dev/null
+++ b/Ch14_Computer_Vision/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb b/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb
index e5683066..ae0b4882 100644
--- a/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb
+++ b/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb
@@ -640,12 +640,25 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"id_cat = dict()\n",
- "id_cat[0] = 'pikachu'\n",
+ "id_cat[0] = 'pikachu'"
+ ]
+ },
+ {
+ "metadata": {
+ "jupyter": {
+ "is_executing": true
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
"\n",
"class FocalLoss(nn.Module):\n",
" def __init__(self, alpha=0.25, gamma=2, device=\"cuda:0\", eps=1e-10):\n",
@@ -658,19 +671,29 @@
" def forward(self, input, target):\n",
" p = torch.sigmoid(input)\n",
" pt = p * target.float() + (1.0 - p) * (1 - target).float()\n",
- " alpha_t = (1.0 - self.alpha) * target.float() + self.alpha * (1 - target).float()\n",
- " loss = - 1.0 * torch.pow((1 - pt), self.gamma) * torch.log(pt + self.eps)\n",
- " return loss.sum()\n",
- " \n",
+ " alpha_t = self.alpha * target.float() + (1.0 - self.alpha) * (1 - target).float()\n",
+ " loss = - alpha_t * torch.pow((1 - pt), self.gamma) * torch.log(pt + self.eps)\n",
+ " return loss.sum()\n"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": 21,
+ "source": [
+ "\n",
"class SSDLoss(nn.Module):\n",
" def __init__(self, loc_factor, jaccard_overlap, device = \"cuda:0\", **kwargs):\n",
" super().__init__()\n",
" self.fl = FocalLoss(**kwargs)\n",
" self.loc_factor = loc_factor\n",
" self.jaccard_overlap = jaccard_overlap\n",
- " \n",
"\n",
- " \n",
+ "\n",
+ "\n",
" self.device = device\n",
"\n",
" def one_hot_encoding(labels, num_classes):\n",
@@ -684,23 +707,23 @@
" torch.log((x[:, 3:4] / anchors[overlap_indicies, 3:4]))\n",
" ], dim=1)\n",
"\n",
- " def forward(self, class_hat, bb_hat, class_true, bb_true, anchors): \n",
+ " def forward(self, class_hat, bb_hat, class_true, bb_true, anchors):\n",
" loc_loss = 0.0\n",
" class_loss = 0.0\n",
- " \n",
- " \n",
+ "\n",
+ "\n",
" for i in range(len(class_true)): # Batch level\n",
"\n",
" class_hat_i = class_hat[i, :, :]\n",
"\n",
" bb_true_i = bb_true[i].float()\n",
- " \n",
- " class_true_i = class_true[i] \n",
- " \n",
+ "\n",
+ " class_true_i = class_true[i]\n",
+ "\n",
" class_target = torch.zeros(class_hat_i.shape[0]).long().to(self.device)\n",
- " \n",
+ "\n",
" overlap_list = d2l.find_overlap(bb_true_i.squeeze(0), anchors, self.jaccard_overlap)\n",
- " \n",
+ "\n",
" temp_loc_loss = 0.0\n",
" for j in range(len(overlap_list)): # BB level\n",
" overlap = overlap_list[j]\n",