#!/usr/bin/env python import os import re import tempfile from collections.abc import Iterator from threading import Thread import cv2 import gradio as gr import spaces import torch from loguru import logger from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer # CSV/TXT 분석 import pandas as pd # PDF 텍스트 추출 import PyPDF2 MAX_CONTENT_CHARS = 8000 # 너무 큰 파일을 막기 위해 최대 표시 8000자 model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it") processor = AutoProcessor.from_pretrained(model_id, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5")) ################################################## # CSV, TXT, PDF 분석 함수 ################################################## def analyze_csv_file(path: str) -> str: """ CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시. """ try: df = pd.read_csv(path) df_str = df.to_string() if len(df_str) > MAX_CONTENT_CHARS: df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}" except Exception as e: return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}" def analyze_txt_file(path: str) -> str: """ TXT 파일 전문 읽기. 너무 길면 일부만 표시. """ try: with open(path, "r", encoding="utf-8") as f: text = f.read() if len(text) > MAX_CONTENT_CHARS: text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}" except Exception as e: return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}" def pdf_to_markdown(pdf_path: str) -> str: """ PDF → Markdown. 페이지별로 간단히 텍스트 추출. """ text_chunks = [] try: with open(pdf_path, "rb") as f: reader = PyPDF2.PdfReader(f) for page_num, page in enumerate(reader.pages, start=1): page_text = page.extract_text() or "" page_text = page_text.strip() if page_text: text_chunks.append(f"## Page {page_num}\n\n{page_text}\n") except Exception as e: return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}" full_text = "\n".join(text_chunks) if len(full_text) > MAX_CONTENT_CHARS: full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}" ################################################## # 이미지/비디오 업로드 제한 검사 ################################################## def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: image_count = 0 video_count = 0 for path in paths: if path.endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def count_files_in_history(history: list[dict]) -> tuple[int, int]: image_count = 0 video_count = 0 for item in history: if item["role"] != "user" or isinstance(item["content"], str): continue if item["content"][0].endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def validate_media_constraints(message: dict, history: list[dict]) -> bool: """ - 비디오 1개 초과 불가 - 비디오와 이미지 혼합 불가 - 이미지 개수 MAX_NUM_IMAGES 초과 불가 - 태그가 있으면 태그 수와 실제 이미지 수 일치 - CSV, TXT, PDF 등은 여기서 제한하지 않음 """ media_files = [] for f in message["files"]: # 이미지: png/jpg/jpeg/gif/webp # 비디오: mp4 # cf) PDF, CSV, TXT 등은 제외 if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"): media_files.append(f) new_image_count, new_video_count = count_files_in_new_message(media_files) history_image_count, history_video_count = count_files_in_history(history) image_count = history_image_count + new_image_count video_count = history_video_count + new_video_count if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1: if image_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False if "" in message["text"]: gr.Warning("Using tags with video files is not supported.") return False if video_count == 0 and image_count > MAX_NUM_IMAGES: gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.") return False if "" in message["text"] and message["text"].count("") != new_image_count: gr.Warning("The number of tags in the text does not match the number of images.") return False return True ################################################## # 비디오 처리 ################################################## def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]: vidcap = cv2.VideoCapture(video_path) fps = vidcap.get(cv2.CAP_PROP_FPS) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_interval = int(fps / 3) frames = [] for i in range(0, total_frames, frame_interval): vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def process_video(video_path: str) -> list[dict]: content = [] frames = downsample_video(video_path) for frame in frames: pil_image, timestamp = frame with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: pil_image.save(temp_file.name) content.append({"type": "text", "text": f"Frame {timestamp}:"}) content.append({"type": "image", "url": temp_file.name}) logger.debug(f"{content=}") return content ################################################## # interleaved 처리 ################################################## def process_interleaved_images(message: dict) -> list[dict]: parts = re.split(r"()", message["text"]) content = [] image_index = 0 for part in parts: if part == "": content.append({"type": "image", "url": message["files"][image_index]}) image_index += 1 elif part.strip(): content.append({"type": "text", "text": part.strip()}) else: # 공백이거나 \n 같은 경우 if isinstance(part, str) and part != "": content.append({"type": "text", "text": part}) return content ################################################## # PDF + CSV + TXT + 이미지/비디오 ################################################## def process_new_user_message(message: dict) -> list[dict]: if not message["files"]: return [{"type": "text", "text": message["text"]}] # 1) 파일 분류 video_files = [f for f in message["files"] if f.endswith(".mp4")] image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)] csv_files = [f for f in message["files"] if f.lower().endswith(".csv")] txt_files = [f for f in message["files"] if f.lower().endswith(".txt")] pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")] # 2) 사용자 원본 text 추가 content_list = [{"type": "text", "text": message["text"]}] # 3) CSV for csv_path in csv_files: csv_analysis = analyze_csv_file(csv_path) content_list.append({"type": "text", "text": csv_analysis}) # 4) TXT for txt_path in txt_files: txt_analysis = analyze_txt_file(txt_path) content_list.append({"type": "text", "text": txt_analysis}) # 5) PDF for pdf_path in pdf_files: pdf_markdown = pdf_to_markdown(pdf_path) content_list.append({"type": "text", "text": pdf_markdown}) # 6) 비디오 (한 개만 허용) if video_files: content_list += process_video(video_files[0]) return content_list # 7) 이미지 처리 if "" in message["text"]: # interleaved return process_interleaved_images(message) else: # 일반 여러 장 for img_path in image_files: content_list.append({"type": "image", "url": img_path}) return content_list ################################################## # history -> LLM 메시지 변환 ################################################## def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": # user_content가 쌓여있다면 user 메시지로 저장 if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] # 그 뒤 item은 assistant messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: # user content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) else: # 이미지나 기타 current_user_content.append({"type": "image", "url": content[0]}) return messages ################################################## # 메인 추론 함수 ################################################## @spaces.GPU(duration=120) def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]: if not validate_media_constraints(message, history): yield "" return messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) messages.append({"role": "user", "content": process_new_user_message(message)}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, ) t = Thread(target=model.generate, kwargs=gen_kwargs) t.start() output = "" for new_text in streamer: output += new_text yield output ################################################## # 예시들 (기존) ################################################## ################################################## # 예시들 (한글화 버전) ################################################## examples = [ [ { "text": "PDF 파일 내용을 요약, 분석하라.", "files": ["assets/additional-examples/pdf.pdf"], } ], [ { "text": "CSV 파일 내용을 요약, 분석하라", "files": ["assets/additional-examples/sample-csv.csv"], } ], [ { "text": "동일한 막대 그래프를 그리는 matplotlib 코드를 작성해주세요.", "files": ["assets/additional-examples/barchart.png"], } ], [ { "text": "이 영상에서 이상한 점이 무엇인가요?", "files": ["assets/additional-examples/tmp.mp4"], } ], [ { "text": "이미 이 영양제를 가지고 있고, 이 제품 을 새로 사려 합니다. 함께 섭취할 때 주의해야 할 점이 있을까요?", "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"], } ], [ { "text": "이미지의 시각적 요소에서 영감을 받아 시를 작성해주세요.", "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"], } ], [ { "text": "이미지의 시각적 요소를 토대로 짧은 악곡을 작곡해주세요.", "files": [ "assets/sample-images/07-1.png", "assets/sample-images/07-2.png", "assets/sample-images/07-3.png", "assets/sample-images/07-4.png", ], } ], [ { "text": "이 집에서 무슨 일이 있었을지 짧은 이야기를 지어보세요.", "files": ["assets/sample-images/08.png"], } ], [ { "text": "이미지들의 순서를 바탕으로 짧은 이야기를 만들어 주세요.", "files": [ "assets/sample-images/09-1.png", "assets/sample-images/09-2.png", "assets/sample-images/09-3.png", "assets/sample-images/09-4.png", "assets/sample-images/09-5.png", ], } ], [ { "text": "이 세계에서 살고 있을 생물들을 상상해서 묘사해주세요.", "files": ["assets/sample-images/10.png"], } ], [ { "text": "이미지에 적힌 텍스트를 읽어주세요.", "files": ["assets/additional-examples/1.png"], } ], [ { "text": "이 티켓은 언제 발급된 것이고, 가격은 얼마인가요?", "files": ["assets/additional-examples/2.png"], } ], [ { "text": "이미지에 있는 텍스트를 그대로 읽어서 마크다운 형태로 적어주세요.", "files": ["assets/additional-examples/3.png"], } ], [ { "text": "이 적분을 풀어주세요.", "files": ["assets/additional-examples/4.png"], } ], [ { "text": "이 이미지를 간단히 캡션으로 설명해주세요.", "files": ["assets/sample-images/01.png"], } ], [ { "text": "이 표지판에는 무슨 문구가 적혀 있나요?", "files": ["assets/sample-images/02.png"], } ], [ { "text": "두 이미지를 비교해서 공통점과 차이점을 말해주세요.", "files": ["assets/sample-images/03.png"], } ], [ { "text": "이미지에 보이는 모든 사물과 그 색상을 나열해주세요.", "files": ["assets/sample-images/04.png"], } ], [ { "text": "장면의 분위기를 묘사해주세요.", "files": ["assets/sample-images/05.png"], } ], ] demo = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), # .webp, .png, .jpg, .jpeg, .gif, .mp4, .csv, .txt, .pdf 모두 허용 textbox=gr.MultimodalTextbox( file_types=[ ".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf" ], file_count="multiple", autofocus=True ), multimodal=True, additional_inputs=[ gr.Textbox( label="System Prompt", value=( "You are a deeply thoughtful AI. Consider problems thoroughly and derive " "correct solutions through systematic reasoning. Please answer in korean." ) ), gr.Slider(label="Max New Tokens", minimum=100, maximum=8000, step=50, value=2000), ], stop_btn=False, title="Vidraft-Gemma-3-27B", examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()