|
|
import torch |
|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
from transformers import SegformerForSemanticSegmentation |
|
|
import albumentations as A |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
import torch.nn as nn |
|
|
import os |
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
DEVICE = torch.device('cpu') |
|
|
MODEL_PATH = "best_model.pth" |
|
|
MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512" |
|
|
NUM_CLASSES = 6 |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model = SegformerForSemanticSegmentation.from_pretrained( |
|
|
MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) |
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
|
|
|
|
|
|
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
|
|
model.load_state_dict(new_state_dict) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
print("Model loaded!") |
|
|
|
|
|
MASK_COLOR_MAP = { |
|
|
0: (0, 0, 0), 1: (255, 0, 0), 2: (0, 255, 0), |
|
|
3: (0, 0, 255), 4: (255, 255, 0), 5: (255, 0, 255) |
|
|
} |
|
|
|
|
|
def transform_image(image_bytes): |
|
|
nparr = np.frombuffer(image_bytes, np.uint8) |
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
original_size = image.shape[:2] |
|
|
transform = A.Compose([ |
|
|
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
ToTensorV2() |
|
|
]) |
|
|
return transform(image=image)['image'].unsqueeze(0).to(DEVICE), original_size |
|
|
|
|
|
def colorize_mask(mask): |
|
|
h, w = mask.shape |
|
|
color_mask = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
for label, color in MASK_COLOR_MAP.items(): |
|
|
color_mask[mask == label] = color |
|
|
return color_mask |
|
|
|
|
|
def to_base64(image_array): |
|
|
img = Image.fromarray(image_array) |
|
|
buffer = BytesIO() |
|
|
img.save(buffer, format="PNG") |
|
|
return base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
|
@app.route('/') |
|
|
def home(): |
|
|
return "Damage Detection API is Running!" |
|
|
|
|
|
@app.route('/predict', methods=['POST']) |
|
|
def predict(): |
|
|
if 'file' not in request.files: return jsonify({'error': 'No file'}), 400 |
|
|
file = request.files['file'] |
|
|
try: |
|
|
input_tensor, original_size = transform_image(file.read()) |
|
|
with torch.no_grad(): |
|
|
outputs = model(pixel_values=input_tensor) |
|
|
logits = nn.functional.interpolate(outputs.logits, size=original_size, mode='bilinear', align_corners=False) |
|
|
pred_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy() |
|
|
|
|
|
rgb_mask = colorize_mask(pred_mask) |
|
|
return jsonify({'mask': f"data:image/png;base64,{to_base64(rgb_mask)}"}) |
|
|
except Exception as e: |
|
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
app.run(host='0.0.0.0', port=7860) |