HiDream-I1-Dev / app.py
cai-qi's picture
Update app.py
75fc5f1 verified
import logging
import os
import random
import time
import traceback
from io import BytesIO
import gradio as gr
import requests
from PIL import Image, PngImagePlugin
from dotenv import load_dotenv
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# API Configuration
API_TOKEN = os.environ.get("HIDREAM_API_TOKEN")
API_REQUEST_URL = os.environ.get("API_REQUEST_URL")
API_RESULT_URL = os.environ.get("API_RESULT_URL")
API_IMAGE_URL = os.environ.get("API_IMAGE_URL")
API_VERSION = os.environ.get("API_VERSION")
API_MODEL_NAME = os.environ.get("API_MODEL_NAME")
MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT"))
POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL"))
MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME"))
# Resolution options
ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
# Log configuration details
logger.info(f"API configuration loaded: REQUEST_URL={API_REQUEST_URL}, RESULT_URL={API_RESULT_URL}, VERSION={API_VERSION}, MODEL={API_MODEL_NAME}")
logger.info(f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s")
class APIError(Exception):
"""Custom exception for API-related errors"""
pass
def create_request(prompt, aspect_ratio="1:1", seed=-1):
"""
Create an image generation request to the API.
Args:
prompt (str): Text prompt describing the image to generate
aspect_ratio (str): Aspect ratio of the output image
seed (int): Seed for reproducibility, -1 for random
Returns:
tuple: (task_id, seed) - Task ID if successful and the seed used
Raises:
APIError: If the API request fails
"""
logger.info(f"Starting create_request with prompt='{prompt[:50]}...', aspect_ratio={aspect_ratio}, seed={seed}")
if not prompt or not prompt.strip():
logger.error("Empty prompt provided to create_request")
raise ValueError("Prompt cannot be empty")
# Validate aspect ratio
if aspect_ratio not in ASPECT_RATIO_OPTIONS:
logger.error(f"Invalid aspect ratio: {aspect_ratio}. Valid options: {', '.join(ASPECT_RATIO_OPTIONS)}")
raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(ASPECT_RATIO_OPTIONS)}")
# Generate random seed if not provided
if seed == -1:
seed = random.randint(1, 1000000)
logger.info(f"Generated random seed: {seed}")
# Validate seed
try:
seed = int(seed)
if seed < -1 or seed > 1000000:
logger.info(f"Invalid seed value: {seed}, forcing to 8888")
seed = 8888
except (TypeError, ValueError) as e:
logger.error(f"Seed validation failed: {str(e)}")
raise ValueError(f"Seed must be an integer but got {seed}")
headers = {
"Authorization": f"Bearer {API_TOKEN}",
"X-accept-language": "en",
"X-source": "api",
"Content-Type": "application/json",
}
generate_data = {
"module": "txt2img",
"prompt": prompt,
"params": {
"batch_size": 1,
"wh_ratio": aspect_ratio,
"seed": seed
},
"version": API_VERSION,
}
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(f"Sending API request [attempt {retry_count+1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'")
response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10)
# Log response status code
logger.info(f"API request response status: {response.status_code}")
response.raise_for_status()
result = response.json()
if not result or "result" not in result:
logger.error(f"Invalid API response format: {str(result)}")
raise APIError(f"Invalid response format from API when sending request: {str(result)}")
task_id = result.get("result", {}).get("task_id")
if not task_id:
logger.error(f"No task ID in API response: {str(result)}")
raise APIError(f"No task ID returned from API: {str(result)}")
logger.info(f"Successfully created task with ID: {task_id}, seed: {seed}")
return task_id, seed
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
error_message = f"HTTP error {status_code}"
try:
error_detail = e.response.json()
error_message += f": {error_detail}"
logger.error(f"API response error content: {error_detail}")
except:
logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}")
if status_code == 401:
logger.error(f"Authentication failed with API token. Status code: {status_code}")
raise APIError("Authentication failed. Please check your API token.")
elif status_code == 429:
retry_count += 1
wait_time = min(2 ** retry_count, 10) # Exponential backoff
logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(wait_time)
elif 400 <= status_code < 500:
try:
error_detail = e.response.json()
error_message += f": {error_detail.get('message', 'Client error')}"
except:
pass
logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}")
raise APIError(error_message)
else:
retry_count += 1
logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.RequestException as e:
logger.error(f"Request error: {str(e)}")
logger.debug(f"Request error details: {traceback.format_exc()}")
raise APIError(f"Failed to connect to API: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in create_request: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise APIError(f"Unexpected error: {str(e)}")
logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'")
raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
def get_results(task_id):
"""
Check the status of an image generation task.
Args:
task_id (str): The task ID to check
Returns:
dict: Task result information
Raises:
APIError: If the API request fails
"""
logger.debug(f"Checking status for task ID: {task_id}")
if not task_id:
logger.error("Empty task ID provided to get_results")
raise ValueError("Task ID cannot be empty")
url = f"{API_RESULT_URL}?task_id={task_id}"
headers = {
"Authorization": f"Bearer {API_TOKEN}",
"X-accept-language": "en",
}
try:
response = requests.get(url, headers=headers, timeout=10)
logger.debug(f"Status check response code: {response.status_code}")
response.raise_for_status()
result = response.json()
if not result or "result" not in result:
logger.warning(f"Invalid response format from API when checking task {task_id}: {str(result)}")
raise APIError(f"Invalid response format from API when checking task {task_id}: {str(result)}")
return result
except requests.exceptions.Timeout:
logger.warning(f"Request timed out when checking task {task_id}")
return None
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
logger.warning(f"HTTP error {status_code} when checking task {task_id}")
try:
error_content = e.response.json()
logger.error(f"Error response content: {error_content}")
except:
logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}")
if status_code == 401:
logger.error(f"Authentication failed when checking task {task_id}")
raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}")
elif 400 <= status_code < 500:
try:
error_detail = e.response.json()
error_message = f"HTTP error {status_code}: {error_detail.get('message', 'Client error')}"
except:
error_message = f"HTTP error {status_code}"
logger.error(error_message)
return None
else:
logger.warning(f"Server error {status_code} when checking task {task_id}")
return None
except requests.exceptions.RequestException as e:
logger.warning(f"Network error when checking task {task_id}: {str(e)}")
logger.debug(f"Network error details: {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return None
def download_image(image_url):
"""
Download an image from a URL and return it as a PIL Image.
Converts WebP to PNG format while preserving original image data.
Args:
image_url (str): URL of the image
Returns:
PIL.Image: Downloaded image object converted to PNG format
Raises:
APIError: If the download fails
"""
logger.info(f"Starting download_image from URL: {image_url}")
if not image_url:
logger.error("Empty image URL provided to download_image")
raise ValueError("Image URL cannot be empty when downloading image")
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(f"Downloading image [attempt {retry_count+1}/{MAX_RETRY_COUNT}] from {image_url}")
response = requests.get(image_url, timeout=15)
logger.debug(f"Image download response status: {response.status_code}, Content-Type: {response.headers.get('Content-Type')}, Content-Length: {response.headers.get('Content-Length')}")
response.raise_for_status()
# Open the image from response content
image = Image.open(BytesIO(response.content))
logger.info(f"Image opened successfully. Format: {image.format}, Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
# Get original metadata before conversion
original_metadata = {}
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
original_metadata[key] = value
logger.debug(f"Original image metadata: {original_metadata}")
# Convert to PNG regardless of original format (WebP, JPEG, etc.)
if image.format != 'PNG':
logger.info(f"Converting image from {image.format} to PNG format")
png_buffer = BytesIO()
# If the image has an alpha channel, preserve it, otherwise convert to RGB
if 'A' in image.getbands():
logger.debug("Preserving alpha channel in image conversion")
image_to_save = image
else:
logger.debug("Converting image to RGB mode")
image_to_save = image.convert('RGB')
image_to_save.save(png_buffer, format='PNG')
png_buffer.seek(0)
image = Image.open(png_buffer)
logger.debug(f"Image converted to PNG. New size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
# Preserve original metadata
for key, value in original_metadata.items():
image.info[key] = value
logger.debug("Original metadata preserved in converted image")
logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
return image
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
logger.error(f"HTTP error {status_code} when downloading image from {image_url}")
try:
error_content = e.response.text[:500]
logger.error(f"Error response content: {error_content}")
except:
logger.error("Could not read error response content")
if 400 <= status_code < 500:
error_message = f"HTTP error {status_code} when downloading image"
logger.error(error_message)
raise APIError(error_message)
else:
retry_count += 1
logger.warning(f"Server error {status_code}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.RequestException as e:
retry_count += 1
logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
logger.debug(f"Network error details: {traceback.format_exc()}")
time.sleep(1)
except Exception as e:
logger.error(f"Error processing image from {image_url}: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise APIError(f"Failed to process image: {str(e)}")
logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries")
raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
def add_metadata_to_image(image, metadata):
"""
Add metadata to a PIL image.
Args:
image (PIL.Image): The image to add metadata to
metadata (dict): Metadata to add to the image
Returns:
PIL.Image: Image with metadata
"""
logger.debug(f"Adding metadata to image: {metadata}")
if not image:
logger.error("Null image provided to add_metadata_to_image")
return None
try:
# Get any existing metadata
existing_metadata = {}
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
existing_metadata[key] = value
logger.debug(f"Existing image metadata: {existing_metadata}")
# Merge with new metadata (new values override existing ones)
all_metadata = {**existing_metadata, **metadata}
logger.debug(f"Combined metadata: {all_metadata}")
# Create a new metadata dictionary for PNG
meta = PngImagePlugin.PngInfo()
# Add each metadata item
for key, value in all_metadata.items():
meta.add_text(key, str(value))
# Save with metadata to a buffer
buffer = BytesIO()
image.save(buffer, format='PNG', pnginfo=meta)
logger.debug("Image saved to buffer with metadata")
# Reload the image from the buffer
buffer.seek(0)
result_image = Image.open(buffer)
logger.debug("Image reloaded from buffer with metadata")
return result_image
except Exception as e:
logger.error(f"Failed to add metadata to image: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return image # Return original image if metadata addition fails
# Create Gradio interface
def create_ui():
logger.info("Creating Gradio UI")
with gr.Blocks(title="HiDream-I1-Dev Image Generator", theme=gr.themes.Soft()) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=4):
gr.Markdown("""
# HiDream-I1-Dev Image Generator
Generate high-quality images from text descriptions using state-of-the-art AI
[πŸ€— HuggingFace](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) |
[GitHub](https://github.com/HiDream-ai/HiDream-I1) |
[Twitter](https://x.com/vivago_ai)
<span style="color: #FF5733; font-weight: bold">For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/).</span>
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="A vibrant and dynamic graffiti mural adorns a weathered brick wall in a bustling urban alleyway, a burst of color and energy amidst the city's grit. Boldly spray-painted letters declare \"HiDream.ai\" alongside other intricate street art designs, a testament to creative expression in the urban landscape.",
lines=3
)
with gr.Row():
aspect_ratio = gr.Radio(
choices=ASPECT_RATIO_OPTIONS,
value=ASPECT_RATIO_OPTIONS[2],
label="Aspect Ratio",
info="Select image aspect ratio"
)
seed = gr.Number(
label="Seed (use -1 for random)",
value=82706,
precision=0
)
with gr.Row():
generate_btn = gr.Button("Generate Image", variant="primary")
clear_btn = gr.Button("Clear")
seed_used = gr.Number(label="Seed Used", interactive=False)
status_msg = gr.Markdown("Status: Ready")
progress = gr.Progress(track_tqdm=False)
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image", format="png", type="pil", interactive=False)
with gr.Accordion("Image Information", open=False):
image_info = gr.JSON(label="Details")
# Status message update function
def update_status(step):
return f"Status: {step}"
# Generate function with status updates
def generate_with_status(prompt, aspect_ratio, seed, progress=gr.Progress()):
logger.info(f"Starting image generation with prompt='{prompt[:50]}...', aspect_ratio={aspect_ratio}, seed={seed}")
status_update = "Sending request to API..."
yield None, seed, status_update, None
try:
if not prompt.strip():
logger.error("Empty prompt provided in UI")
status_update = "Error: Prompt cannot be empty"
yield None, seed, status_update, None
return
# Create request
logger.info("Creating API request")
task_id, used_seed = create_request(prompt, aspect_ratio, seed)
status_update = f"Request sent. Task ID: {task_id}. Waiting for results..."
yield None, used_seed, status_update, None
# Poll for results
start_time = time.time()
last_completion_ratio = 0
progress(0, desc="Initializing...")
logger.info(f"Starting to poll for results for task ID: {task_id}")
while time.time() - start_time < MAX_POLL_TIME:
elapsed_time = time.time() - start_time
logger.debug(f"Polling for results - Task ID: {task_id}, Elapsed time: {elapsed_time:.2f}s")
result = get_results(task_id)
if not result:
logger.debug(f"No result yet for task ID: {task_id}, waiting {POLL_INTERVAL}s...")
time.sleep(POLL_INTERVAL)
continue
sub_results = result.get("result", {}).get("sub_task_results", [])
if not sub_results:
logger.debug(f"No sub-task results yet for task ID: {task_id}, waiting {POLL_INTERVAL}s...")
time.sleep(POLL_INTERVAL)
continue
status = sub_results[0].get("task_status")
logger.debug(f"Task status for ID {task_id}: {status}")
# Get and display completion ratio
completion_ratio = sub_results[0].get('task_completion', 0) * 100
if completion_ratio != last_completion_ratio:
# Only update UI when completion ratio changes
last_completion_ratio = completion_ratio
progress_bar = "β–ˆ" * int(completion_ratio / 10) + "β–‘" * (10 - int(completion_ratio / 10))
status_update = f"Generating image: {completion_ratio}% complete"
progress(completion_ratio / 100, desc=f"Generating image")
logger.info(f"Generation progress - Task ID: {task_id}, Completion: {completion_ratio:.1f}%")
yield None, used_seed, status_update, None
# Check task status
if status == 1: # Success
logger.info(f"Task completed successfully - Task ID: {task_id}")
progress(1.0, desc="Generation complete")
image_name = sub_results[0].get("image")
if not image_name:
logger.error(f"No image name in successful response. Response: {sub_results[0]}")
status_update = "Error: No image name in successful response"
yield None, used_seed, status_update, None
return
status_update = "Downloading generated image..."
yield None, used_seed, status_update, None
image_url = f"{API_IMAGE_URL}{image_name}.png"
logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}")
image = download_image(image_url)
if image:
# Add metadata to the image
logger.info(f"Adding metadata to image - Task ID: {task_id}")
metadata = {
"prompt": prompt,
"seed": str(used_seed),
"model": API_MODEL_NAME,
"aspect_ratio": aspect_ratio,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"generated_by": "HiDream-I1-Dev Generator"
}
image_with_metadata = add_metadata_to_image(image, metadata)
# Create info for display
info = {
"model": API_MODEL_NAME,
"prompt": prompt,
"seed": used_seed,
"aspect_ratio": aspect_ratio,
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S")
}
status_update = "Image generated successfully!"
logger.info(f"Image generation complete - Task ID: {task_id}")
yield image_with_metadata, used_seed, status_update, info
return
else:
logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}")
status_update = "Error: Failed to download the generated image"
yield None, used_seed, status_update, None
return
elif status in {3, 4}: # Failed or Canceled
error_msg = sub_results[0].get("task_error", "Unknown error")
logger.error(f"Task failed - Task ID: {task_id}, Status: {status}, Error: {error_msg}")
status_update = f"Error: Task failed with status {status}: {error_msg}"
yield None, used_seed, status_update, None
return
# Only update time elapsed if completion ratio didn't change
if completion_ratio == last_completion_ratio:
status_update = f"Waiting for image generation... {completion_ratio}% complete ({int(time.time() - start_time)}s elapsed)"
yield None, used_seed, status_update, None
time.sleep(POLL_INTERVAL)
logger.error(f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s")
status_update = f"Error: Timeout waiting for image generation after {MAX_POLL_TIME} seconds"
yield None, used_seed, status_update, None
except APIError as e:
logger.error(f"API Error during generation: {str(e)}")
status_update = f"API Error: {str(e)}"
yield None, seed, status_update, None
except ValueError as e:
logger.error(f"Value Error during generation: {str(e)}")
status_update = f"Value Error: {str(e)}"
yield None, seed, status_update, None
except Exception as e:
logger.error(f"Unexpected error during image generation: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
status_update = f"Unexpected error: {str(e)}"
yield None, seed, status_update, None
# Set up event handlers
generate_btn.click(
fn=generate_with_status,
inputs=[prompt, aspect_ratio, seed],
outputs=[output_image, seed_used, status_msg, image_info]
)
def clear_outputs():
logger.info("Clearing UI outputs")
return None, -1, "Status: Ready", None
clear_btn.click(
fn=clear_outputs,
inputs=None,
outputs=[output_image, seed_used, status_msg, image_info]
)
# Examples
gr.Examples(
examples=[
[
"A vibrant and dynamic graffiti mural adorns a weathered brick wall in a bustling urban alleyway, a burst of color and energy amidst the city's grit. Boldly spray-painted letters declare \"HiDream.ai\" alongside other intricate street art designs, a testament to creative expression in the urban landscape.",
"4:3", 82706],
[
"A modern art interpretation of a traditional landscape painting, using bold colors and abstract forms to represent mountains, rivers, and mist. Incorporate calligraphic elements and a sense of dynamic energy.",
"1:1", 661320],
[
"Intimate portrait of a young woman from a nomadic tribe in ancient China, wearing fur-trimmed clothing and intricate silver jewelry. Wind-swept hair and a resilient gaze. Background of a vast, open grassland under a dramatic sky.",
"1:1", 34235],
[
"Time-lapse concept: A single tree shown through four seasons simultaneously, spring blossoms, summer green, autumn colors, winter snow, blended seamlessly.",
"1:1", 241106]
],
inputs=[prompt, aspect_ratio, seed],
outputs=[output_image, seed_used, status_msg, image_info],
fn=generate_with_status,
cache_examples=True,
)
logger.info("Gradio UI created successfully")
return demo
# Launch app
if __name__ == "__main__":
logger.info("Starting HiDream-I1-Dev Image Generator application")
demo = create_ui()
logger.info("Launching Gradio interface with queue")
demo.queue(max_size=50, default_concurrency_limit=4).launch(show_api=False)
logger.info("Application shutdown")