Skip to content

Commit ec1f61a

Browse files
committed
feat: enhance model file serving with CORS headers and refactor fetch logic
1 parent 4dce045 commit ec1f61a

File tree

2 files changed

+63
-41
lines changed
  • python
  • scratch-arduino-extensions/packages/scratch-vm/src/extensions/tmachine_image

2 files changed

+63
-41
lines changed

python/main.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
object_detection = ObjectDetection()
1010

11+
1112
def on_matrix_draw(_, data):
1213
print(f"Received frame to draw on matrix: {data}")
1314
# from 5x5 to 8x13 matrix
@@ -87,12 +88,24 @@ def on_modulino_button_pressed(btn):
8788
ui.send_message("modulino_buttons_pressed", {"btn": btn})
8889

8990

90-
9191
Bridge.provide("modulino_button_pressed", on_modulino_button_pressed)
9292

9393

94-
ui.expose_api("GET", "/my-model/model.json", lambda: FileResponse(os.path.join("/app/assets/models/tm-my-image-model", "model.json"), headers={"Cache-Control": "no-store"}))
95-
ui.expose_api("GET", "/my-model/metadata.json", lambda: FileResponse(os.path.join("/app/assets/models/tm-my-image-model", "metadata.json"), headers={"Cache-Control": "no-store"}))
96-
ui.expose_api("GET", "/my-model/weights.bin", lambda: FileResponse(os.path.join("/app/assets/models/tm-my-image-model", "weights.bin"), headers={"Cache-Control": "no-store"}))
94+
def serve_model_file(filepath):
95+
"""Serve model files with CORS headers"""
96+
return FileResponse(
97+
os.path.join("/app/assets/models/tm-my-image-model", filepath),
98+
headers={
99+
"Cache-Control": "no-store",
100+
"Access-Control-Allow-Origin": "*",
101+
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
102+
"Access-Control-Allow-Headers": "*",
103+
},
104+
)
105+
106+
107+
ui.expose_api("GET", "/my-model/model.json", lambda: serve_model_file("model.json"))
108+
ui.expose_api("GET", "/my-model/metadata.json", lambda: serve_model_file("metadata.json"))
109+
ui.expose_api("GET", "/my-model/weights.bin", lambda: serve_model_file("weights.bin"))
97110

98111
App.run()

scratch-arduino-extensions/packages/scratch-vm/src/extensions/tmachine_image/index.js

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ const ArgumentType = require(
55

66
const Video = require("../../../../../../scratch-editor/packages/scratch-vm/src/io/video");
77

8-
98
// TODO add icons
109
const iconURI = "";
1110
const menuIconURI = "";
@@ -26,12 +25,12 @@ class TeachableMachineImage {
2625

2726
async fetchModelLabels() {
2827
try {
29-
const response = await fetch('https://192.168.1.39:7000/my-model/metadata.json');
28+
const response = await fetch(this.modelURL + "/metadata.json");
3029
const metadata = await response.json();
3130
this.modelLabels = metadata.labels || [];
32-
console.log('Fetched model labels:', this.modelLabels);
31+
console.log("Fetched model labels:", this.modelLabels);
3332
} catch (error) {
34-
console.error('Error fetching model labels:', error);
33+
console.error("Error fetching model labels:", error);
3534
this.modelLabels = []; // fallback
3635
}
3736
}
@@ -40,27 +39,28 @@ class TeachableMachineImage {
4039
try {
4140
// Load TensorFlow.js and Teachable Machine library
4241
if (!window.tf) {
43-
await this.loadScript('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js');
42+
await this.loadScript("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js");
4443
}
4544
if (!window.tmImage) {
46-
await this.loadScript('https://cdn.jsdelivr.net/npm/@teachablemachine/image@latest/dist/teachablemachine-image.min.js');
45+
await this.loadScript(
46+
"https://cdn.jsdelivr.net/npm/@teachablemachine/image@latest/dist/teachablemachine-image.min.js",
47+
);
4748
}
4849

4950
const modelURL = this.modelURL + "model.json";
5051
const metadataURL = this.modelURL + "metadata.json";
5152

5253
this.model = await tmImage.load(modelURL, metadataURL);
5354
this.isModelLoaded = true;
54-
console.log('Teachable Machine model loaded successfully');
55-
55+
console.log("Teachable Machine model loaded successfully");
5656
} catch (error) {
57-
console.error('Error loading Teachable Machine model:', error);
57+
console.error("Error loading Teachable Machine model:", error);
5858
}
5959
}
6060

6161
loadScript(src) {
6262
return new Promise((resolve, reject) => {
63-
const script = document.createElement('script');
63+
const script = document.createElement("script");
6464
script.src = src;
6565
script.onload = resolve;
6666
script.onerror = reject;
@@ -70,26 +70,25 @@ class TeachableMachineImage {
7070

7171
startPredictionLoop() {
7272
if (!this.isModelLoaded) {
73-
console.log("Model not loaded");
74-
return;
75-
};
73+
console.log("Model not loaded");
74+
return;
75+
}
7676

7777
const predict = async () => {
7878
try {
79-
const canvas = this.runtime.ioDevices.video.getFrame({
80-
format: Video.FORMAT_CANVAS,
81-
dimensions: [480, 360], // the same as the stage resolution
82-
});
83-
if (!canvas) {
84-
console.log("No canvas available from video frame.");
85-
return;
86-
}
87-
const prediction = await this.model.predict(canvas);
88-
this.predictions = prediction;
89-
console.log("preditions", predictions);
90-
79+
const canvas = this.runtime.ioDevices.video.getFrame({
80+
format: Video.FORMAT_CANVAS,
81+
dimensions: [480, 360], // the same as the stage resolution
82+
});
83+
if (!canvas) {
84+
console.log("No canvas available from video frame.");
85+
return;
86+
}
87+
const prediction = await this.model.predict(canvas);
88+
console.log("preditions", prediction);
89+
this.predictions = prediction;
9190
} catch (error) {
92-
console.error('Prediction error:', error);
91+
console.error("Prediction error:", error);
9392
}
9493

9594
// Continue loop
@@ -113,7 +112,7 @@ TeachableMachineImage.prototype.getInfo = function() {
113112
menuIconURI: menuIconURI,
114113
blockIconURI: iconURI,
115114
blocks: [
116-
{
115+
{
117116
opcode: "startDetectionLoop",
118117
blockType: BlockType.COMMAND,
119118
text: "start detection",
@@ -128,13 +127,13 @@ TeachableMachineImage.prototype.getInfo = function() {
128127
arguments: {
129128
OBJECT: {
130129
type: ArgumentType.STRING,
131-
menu: 'modelLabels',
132-
defaultValue: 'ok',
130+
menu: "modelLabels",
131+
defaultValue: "ok",
133132
},
134133
THRESHOLD: {
135134
type: ArgumentType.NUMBER,
136-
defaultValue: 50
137-
}
135+
defaultValue: 50,
136+
},
138137
},
139138
},
140139
{
@@ -150,15 +149,15 @@ TeachableMachineImage.prototype.getInfo = function() {
150149
},
151150
THRESHOLD: {
152151
type: ArgumentType.NUMBER,
153-
defaultValue: 50
154-
}
152+
defaultValue: 50,
153+
},
155154
},
156155
},
157156
{
158157
opcode: "getConfidence",
159158
blockType: BlockType.REPORTER,
160159
text: "confidence of [OBJECT]",
161-
func: "getConfidenceBlock",
160+
func: "getConfidence",
162161
arguments: {
163162
OBJECT: {
164163
type: ArgumentType.STRING,
@@ -169,7 +168,7 @@ TeachableMachineImage.prototype.getInfo = function() {
169168
},
170169
],
171170
menus: {
172-
modelLabels: 'getModelLabels'
171+
modelLabels: "getModelLabels",
173172
},
174173
};
175174
};
@@ -188,9 +187,19 @@ TeachableMachineImage.prototype.isObjectDetected = function(args) {
188187
return confidence > (args.THRESHOLD / 100);
189188
};
190189

190+
TeachableMachineImage.prototype.getConfidence = function(args) {
191+
if (!this.predictions || this.predictions.length === 0) return 0;
192+
193+
const prediction = this.predictions.find(p => p.className === args.OBJECT);
194+
const confidence = prediction ? Math.round(prediction.probability * 100) : 0;
195+
196+
console.log("get confidence for", args.OBJECT, "=", confidence + "%");
197+
return confidence;
198+
};
199+
191200
TeachableMachineImage.prototype.startDetectionLoop = function(args) {
192-
this.runtime.ioDevices.video.enableVideo();
193-
this.startPredictionLoop()
201+
this.runtime.ioDevices.video.enableVideo();
202+
this.startPredictionLoop();
194203
};
195204

196205
module.exports = TeachableMachineImage;

0 commit comments

Comments
 (0)