Spaces:
Running
Running
Commit
Β·
4708376
1
Parent(s):
7fef6fd
adding SMOL tool
Browse files- app.py +12 -35
- extract_tools.py +8 -36
- llm_service.py β llm/llm_service.py +0 -0
- hub_prompts.py β prompts/hub_prompts.py +0 -0
- tool_utils/image_description.py +75 -0
app.py
CHANGED
@@ -3,8 +3,8 @@ import streamlit as st
|
|
3 |
from PIL import Image
|
4 |
from pathlib import Path
|
5 |
from QA_bot import tyre_synap_bot as bot
|
6 |
-
from llm_service import get_llm
|
7 |
-
from hub_prompts import PREFIX
|
8 |
|
9 |
from extract_tools import get_all_tools
|
10 |
from langchain.agents import AgentExecutor
|
@@ -15,15 +15,14 @@ from langchain.tools.render import render_text_description
|
|
15 |
|
16 |
import logging
|
17 |
import warnings
|
18 |
-
warnings.filterwarnings("ignore")
|
19 |
|
|
|
20 |
logging.basicConfig(filename="newfile.log",
|
21 |
format='%(asctime)s %(message)s',
|
22 |
filemode='w')
|
23 |
logger = logging.getLogger()
|
24 |
|
25 |
llm = None
|
26 |
-
tools = None
|
27 |
cv_agent = None
|
28 |
|
29 |
@st.cache_resource
|
@@ -32,14 +31,14 @@ def call_llmservice_model(option,api_key):
|
|
32 |
return model
|
33 |
|
34 |
@st.cache_resource
|
35 |
-
def setup_agent_prompt():
|
36 |
prompt = hub.pull("hwchase17/react-json")
|
37 |
-
if len(
|
38 |
logger.error ("No Tools added")
|
39 |
else :
|
40 |
prompt = prompt.partial(
|
41 |
-
tools= render_text_description(
|
42 |
-
tool_names= ", ".join([t.name for t in
|
43 |
additional_kwargs={
|
44 |
'system_message':PREFIX,
|
45 |
}
|
@@ -48,7 +47,9 @@ def setup_agent_prompt():
|
|
48 |
|
49 |
@st.cache_resource
|
50 |
def agent_initalize():
|
51 |
-
|
|
|
|
|
52 |
lm_with_stop = llm.bind(stop=["\nObservation"])
|
53 |
#### we can use create_react_agent https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/react/agent.py
|
54 |
agent = (
|
@@ -62,29 +63,9 @@ def agent_initalize():
|
|
62 |
)
|
63 |
|
64 |
# instantiate AgentExecutor
|
65 |
-
agent_executor = AgentExecutor(agent=agent, tools=
|
66 |
return agent_executor
|
67 |
|
68 |
-
# def agent_initalize(tools,max_iterations=5):
|
69 |
-
# zero_shot_agent = initialize_agent(
|
70 |
-
# agent= AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
71 |
-
# tools = tools,
|
72 |
-
# llm = llm,
|
73 |
-
# verbose = True,
|
74 |
-
# max_iterations = max_iterations,
|
75 |
-
# memory = None,
|
76 |
-
# handle_parsing_errors=True,
|
77 |
-
# agent_kwargs={
|
78 |
-
# 'system_message':PREFIX,
|
79 |
-
# # 'format_instructions':FORMAT_INSTRUCTIONS,
|
80 |
-
# # 'suffix':SUFFIX
|
81 |
-
# }
|
82 |
-
# )
|
83 |
-
# # sys_message = PREFIX
|
84 |
-
# # zero_shot_agent.agent.llm_chain.prompt.template = sys_message
|
85 |
-
# return zero_shot_agent
|
86 |
-
|
87 |
-
|
88 |
def main():
|
89 |
database_store = 'image_store'
|
90 |
st.session_state.disabled = False
|
@@ -137,14 +118,10 @@ def main():
|
|
137 |
global llm
|
138 |
llm = call_llmservice_model(option=option,api_key=api_key)
|
139 |
logger.info("\tLLM Service {} Active ... !".format(llm.get_name()))
|
140 |
-
|
141 |
-
global tools
|
142 |
-
tools = get_all_tools()
|
143 |
-
logger.info("\tFound {} tools ".format(len(tools)))
|
144 |
## generate Agent
|
145 |
global agent
|
146 |
cv_agent = agent_initalize()
|
147 |
-
logger.info('\tAgent inintalized with {} tools '.format(len(tools)))
|
148 |
|
149 |
with open(file_path, mode='wb') as w:
|
150 |
w.write(uploaded_file.getvalue())
|
|
|
3 |
from PIL import Image
|
4 |
from pathlib import Path
|
5 |
from QA_bot import tyre_synap_bot as bot
|
6 |
+
from llm.llm_service import get_llm
|
7 |
+
from prompts.hub_prompts import PREFIX
|
8 |
|
9 |
from extract_tools import get_all_tools
|
10 |
from langchain.agents import AgentExecutor
|
|
|
15 |
|
16 |
import logging
|
17 |
import warnings
|
|
|
18 |
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
logging.basicConfig(filename="newfile.log",
|
21 |
format='%(asctime)s %(message)s',
|
22 |
filemode='w')
|
23 |
logger = logging.getLogger()
|
24 |
|
25 |
llm = None
|
|
|
26 |
cv_agent = None
|
27 |
|
28 |
@st.cache_resource
|
|
|
31 |
return model
|
32 |
|
33 |
@st.cache_resource
|
34 |
+
def setup_agent_prompt(_tools):
|
35 |
prompt = hub.pull("hwchase17/react-json")
|
36 |
+
if len(_tools) == 0 :
|
37 |
logger.error ("No Tools added")
|
38 |
else :
|
39 |
prompt = prompt.partial(
|
40 |
+
tools = render_text_description(_tools),
|
41 |
+
tool_names= ", ".join([t.name for t in _tools]),
|
42 |
additional_kwargs={
|
43 |
'system_message':PREFIX,
|
44 |
}
|
|
|
47 |
|
48 |
@st.cache_resource
|
49 |
def agent_initalize():
|
50 |
+
agent_tools = get_all_tools()
|
51 |
+
logger.info("\tFound {} tools ".format(len(agent_tools)))
|
52 |
+
agent_prompt = setup_agent_prompt(_tools=agent_tools)
|
53 |
lm_with_stop = llm.bind(stop=["\nObservation"])
|
54 |
#### we can use create_react_agent https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/react/agent.py
|
55 |
agent = (
|
|
|
63 |
)
|
64 |
|
65 |
# instantiate AgentExecutor
|
66 |
+
agent_executor = AgentExecutor(agent=agent, tools=agent_tools, verbose=True,handle_parsing_errors=True)
|
67 |
return agent_executor
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def main():
|
70 |
database_store = 'image_store'
|
71 |
st.session_state.disabled = False
|
|
|
118 |
global llm
|
119 |
llm = call_llmservice_model(option=option,api_key=api_key)
|
120 |
logger.info("\tLLM Service {} Active ... !".format(llm.get_name()))
|
121 |
+
|
|
|
|
|
|
|
122 |
## generate Agent
|
123 |
global agent
|
124 |
cv_agent = agent_initalize()
|
|
|
125 |
|
126 |
with open(file_path, mode='wb') as w:
|
127 |
w.write(uploaded_file.getvalue())
|
extract_tools.py
CHANGED
@@ -4,7 +4,7 @@ import requests
|
|
4 |
from PIL import Image
|
5 |
import logging
|
6 |
import torch
|
7 |
-
from llm_service import get_llm
|
8 |
from langchain_core.tools import tool,Tool
|
9 |
from langchain_community.tools import DuckDuckGoSearchResults
|
10 |
from langchain_groq import ChatGroq
|
@@ -13,7 +13,7 @@ from typing import List
|
|
13 |
from tool_utils.clip_segmentation import CLIPSEG
|
14 |
from tool_utils.yolo_world import YoloWorld
|
15 |
from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
|
16 |
-
|
17 |
try:
|
18 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
19 |
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
@@ -73,37 +73,10 @@ def panoptic_image_segemntation(image_path:str)->str:
|
|
73 |
def image_description(img_path:str)->str:
|
74 |
"Use this tool to describe the image " \
|
75 |
"The tool helps you to identify weather in the image as well "
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
else:
|
81 |
-
image = Image.open(img_path).convert('RGB')
|
82 |
-
try:
|
83 |
-
processor = BlipProcessor.from_pretrained(hf_model)
|
84 |
-
caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
|
85 |
-
except:
|
86 |
-
logging.error("unable to load the Blip model ")
|
87 |
-
|
88 |
-
logging.info("Image Caption model loaded ! ")
|
89 |
-
|
90 |
-
# unconditional image captioning
|
91 |
-
inputs = processor(image, return_tensors ='pt').to(device)
|
92 |
-
output = caption_model.generate(**inputs, max_new_tokens=50)
|
93 |
-
caption = processor.decode(output[0], skip_special_tokens=True)
|
94 |
-
|
95 |
-
# # conditional image captioning
|
96 |
-
# obj_text = "Total number of objects in image "
|
97 |
-
# inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
|
98 |
-
# out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
|
99 |
-
# object_caption = processor.decode(out_2[0], skip_special_tokens=True)
|
100 |
-
|
101 |
-
## clear the GPU cache
|
102 |
-
with torch.no_grad():
|
103 |
-
torch.cuda.empty_cache()
|
104 |
-
text = caption + " ."
|
105 |
-
return text
|
106 |
-
|
107 |
|
108 |
@tool
|
109 |
def clipsegmentation_mask(input_data:str)->str:
|
@@ -163,12 +136,11 @@ def get_image_quality(image_path:str)->str:
|
|
163 |
|
164 |
brightness_text = brightness_check(image)
|
165 |
blurry_text = gaussian_noise_check(image)
|
166 |
-
snr_text = snr_check(image)
|
167 |
-
final_text = "Image properties are :\n{}\n{}
|
168 |
return final_text
|
169 |
|
170 |
|
171 |
-
|
172 |
def get_all_tools():
|
173 |
## bind tools
|
174 |
image_desc_tool = Tool(
|
|
|
4 |
from PIL import Image
|
5 |
import logging
|
6 |
import torch
|
7 |
+
from llm.llm_service import get_llm
|
8 |
from langchain_core.tools import tool,Tool
|
9 |
from langchain_community.tools import DuckDuckGoSearchResults
|
10 |
from langchain_groq import ChatGroq
|
|
|
13 |
from tool_utils.clip_segmentation import CLIPSEG
|
14 |
from tool_utils.yolo_world import YoloWorld
|
15 |
from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
|
16 |
+
from tool_utils.image_description import SMOLVLM2
|
17 |
try:
|
18 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
19 |
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
|
|
73 |
def image_description(img_path:str)->str:
|
74 |
"Use this tool to describe the image " \
|
75 |
"The tool helps you to identify weather in the image as well "
|
76 |
+
smol_vlm = SMOLVLM2(memory_efficient=True)
|
77 |
+
query="Describe the image. Higlight the details in 2-3 lines"
|
78 |
+
response = smol_vlm.run_inference_on_image(image_path=img_path,query=query)
|
79 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
@tool
|
82 |
def clipsegmentation_mask(input_data:str)->str:
|
|
|
136 |
|
137 |
brightness_text = brightness_check(image)
|
138 |
blurry_text = gaussian_noise_check(image)
|
139 |
+
# snr_text = snr_check(image)
|
140 |
+
final_text = "Image properties are :\n{}\n{}".format(blurry_text, brightness_text)
|
141 |
return final_text
|
142 |
|
143 |
|
|
|
144 |
def get_all_tools():
|
145 |
## bind tools
|
146 |
image_desc_tool = Tool(
|
llm_service.py β llm/llm_service.py
RENAMED
File without changes
|
hub_prompts.py β prompts/hub_prompts.py
RENAMED
File without changes
|
tool_utils/image_description.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import gc
|
5 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
6 |
+
import logging
|
7 |
+
|
8 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
torch.cuda.empty_cache()
|
10 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF']= 'max_split_size_mb:1024'
|
11 |
+
gc.collect()
|
12 |
+
|
13 |
+
class SMOLVLM2:
|
14 |
+
def __init__(self,model_name = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" , memory_efficient=True):
|
15 |
+
self.half = True
|
16 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
17 |
+
if self.support_flash_attension(device_id=0):
|
18 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
19 |
+
model_name,
|
20 |
+
torch_dtype=torch.float16,
|
21 |
+
_attn_implementation="flash_attention_2"
|
22 |
+
).to(device)
|
23 |
+
else:
|
24 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
25 |
+
model_name,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
).to(device)
|
28 |
+
logging.info("Model loaded")
|
29 |
+
self.print_gpu_memory()
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def print_gpu_memory():
|
33 |
+
logging.info(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
34 |
+
logging.info(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
35 |
+
|
36 |
+
## check for flash attension
|
37 |
+
@staticmethod
|
38 |
+
def support_flash_attension(device_id):
|
39 |
+
""" Check if GPU supports FalshAttension"""
|
40 |
+
support = False
|
41 |
+
major, minor = torch.cuda.get_device_capability(device_id)
|
42 |
+
if major<8:
|
43 |
+
print("GPU does not support Flash Attension")
|
44 |
+
else:
|
45 |
+
support = True
|
46 |
+
return support
|
47 |
+
|
48 |
+
def run_inference_on_image(self,image_path,query):
|
49 |
+
messages = [
|
50 |
+
{
|
51 |
+
"role":"user",
|
52 |
+
"content":[
|
53 |
+
{"type":"image","path":image_path},
|
54 |
+
{"type":"text","text":query}
|
55 |
+
]
|
56 |
+
}
|
57 |
+
]
|
58 |
+
inputs = self.processor.apply_chat_template(
|
59 |
+
messages,
|
60 |
+
add_generation_prompt = True,
|
61 |
+
tokenize = True,
|
62 |
+
return_dict = True,
|
63 |
+
return_tensors = 'pt'
|
64 |
+
)
|
65 |
+
if self.half:
|
66 |
+
inputs.to(torch.half).to(device)
|
67 |
+
else:
|
68 |
+
inputs.to(device)
|
69 |
+
generated_ids = self.model.generate(**inputs,do_sample = False , max_new_tokens = 1024)
|
70 |
+
generated_texts = self.processor.batch_decode(generated_ids,skip_special_tokens=True)
|
71 |
+
del inputs
|
72 |
+
torch.cuda.empty_cache()
|
73 |
+
return generated_texts[0].split('\n')[-1]
|
74 |
+
|
75 |
+
|