import torch from flask import Flask, request, jsonify from flask_cors import CORS import numpy as np import cv2 import base64 from io import BytesIO from PIL import Image from transformers import SegformerForSemanticSegmentation import albumentations as A from albumentations.pytorch import ToTensorV2 import torch.nn as nn import os app = Flask(__name__) CORS(app) # CONFIG DEVICE = torch.device('cpu') # Hugging Face Free Tier is CPU MODEL_PATH = "best_model.pth" MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512" NUM_CLASSES = 6 # Load Model print("Loading model...") model = SegformerForSemanticSegmentation.from_pretrained( MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True ) # --- FIX IS HERE --- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) state_dict = checkpoint['model_state_dict'] # Fix key names (remove 'module.' if trained on multi-GPU) new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} model.load_state_dict(new_state_dict) model.to(DEVICE) model.eval() print("Model loaded!") MASK_COLOR_MAP = { 0: (0, 0, 0), 1: (255, 0, 0), 2: (0, 255, 0), 3: (0, 0, 255), 4: (255, 255, 0), 5: (255, 0, 255) } def transform_image(image_bytes): nparr = np.frombuffer(image_bytes, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) original_size = image.shape[:2] transform = A.Compose([ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) return transform(image=image)['image'].unsqueeze(0).to(DEVICE), original_size def colorize_mask(mask): h, w = mask.shape color_mask = np.zeros((h, w, 3), dtype=np.uint8) for label, color in MASK_COLOR_MAP.items(): color_mask[mask == label] = color return color_mask def to_base64(image_array): img = Image.fromarray(image_array) buffer = BytesIO() img.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode('utf-8') @app.route('/') def home(): return "Damage Detection API is Running!" @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return jsonify({'error': 'No file'}), 400 file = request.files['file'] try: input_tensor, original_size = transform_image(file.read()) with torch.no_grad(): outputs = model(pixel_values=input_tensor) logits = nn.functional.interpolate(outputs.logits, size=original_size, mode='bilinear', align_corners=False) pred_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy() rgb_mask = colorize_mask(pred_mask) return jsonify({'mask': f"data:image/png;base64,{to_base64(rgb_mask)}"}) except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)