borso271 commited on
Commit
810ff2d
Β·
1 Parent(s): 7f45a02

Fix Gradio interface - replaced BarPlot with Dataframe, added debugging, improved event handlers

Browse files
Files changed (1) hide show
  1. app.py +58 -37
app.py CHANGED
@@ -6,14 +6,24 @@ from PIL import Image
6
  import io
7
  from handler import EndpointHandler
8
 
9
- handler = EndpointHandler()
 
 
 
 
 
 
 
10
 
11
  def classify_image(image, top_k=10):
12
  """
13
  Main classification function for public interface.
14
  """
 
 
 
15
  if image is None:
16
- return None, "Please upload an image"
17
 
18
  try:
19
  # Convert PIL image to base64
@@ -34,25 +44,28 @@ def classify_image(image, top_k=10):
34
  # Create formatted output
35
  output_text = "**Top {} Classifications:**\n\n".format(len(result))
36
 
37
- # Create a dictionary for the bar chart
38
- chart_data = {}
39
 
40
  for i, item in enumerate(result, 1):
41
  score_pct = item['score'] * 100
42
  output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
43
- chart_data[item['label']] = item['score']
44
 
45
- return chart_data, output_text
46
  else:
47
- return None, f"Error: {result.get('error', 'Unknown error')}"
48
 
49
  except Exception as e:
50
- return None, f"Error: {str(e)}"
51
 
52
  def upsert_labels_admin(admin_token, new_items_json):
53
  """
54
  Admin function to add new labels.
55
  """
 
 
 
56
  if not admin_token:
57
  return "Error: Admin token required"
58
 
@@ -84,6 +97,9 @@ def reload_labels_admin(admin_token, version):
84
  """
85
  Admin function to reload a specific label version.
86
  """
 
 
 
87
  if not admin_token:
88
  return "Error: Admin token required"
89
 
@@ -114,17 +130,20 @@ def get_current_stats():
114
  """
115
  Get current label statistics.
116
  """
 
 
 
117
  try:
118
  num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
119
  version = getattr(handler, 'labels_version', 1)
120
  device = handler.device if hasattr(handler, 'device') else "unknown"
121
 
122
  stats = f"""
123
- **Current Statistics:**
124
- - Number of labels: {num_labels}
125
- - Labels version: {version}
126
- - Device: {device}
127
- - Model: MobileCLIP-B
128
  """
129
 
130
  if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
@@ -137,6 +156,7 @@ def get_current_stats():
137
  return f"Error getting stats: {str(e)}"
138
 
139
  # Create Gradio interface
 
140
  with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
141
  gr.Markdown("""
142
  # πŸ–ΌοΈ MobileCLIP-B Zero-Shot Image Classifier
@@ -161,29 +181,26 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
161
  classify_btn = gr.Button("πŸš€ Classify Image", variant="primary")
162
 
163
  with gr.Column():
164
- output_chart = gr.BarPlot(
165
- label="Classification Confidence",
166
- x_label="Label",
167
- y_label="Confidence",
168
- vertical=False,
169
- height=400
170
- )
171
  output_text = gr.Markdown(label="Classification Results")
 
 
 
 
 
 
172
 
173
- gr.Examples(
174
- examples=[
175
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/cheetah.jpg"],
176
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/elephant.jpg"],
177
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/image_classifier/examples/giraffe.jpg"]
178
- ],
179
- inputs=input_image,
180
- label="Example Images"
181
  )
182
 
183
- classify_btn.click(
184
- classify_image,
 
185
  inputs=[input_image, top_k_slider],
186
- outputs=[output_chart, output_text]
187
  )
188
 
189
  with gr.Tab("πŸ”§ Admin Panel"):
@@ -203,7 +220,8 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
203
  stats_display = gr.Markdown(value=get_current_stats())
204
  refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats")
205
  refresh_stats_btn.click(
206
- get_current_stats,
 
207
  outputs=stats_display
208
  )
209
 
@@ -227,7 +245,7 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
227
  upsert_output = gr.Markdown()
228
 
229
  upsert_btn.click(
230
- upsert_labels_admin,
231
  inputs=[admin_token_input, new_items_input],
232
  outputs=upsert_output
233
  )
@@ -243,7 +261,7 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
243
  reload_output = gr.Markdown()
244
 
245
  reload_btn.click(
246
- reload_labels_admin,
247
  inputs=[admin_token_input, version_input],
248
  outputs=reload_output
249
  )
@@ -258,13 +276,13 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
258
  - πŸš€ **Fast inference**: < 30ms on GPU
259
  - 🏷️ **Dynamic labels**: Add/update labels without redeployment
260
  - πŸ”„ **Version control**: Track and reload label versions
261
- - πŸ“Š **Visual results**: Bar charts and confidence scores
262
 
263
  ### Environment Variables (set in Space Settings):
264
  - `ADMIN_TOKEN`: Secret token for admin operations
265
- - `HF_LABEL_REPO`: Hub repository for label storage (e.g., "username/labels")
266
  - `HF_WRITE_TOKEN`: Token with write permissions to label repo
267
- - `HF_READ_TOKEN`: Token with read permissions (optional, defaults to write token)
268
 
269
  ### Model Details:
270
  - **Architecture**: MobileCLIP-B with MobileOne blocks
@@ -276,5 +294,8 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
276
  Model weights are licensed under Apple Sample Code License (ASCL).
277
  """)
278
 
 
 
279
  if __name__ == "__main__":
 
280
  demo.launch()
 
6
  import io
7
  from handler import EndpointHandler
8
 
9
+ # Initialize handler
10
+ print("Initializing MobileCLIP handler...")
11
+ try:
12
+ handler = EndpointHandler()
13
+ print(f"Handler initialized successfully! Device: {handler.device}")
14
+ except Exception as e:
15
+ print(f"Error initializing handler: {e}")
16
+ handler = None
17
 
18
  def classify_image(image, top_k=10):
19
  """
20
  Main classification function for public interface.
21
  """
22
+ if handler is None:
23
+ return "Error: Handler not initialized", None
24
+
25
  if image is None:
26
+ return "Please upload an image", None
27
 
28
  try:
29
  # Convert PIL image to base64
 
44
  # Create formatted output
45
  output_text = "**Top {} Classifications:**\n\n".format(len(result))
46
 
47
+ # Create data for bar chart (list of tuples)
48
+ chart_data = []
49
 
50
  for i, item in enumerate(result, 1):
51
  score_pct = item['score'] * 100
52
  output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
53
+ chart_data.append((item['label'], item['score']))
54
 
55
+ return output_text, chart_data
56
  else:
57
+ return f"Error: {result.get('error', 'Unknown error')}", None
58
 
59
  except Exception as e:
60
+ return f"Error: {str(e)}", None
61
 
62
  def upsert_labels_admin(admin_token, new_items_json):
63
  """
64
  Admin function to add new labels.
65
  """
66
+ if handler is None:
67
+ return "Error: Handler not initialized"
68
+
69
  if not admin_token:
70
  return "Error: Admin token required"
71
 
 
97
  """
98
  Admin function to reload a specific label version.
99
  """
100
+ if handler is None:
101
+ return "Error: Handler not initialized"
102
+
103
  if not admin_token:
104
  return "Error: Admin token required"
105
 
 
130
  """
131
  Get current label statistics.
132
  """
133
+ if handler is None:
134
+ return "Handler not initialized"
135
+
136
  try:
137
  num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
138
  version = getattr(handler, 'labels_version', 1)
139
  device = handler.device if hasattr(handler, 'device') else "unknown"
140
 
141
  stats = f"""
142
+ **Current Statistics:**
143
+ - Number of labels: {num_labels}
144
+ - Labels version: {version}
145
+ - Device: {device}
146
+ - Model: MobileCLIP-B
147
  """
148
 
149
  if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
 
156
  return f"Error getting stats: {str(e)}"
157
 
158
  # Create Gradio interface
159
+ print("Creating Gradio interface...")
160
  with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
161
  gr.Markdown("""
162
  # πŸ–ΌοΈ MobileCLIP-B Zero-Shot Image Classifier
 
181
  classify_btn = gr.Button("πŸš€ Classify Image", variant="primary")
182
 
183
  with gr.Column():
 
 
 
 
 
 
 
184
  output_text = gr.Markdown(label="Classification Results")
185
+ # Simplified bar chart using Dataframe
186
+ output_chart = gr.Dataframe(
187
+ headers=["Label", "Confidence"],
188
+ label="Classification Scores",
189
+ interactive=False
190
+ )
191
 
192
+ # Event handler for classification
193
+ classify_btn.click(
194
+ fn=classify_image,
195
+ inputs=[input_image, top_k_slider],
196
+ outputs=[output_text, output_chart]
 
 
 
197
  )
198
 
199
+ # Also trigger on image upload
200
+ input_image.change(
201
+ fn=classify_image,
202
  inputs=[input_image, top_k_slider],
203
+ outputs=[output_text, output_chart]
204
  )
205
 
206
  with gr.Tab("πŸ”§ Admin Panel"):
 
220
  stats_display = gr.Markdown(value=get_current_stats())
221
  refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats")
222
  refresh_stats_btn.click(
223
+ fn=get_current_stats,
224
+ inputs=[],
225
  outputs=stats_display
226
  )
227
 
 
245
  upsert_output = gr.Markdown()
246
 
247
  upsert_btn.click(
248
+ fn=upsert_labels_admin,
249
  inputs=[admin_token_input, new_items_input],
250
  outputs=upsert_output
251
  )
 
261
  reload_output = gr.Markdown()
262
 
263
  reload_btn.click(
264
+ fn=reload_labels_admin,
265
  inputs=[admin_token_input, version_input],
266
  outputs=reload_output
267
  )
 
276
  - πŸš€ **Fast inference**: < 30ms on GPU
277
  - 🏷️ **Dynamic labels**: Add/update labels without redeployment
278
  - πŸ”„ **Version control**: Track and reload label versions
279
+ - πŸ“Š **Visual results**: Classification scores and confidence
280
 
281
  ### Environment Variables (set in Space Settings):
282
  - `ADMIN_TOKEN`: Secret token for admin operations
283
+ - `HF_LABEL_REPO`: Hub repository for label storage
284
  - `HF_WRITE_TOKEN`: Token with write permissions to label repo
285
+ - `HF_READ_TOKEN`: Token with read permissions (optional)
286
 
287
  ### Model Details:
288
  - **Architecture**: MobileCLIP-B with MobileOne blocks
 
294
  Model weights are licensed under Apple Sample Code License (ASCL).
295
  """)
296
 
297
+ print("Gradio interface created successfully!")
298
+
299
  if __name__ == "__main__":
300
+ print("Launching Gradio app...")
301
  demo.launch()