import argparse import gradio as gr import os from PIL import Image import spaces from kimi_vl.serve.frontend import reload_javascript from kimi_vl.serve.utils import ( configure_logger, pil_to_base64, parse_and_draw_response, strip_stop_words, is_variable_assigned, ) from kimi_vl.serve.gradio_utils import ( cancel_outputing, delete_last_conversation, reset_state, reset_textbox, transfer_input, wrap_gen_fn, ) from kimi_vl.serve.chat_utils import ( generate_prompt_with_history, convert_conversation_to_prompts, highlight_instruction, to_gradio_chatbot, to_gradio_history, ) from kimi_vl.serve.inference import kimi_vl_generate, load_model from kimi_vl.serve.examples import get_examples TITLE = """

Chat with Kimi-VL-A3B-Instruct

""" DESCRIPTION_TOP = """Kimi-VL-A3B is a multi-modal LLM that can understand text, single-image, multi-image, and video, and generate reply. For thinking version, please try [Kimi-VL-A3B-Thinking](https://huggingface.co/spaces/moonshotai/Kimi-VL-A3B-Thinking).""" DESCRIPTION = """""" ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) DEPLOY_MODELS = dict() logger = configure_logger() def resize_image(image: Image.Image, max_size: int = 640, min_size: int = 28): width, height = image.size if width < min_size or height < min_size: # Double both dimensions while maintaining aspect ratio scale = min_size / min(width, height) new_width = int(width * scale) new_height = int(height * scale) image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) elif max_size > 0 and (width > max_size or height > max_size): # Double both dimensions while maintaining aspect ratio scale = max_size / max(width, height) new_width = int(width * scale) new_height = int(height * scale) image = image.resize((new_width, new_height)) return image def load_frames(video_file, max_num_frames=64, long_edge=448): from decord import VideoReader vr = VideoReader(video_file) duration = len(vr) fps = vr.get_avg_fps() length = int(duration / fps) num_frames = min(max_num_frames, length) frame_timestamps = [int(duration / num_frames * (i+0.5)) / fps for i in range(num_frames)] frame_indices = [int(duration / num_frames * (i+0.5)) for i in range(num_frames)] frames_data = vr.get_batch(frame_indices).asnumpy() imgs = [] for idx in range(num_frames): img = resize_image(Image.fromarray(frames_data[idx]).convert("RGB"), long_edge) imgs.append(img) return imgs, frame_timestamps def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="Kimi-VL-A3B-Instruct") parser.add_argument( "--local-path", type=str, default="", help="huggingface ckpt, optional", ) # lazy load parser.add_argument("--lazy-load", action='store_true') parser.add_argument("--ip", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) return parser.parse_args() def fetch_model(model_name: str): global args, DEPLOY_MODELS if args.local_path: model_path = args.local_path else: model_path = f"moonshotai/{args.model}" if model_name in DEPLOY_MODELS: model_info = DEPLOY_MODELS[model_name] print(f"{model_name} has been loaded.") else: print(f"{model_name} is loading...") DEPLOY_MODELS[model_name] = load_model(model_path) print(f"Load {model_name} successfully...") model_info = DEPLOY_MODELS[model_name] return model_info def preview_images(files) -> list[str]: if files is None: return [] image_paths = [] for file in files: image_paths.append(file.name) return image_paths def get_prompt(conversation) -> str: """ Get the prompt for the conversation. """ system_prompt = conversation.system_template.format(system_message=conversation.system_message) return system_prompt @wrap_gen_fn @spaces.GPU(duration=180) def predict( text, images, chatbot, history, top_p, temperature, max_length_tokens, max_context_length_tokens, video_num_frames, video_long_edge, chunk_size: int = 512, ): """ Predict the response for the input text and images. Args: text (str): The input text. images (list[PIL.Image.Image]): The input images. chatbot (list): The chatbot. history (list): The history. top_p (float): The top-p value. temperature (float): The temperature value. repetition_penalty (float): The repetition penalty value. max_length_tokens (int): The max length tokens. max_context_length_tokens (int): The max context length tokens. video_num_frames (int): The max number of frames for one video. chunk_size (int): The chunk size. """ print("running the prediction function") try: model, processor = fetch_model(args.model) if text == "": yield chatbot, history, "Empty context." return except KeyError: yield [[text, "No Model Found"]], [], "No Model Found" return if images is None: images = [] # load images pil_images = [] timestamps = None for img_or_file in images: try: # load as pil image if isinstance(images, Image.Image): pil_images.append(img_or_file) else: image = Image.open(img_or_file.name).convert("RGB") pil_images.append(image) except: try: pil_images, timestamps = load_frames(img_or_file, video_num_frames, video_long_edge) ## Only allow one video as input break except Exception as e: print(f"Error loading image or video: {e}") # generate prompt conversation = generate_prompt_with_history( text, pil_images, timestamps, history, processor, max_length=max_context_length_tokens, ) all_conv, last_image = convert_conversation_to_prompts(conversation) stop_words = conversation.stop_str gradio_chatbot_output = to_gradio_chatbot(conversation) full_response = "" # with torch.no_grad(): for x in kimi_vl_generate( conversations=all_conv, model=model, processor=processor, stop_words=stop_words, max_length=max_length_tokens, temperature=temperature, top_p=top_p, ): full_response += x response = strip_stop_words(full_response, stop_words) conversation.update_last_message(response) gradio_chatbot_output[-1][1] = highlight_instruction(response) yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..." if last_image is not None: vg_image = parse_and_draw_response(response, last_image) if vg_image is not None: vg_base64 = pil_to_base64(vg_image, "vg", max_size=800, min_size=400) # the end of the last message will be ```python ``` gradio_chatbot_output[-1][1] += '\n\n' + vg_base64 yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..." logger.info("flushed result to gradio") if is_variable_assigned("x"): print( f"temperature: {temperature}, " f"top_p: {top_p}, " f"max_length_tokens: {max_length_tokens}" ) yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success" def retry( text, images, chatbot, history, top_p, temperature, max_length_tokens, max_context_length_tokens, chunk_size: int = 512, ): """ Retry the response for the input text and images. """ if len(history) == 0: yield (chatbot, history, "Empty context") return chatbot.pop() history.pop() text = history.pop()[-1] if type(text) is tuple: text, _ = text yield from predict( text, images, chatbot, history, top_p, temperature, max_length_tokens, max_context_length_tokens, chunk_size, ) def build_demo(args: argparse.Namespace) -> gr.Blocks: if args.lazy_load: fetch_model(args.model) with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo: history = gr.State([]) input_text = gr.State() input_images = gr.State() with gr.Row(): gr.HTML(TITLE) status_display = gr.Markdown("Success", elem_id="status_display") gr.Markdown(DESCRIPTION_TOP) with gr.Row(equal_height=True): with gr.Column(scale=4): with gr.Row(): chatbot = gr.Chatbot( elem_id=f"{args.model}-chatbot", show_share_button=True, bubble_full_width=False, height=600, ) with gr.Row(): with gr.Column(scale=4): text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False) with gr.Column(min_width=70): submit_btn = gr.Button("Send") with gr.Column(min_width=70): cancel_btn = gr.Button("Stop") with gr.Row(): empty_btn = gr.Button("๐Ÿงน New Conversation") retry_btn = gr.Button("๐Ÿ”„ Regenerate") del_last_btn = gr.Button("๐Ÿ—‘๏ธ Remove Last Turn") with gr.Column(): # add note no more than 2 images once gr.Markdown("Note: you can upload no more than 10 images once") upload_images = gr.Files(file_types=["image", "video"], show_label=True) gallery = gr.Gallery(columns=[3], height="200px", show_label=True) upload_images.change(preview_images, inputs=upload_images, outputs=gallery) # Parameter Setting Tab for control the generation parameters with gr.Tab(label="Parameter Setting"): top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p") temperature = gr.Slider( minimum=0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature" ) max_length_tokens = gr.Slider( minimum=512, maximum=8192, value=4096, step=64, interactive=True, label="Max Generation Tokens" ) max_context_length_tokens = gr.Slider( minimum=512, maximum=16384, value=4096, step=64, interactive=True, label="Max Context Length Tokens" ) video_num_frames = gr.Slider( minimum=1, maximum=32, value=16, step=1, interactive=True, label="Max Number of Frames for Video" ) video_long_edge = gr.Slider( minimum=28, maximum=896, value=448, step=28, interactive=True, label="Long Edge of Video" ) show_images = gr.HTML(visible=False) gr.Examples( examples=get_examples(ROOT_DIR), inputs=[upload_images, show_images, text_box], ) gr.Markdown() input_widgets = [ input_text, input_images, chatbot, history, top_p, temperature, max_length_tokens, max_context_length_tokens, video_num_frames, video_long_edge, ] output_widgets = [chatbot, history, status_display] transfer_input_args = dict( fn=transfer_input, inputs=[text_box, upload_images], outputs=[input_text, input_images, text_box, upload_images, submit_btn], show_progress=True, ) predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True) retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True) reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display]) predict_events = [ text_box.submit(**transfer_input_args).then(**predict_args), submit_btn.click(**transfer_input_args).then(**predict_args), ] empty_btn.click(reset_state, outputs=output_widgets, show_progress=True) empty_btn.click(**reset_args) retry_btn.click(**retry_args) del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True) cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events) demo.title = f"{args.model} Chatbot" return demo def main(args: argparse.Namespace): demo = build_demo(args) reload_javascript() # concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS favicon_path = os.path.join("kimi_vl/serve/assets/favicon.ico") demo.queue().launch( favicon_path=favicon_path, server_name=args.ip, server_port=args.port, ) if __name__ == "__main__": args = parse_args() main(args)