#!/usr/bin/env python import os import re import tempfile from collections.abc import Iterator from threading import Thread import cv2 import gradio as gr import spaces import torch from loguru import logger from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer # CSV/TXT 분석 import pandas as pd # PDF 텍스트 추출 import PyPDF2 MAX_CONTENT_CHARS = 4000 # 너무 큰 파일을 막기 위해 최대 표시 4000자 model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it") processor = AutoProcessor.from_pretrained(model_id, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5")) ################################################## # CSV, TXT, PDF 분석 함수 ################################################## def analyze_csv_file(path: str) -> str: """ CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시. """ try: df = pd.read_csv(path) # 데이터 프레임 크기 제한 (행/열 수가 많은 경우) if df.shape[0] > 50 or df.shape[1] > 10: df = df.iloc[:50, :10] df_str = df.to_string() if len(df_str) > MAX_CONTENT_CHARS: df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}" except Exception as e: return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}" def analyze_txt_file(path: str) -> str: """ TXT 파일 전문 읽기. 너무 길면 일부만 표시. """ try: with open(path, "r", encoding="utf-8") as f: text = f.read() if len(text) > MAX_CONTENT_CHARS: text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}" except Exception as e: return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}" def pdf_to_markdown(pdf_path: str) -> str: """ PDF → Markdown. 페이지별로 간단히 텍스트 추출. """ text_chunks = [] try: with open(pdf_path, "rb") as f: reader = PyPDF2.PdfReader(f) # 최대 5페이지만 처리 max_pages = min(5, len(reader.pages)) for page_num in range(max_pages): page = reader.pages[page_num] page_text = page.extract_text() or "" page_text = page_text.strip() if page_text: # 페이지별 텍스트도 제한 if len(page_text) > MAX_CONTENT_CHARS // max_pages: page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)" text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n") if len(reader.pages) > max_pages: text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...") except Exception as e: return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}" full_text = "\n".join(text_chunks) if len(full_text) > MAX_CONTENT_CHARS: full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..." return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}" ################################################## # 이미지/비디오 업로드 제한 검사 ################################################## def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: image_count = 0 video_count = 0 for path in paths: if path.endswith(".mp4"): video_count += 1 elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE): image_count += 1 return image_count, video_count def count_files_in_history(history: list[dict]) -> tuple[int, int]: image_count = 0 video_count = 0 for item in history: if item["role"] != "user" or isinstance(item["content"], str): continue if isinstance(item["content"], list) and len(item["content"]) > 0: file_path = item["content"][0] if isinstance(file_path, str): if file_path.endswith(".mp4"): video_count += 1 elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE): image_count += 1 return image_count, video_count def validate_media_constraints(message: dict, history: list[dict]) -> bool: """ - 비디오 1개 초과 불가 - 비디오와 이미지 혼합 불가 - 이미지 개수 MAX_NUM_IMAGES 초과 불가 - 태그가 있으면 태그 수와 실제 이미지 수 일치 - CSV, TXT, PDF 등은 여기서 제한하지 않음 """ # 이미지와 비디오 파일만 필터링 media_files = [] for f in message["files"]: if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"): media_files.append(f) new_image_count, new_video_count = count_files_in_new_message(media_files) history_image_count, history_video_count = count_files_in_history(history) image_count = history_image_count + new_image_count video_count = history_video_count + new_video_count if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1: if image_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False if "" in message["text"]: gr.Warning("Using tags with video files is not supported.") return False if video_count == 0 and image_count > MAX_NUM_IMAGES: gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.") return False # 이미지 태그 검증 (실제 이미지 파일만 계산) if "" in message["text"]: # 이미지 파일만 필터링 image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)] image_tag_count = message["text"].count("") if image_tag_count != len(image_files): gr.Warning("The number of tags in the text does not match the number of image files.") return False return True ################################################## # 비디오 처리 ################################################## def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]: vidcap = cv2.VideoCapture(video_path) fps = vidcap.get(cv2.CAP_PROP_FPS) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) # 더 적은 프레임을 추출하도록 조정 frame_interval = max(int(fps), int(total_frames / 10)) # 초당 1프레임 또는 최대 10프레임 frames = [] for i in range(0, total_frames, frame_interval): vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) # 최대 5프레임만 사용 if len(frames) >= 5: break vidcap.release() return frames def process_video(video_path: str) -> list[dict]: content = [] frames = downsample_video(video_path) for frame in frames: pil_image, timestamp = frame with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: pil_image.save(temp_file.name) content.append({"type": "text", "text": f"Frame {timestamp}:"}) content.append({"type": "image", "url": temp_file.name}) logger.debug(f"{content=}") return content ################################################## # interleaved 처리 ################################################## def process_interleaved_images(message: dict) -> list[dict]: parts = re.split(r"()", message["text"]) content = [] image_index = 0 # 이미지 파일만 필터링 image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)] for part in parts: if part == "" and image_index < len(image_files): content.append({"type": "image", "url": image_files[image_index]}) image_index += 1 elif part.strip(): content.append({"type": "text", "text": part.strip()}) else: # 공백이거나 \n 같은 경우 if isinstance(part, str) and part != "": content.append({"type": "text", "text": part}) return content ################################################## # PDF + CSV + TXT + 이미지/비디오 ################################################## def is_image_file(file_path: str) -> bool: """이미지 파일인지 확인""" return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE)) def is_video_file(file_path: str) -> bool: """비디오 파일인지 확인""" return file_path.endswith(".mp4") def is_document_file(file_path: str) -> bool: """문서 파일인지 확인 (PDF, CSV, TXT)""" return (file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt")) def process_new_user_message(message: dict) -> list[dict]: if not message["files"]: return [{"type": "text", "text": message["text"]}] # 1) 파일 분류 video_files = [f for f in message["files"] if is_video_file(f)] image_files = [f for f in message["files"] if is_image_file(f)] csv_files = [f for f in message["files"] if f.lower().endswith(".csv")] txt_files = [f for f in message["files"] if f.lower().endswith(".txt")] pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")] # 2) 사용자 원본 text 추가 content_list = [{"type": "text", "text": message["text"]}] # 3) CSV for csv_path in csv_files: csv_analysis = analyze_csv_file(csv_path) content_list.append({"type": "text", "text": csv_analysis}) # 4) TXT for txt_path in txt_files: txt_analysis = analyze_txt_file(txt_path) content_list.append({"type": "text", "text": txt_analysis}) # 5) PDF for pdf_path in pdf_files: pdf_markdown = pdf_to_markdown(pdf_path) content_list.append({"type": "text", "text": pdf_markdown}) # 6) 비디오 (한 개만 허용) if video_files: content_list += process_video(video_files[0]) return content_list # 7) 이미지 처리 if "" in message["text"] and image_files: # interleaved interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files}) # 원본 content_list 앞부분(텍스트)을 제거하고 interleaved로 대체 if content_list[0]["type"] == "text": content_list = content_list[1:] # 원본 텍스트 제거 return interleaved_content + content_list # interleaved + 나머지 문서 분석 내용 else: # 일반 여러 장 for img_path in image_files: content_list.append({"type": "image", "url": img_path}) return content_list ################################################## # history -> LLM 메시지 변환 ################################################## def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": # user_content가 쌓여있다면 user 메시지로 저장 if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] # 그 뒤 item은 assistant messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: # user content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) elif isinstance(content, list) and len(content) > 0: file_path = content[0] if is_image_file(file_path): current_user_content.append({"type": "image", "url": file_path}) else: # 비이미지 파일은 텍스트로 처리 current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"}) # 마지막 사용자 메시지가 처리되지 않은 경우 추가 if current_user_content: messages.append({"role": "user", "content": current_user_content}) return messages ################################################## # 메인 추론 함수 ################################################## @spaces.GPU(duration=120) def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]: if not validate_media_constraints(message, history): yield "" return try: messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) # 사용자 메시지 처리 user_content = process_new_user_message(message) # 토큰 수를 줄이기 위해 너무 긴 텍스트는 잘라내기 for item in user_content: if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS: item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..." messages.append({"role": "user", "content": user_content}) # 모델 입력 생성 전 최종 확인 # 이미지나 비디오가 아닌 파일들은 모델의 "image" 파이프라인으로 전달되지 않도록 필터링 for msg in messages: if msg["role"] != "user": continue filtered_content = [] for item in msg["content"]: if item["type"] == "image": if is_image_file(item["url"]): filtered_content.append(item) else: # 이미지 파일이 아닌 경우 텍스트로 변환 filtered_content.append({ "type": "text", "text": f"[Non-image file: {os.path.basename(item['url'])}]" }) else: filtered_content.append(item) msg["content"] = filtered_content # 모델 입력 생성 inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model.device, dtype=torch.bfloat16) # 텍스트 생성 스트리머 설정 streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, ) # 별도 스레드에서 텍스트 생성 t = Thread(target=model.generate, kwargs=gen_kwargs) t.start() # 결과 스트리밍 output = "" for new_text in streamer: output += new_text yield output except Exception as e: logger.error(f"Error in run: {str(e)}") yield f"죄송합니다. 오류가 발생했습니다: {str(e)}" ################################################## # 예시들 (한글화 버전) ################################################## examples = [ [ { "text": "PDF 파일 내용을 요약, 분석하라.", "files": ["assets/additional-examples/pdf.pdf"], } ], [ { "text": "CSV 파일 내용을 요약, 분석하라", "files": ["assets/additional-examples/sample-csv.csv"], } ], [ { "text": "동일한 막대 그래프를 그리는 matplotlib 코드를 작성해주세요.", "files": ["assets/additional-examples/barchart.png"], } ], [ { "text": "이 영상에서 이상한 점이 무엇인가요?", "files": ["assets/additional-examples/tmp.mp4"], } ], [ { "text": "이미 이 영양제를 가지고 있고, 이 제품 을 새로 사려 합니다. 함께 섭취할 때 주의해야 할 점이 있을까요?", "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"], } ], [ { "text": "이미지의 시각적 요소에서 영감을 받아 시를 작성해주세요.", "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"], } ], [ { "text": "이미지의 시각적 요소를 토대로 짧은 악곡을 작곡해주세요.", "files": [ "assets/sample-images/07-1.png", "assets/sample-images/07-2.png", "assets/sample-images/07-3.png", "assets/sample-images/07-4.png", ], } ], [ { "text": "이 집에서 무슨 일이 있었을지 짧은 이야기를 지어보세요.", "files": ["assets/sample-images/08.png"], } ], [ { "text": "이미지들의 순서를 바탕으로 짧은 이야기를 만들어 주세요.", "files": [ "assets/sample-images/09-1.png", "assets/sample-images/09-2.png", "assets/sample-images/09-3.png", "assets/sample-images/09-4.png", "assets/sample-images/09-5.png", ], } ], [ { "text": "이 세계에서 살고 있을 생물들을 상상해서 묘사해주세요.", "files": ["assets/sample-images/10.png"], } ], [ { "text": "이미지에 적힌 텍스트를 읽어주세요.", "files": ["assets/additional-examples/1.png"], } ], [ { "text": "이 티켓은 언제 발급된 것이고, 가격은 얼마인가요?", "files": ["assets/additional-examples/2.png"], } ], [ { "text": "이미지에 있는 텍스트를 그대로 읽어서 마크다운 형태로 적어주세요.", "files": ["assets/additional-examples/3.png"], } ], [ { "text": "이 적분을 풀어주세요.", "files": ["assets/additional-examples/4.png"], } ], [ { "text": "이 이미지를 간단히 캡션으로 설명해주세요.", "files": ["assets/sample-images/01.png"], } ], [ { "text": "이 표지판에는 무슨 문구가 적혀 있나요?", "files": ["assets/sample-images/02.png"], } ], [ { "text": "두 이미지를 비교해서 공통점과 차이점을 말해주세요.", "files": ["assets/sample-images/03.png"], } ], [ { "text": "이미지에 보이는 모든 사물과 그 색상을 나열해주세요.", "files": ["assets/sample-images/04.png"], } ], [ { "text": "장면의 분위기를 묘사해주세요.", "files": ["assets/sample-images/05.png"], } ], ] ################################################## # Gradio 인터페이스 설정 ################################################## demo = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), # .webp, .png, .jpg, .jpeg, .gif, .mp4, .csv, .txt, .pdf 모두 허용 textbox=gr.MultimodalTextbox( file_types=[ ".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf" ], file_count="multiple", autofocus=True ), multimodal=True, additional_inputs=[ gr.Textbox( label="System Prompt", value=( "You are a deeply thoughtful AI. Consider problems thoroughly and derive " "correct solutions through systematic reasoning. Please answer in korean." ) ), gr.Slider(label="Max New Tokens", minimum=100, maximum=8000, step=50, value=2000), ], stop_btn=False, title="Vidraft-Gemma-3-27B", examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()