Spaces:
Sleeping
Sleeping
| import torch | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import numpy as np | |
| import cv2 | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| from transformers import SegformerForSemanticSegmentation | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import torch.nn as nn | |
| import os | |
| app = Flask(__name__) | |
| CORS(app) | |
| # CONFIG | |
| DEVICE = torch.device('cpu') | |
| MODEL_PATH = "best_model.pth" | |
| MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512" | |
| NUM_CLASSES = 6 | |
| # Load Model | |
| print("Loading model...") | |
| model = SegformerForSemanticSegmentation.from_pretrained( | |
| MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True | |
| ) | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| state_dict = checkpoint['model_state_dict'] | |
| new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
| model.load_state_dict(new_state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("Model loaded!") | |
| MASK_COLOR_MAP = { | |
| 0: (0, 0, 0), 1: (255, 0, 0), 2: (0, 255, 0), | |
| 3: (0, 0, 255), 4: (255, 255, 0), 5: (255, 0, 255) | |
| } | |
| def transform_image(image_bytes): | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| original_size = image.shape[:2] | |
| transform = A.Compose([ | |
| A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ToTensorV2() | |
| ]) | |
| return transform(image=image)['image'].unsqueeze(0).to(DEVICE), original_size | |
| def colorize_mask(mask): | |
| h, w = mask.shape | |
| color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
| for label, color in MASK_COLOR_MAP.items(): | |
| color_mask[mask == label] = color | |
| return color_mask | |
| def to_base64(image_array): | |
| img = Image.fromarray(image_array) | |
| buffer = BytesIO() | |
| img.save(buffer, format="PNG") | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| def home(): | |
| return "Damage Detection API is Running!" | |
| def predict(): | |
| if 'file' not in request.files: return jsonify({'error': 'No file'}), 400 | |
| file = request.files['file'] | |
| try: | |
| input_tensor, original_size = transform_image(file.read()) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values=input_tensor) | |
| logits = nn.functional.interpolate(outputs.logits, size=original_size, mode='bilinear', align_corners=False) | |
| pred_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy() | |
| rgb_mask = colorize_mask(pred_mask) | |
| return jsonify({'mask': f"data:image/png;base64,{to_base64(rgb_mask)}"}) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| # --- CRITICAL CHANGE FOR HUGGING FACE --- | |
| if __name__ == '__main__': | |
| # Hugging Face runs on port 7860 | |
| app.run(host='0.0.0.0', port=7860) |