|
|
import spaces |
|
|
import gradio as gr |
|
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
|
from t2i_config import KERNELS_PREFETCH_ON_STARTUP, KERNELS_PREFETCH_REPOS |
|
|
|
|
|
if KERNELS_PREFETCH_ON_STARTUP: |
|
|
try: |
|
|
from kernels import has_kernel, get_kernel |
|
|
for _repo_id in KERNELS_PREFETCH_REPOS: |
|
|
if has_kernel(_repo_id): |
|
|
get_kernel(_repo_id) |
|
|
except Exception as _e: |
|
|
print(f"INFO : Kernels prefetch skipped: {_e}") |
|
|
|
|
|
from t2i.infer import (infer, infer_multi, infer_simple, save_image_history, save_gallery_history, |
|
|
update_param_mode_gr, update_ar_gr, |
|
|
MAX_SEED, MAX_IMAGE_SIZE, ASPECT_RATIOS, FILE_FORMATS, DEFAULT_TASKS, DEFAULT_DURATION, |
|
|
DEFAULT_I2I_STRENGTH, DEFAULT_UPSCALE_STRENGTH, DEFAULT_UPSCALE_BY, DEFAULT_CLIP_SKIP, |
|
|
models, MODEL_TYPES, SAMPLER_NAMES, PRED_TYPES, VAE_NAMES, |
|
|
UPSCALE_MODES, PARAM_MODES, PIPELINE_TYPES) |
|
|
|
|
|
css = """ |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 1080px; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(fill_height=True, fill_width=True) as demo: |
|
|
with gr.Tab("Image Generator"): |
|
|
lora_dict = gr.State({}) |
|
|
with gr.Column(elem_id="col-container"): |
|
|
with gr.Tab("Normal"): |
|
|
with gr.Row(): |
|
|
prompt = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False) |
|
|
run_button = gr.Button("Run", scale=0) |
|
|
run_button_simple = gr.Button("Simple", scale=0, visible=False) |
|
|
result = gr.Image(label="Result", show_label=False, format="png", type="filepath", interactive=False, buttons=["download", "fullscreen"]) |
|
|
|
|
|
with gr.Tab("Multi"): |
|
|
with gr.Row(): |
|
|
prompt_multi = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False) |
|
|
run_button_multi = gr.Button("Run", scale=0) |
|
|
model_name_multi = gr.Dropdown(label="Model", choices=models, value=models[0], multiselect=True, allow_custom_value=True) |
|
|
num_images = gr.Slider(label="Count", minimum=1, maximum=16, step=1, value=1) |
|
|
result_multi = gr.Gallery(label="Result", columns=2, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"]) |
|
|
|
|
|
with gr.Accordion("Output History", open=False): |
|
|
history_files = gr.Files(interactive=False, visible=False) |
|
|
history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"]) |
|
|
history_clear_button = gr.Button(value="Clear History", variant="secondary") |
|
|
history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, api_visibility="undocumented") |
|
|
|
|
|
with gr.Group(): |
|
|
negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", |
|
|
value="") |
|
|
with gr.Row(equal_height=True): |
|
|
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
|
with gr.Row(equal_height=True): |
|
|
param_mode = gr.Radio(label="Parameter Settings", choices=PARAM_MODES, value=PARAM_MODES[0]) |
|
|
ar = gr.Dropdown(label="Aspect Ratio", choices=ASPECT_RATIOS, value=ASPECT_RATIOS[0]) |
|
|
with gr.Row(equal_height=True): |
|
|
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False) |
|
|
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False) |
|
|
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=7, visible=False) |
|
|
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=60, step=1, value=28, visible=False) |
|
|
with gr.Group(): |
|
|
model_name = gr.Dropdown(label="Model", choices=models, value=models[0], allow_custom_value=True) |
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
with gr.Row(equal_height=True): |
|
|
model_type = gr.Dropdown(label="Model Type", choices=MODEL_TYPES, value=MODEL_TYPES[0]) |
|
|
vae = gr.Dropdown(label="VAE", choices=VAE_NAMES, value=VAE_NAMES[0], allow_custom_value=True) |
|
|
with gr.Row(equal_height=True): |
|
|
sampler = gr.Dropdown(label="Sampler", choices=SAMPLER_NAMES, value=SAMPLER_NAMES[0]) |
|
|
pred_type = gr.Dropdown(label="Sampler prediction", choices=PRED_TYPES, value=PRED_TYPES[0]) |
|
|
with gr.Row(equal_height=True): |
|
|
pipe_type = gr.Dropdown(label="Pipeline Type", choices=PIPELINE_TYPES, value=PIPELINE_TYPES[0]) |
|
|
clip_skip = gr.Slider(label="Clip Skip", minimum=0, maximum=12, step=1, value=DEFAULT_CLIP_SKIP) |
|
|
with gr.Row(equal_height=True): |
|
|
task = gr.Radio(label="Task", choices=DEFAULT_TASKS, value=DEFAULT_TASKS[0]) |
|
|
strength = gr.Slider(label="Image-to-Image / Inpainting Strength", minimum=0, maximum=1., step=0.01, value=DEFAULT_I2I_STRENGTH) |
|
|
input_image = gr.ImageEditor(label="Input Image", type="filepath", sources=["upload", "clipboard", "webcam"], image_mode='RGB', layers=False, buttons=[], canvas_size=(384, 384), width=384, height=512, |
|
|
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), eraser=gr.Eraser(default_size="32")) |
|
|
with gr.Row(equal_height=True): |
|
|
upscale_mode = gr.Dropdown(label="Upscaling", choices=UPSCALE_MODES, value=UPSCALE_MODES[0]) |
|
|
upscale_strength = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.05, value=DEFAULT_UPSCALE_STRENGTH) |
|
|
upscale_by = gr.Slider(label="Upscale by", minimum=1, maximum=1.5, step=0.1, value=DEFAULT_UPSCALE_BY) |
|
|
with gr.Row(equal_height=True): |
|
|
format = gr.Dropdown(label="Output Format", choices=FILE_FORMATS, value=FILE_FORMATS[0]) |
|
|
gpu_duration = gr.Number(minimum=0, maximum=240, value=DEFAULT_DURATION, label="GPU time duration (seconds per image)") |
|
|
|
|
|
with gr.Tab("PNG Info"): |
|
|
def extract_exif_data(image): |
|
|
if image is None: return "" |
|
|
try: |
|
|
metadata_keys = ["parameters", "metadata", "prompt", "Comment"] |
|
|
for key in metadata_keys: |
|
|
if key in image.info: |
|
|
return image.info[key] |
|
|
return str(image.info) |
|
|
except Exception as e: |
|
|
return f"Error extracting metadata: {str(e)}" |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"]) |
|
|
with gr.Column(): |
|
|
result_metadata = gr.Textbox(label="Metadata", show_label=True, buttons=["copy"], interactive=False, container=True, max_lines=99) |
|
|
image_metadata.change(fn=extract_exif_data, inputs=[image_metadata], outputs=[result_metadata], api_visibility="undocumented") |
|
|
|
|
|
gr.on(triggers=[run_button.click, prompt.submit], fn=infer, |
|
|
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, |
|
|
model_name, sampler, pred_type, vae, model_type, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by, |
|
|
input_image, strength, param_mode, ar, format, task, gpu_duration], |
|
|
outputs=[result]) |
|
|
gr.on(triggers=[run_button_multi.click, prompt_multi.submit], fn=infer_multi, |
|
|
inputs=[prompt_multi, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, |
|
|
model_name_multi, sampler, pred_type, vae, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by, |
|
|
input_image, strength, param_mode, ar, format, num_images, task, gpu_duration], |
|
|
outputs=[result_multi]) |
|
|
run_button_simple.click(fn=infer_simple, inputs=[prompt, negative_prompt, seed, randomize_seed, model_name], outputs=[result]) |
|
|
|
|
|
result.change(save_image_history, [result, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented") |
|
|
result_multi.change(save_gallery_history, [result_multi, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented") |
|
|
|
|
|
ar.change(update_ar_gr, [ar], [width, height], queue=False, api_visibility="undocumented") |
|
|
param_mode.change(update_param_mode_gr, [param_mode], [guidance_scale, num_inference_steps], queue=False, api_visibility="undocumented") |
|
|
|
|
|
demo.queue().launch(ssr_mode=False, mcp_server=True, css=css) |
|
|
|