UDface11jkj commited on
Commit
e4cf476
Β·
verified Β·
1 Parent(s): 17a9926

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +87 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,89 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import os
5
+ import time
6
+ import tempfile
7
+ from huggingface_hub import snapshot_download
8
 
9
+ class ImageGenerator:
10
+ def __init__(self, ae_path, dit_path, qwen2vl_model_path, max_length=640):
11
+ # Initialize the model with the provided paths
12
+ self.ae_path = ae_path
13
+ self.dit_path = dit_path
14
+ self.qwen2vl_model_path = qwen2vl_model_path
15
+ self.max_length = max_length
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.load_model()
18
+
19
+ def load_model(self):
20
+ # Load model weights or any necessary model setup here
21
+ pass
22
+
23
+ def to_cuda(self):
24
+ # Move model to GPU if available
25
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ # Example: Loading your model (use actual code to load)
27
+ self.model = torch.load(self.ae_path, map_location=self.device)
28
+ # Additional model loading logic for your specific case
29
+
30
+ def inference(prompt, image, seed, size_level, model):
31
+ # Add model prediction logic here
32
+ # Example: Pass image and prompt to the model to generate output
33
+ # Modify according to your actual model's inference code
34
+ result_image = image # Placeholder, replace with actual generation logic
35
+ used_seed = seed if seed != -1 else int(time.time()) # Use random seed if -1
36
+ return result_image, used_seed
37
+
38
+ # Set page config for better UI layout
39
+ st.set_page_config(page_title="Ghibli style", layout="centered")
40
+ st.title("πŸ–ΌοΈ Ghibli style for Free : AI Image Editing")
41
+ st.markdown("Ghibli style images with AI.")
42
+
43
+ # === User Inputs ===
44
+ prompt = "Turn into an illustration in Studio Ghibli style"
45
+ uploaded_image = st.file_uploader("πŸ“€ Upload an Image", type=["jpg", "jpeg", "png"])
46
+ seed = st.number_input("🎲 Random Seed (-1 for random)", value=-1, step=1)
47
+ size_level = st.number_input("πŸ“ Size Level (minimum 512)", value=512, min_value=512, step=32)
48
+
49
+ generate_button = st.button("πŸš€ Generate")
50
+
51
+ # === Load Model (Cached) ===
52
+ @st.cache_resource
53
+ def load_model():
54
+ repo = "stepfun-ai/Step1X-Edit"
55
+ local_dir = "./step1x_weights"
56
+ os.makedirs(local_dir, exist_ok=True)
57
+ snapshot_download(repo_id=repo, local_dir=local_dir, local_dir_use_symlinks=False)
58
+
59
+ model = ImageGenerator(
60
+ ae_path=os.path.join(local_dir, 'vae.safetensors'),
61
+ dit_path=os.path.join(local_dir, "step1x-edit-i1258.safetensors"),
62
+ qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
63
+ max_length=640
64
+ )
65
+ return model
66
+
67
+ image_edit_model = load_model()
68
+
69
+ # === Inference and Image Display ===
70
+ if generate_button and uploaded_image is not None:
71
+ input_image = Image.open(uploaded_image).convert("RGB")
72
+ # Resize image for faster inference (adjust to your model's requirements)
73
+ input_image.thumbnail((size_level, size_level))
74
+
75
+ with st.spinner("πŸ”„ Generating edited image..."):
76
+ start = time.time()
77
+ try:
78
+ result_image, used_seed = inference(prompt, input_image, seed, size_level, image_edit_model)
79
+ end = time.time()
80
+
81
+ st.success(f"βœ… Done in {end - start:.2f} seconds β€” Seed used: {used_seed}")
82
+
83
+ # Save and display the result in temporary file
84
+ with tempfile.NamedTemporaryFile(dir="/tmp", delete=False, suffix=".png") as temp_file:
85
+ result_image.save(temp_file.name)
86
+ st.image(temp_file.name, caption="πŸ–ΌοΈ Edited Image", use_column_width=True)
87
+ except Exception as e:
88
+ st.error(f"❌ Inference failed: {e}")
89
+ st.stop()