File size: 2,366 Bytes
e546fea 2b4b81f e546fea 958511f e546fea 19fbf4c 958511f 3618356 19fbf4c 958511f e546fea 2d64873 19fbf4c 2d64873 f0a6ca3 e546fea 3e75999 2b4b81f 19fbf4c 3e75999 2b4b81f 19fbf4c 2d64873 19fbf4c 2d64873 19fbf4c 2d64873 e546fea 19fbf4c 2b4b81f c368dca e546fea 19fbf4c e546fea 19fbf4c 20a2fe0 19fbf4c b2b24c7 958511f 19fbf4c 958511f e546fea 19fbf4c e546fea 2373e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
torch.set_float32_matmul_precision(["high", "highest"][0])
# Load BiRefNet model
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")
# Preprocessing pipeline
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def process(image):
"""Segment person/body from image with transparent background"""
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cpu")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Apply mask as alpha channel (transparent background)
image = image.convert("RGBA")
image.putalpha(mask)
return image # Transparent PNG
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
result = process(im)
return result
def process_file(f):
"""Process uploaded file and save output as PNG with transparency"""
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path, "PNG") # Save with transparency
return name_path
# Gradio UI Components
slider1 = gr.Image()
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image", type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="Output PNG file")
# Example image
chameleon = load_img("butterfly.jpg", output_type="pil")
# Tab for PNG output
tab3 = gr.Interface(
process_file,
inputs=image2,
outputs=png_file,
examples=["butterfly.jpg"],
api_name="png",
)
# Main demo
demo = gr.TabbedInterface([tab3], ["PNG Output"], title="Body Extractor")
if __name__ == "__main__":
demo.launch(share=True)
|