File size: 9,478 Bytes
aefd7f3 adefd5c aefd7f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import spaces
import gradio as gr
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from t2i_config import KERNELS_PREFETCH_ON_STARTUP, KERNELS_PREFETCH_REPOS
if KERNELS_PREFETCH_ON_STARTUP:
try:
from kernels import has_kernel, get_kernel
for _repo_id in KERNELS_PREFETCH_REPOS:
if has_kernel(_repo_id):
get_kernel(_repo_id)
except Exception as _e:
print(f"INFO : Kernels prefetch skipped: {_e}")
from t2i.infer import (infer, infer_multi, infer_simple, save_image_history, save_gallery_history,
update_param_mode_gr, update_ar_gr,
MAX_SEED, MAX_IMAGE_SIZE, ASPECT_RATIOS, FILE_FORMATS, DEFAULT_TASKS, DEFAULT_DURATION,
DEFAULT_I2I_STRENGTH, DEFAULT_UPSCALE_STRENGTH, DEFAULT_UPSCALE_BY, DEFAULT_CLIP_SKIP,
models, MODEL_TYPES, SAMPLER_NAMES, PRED_TYPES, VAE_NAMES,
UPSCALE_MODES, PARAM_MODES, PIPELINE_TYPES)
css = """
#col-container {
margin: 0 auto;
max-width: 1080px;
}
"""
with gr.Blocks(fill_height=True, fill_width=True) as demo:
with gr.Tab("Image Generator"):
lora_dict = gr.State({})
with gr.Column(elem_id="col-container"):
with gr.Tab("Normal"):
with gr.Row():
prompt = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
run_button = gr.Button("Run", scale=0)
run_button_simple = gr.Button("Simple", scale=0, visible=False) # for API
result = gr.Image(label="Result", show_label=False, format="png", type="filepath", interactive=False, buttons=["download", "fullscreen"])
with gr.Tab("Multi"):
with gr.Row():
prompt_multi = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
run_button_multi = gr.Button("Run", scale=0)
model_name_multi = gr.Dropdown(label="Model", choices=models, value=models[0], multiselect=True, allow_custom_value=True)
num_images = gr.Slider(label="Count", minimum=1, maximum=16, step=1, value=1)
result_multi = gr.Gallery(label="Result", columns=2, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"])
with gr.Accordion("Output History", open=False):
history_files = gr.Files(interactive=False, visible=False)
history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"])
history_clear_button = gr.Button(value="Clear History", variant="secondary")
history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, api_visibility="undocumented")
with gr.Group():
negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt",
value="") # nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn
with gr.Row(equal_height=True):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row(equal_height=True):
param_mode = gr.Radio(label="Parameter Settings", choices=PARAM_MODES, value=PARAM_MODES[0])
ar = gr.Dropdown(label="Aspect Ratio", choices=ASPECT_RATIOS, value=ASPECT_RATIOS[0])
with gr.Row(equal_height=True):
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=7, visible=False)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=60, step=1, value=28, visible=False)
with gr.Group():
model_name = gr.Dropdown(label="Model", choices=models, value=models[0], allow_custom_value=True)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row(equal_height=True):
model_type = gr.Dropdown(label="Model Type", choices=MODEL_TYPES, value=MODEL_TYPES[0])
vae = gr.Dropdown(label="VAE", choices=VAE_NAMES, value=VAE_NAMES[0], allow_custom_value=True)
with gr.Row(equal_height=True):
sampler = gr.Dropdown(label="Sampler", choices=SAMPLER_NAMES, value=SAMPLER_NAMES[0])
pred_type = gr.Dropdown(label="Sampler prediction", choices=PRED_TYPES, value=PRED_TYPES[0])
with gr.Row(equal_height=True):
pipe_type = gr.Dropdown(label="Pipeline Type", choices=PIPELINE_TYPES, value=PIPELINE_TYPES[0])
clip_skip = gr.Slider(label="Clip Skip", minimum=0, maximum=12, step=1, value=DEFAULT_CLIP_SKIP)
with gr.Row(equal_height=True):
task = gr.Radio(label="Task", choices=DEFAULT_TASKS, value=DEFAULT_TASKS[0])
strength = gr.Slider(label="Image-to-Image / Inpainting Strength", minimum=0, maximum=1., step=0.01, value=DEFAULT_I2I_STRENGTH)
input_image = gr.ImageEditor(label="Input Image", type="filepath", sources=["upload", "clipboard", "webcam"], image_mode='RGB', layers=False, buttons=[], canvas_size=(384, 384), width=384, height=512,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), eraser=gr.Eraser(default_size="32"))
with gr.Row(equal_height=True):
upscale_mode = gr.Dropdown(label="Upscaling", choices=UPSCALE_MODES, value=UPSCALE_MODES[0])
upscale_strength = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.05, value=DEFAULT_UPSCALE_STRENGTH)
upscale_by = gr.Slider(label="Upscale by", minimum=1, maximum=1.5, step=0.1, value=DEFAULT_UPSCALE_BY)
with gr.Row(equal_height=True):
format = gr.Dropdown(label="Output Format", choices=FILE_FORMATS, value=FILE_FORMATS[0])
gpu_duration = gr.Number(minimum=0, maximum=240, value=DEFAULT_DURATION, label="GPU time duration (seconds per image)")
with gr.Tab("PNG Info"):
def extract_exif_data(image):
if image is None: return ""
try:
metadata_keys = ["parameters", "metadata", "prompt", "Comment"]
for key in metadata_keys:
if key in image.info:
return image.info[key]
return str(image.info)
except Exception as e:
return f"Error extracting metadata: {str(e)}"
with gr.Row():
with gr.Column():
image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
with gr.Column():
result_metadata = gr.Textbox(label="Metadata", show_label=True, buttons=["copy"], interactive=False, container=True, max_lines=99)
image_metadata.change(fn=extract_exif_data, inputs=[image_metadata], outputs=[result_metadata], api_visibility="undocumented")
gr.on(triggers=[run_button.click, prompt.submit], fn=infer,
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
model_name, sampler, pred_type, vae, model_type, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
input_image, strength, param_mode, ar, format, task, gpu_duration],
outputs=[result])
gr.on(triggers=[run_button_multi.click, prompt_multi.submit], fn=infer_multi,
inputs=[prompt_multi, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
model_name_multi, sampler, pred_type, vae, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
input_image, strength, param_mode, ar, format, num_images, task, gpu_duration],
outputs=[result_multi])
run_button_simple.click(fn=infer_simple, inputs=[prompt, negative_prompt, seed, randomize_seed, model_name], outputs=[result])
result.change(save_image_history, [result, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented")
result_multi.change(save_gallery_history, [result_multi, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented")
ar.change(update_ar_gr, [ar], [width, height], queue=False, api_visibility="undocumented")
param_mode.change(update_param_mode_gr, [param_mode], [guidance_scale, num_inference_steps], queue=False, api_visibility="undocumented")
demo.queue().launch(ssr_mode=False, mcp_server=True, css=css)
|