AhsanAftab's picture
Update app.py
ea4cbcc verified
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
import numpy as np
import cv2
import base64
from io import BytesIO
from PIL import Image
from transformers import SegformerForSemanticSegmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import os
app = Flask(__name__)
CORS(app)
# CONFIG
DEVICE = torch.device('cpu') # Hugging Face Free Tier is CPU
MODEL_PATH = "best_model.pth"
MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512"
NUM_CLASSES = 6
# Load Model
print("Loading model...")
model = SegformerForSemanticSegmentation.from_pretrained(
MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
)
# --- FIX IS HERE ---
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
state_dict = checkpoint['model_state_dict']
# Fix key names (remove 'module.' if trained on multi-GPU)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.to(DEVICE)
model.eval()
print("Model loaded!")
MASK_COLOR_MAP = {
0: (0, 0, 0), 1: (255, 0, 0), 2: (0, 255, 0),
3: (0, 0, 255), 4: (255, 255, 0), 5: (255, 0, 255)
}
def transform_image(image_bytes):
nparr = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_size = image.shape[:2]
transform = A.Compose([
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
return transform(image=image)['image'].unsqueeze(0).to(DEVICE), original_size
def colorize_mask(mask):
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for label, color in MASK_COLOR_MAP.items():
color_mask[mask == label] = color
return color_mask
def to_base64(image_array):
img = Image.fromarray(image_array)
buffer = BytesIO()
img.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode('utf-8')
@app.route('/')
def home():
return "Damage Detection API is Running!"
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files: return jsonify({'error': 'No file'}), 400
file = request.files['file']
try:
input_tensor, original_size = transform_image(file.read())
with torch.no_grad():
outputs = model(pixel_values=input_tensor)
logits = nn.functional.interpolate(outputs.logits, size=original_size, mode='bilinear', align_corners=False)
pred_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
rgb_mask = colorize_mask(pred_mask)
return jsonify({'mask': f"data:image/png;base64,{to_base64(rgb_mask)}"})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)