Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| #from utils import * | |
| import random | |
| is_clicked = False | |
| out_img_list = ['', '', '', '', ''] | |
| out_state_list = [False, False, False, False, False] | |
| def fn_query_on_load(): | |
| return "Cats at sunset" | |
| def fn_refresh(): | |
| return out_img_list | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| # Stable Diffusion Image Generation | |
| ### Enter query to generate images in various styles | |
| """) | |
| with gr.Row(visible=True): | |
| with gr.Column(): | |
| with gr.Row(): | |
| search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None) | |
| with gr.Row(visible=True): | |
| #with gr.Column(): | |
| out1 = gr.Image(value="out1.png", interactive=False, width=128, label='Oil Painting') | |
| #submit1 = gr.Button("Submit", variant='primary') | |
| #with gr.Column(): | |
| out2 = gr.Image(value="out2.png", interactive=False, width=128, label='Low Poly HD Style') | |
| #submit2 = gr.Button("Submit", variant='primary') | |
| #with gr.Column(): | |
| out3 = gr.Image(value="out3.png", interactive=False, width=128, label='Matrix style') | |
| #submit3 = gr.Button("Submit", variant='primary') | |
| #with gr.Column(): | |
| out4 = gr.Image(value="out4.png", interactive=False, width=128, label='Dreamy Painting') | |
| #submit4 = gr.Button("Submit", variant='primary') | |
| #with gr.Column(): | |
| out5 = gr.Image(value="out5.png", interactive=False, width=128, label='Depth Map Style') | |
| #submit5 = gr.Button("Submit", variant='primary') | |
| with gr.Row(visible=True): | |
| clear_btn = gr.ClearButton() | |
| def clear_data(): | |
| return { | |
| out1: None, | |
| out2: None, | |
| out3: None, | |
| out4: None, | |
| out5: None, | |
| search_text: None | |
| } | |
| clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text]) | |
| '''def func_generate(query, concept_idx, seed): | |
| prompt = query + ' in the style of bulb' | |
| text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, | |
| return_tensors="pt") | |
| input_ids = text_input.input_ids.to(torch_device) | |
| # Get token embeddings | |
| position_ids = text_encoder.text_model.embeddings.position_ids[:, :77] | |
| position_embeddings = pos_emb_layer(position_ids) | |
| s = seed | |
| token_embeddings = token_emb_layer(input_ids) | |
| # The new embedding - our special birb word | |
| replacement_token_embedding = concept_embeds[concept_idx].to(torch_device) | |
| # Insert this into the token embeddings | |
| token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device) | |
| # Combine with pos embs | |
| input_embeddings = token_embeddings + position_embeddings | |
| # Feed through to get final output embs | |
| modified_output_embeddings = get_output_embeds(input_embeddings) | |
| # And generate an image with this: | |
| s = random.randint(s + 1, s + 30) | |
| g = torch.manual_seed(s) | |
| return generate_with_embs(text_input, modified_output_embeddings, generator=g) | |
| def generate_oil_painting(query): | |
| return { | |
| out1: func_generate(query, 0, 0) | |
| } | |
| def generate_low_poly_hd(query): | |
| return { | |
| out2: func_generate(query, 1, 30) | |
| } | |
| def generate_matrix_style(query): | |
| return { | |
| out3: func_generate(query, 2, 60) | |
| } | |
| def generate_dreamy_painting(query): | |
| return { | |
| out4: func_generate(query, 3, 90) | |
| } | |
| def generate_depth_map_style(query): | |
| return { | |
| out5: func_generate(query, 4, 120) | |
| } | |
| submit1.click( | |
| generate_oil_painting, | |
| search_text, | |
| out1 | |
| ) | |
| submit2.click( | |
| generate_low_poly_hd, | |
| search_text, | |
| out2 | |
| ) | |
| submit3.click( | |
| generate_matrix_style, | |
| search_text, | |
| out3 | |
| ) | |
| submit4.click( | |
| generate_dreamy_painting, | |
| search_text, | |
| out4 | |
| ) | |
| submit5.click( | |
| generate_depth_map_style, | |
| search_text, | |
| out5 | |
| ) | |
| ''' | |
| ''' | |
| Launch the app | |
| ''' | |
| app.launch() | |