File size: 9,478 Bytes
aefd7f3
 
 
adefd5c
 
 
 
 
 
 
 
 
 
 
aefd7f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import spaces
import gradio as gr
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from t2i_config import KERNELS_PREFETCH_ON_STARTUP, KERNELS_PREFETCH_REPOS

if KERNELS_PREFETCH_ON_STARTUP:
    try:
        from kernels import has_kernel, get_kernel
        for _repo_id in KERNELS_PREFETCH_REPOS:
            if has_kernel(_repo_id):
                get_kernel(_repo_id)
    except Exception as _e:
        print(f"INFO : Kernels prefetch skipped: {_e}")

from t2i.infer import (infer, infer_multi, infer_simple, save_image_history, save_gallery_history,
                       update_param_mode_gr, update_ar_gr,
                       MAX_SEED, MAX_IMAGE_SIZE, ASPECT_RATIOS, FILE_FORMATS, DEFAULT_TASKS, DEFAULT_DURATION,
                       DEFAULT_I2I_STRENGTH, DEFAULT_UPSCALE_STRENGTH, DEFAULT_UPSCALE_BY, DEFAULT_CLIP_SKIP,
                       models, MODEL_TYPES, SAMPLER_NAMES, PRED_TYPES, VAE_NAMES,
                       UPSCALE_MODES, PARAM_MODES, PIPELINE_TYPES)

css = """
#col-container {
    margin: 0 auto;
    max-width: 1080px;
}
"""

with gr.Blocks(fill_height=True, fill_width=True) as demo:
    with gr.Tab("Image Generator"):
        lora_dict = gr.State({})
        with gr.Column(elem_id="col-container"):
            with gr.Tab("Normal"):
                with gr.Row():
                    prompt = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
                    run_button = gr.Button("Run", scale=0)
                    run_button_simple = gr.Button("Simple", scale=0, visible=False) # for API
                result = gr.Image(label="Result", show_label=False, format="png", type="filepath", interactive=False, buttons=["download", "fullscreen"])

            with gr.Tab("Multi"):
                with gr.Row():
                    prompt_multi = gr.Text(label="Prompt", show_label=False, lines=1, placeholder="Enter your prompt", container=False)
                    run_button_multi = gr.Button("Run", scale=0)
                model_name_multi = gr.Dropdown(label="Model", choices=models, value=models[0], multiselect=True, allow_custom_value=True)
                num_images = gr.Slider(label="Count", minimum=1, maximum=16, step=1, value=1)
                result_multi = gr.Gallery(label="Result", columns=2, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"])

            with gr.Accordion("Output History", open=False):
                history_files = gr.Files(interactive=False, visible=False)
                history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, buttons=["download", "fullscreen"])
                history_clear_button = gr.Button(value="Clear History", variant="secondary")
                history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, api_visibility="undocumented")
            
            with gr.Group():
                negative_prompt = gr.Text(label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt",
                                            value="") # nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn
                with gr.Row(equal_height=True):
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                with gr.Row(equal_height=True):
                    param_mode = gr.Radio(label="Parameter Settings", choices=PARAM_MODES, value=PARAM_MODES[0])
                    ar = gr.Dropdown(label="Aspect Ratio", choices=ASPECT_RATIOS, value=ASPECT_RATIOS[0])
                with gr.Row(equal_height=True):
                    width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
                    height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, visible=False)
                    guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=7, visible=False)
                    num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=60, step=1, value=28, visible=False)
                with gr.Group():
                    model_name = gr.Dropdown(label="Model", choices=models, value=models[0], allow_custom_value=True)
                with gr.Accordion("Advanced Settings", open=False):
                    with gr.Row(equal_height=True):
                        model_type = gr.Dropdown(label="Model Type", choices=MODEL_TYPES, value=MODEL_TYPES[0])
                        vae = gr.Dropdown(label="VAE", choices=VAE_NAMES, value=VAE_NAMES[0], allow_custom_value=True)
                    with gr.Row(equal_height=True):
                        sampler = gr.Dropdown(label="Sampler", choices=SAMPLER_NAMES, value=SAMPLER_NAMES[0])
                        pred_type = gr.Dropdown(label="Sampler prediction", choices=PRED_TYPES, value=PRED_TYPES[0])
                    with gr.Row(equal_height=True):
                        pipe_type = gr.Dropdown(label="Pipeline Type", choices=PIPELINE_TYPES, value=PIPELINE_TYPES[0])
                        clip_skip = gr.Slider(label="Clip Skip", minimum=0, maximum=12, step=1, value=DEFAULT_CLIP_SKIP)
                    with gr.Row(equal_height=True):
                        task = gr.Radio(label="Task", choices=DEFAULT_TASKS, value=DEFAULT_TASKS[0])
                        strength = gr.Slider(label="Image-to-Image / Inpainting Strength", minimum=0, maximum=1., step=0.01, value=DEFAULT_I2I_STRENGTH)
                    input_image = gr.ImageEditor(label="Input Image", type="filepath", sources=["upload", "clipboard", "webcam"], image_mode='RGB', layers=False, buttons=[], canvas_size=(384, 384), width=384, height=512,
                                                    brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), eraser=gr.Eraser(default_size="32"))
                    with gr.Row(equal_height=True):
                        upscale_mode = gr.Dropdown(label="Upscaling", choices=UPSCALE_MODES, value=UPSCALE_MODES[0])
                        upscale_strength = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.05, value=DEFAULT_UPSCALE_STRENGTH)
                        upscale_by = gr.Slider(label="Upscale by", minimum=1, maximum=1.5, step=0.1, value=DEFAULT_UPSCALE_BY)
                    with gr.Row(equal_height=True):
                        format = gr.Dropdown(label="Output Format", choices=FILE_FORMATS, value=FILE_FORMATS[0])
                        gpu_duration = gr.Number(minimum=0, maximum=240, value=DEFAULT_DURATION, label="GPU time duration (seconds per image)")

    with gr.Tab("PNG Info"):
        def extract_exif_data(image):
            if image is None: return ""
            try:
                metadata_keys = ["parameters", "metadata", "prompt", "Comment"]
                for key in metadata_keys:
                    if key in image.info:
                        return image.info[key]
                return str(image.info)
            except Exception as e:
                return f"Error extracting metadata: {str(e)}"
        with gr.Row():
            with gr.Column():
                image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
            with gr.Column():
                result_metadata = gr.Textbox(label="Metadata", show_label=True, buttons=["copy"], interactive=False, container=True, max_lines=99)
                image_metadata.change(fn=extract_exif_data, inputs=[image_metadata], outputs=[result_metadata], api_visibility="undocumented")

    gr.on(triggers=[run_button.click, prompt.submit], fn=infer,
          inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
                  model_name, sampler, pred_type, vae, model_type, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
                  input_image, strength, param_mode, ar, format, task, gpu_duration],
          outputs=[result])
    gr.on(triggers=[run_button_multi.click, prompt_multi.submit], fn=infer_multi,
          inputs=[prompt_multi, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
                  model_name_multi, sampler, pred_type, vae, clip_skip, pipe_type, lora_dict, upscale_mode, upscale_strength, upscale_by,
                  input_image, strength, param_mode, ar, format, num_images, task, gpu_duration],
          outputs=[result_multi])
    run_button_simple.click(fn=infer_simple, inputs=[prompt, negative_prompt, seed, randomize_seed, model_name], outputs=[result])

    result.change(save_image_history, [result, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented")
    result_multi.change(save_gallery_history, [result_multi, history_gallery, history_files], [history_gallery, history_files], queue=False, api_visibility="undocumented")

    ar.change(update_ar_gr, [ar], [width, height], queue=False, api_visibility="undocumented")
    param_mode.change(update_param_mode_gr, [param_mode], [guidance_scale, num_inference_steps], queue=False, api_visibility="undocumented")

demo.queue().launch(ssr_mode=False, mcp_server=True, css=css)