Fix Gradio interface - replaced BarPlot with Dataframe, added debugging, improved event handlers
Browse files
app.py
CHANGED
|
@@ -6,14 +6,24 @@ from PIL import Image
|
|
| 6 |
import io
|
| 7 |
from handler import EndpointHandler
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def classify_image(image, top_k=10):
|
| 12 |
"""
|
| 13 |
Main classification function for public interface.
|
| 14 |
"""
|
|
|
|
|
|
|
|
|
|
| 15 |
if image is None:
|
| 16 |
-
return
|
| 17 |
|
| 18 |
try:
|
| 19 |
# Convert PIL image to base64
|
|
@@ -34,25 +44,28 @@ def classify_image(image, top_k=10):
|
|
| 34 |
# Create formatted output
|
| 35 |
output_text = "**Top {} Classifications:**\n\n".format(len(result))
|
| 36 |
|
| 37 |
-
# Create
|
| 38 |
-
chart_data =
|
| 39 |
|
| 40 |
for i, item in enumerate(result, 1):
|
| 41 |
score_pct = item['score'] * 100
|
| 42 |
output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
|
| 43 |
-
chart_data
|
| 44 |
|
| 45 |
-
return
|
| 46 |
else:
|
| 47 |
-
return
|
| 48 |
|
| 49 |
except Exception as e:
|
| 50 |
-
return
|
| 51 |
|
| 52 |
def upsert_labels_admin(admin_token, new_items_json):
|
| 53 |
"""
|
| 54 |
Admin function to add new labels.
|
| 55 |
"""
|
|
|
|
|
|
|
|
|
|
| 56 |
if not admin_token:
|
| 57 |
return "Error: Admin token required"
|
| 58 |
|
|
@@ -84,6 +97,9 @@ def reload_labels_admin(admin_token, version):
|
|
| 84 |
"""
|
| 85 |
Admin function to reload a specific label version.
|
| 86 |
"""
|
|
|
|
|
|
|
|
|
|
| 87 |
if not admin_token:
|
| 88 |
return "Error: Admin token required"
|
| 89 |
|
|
@@ -114,17 +130,20 @@ def get_current_stats():
|
|
| 114 |
"""
|
| 115 |
Get current label statistics.
|
| 116 |
"""
|
|
|
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
|
| 119 |
version = getattr(handler, 'labels_version', 1)
|
| 120 |
device = handler.device if hasattr(handler, 'device') else "unknown"
|
| 121 |
|
| 122 |
stats = f"""
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
"""
|
| 129 |
|
| 130 |
if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
|
|
@@ -137,6 +156,7 @@ def get_current_stats():
|
|
| 137 |
return f"Error getting stats: {str(e)}"
|
| 138 |
|
| 139 |
# Create Gradio interface
|
|
|
|
| 140 |
with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
| 141 |
gr.Markdown("""
|
| 142 |
# πΌοΈ MobileCLIP-B Zero-Shot Image Classifier
|
|
@@ -161,29 +181,26 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
|
| 161 |
classify_btn = gr.Button("π Classify Image", variant="primary")
|
| 162 |
|
| 163 |
with gr.Column():
|
| 164 |
-
output_chart = gr.BarPlot(
|
| 165 |
-
label="Classification Confidence",
|
| 166 |
-
x_label="Label",
|
| 167 |
-
y_label="Confidence",
|
| 168 |
-
vertical=False,
|
| 169 |
-
height=400
|
| 170 |
-
)
|
| 171 |
output_text = gr.Markdown(label="Classification Results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
],
|
| 179 |
-
inputs=input_image,
|
| 180 |
-
label="Example Images"
|
| 181 |
)
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
|
|
|
| 185 |
inputs=[input_image, top_k_slider],
|
| 186 |
-
outputs=[
|
| 187 |
)
|
| 188 |
|
| 189 |
with gr.Tab("π§ Admin Panel"):
|
|
@@ -203,7 +220,8 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
|
| 203 |
stats_display = gr.Markdown(value=get_current_stats())
|
| 204 |
refresh_stats_btn = gr.Button("π Refresh Stats")
|
| 205 |
refresh_stats_btn.click(
|
| 206 |
-
get_current_stats,
|
|
|
|
| 207 |
outputs=stats_display
|
| 208 |
)
|
| 209 |
|
|
@@ -227,7 +245,7 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
|
| 227 |
upsert_output = gr.Markdown()
|
| 228 |
|
| 229 |
upsert_btn.click(
|
| 230 |
-
upsert_labels_admin,
|
| 231 |
inputs=[admin_token_input, new_items_input],
|
| 232 |
outputs=upsert_output
|
| 233 |
)
|
|
@@ -243,7 +261,7 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
|
| 243 |
reload_output = gr.Markdown()
|
| 244 |
|
| 245 |
reload_btn.click(
|
| 246 |
-
reload_labels_admin,
|
| 247 |
inputs=[admin_token_input, version_input],
|
| 248 |
outputs=reload_output
|
| 249 |
)
|
|
@@ -258,13 +276,13 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
|
| 258 |
- π **Fast inference**: < 30ms on GPU
|
| 259 |
- π·οΈ **Dynamic labels**: Add/update labels without redeployment
|
| 260 |
- π **Version control**: Track and reload label versions
|
| 261 |
-
- π **Visual results**:
|
| 262 |
|
| 263 |
### Environment Variables (set in Space Settings):
|
| 264 |
- `ADMIN_TOKEN`: Secret token for admin operations
|
| 265 |
-
- `HF_LABEL_REPO`: Hub repository for label storage
|
| 266 |
- `HF_WRITE_TOKEN`: Token with write permissions to label repo
|
| 267 |
-
- `HF_READ_TOKEN`: Token with read permissions (optional
|
| 268 |
|
| 269 |
### Model Details:
|
| 270 |
- **Architecture**: MobileCLIP-B with MobileOne blocks
|
|
@@ -276,5 +294,8 @@ with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
|
| 276 |
Model weights are licensed under Apple Sample Code License (ASCL).
|
| 277 |
""")
|
| 278 |
|
|
|
|
|
|
|
| 279 |
if __name__ == "__main__":
|
|
|
|
| 280 |
demo.launch()
|
|
|
|
| 6 |
import io
|
| 7 |
from handler import EndpointHandler
|
| 8 |
|
| 9 |
+
# Initialize handler
|
| 10 |
+
print("Initializing MobileCLIP handler...")
|
| 11 |
+
try:
|
| 12 |
+
handler = EndpointHandler()
|
| 13 |
+
print(f"Handler initialized successfully! Device: {handler.device}")
|
| 14 |
+
except Exception as e:
|
| 15 |
+
print(f"Error initializing handler: {e}")
|
| 16 |
+
handler = None
|
| 17 |
|
| 18 |
def classify_image(image, top_k=10):
|
| 19 |
"""
|
| 20 |
Main classification function for public interface.
|
| 21 |
"""
|
| 22 |
+
if handler is None:
|
| 23 |
+
return "Error: Handler not initialized", None
|
| 24 |
+
|
| 25 |
if image is None:
|
| 26 |
+
return "Please upload an image", None
|
| 27 |
|
| 28 |
try:
|
| 29 |
# Convert PIL image to base64
|
|
|
|
| 44 |
# Create formatted output
|
| 45 |
output_text = "**Top {} Classifications:**\n\n".format(len(result))
|
| 46 |
|
| 47 |
+
# Create data for bar chart (list of tuples)
|
| 48 |
+
chart_data = []
|
| 49 |
|
| 50 |
for i, item in enumerate(result, 1):
|
| 51 |
score_pct = item['score'] * 100
|
| 52 |
output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n"
|
| 53 |
+
chart_data.append((item['label'], item['score']))
|
| 54 |
|
| 55 |
+
return output_text, chart_data
|
| 56 |
else:
|
| 57 |
+
return f"Error: {result.get('error', 'Unknown error')}", None
|
| 58 |
|
| 59 |
except Exception as e:
|
| 60 |
+
return f"Error: {str(e)}", None
|
| 61 |
|
| 62 |
def upsert_labels_admin(admin_token, new_items_json):
|
| 63 |
"""
|
| 64 |
Admin function to add new labels.
|
| 65 |
"""
|
| 66 |
+
if handler is None:
|
| 67 |
+
return "Error: Handler not initialized"
|
| 68 |
+
|
| 69 |
if not admin_token:
|
| 70 |
return "Error: Admin token required"
|
| 71 |
|
|
|
|
| 97 |
"""
|
| 98 |
Admin function to reload a specific label version.
|
| 99 |
"""
|
| 100 |
+
if handler is None:
|
| 101 |
+
return "Error: Handler not initialized"
|
| 102 |
+
|
| 103 |
if not admin_token:
|
| 104 |
return "Error: Admin token required"
|
| 105 |
|
|
|
|
| 130 |
"""
|
| 131 |
Get current label statistics.
|
| 132 |
"""
|
| 133 |
+
if handler is None:
|
| 134 |
+
return "Handler not initialized"
|
| 135 |
+
|
| 136 |
try:
|
| 137 |
num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0
|
| 138 |
version = getattr(handler, 'labels_version', 1)
|
| 139 |
device = handler.device if hasattr(handler, 'device') else "unknown"
|
| 140 |
|
| 141 |
stats = f"""
|
| 142 |
+
**Current Statistics:**
|
| 143 |
+
- Number of labels: {num_labels}
|
| 144 |
+
- Labels version: {version}
|
| 145 |
+
- Device: {device}
|
| 146 |
+
- Model: MobileCLIP-B
|
| 147 |
"""
|
| 148 |
|
| 149 |
if hasattr(handler, 'class_names') and len(handler.class_names) > 0:
|
|
|
|
| 156 |
return f"Error getting stats: {str(e)}"
|
| 157 |
|
| 158 |
# Create Gradio interface
|
| 159 |
+
print("Creating Gradio interface...")
|
| 160 |
with gr.Blocks(title="MobileCLIP Image Classifier") as demo:
|
| 161 |
gr.Markdown("""
|
| 162 |
# πΌοΈ MobileCLIP-B Zero-Shot Image Classifier
|
|
|
|
| 181 |
classify_btn = gr.Button("π Classify Image", variant="primary")
|
| 182 |
|
| 183 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
output_text = gr.Markdown(label="Classification Results")
|
| 185 |
+
# Simplified bar chart using Dataframe
|
| 186 |
+
output_chart = gr.Dataframe(
|
| 187 |
+
headers=["Label", "Confidence"],
|
| 188 |
+
label="Classification Scores",
|
| 189 |
+
interactive=False
|
| 190 |
+
)
|
| 191 |
|
| 192 |
+
# Event handler for classification
|
| 193 |
+
classify_btn.click(
|
| 194 |
+
fn=classify_image,
|
| 195 |
+
inputs=[input_image, top_k_slider],
|
| 196 |
+
outputs=[output_text, output_chart]
|
|
|
|
|
|
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
+
# Also trigger on image upload
|
| 200 |
+
input_image.change(
|
| 201 |
+
fn=classify_image,
|
| 202 |
inputs=[input_image, top_k_slider],
|
| 203 |
+
outputs=[output_text, output_chart]
|
| 204 |
)
|
| 205 |
|
| 206 |
with gr.Tab("π§ Admin Panel"):
|
|
|
|
| 220 |
stats_display = gr.Markdown(value=get_current_stats())
|
| 221 |
refresh_stats_btn = gr.Button("π Refresh Stats")
|
| 222 |
refresh_stats_btn.click(
|
| 223 |
+
fn=get_current_stats,
|
| 224 |
+
inputs=[],
|
| 225 |
outputs=stats_display
|
| 226 |
)
|
| 227 |
|
|
|
|
| 245 |
upsert_output = gr.Markdown()
|
| 246 |
|
| 247 |
upsert_btn.click(
|
| 248 |
+
fn=upsert_labels_admin,
|
| 249 |
inputs=[admin_token_input, new_items_input],
|
| 250 |
outputs=upsert_output
|
| 251 |
)
|
|
|
|
| 261 |
reload_output = gr.Markdown()
|
| 262 |
|
| 263 |
reload_btn.click(
|
| 264 |
+
fn=reload_labels_admin,
|
| 265 |
inputs=[admin_token_input, version_input],
|
| 266 |
outputs=reload_output
|
| 267 |
)
|
|
|
|
| 276 |
- π **Fast inference**: < 30ms on GPU
|
| 277 |
- π·οΈ **Dynamic labels**: Add/update labels without redeployment
|
| 278 |
- π **Version control**: Track and reload label versions
|
| 279 |
+
- π **Visual results**: Classification scores and confidence
|
| 280 |
|
| 281 |
### Environment Variables (set in Space Settings):
|
| 282 |
- `ADMIN_TOKEN`: Secret token for admin operations
|
| 283 |
+
- `HF_LABEL_REPO`: Hub repository for label storage
|
| 284 |
- `HF_WRITE_TOKEN`: Token with write permissions to label repo
|
| 285 |
+
- `HF_READ_TOKEN`: Token with read permissions (optional)
|
| 286 |
|
| 287 |
### Model Details:
|
| 288 |
- **Architecture**: MobileCLIP-B with MobileOne blocks
|
|
|
|
| 294 |
Model weights are licensed under Apple Sample Code License (ASCL).
|
| 295 |
""")
|
| 296 |
|
| 297 |
+
print("Gradio interface created successfully!")
|
| 298 |
+
|
| 299 |
if __name__ == "__main__":
|
| 300 |
+
print("Launching Gradio app...")
|
| 301 |
demo.launch()
|