File size: 2,366 Bytes
e546fea
 
 
 
 
 
 
2b4b81f
e546fea
958511f
e546fea
19fbf4c
958511f
 
 
3618356
19fbf4c
 
958511f
 
 
 
 
 
 
e546fea
2d64873
 
19fbf4c
2d64873
f0a6ca3
e546fea
 
 
 
 
3e75999
2b4b81f
19fbf4c
 
3e75999
2b4b81f
19fbf4c
 
 
 
 
 
 
 
 
 
 
2d64873
19fbf4c
 
2d64873
 
 
19fbf4c
2d64873
e546fea
19fbf4c
 
2b4b81f
c368dca
e546fea
19fbf4c
e546fea
19fbf4c
20a2fe0
19fbf4c
b2b24c7
958511f
19fbf4c
 
 
 
 
 
 
958511f
e546fea
19fbf4c
 
 
e546fea
2373e76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image

torch.set_float32_matmul_precision(["high", "highest"][0])

# Load BiRefNet model
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")

# Preprocessing pipeline
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


def process(image):
    """Segment person/body from image with transparent background"""
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cpu")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)

    # Apply mask as alpha channel (transparent background)
    image = image.convert("RGBA")
    image.putalpha(mask)

    return image  # Transparent PNG


def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    result = process(im)
    return result


def process_file(f):
    """Process uploaded file and save output as PNG with transparency"""
    name_path = f.rsplit(".", 1)[0] + ".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path, "PNG")  # Save with transparency
    return name_path


# Gradio UI Components
slider1 = gr.Image()
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image", type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="Output PNG file")

# Example image
chameleon = load_img("butterfly.jpg", output_type="pil")

# Tab for PNG output
tab3 = gr.Interface(
    process_file,
    inputs=image2,
    outputs=png_file,
    examples=["butterfly.jpg"],
    api_name="png",
)

# Main demo
demo = gr.TabbedInterface([tab3], ["PNG Output"], title="Body Extractor")

if __name__ == "__main__":
    demo.launch(share=True)