Samarth991 commited on
Commit
4708376
Β·
1 Parent(s): 7fef6fd

adding SMOL tool

Browse files
app.py CHANGED
@@ -3,8 +3,8 @@ import streamlit as st
3
  from PIL import Image
4
  from pathlib import Path
5
  from QA_bot import tyre_synap_bot as bot
6
- from llm_service import get_llm
7
- from hub_prompts import PREFIX
8
 
9
  from extract_tools import get_all_tools
10
  from langchain.agents import AgentExecutor
@@ -15,15 +15,14 @@ from langchain.tools.render import render_text_description
15
 
16
  import logging
17
  import warnings
18
- warnings.filterwarnings("ignore")
19
 
 
20
  logging.basicConfig(filename="newfile.log",
21
  format='%(asctime)s %(message)s',
22
  filemode='w')
23
  logger = logging.getLogger()
24
 
25
  llm = None
26
- tools = None
27
  cv_agent = None
28
 
29
  @st.cache_resource
@@ -32,14 +31,14 @@ def call_llmservice_model(option,api_key):
32
  return model
33
 
34
  @st.cache_resource
35
- def setup_agent_prompt():
36
  prompt = hub.pull("hwchase17/react-json")
37
- if len(tools) == 0 :
38
  logger.error ("No Tools added")
39
  else :
40
  prompt = prompt.partial(
41
- tools= render_text_description(tools),
42
- tool_names= ", ".join([t.name for t in tools]),
43
  additional_kwargs={
44
  'system_message':PREFIX,
45
  }
@@ -48,7 +47,9 @@ def setup_agent_prompt():
48
 
49
  @st.cache_resource
50
  def agent_initalize():
51
- agent_prompt = setup_agent_prompt()
 
 
52
  lm_with_stop = llm.bind(stop=["\nObservation"])
53
  #### we can use create_react_agent https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/react/agent.py
54
  agent = (
@@ -62,29 +63,9 @@ def agent_initalize():
62
  )
63
 
64
  # instantiate AgentExecutor
65
- agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True,handle_parsing_errors=True)
66
  return agent_executor
67
 
68
- # def agent_initalize(tools,max_iterations=5):
69
- # zero_shot_agent = initialize_agent(
70
- # agent= AgentType.ZERO_SHOT_REACT_DESCRIPTION,
71
- # tools = tools,
72
- # llm = llm,
73
- # verbose = True,
74
- # max_iterations = max_iterations,
75
- # memory = None,
76
- # handle_parsing_errors=True,
77
- # agent_kwargs={
78
- # 'system_message':PREFIX,
79
- # # 'format_instructions':FORMAT_INSTRUCTIONS,
80
- # # 'suffix':SUFFIX
81
- # }
82
- # )
83
- # # sys_message = PREFIX
84
- # # zero_shot_agent.agent.llm_chain.prompt.template = sys_message
85
- # return zero_shot_agent
86
-
87
-
88
  def main():
89
  database_store = 'image_store'
90
  st.session_state.disabled = False
@@ -137,14 +118,10 @@ def main():
137
  global llm
138
  llm = call_llmservice_model(option=option,api_key=api_key)
139
  logger.info("\tLLM Service {} Active ... !".format(llm.get_name()))
140
- ## extract tools
141
- global tools
142
- tools = get_all_tools()
143
- logger.info("\tFound {} tools ".format(len(tools)))
144
  ## generate Agent
145
  global agent
146
  cv_agent = agent_initalize()
147
- logger.info('\tAgent inintalized with {} tools '.format(len(tools)))
148
 
149
  with open(file_path, mode='wb') as w:
150
  w.write(uploaded_file.getvalue())
 
3
  from PIL import Image
4
  from pathlib import Path
5
  from QA_bot import tyre_synap_bot as bot
6
+ from llm.llm_service import get_llm
7
+ from prompts.hub_prompts import PREFIX
8
 
9
  from extract_tools import get_all_tools
10
  from langchain.agents import AgentExecutor
 
15
 
16
  import logging
17
  import warnings
 
18
 
19
+ warnings.filterwarnings("ignore")
20
  logging.basicConfig(filename="newfile.log",
21
  format='%(asctime)s %(message)s',
22
  filemode='w')
23
  logger = logging.getLogger()
24
 
25
  llm = None
 
26
  cv_agent = None
27
 
28
  @st.cache_resource
 
31
  return model
32
 
33
  @st.cache_resource
34
+ def setup_agent_prompt(_tools):
35
  prompt = hub.pull("hwchase17/react-json")
36
+ if len(_tools) == 0 :
37
  logger.error ("No Tools added")
38
  else :
39
  prompt = prompt.partial(
40
+ tools = render_text_description(_tools),
41
+ tool_names= ", ".join([t.name for t in _tools]),
42
  additional_kwargs={
43
  'system_message':PREFIX,
44
  }
 
47
 
48
  @st.cache_resource
49
  def agent_initalize():
50
+ agent_tools = get_all_tools()
51
+ logger.info("\tFound {} tools ".format(len(agent_tools)))
52
+ agent_prompt = setup_agent_prompt(_tools=agent_tools)
53
  lm_with_stop = llm.bind(stop=["\nObservation"])
54
  #### we can use create_react_agent https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/react/agent.py
55
  agent = (
 
63
  )
64
 
65
  # instantiate AgentExecutor
66
+ agent_executor = AgentExecutor(agent=agent, tools=agent_tools, verbose=True,handle_parsing_errors=True)
67
  return agent_executor
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def main():
70
  database_store = 'image_store'
71
  st.session_state.disabled = False
 
118
  global llm
119
  llm = call_llmservice_model(option=option,api_key=api_key)
120
  logger.info("\tLLM Service {} Active ... !".format(llm.get_name()))
121
+
 
 
 
122
  ## generate Agent
123
  global agent
124
  cv_agent = agent_initalize()
 
125
 
126
  with open(file_path, mode='wb') as w:
127
  w.write(uploaded_file.getvalue())
extract_tools.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  from PIL import Image
5
  import logging
6
  import torch
7
- from llm_service import get_llm
8
  from langchain_core.tools import tool,Tool
9
  from langchain_community.tools import DuckDuckGoSearchResults
10
  from langchain_groq import ChatGroq
@@ -13,7 +13,7 @@ from typing import List
13
  from tool_utils.clip_segmentation import CLIPSEG
14
  from tool_utils.yolo_world import YoloWorld
15
  from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
16
-
17
  try:
18
  from transformers import BlipProcessor, BlipForConditionalGeneration
19
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
@@ -73,37 +73,10 @@ def panoptic_image_segemntation(image_path:str)->str:
73
  def image_description(img_path:str)->str:
74
  "Use this tool to describe the image " \
75
  "The tool helps you to identify weather in the image as well "
76
- hf_model = "Salesforce/blip-image-captioning-base"
77
- text = ""
78
- if img_path.startswith('https'):
79
- image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
80
- else:
81
- image = Image.open(img_path).convert('RGB')
82
- try:
83
- processor = BlipProcessor.from_pretrained(hf_model)
84
- caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
85
- except:
86
- logging.error("unable to load the Blip model ")
87
-
88
- logging.info("Image Caption model loaded ! ")
89
-
90
- # unconditional image captioning
91
- inputs = processor(image, return_tensors ='pt').to(device)
92
- output = caption_model.generate(**inputs, max_new_tokens=50)
93
- caption = processor.decode(output[0], skip_special_tokens=True)
94
-
95
- # # conditional image captioning
96
- # obj_text = "Total number of objects in image "
97
- # inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
98
- # out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
99
- # object_caption = processor.decode(out_2[0], skip_special_tokens=True)
100
-
101
- ## clear the GPU cache
102
- with torch.no_grad():
103
- torch.cuda.empty_cache()
104
- text = caption + " ."
105
- return text
106
-
107
 
108
  @tool
109
  def clipsegmentation_mask(input_data:str)->str:
@@ -163,12 +136,11 @@ def get_image_quality(image_path:str)->str:
163
 
164
  brightness_text = brightness_check(image)
165
  blurry_text = gaussian_noise_check(image)
166
- snr_text = snr_check(image)
167
- final_text = "Image properties are :\n{}\n{}\n{}".format(blurry_text, brightness_text,snr_text)
168
  return final_text
169
 
170
 
171
-
172
  def get_all_tools():
173
  ## bind tools
174
  image_desc_tool = Tool(
 
4
  from PIL import Image
5
  import logging
6
  import torch
7
+ from llm.llm_service import get_llm
8
  from langchain_core.tools import tool,Tool
9
  from langchain_community.tools import DuckDuckGoSearchResults
10
  from langchain_groq import ChatGroq
 
13
  from tool_utils.clip_segmentation import CLIPSEG
14
  from tool_utils.yolo_world import YoloWorld
15
  from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
16
+ from tool_utils.image_description import SMOLVLM2
17
  try:
18
  from transformers import BlipProcessor, BlipForConditionalGeneration
19
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
 
73
  def image_description(img_path:str)->str:
74
  "Use this tool to describe the image " \
75
  "The tool helps you to identify weather in the image as well "
76
+ smol_vlm = SMOLVLM2(memory_efficient=True)
77
+ query="Describe the image. Higlight the details in 2-3 lines"
78
+ response = smol_vlm.run_inference_on_image(image_path=img_path,query=query)
79
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  @tool
82
  def clipsegmentation_mask(input_data:str)->str:
 
136
 
137
  brightness_text = brightness_check(image)
138
  blurry_text = gaussian_noise_check(image)
139
+ # snr_text = snr_check(image)
140
+ final_text = "Image properties are :\n{}\n{}".format(blurry_text, brightness_text)
141
  return final_text
142
 
143
 
 
144
  def get_all_tools():
145
  ## bind tools
146
  image_desc_tool = Tool(
llm_service.py β†’ llm/llm_service.py RENAMED
File without changes
hub_prompts.py β†’ prompts/hub_prompts.py RENAMED
File without changes
tool_utils/image_description.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import gc
5
+ from transformers import AutoProcessor, AutoModelForImageTextToText
6
+ import logging
7
+
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+ torch.cuda.empty_cache()
10
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']= 'max_split_size_mb:1024'
11
+ gc.collect()
12
+
13
+ class SMOLVLM2:
14
+ def __init__(self,model_name = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" , memory_efficient=True):
15
+ self.half = True
16
+ self.processor = AutoProcessor.from_pretrained(model_name)
17
+ if self.support_flash_attension(device_id=0):
18
+ self.model = AutoModelForImageTextToText.from_pretrained(
19
+ model_name,
20
+ torch_dtype=torch.float16,
21
+ _attn_implementation="flash_attention_2"
22
+ ).to(device)
23
+ else:
24
+ self.model = AutoModelForImageTextToText.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.float16,
27
+ ).to(device)
28
+ logging.info("Model loaded")
29
+ self.print_gpu_memory()
30
+
31
+ @staticmethod
32
+ def print_gpu_memory():
33
+ logging.info(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
34
+ logging.info(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
35
+
36
+ ## check for flash attension
37
+ @staticmethod
38
+ def support_flash_attension(device_id):
39
+ """ Check if GPU supports FalshAttension"""
40
+ support = False
41
+ major, minor = torch.cuda.get_device_capability(device_id)
42
+ if major<8:
43
+ print("GPU does not support Flash Attension")
44
+ else:
45
+ support = True
46
+ return support
47
+
48
+ def run_inference_on_image(self,image_path,query):
49
+ messages = [
50
+ {
51
+ "role":"user",
52
+ "content":[
53
+ {"type":"image","path":image_path},
54
+ {"type":"text","text":query}
55
+ ]
56
+ }
57
+ ]
58
+ inputs = self.processor.apply_chat_template(
59
+ messages,
60
+ add_generation_prompt = True,
61
+ tokenize = True,
62
+ return_dict = True,
63
+ return_tensors = 'pt'
64
+ )
65
+ if self.half:
66
+ inputs.to(torch.half).to(device)
67
+ else:
68
+ inputs.to(device)
69
+ generated_ids = self.model.generate(**inputs,do_sample = False , max_new_tokens = 1024)
70
+ generated_texts = self.processor.batch_decode(generated_ids,skip_special_tokens=True)
71
+ del inputs
72
+ torch.cuda.empty_cache()
73
+ return generated_texts[0].split('\n')[-1]
74
+
75
+