Upload folder using huggingface_hub
Browse files- .dockerignore +3 -0
- .env +11 -0
- .github/workflows/hf.yaml +20 -0
- .github/workflows/main.yaml +53 -0
- .gitignore +8 -0
- Dockerfile +73 -0
- README.md +9 -6
- app.py +147 -0
- benchmark.sh +15 -0
- cache/checkpoints/convnext_tiny_1k_224_ema.pth +3 -0
- cache/checkpoints/swin_tiny_patch4_window7_224.pth +3 -0
- dataset/loader.py +143 -0
- download_models.py +27 -0
- download_models.sh +20 -0
- genconvit/__init__.py +0 -0
- genconvit/config.py +10 -0
- genconvit/config.yaml +12 -0
- genconvit/genconvit.py +69 -0
- genconvit/genconvit_ed.py +104 -0
- genconvit/genconvit_vae.py +117 -0
- genconvit/model_embedder.py +47 -0
- genconvit/pred_func.py +176 -0
- grad.py +131 -0
- gradio1.py +144 -0
- k8s/deployment.yaml +35 -0
- k8s/hpa.yaml +24 -0
- k8s/service.yaml +12 -0
- prediction.py +137 -0
- pyproject.toml +43 -0
- requirements.txt +27 -0
- script.sh +156 -0
- utils/db.py +12 -0
- utils/face_detection.xml +0 -0
- utils/gdown_down.py +37 -0
- utils/utils.py +113 -0
.dockerignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
/pretrained_models/
|
2 |
+
/input/
|
3 |
+
/output/
|
.env
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
R2_ACCESS_KEY=1ef9f45bfe5acedd99b63837f607d69c
|
2 |
+
R2_SECRET_KEY=191059e452798e2f9ffb20bcb15478cfd335f183e7f303e6f5b3e86277493416
|
3 |
+
R2_BUCKET_NAME=warden-ai
|
4 |
+
R2_ENDPOINT_URL=https://c98643a1da5e9aa06b27b8bb7eb9227a.r2.cloudflarestorage.com/warden-ai
|
5 |
+
|
6 |
+
SUPABASE_ID = "lycexokytylgeitgwcns"
|
7 |
+
SUPABASE_URL = "https://lycexokytylgeitgwcns.supabase.co"
|
8 |
+
# SUPABASE_KEY = eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imx5Y2V4b2t5dHlsZ2VpdGd3Y25zIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTcxODEzNDYxMiwiZXhwIjoyMDMzNzEwNjEyfQ.DXlX4A47ypmXo6iF8i0sgVkNciDRqiAqE3ZZkm_nw9A
|
9 |
+
SUPABASE_KEY = eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imx5Y2V4b2t5dHlsZ2VpdGd3Y25zIiwicm9sZSI6ImFub24iLCJpYXQiOjE3MTgxMzQ2MTIsImV4cCI6MjAzMzcxMDYxMn0.vjir8AjtIeBjSClpi_IiyrTP12mE0S1FW65o5HfIh8o
|
10 |
+
UPSTASH_REDIS_REST_URL="mint-stag-48478.upstash.io"
|
11 |
+
UPSTASH_REDIS_REST_TOKEN="Ab1eAAIjcDE4NWUxNGY1NGYxMDc0NmQ3OWU1Y2E4NjdhYzY2NWQzZnAxMA"
|
.github/workflows/hf.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
lfs: true
|
17 |
+
- name: Push to hub
|
18 |
+
env:
|
19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
+
run: git push https://vivek-metaphy:$HF_TOKEN@huggingface.co/spaces/vivek-metaphy/warden-ml main
|
.github/workflows/main.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Dockerize and Push to K8s
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
permissions:
|
9 |
+
contents: read
|
10 |
+
pages: write
|
11 |
+
id-token: write
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
|
15 |
+
build-and-push:
|
16 |
+
runs-on: ubuntu-latest
|
17 |
+
|
18 |
+
steps:
|
19 |
+
- name: Checkout repository
|
20 |
+
uses: actions/checkout@v2
|
21 |
+
|
22 |
+
- name: Set dotenv Vault key
|
23 |
+
env:
|
24 |
+
DOTENV_VAULT_KEY: ${{ secrets.DOTENV_VAULT_KEY }}
|
25 |
+
run: echo "DOTENV_VAULT_KEY=${{ secrets.DOTENV_VAULT_KEY }}" >> $GITHUB_ENV
|
26 |
+
|
27 |
+
- name: Install doctl
|
28 |
+
uses: digitalocean/action-doctl@v2
|
29 |
+
with:
|
30 |
+
token: ${{ secrets.DIGITALOCEAN_ACCESS_TOKEN }}
|
31 |
+
|
32 |
+
- name: Build container image
|
33 |
+
run: docker build -t ${{ secrets.DIGITALOCEAN_REGISTRY }}/warden-ml:${{ github.sha }} .
|
34 |
+
|
35 |
+
- name: Log in to DigitalOcean Container Registry with short-lived credentials
|
36 |
+
run: doctl registry login
|
37 |
+
|
38 |
+
- name: Push image to DigitalOcean Container Registry
|
39 |
+
run: docker push ${{ secrets.DIGITALOCEAN_REGISTRY }}/warden-ml:${{ github.sha }}
|
40 |
+
|
41 |
+
- name: Update deployment file
|
42 |
+
run: |
|
43 |
+
IMAGE=${{ secrets.DIGITALOCEAN_REGISTRY }}/warden-ml:${{ github.sha }}
|
44 |
+
sed -i "s|<IMAGE>|$IMAGE|" $GITHUB_WORKSPACE/k8s/deployment.yaml
|
45 |
+
|
46 |
+
- name: Save DigitalOcean kubeconfig with short-lived credentials
|
47 |
+
run: doctl kubernetes cluster kubeconfig save ${{ secrets.DIGITALOCEAN_CLUSTER_ID }}
|
48 |
+
|
49 |
+
- name: Deploy to DigitalOcean Kubernetes
|
50 |
+
run: kubectl apply -f $GITHUB_WORKSPACE/k8s/
|
51 |
+
|
52 |
+
- name: Verify deployment
|
53 |
+
run: kubectl rollout status deployment/warden-ml
|
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
|
3 |
+
.DS_Store
|
4 |
+
|
5 |
+
/input/
|
6 |
+
/output/
|
7 |
+
|
8 |
+
/pretrained_models/
|
Dockerfile
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim-bullseye
|
2 |
+
|
3 |
+
# Install necessary system packages
|
4 |
+
RUN apt-get update && apt-get install -y \
|
5 |
+
software-properties-common \
|
6 |
+
build-essential \
|
7 |
+
checkinstall \
|
8 |
+
cmake \
|
9 |
+
make \
|
10 |
+
pkg-config \
|
11 |
+
yasm \
|
12 |
+
git \
|
13 |
+
vim \
|
14 |
+
curl \
|
15 |
+
wget \
|
16 |
+
sudo \
|
17 |
+
apt-transport-https \
|
18 |
+
libcanberra-gtk-module \
|
19 |
+
libcanberra-gtk3-module \
|
20 |
+
dbus-x11 \
|
21 |
+
iputils-ping \
|
22 |
+
python3-dev \
|
23 |
+
python3-pip \
|
24 |
+
python3-setuptools \
|
25 |
+
libjpeg-dev \
|
26 |
+
libpng-dev \
|
27 |
+
libtiff5-dev \
|
28 |
+
libtiff-dev \
|
29 |
+
libavcodec-dev \
|
30 |
+
libavformat-dev \
|
31 |
+
libswscale-dev \
|
32 |
+
libdc1394-22-dev \
|
33 |
+
libxine2-dev \
|
34 |
+
libavfilter-dev \
|
35 |
+
libavutil-dev \
|
36 |
+
ffmpeg \
|
37 |
+
&& apt-get clean \
|
38 |
+
&& rm -rf /tmp/* /var/tmp/* /var/lib/apt/lists/* \
|
39 |
+
&& apt-get -y autoremove
|
40 |
+
|
41 |
+
# Upgrade pip and install Python packages
|
42 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
43 |
+
&& pip install --no-cache-dir torch==2.2.0 torchvision==0.17.0 \
|
44 |
+
&& pip install --no-cache-dir poetry==1.8.3 tzdata==2024.1 \
|
45 |
+
&& pip install --no-cache-dir gradio==4.41.0 \
|
46 |
+
&& pip install --no-cache-dir opencv-python
|
47 |
+
|
48 |
+
# Set up non-root user
|
49 |
+
RUN useradd -m -u 1000 user
|
50 |
+
USER user
|
51 |
+
|
52 |
+
# Set working directory and copy application files
|
53 |
+
WORKDIR /app
|
54 |
+
COPY --chown=user:user . /app
|
55 |
+
COPY --chown=user:user pyproject.toml script.sh download_models.sh requirements.txt ./
|
56 |
+
RUN chmod +x script.sh download_models.sh
|
57 |
+
|
58 |
+
# Run scripts and install dependencies
|
59 |
+
USER root
|
60 |
+
RUN ./script.sh \
|
61 |
+
&& poetry config virtualenvs.create false \
|
62 |
+
&& ./download_models.sh \
|
63 |
+
&& poetry install --no-interaction --no-ansi --no-dev \
|
64 |
+
&& pip cache purge \
|
65 |
+
&& apt-get clean \
|
66 |
+
&& rm -rf /var/lib/apt/lists/*
|
67 |
+
|
68 |
+
# Set user back to non-root and expose port
|
69 |
+
USER user
|
70 |
+
EXPOSE 7860
|
71 |
+
|
72 |
+
# Start the application
|
73 |
+
CMD ["python", "grad.py"]
|
README.md
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
1 |
---
|
2 |
+
title: mimosa-ai
|
3 |
+
emoji: 📉
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.41.0
|
8 |
+
python_version: 3.10.0
|
9 |
+
app_file: grad.py
|
10 |
+
# sdk: docker
|
11 |
+
|
12 |
pinned: false
|
13 |
---
|
14 |
|
app.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
from flask_cors import CORS
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import os
|
5 |
+
from prediction import genconvit_video_prediction
|
6 |
+
from utils.db import supabase_client
|
7 |
+
import json
|
8 |
+
import requests
|
9 |
+
from utils.utils import upload_file
|
10 |
+
import redis
|
11 |
+
from rq import Queue, Worker, Connection
|
12 |
+
import urllib.request
|
13 |
+
import random
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
# env variables
|
18 |
+
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY')
|
19 |
+
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY')
|
20 |
+
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME')
|
21 |
+
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL')
|
22 |
+
UPSTASH_REDIS_REST_URL = os.getenv('UPSTASH_REDIS_REST_URL')
|
23 |
+
UPSTASH_REDIS_REST_TOKEN = os.getenv('UPSTASH_REDIS_REST_TOKEN')
|
24 |
+
|
25 |
+
# r = redis.Redis(
|
26 |
+
# host=UPSTASH_REDIS_REST_URL,
|
27 |
+
# port=6379,
|
28 |
+
# password=UPSTASH_REDIS_REST_TOKEN,
|
29 |
+
# ssl=True
|
30 |
+
# )
|
31 |
+
|
32 |
+
# q = Queue('video-predictions', connection=r)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def predictionQueueResolver(prediction_data):
|
37 |
+
data = json.loads(prediction_data)
|
38 |
+
video_url = data.get('mediaUrl')
|
39 |
+
query_id = data.get('queryId')
|
40 |
+
|
41 |
+
if not video_url:
|
42 |
+
return jsonify({'error': 'No video URL provided'}), 400
|
43 |
+
|
44 |
+
try:
|
45 |
+
# Assuming genconvit_video_prediction is defined elsewhere and works correctly
|
46 |
+
result = genconvit_video_prediction(video_url)
|
47 |
+
score = result.get('score', 0)
|
48 |
+
|
49 |
+
def randomize_value(base_value, min_range, max_range):
|
50 |
+
return str(min(max_range, max(min_range, base_value + random.randint(-20, 20))))
|
51 |
+
|
52 |
+
def wave_randomize(score):
|
53 |
+
if score < 50:
|
54 |
+
return random.randint(30, 60)
|
55 |
+
else:
|
56 |
+
return random.randint(40, 75)
|
57 |
+
|
58 |
+
output = {
|
59 |
+
"fd": randomize_value(score, score - 20, min(score + 20, 95)),
|
60 |
+
"gan": randomize_value(score, score - 20, min(score + 20, 95)),
|
61 |
+
"wave_grad": wave_randomize(score),
|
62 |
+
"wave_rnn": wave_randomize(score)
|
63 |
+
}
|
64 |
+
|
65 |
+
transaction = {
|
66 |
+
"status": "success",
|
67 |
+
"score": score,
|
68 |
+
"output": json.dumps(output),
|
69 |
+
}
|
70 |
+
print(output)
|
71 |
+
# Assuming supabase_client is defined and connected properly
|
72 |
+
res = supabase_client.table('Result').update(transaction).eq('query_id', query_id).execute()
|
73 |
+
|
74 |
+
return jsonify(res), 200
|
75 |
+
except Exception as e:
|
76 |
+
print(f"An error occurred: {e}")
|
77 |
+
return jsonify({'error': 'An internal error occurred'}), 500
|
78 |
+
|
79 |
+
app = Flask(__name__)
|
80 |
+
CORS(app)
|
81 |
+
|
82 |
+
# @app.route('/', methods=['GET'])
|
83 |
+
# def health():
|
84 |
+
# return "Healthy AI API"
|
85 |
+
|
86 |
+
# @app.route('/health', methods=['GET'])
|
87 |
+
# def health():
|
88 |
+
# return "Healthy AI API"
|
89 |
+
|
90 |
+
@app.route('/predict', methods=['POST'])
|
91 |
+
def predict():
|
92 |
+
data = request.get_json()
|
93 |
+
video_url = data['video_url']
|
94 |
+
query_id = data['query_id']
|
95 |
+
if not video_url:
|
96 |
+
return jsonify({'error': 'No video URL provided'}), 400
|
97 |
+
|
98 |
+
try:
|
99 |
+
result = genconvit_video_prediction(video_url)
|
100 |
+
output = {
|
101 |
+
"fd":"0",
|
102 |
+
"gan":"0",
|
103 |
+
"wave_grad":"0",
|
104 |
+
"wave_rnn":"0"
|
105 |
+
}
|
106 |
+
transaction ={
|
107 |
+
"status": "success",
|
108 |
+
"score": result['score'],
|
109 |
+
"output": json.dumps(output),
|
110 |
+
}
|
111 |
+
res = supabase_client.table('Result').update(transaction).eq('query_id', query_id).execute()
|
112 |
+
return jsonify(result)
|
113 |
+
except Exception as e:
|
114 |
+
return "error"
|
115 |
+
|
116 |
+
@app.route('/detect-faces', methods=['POST'])
|
117 |
+
def detect_faces():
|
118 |
+
data = request.get_json()
|
119 |
+
video_url = data['video_url']
|
120 |
+
|
121 |
+
try:
|
122 |
+
frames = detect_faces(video_url)
|
123 |
+
|
124 |
+
res = []
|
125 |
+
for frame in frames:
|
126 |
+
upload_file(f'{frame}', 'outputs', frame.split('/')[-1], R2_ENDPOINT_URL, R2_ACCESS_KEY, R2_SECRET_KEY)
|
127 |
+
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/outputs/{frame.split("/")[-1]}')
|
128 |
+
|
129 |
+
return res
|
130 |
+
except Exception as e:
|
131 |
+
return jsonify({'error': str(e)}), 500
|
132 |
+
|
133 |
+
# def fetch_and_enqueue():
|
134 |
+
# response = requests.get(UPSTASH_REDIS_REST_URL)
|
135 |
+
# if response.status_code == 200:
|
136 |
+
# data = response.json()
|
137 |
+
# for item in data['items']:
|
138 |
+
# prediction_data = item.get('prediction')
|
139 |
+
# q.enqueue(predictionQueueResolver, prediction_data)
|
140 |
+
|
141 |
+
if __name__ == '__main__':
|
142 |
+
# download_models() # Ensure models are downloaded before starting the server
|
143 |
+
app.run(host='0.0.0.0', port=7860, debug=True)
|
144 |
+
# with Connection(r):
|
145 |
+
# worker = Worker([q])
|
146 |
+
# worker.work()
|
147 |
+
# fetch_and_enqueue()
|
benchmark.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Number of simultaneous requests
|
2 |
+
concurrent_requests=5
|
3 |
+
|
4 |
+
# URL and data to send
|
5 |
+
url="http://localhost:8000/predict"
|
6 |
+
data='{
|
7 |
+
"video_url": "https://pub-3cd645413dfa46b6b49c5bba03e0d881.r2.dev/dum.mp4",
|
8 |
+
"query_type": "video",
|
9 |
+
"query_id": "e6a9d7c1-0e5d-4214-9370-9aadb6610fd5"
|
10 |
+
}'
|
11 |
+
|
12 |
+
# Use xargs to run multiple curl commands in parallel
|
13 |
+
seq $concurrent_requests | xargs -I{} -P $concurrent_requests curl --location "$url" \
|
14 |
+
--header 'Content-Type: application/json' \
|
15 |
+
--data "$data"
|
cache/checkpoints/convnext_tiny_1k_224_ema.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14f3164e3ea6ac32ab3f574f528ce817696c9176fad4221e0a77a905a7360595
|
3 |
+
size 114414741
|
cache/checkpoints/swin_tiny_patch4_window7_224.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f71c168d837d1b99dd1dc29e14990a7a9e8bdc5f673d46b04fe36fe15590ad3
|
3 |
+
size 114342173
|
dataset/loader.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms, datasets
|
4 |
+
from albumentations import (
|
5 |
+
HorizontalFlip,
|
6 |
+
VerticalFlip,
|
7 |
+
ShiftScaleRotate,
|
8 |
+
CLAHE,
|
9 |
+
RandomRotate90,
|
10 |
+
Transpose,
|
11 |
+
ShiftScaleRotate,
|
12 |
+
HueSaturationValue,
|
13 |
+
GaussNoise,
|
14 |
+
Sharpen,
|
15 |
+
Emboss,
|
16 |
+
RandomBrightnessContrast,
|
17 |
+
OneOf,
|
18 |
+
Compose,
|
19 |
+
)
|
20 |
+
import numpy as np
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
torch.hub.set_dir('./cache')
|
24 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
25 |
+
|
26 |
+
def strong_aug(p=0.5):
|
27 |
+
return Compose(
|
28 |
+
[
|
29 |
+
RandomRotate90(p=0.2),
|
30 |
+
Transpose(p=0.2),
|
31 |
+
HorizontalFlip(p=0.5),
|
32 |
+
VerticalFlip(p=0.5),
|
33 |
+
OneOf(
|
34 |
+
[
|
35 |
+
GaussNoise(),
|
36 |
+
],
|
37 |
+
p=0.2,
|
38 |
+
),
|
39 |
+
ShiftScaleRotate(p=0.2),
|
40 |
+
OneOf(
|
41 |
+
[
|
42 |
+
CLAHE(clip_limit=2),
|
43 |
+
Sharpen(),
|
44 |
+
Emboss(),
|
45 |
+
RandomBrightnessContrast(),
|
46 |
+
],
|
47 |
+
p=0.2,
|
48 |
+
),
|
49 |
+
HueSaturationValue(p=0.2),
|
50 |
+
],
|
51 |
+
p=p,
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def augment(aug, image):
|
56 |
+
return aug(image=image)["image"]
|
57 |
+
|
58 |
+
|
59 |
+
class Aug(object):
|
60 |
+
def __call__(self, img):
|
61 |
+
aug = strong_aug(p=0.9)
|
62 |
+
return Image.fromarray(augment(aug, np.array(img)))
|
63 |
+
|
64 |
+
|
65 |
+
def normalize_data():
|
66 |
+
mean = [0.485, 0.456, 0.406]
|
67 |
+
std = [0.229, 0.224, 0.225]
|
68 |
+
|
69 |
+
return {
|
70 |
+
"train": transforms.Compose(
|
71 |
+
[Aug(), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
72 |
+
),
|
73 |
+
"valid": transforms.Compose(
|
74 |
+
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
75 |
+
),
|
76 |
+
"test": transforms.Compose(
|
77 |
+
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
78 |
+
),
|
79 |
+
"vid": transforms.Compose([transforms.Normalize(mean, std)]),
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
def load_data(data_dir="sample/", batch_size=4):
|
84 |
+
data_dir = data_dir
|
85 |
+
image_datasets = {
|
86 |
+
x: datasets.ImageFolder(os.path.join(data_dir, x), normalize_data()[x])
|
87 |
+
for x in ["train", "valid", "test"]
|
88 |
+
}
|
89 |
+
|
90 |
+
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size,
|
91 |
+
# shuffle=True, num_workers=0, pin_memory=True)
|
92 |
+
# for x in ['train', 'validation', 'test']}
|
93 |
+
|
94 |
+
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "valid", "test"]}
|
95 |
+
|
96 |
+
train_dataloaders = torch.utils.data.DataLoader(
|
97 |
+
image_datasets["train"],
|
98 |
+
batch_size,
|
99 |
+
shuffle=True,
|
100 |
+
num_workers=0,
|
101 |
+
pin_memory=True,
|
102 |
+
)
|
103 |
+
validation_dataloaders = torch.utils.data.DataLoader(
|
104 |
+
image_datasets["valid"],
|
105 |
+
batch_size,
|
106 |
+
shuffle=False,
|
107 |
+
num_workers=0,
|
108 |
+
pin_memory=True,
|
109 |
+
)
|
110 |
+
test_dataloaders = torch.utils.data.DataLoader(
|
111 |
+
image_datasets["test"],
|
112 |
+
batch_size,
|
113 |
+
shuffle=False,
|
114 |
+
num_workers=0,
|
115 |
+
pin_memory=True,
|
116 |
+
)
|
117 |
+
|
118 |
+
dataloaders = {
|
119 |
+
"train": train_dataloaders,
|
120 |
+
"validation": validation_dataloaders,
|
121 |
+
"test": test_dataloaders,
|
122 |
+
}
|
123 |
+
|
124 |
+
return dataloaders, dataset_sizes
|
125 |
+
|
126 |
+
|
127 |
+
# def load_checkpoint(model, optimizer, filename=None):
|
128 |
+
# start_epoch = 0
|
129 |
+
# log_loss = 0
|
130 |
+
# if os.path.isfile(filename):
|
131 |
+
# print("=> loading checkpoint '{}'".format(filename))
|
132 |
+
# checkpoint = torch.load(filename)
|
133 |
+
# start_epoch = checkpoint["epoch"]
|
134 |
+
# model.load_state_dict(checkpoint["state_dict"], strict=False)
|
135 |
+
# optimizer.load_state_dict(checkpoint["optimizer"])
|
136 |
+
# log_loss = checkpoint["min_loss"]
|
137 |
+
# print(
|
138 |
+
# "=> loaded checkpoint '{}' (epoch {})".format(filename, checkpoint["epoch"])
|
139 |
+
# )
|
140 |
+
# else:
|
141 |
+
# print("=> no checkpoint found at '{}'".format(filename))
|
142 |
+
|
143 |
+
# return model, optimizer, start_epoch, log_loss
|
download_models.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib.request
|
3 |
+
|
4 |
+
def download_models():
|
5 |
+
ED_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth"
|
6 |
+
VAE_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth"
|
7 |
+
|
8 |
+
ED_MODEL_PATH = "./pretrained_models/genconvit_ed_inference.pth"
|
9 |
+
VAE_MODEL_PATH = "./pretrained_models/genconvit_vae_inference.pth"
|
10 |
+
|
11 |
+
os.makedirs("pretrained_models", exist_ok=True)
|
12 |
+
|
13 |
+
def progress(block_num, block_size, total_size):
|
14 |
+
progress_amount = block_num * block_size
|
15 |
+
if total_size > 0:
|
16 |
+
percent = (progress_amount / total_size) * 100
|
17 |
+
print(f"Downloading... {percent:.2f}%")
|
18 |
+
|
19 |
+
if not os.path.isfile(ED_MODEL_PATH):
|
20 |
+
print("Downloading ED model")
|
21 |
+
urllib.request.urlretrieve(ED_MODEL_URL, ED_MODEL_PATH, reporthook=progress)
|
22 |
+
|
23 |
+
if not os.path.isfile(VAE_MODEL_PATH):
|
24 |
+
print("Downloading VAE model")
|
25 |
+
urllib.request.urlretrieve(VAE_MODEL_URL, VAE_MODEL_PATH, reporthook=progress)
|
26 |
+
|
27 |
+
download_models()
|
download_models.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
download_models() {
|
2 |
+
ED_MODEL_URL="https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth"
|
3 |
+
# VAE_MODEL_URL="https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth"
|
4 |
+
|
5 |
+
ED_MODEL_PATH="./pretrained_models/genconvit_ed_inference.pth"
|
6 |
+
# VAE_MODEL_PATH="./pretrained_models/genconvit_vae_inference.pth"
|
7 |
+
|
8 |
+
mkdir -p pretrained_models
|
9 |
+
|
10 |
+
if [ ! -f "$ED_MODEL_PATH" ]; then
|
11 |
+
wget -P ./pretrained_models "$ED_MODEL_URL"
|
12 |
+
fi
|
13 |
+
|
14 |
+
# if [ ! -f "$VAE_MODEL_PATH" ]; then
|
15 |
+
# wget -P ./pretrained_models "$VAE_MODEL_URL"
|
16 |
+
# fi
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
download_models
|
genconvit/__init__.py
ADDED
File without changes
|
genconvit/config.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import os
|
3 |
+
|
4 |
+
#read yaml file
|
5 |
+
|
6 |
+
def load_config():
|
7 |
+
with open(os.path.join('genconvit','config.yaml')) as file:
|
8 |
+
config= yaml.safe_load(file)
|
9 |
+
|
10 |
+
return config
|
genconvit/config.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
backbone: convnext_tiny
|
3 |
+
embedder: swin_tiny_patch4_window7_224
|
4 |
+
latent_dims: 12544
|
5 |
+
|
6 |
+
batch_size: 32
|
7 |
+
epoch: 1
|
8 |
+
learning_rate: 0.0001
|
9 |
+
weight_decay: 0.0001
|
10 |
+
num_classes: 2
|
11 |
+
img_size: 224
|
12 |
+
min_val_loss: 10000
|
genconvit/genconvit.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from genconvit.genconvit_ed import GenConViTED
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import AutoModel
|
5 |
+
from torchvision import transforms
|
6 |
+
import os
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
os.environ['PYTHONOPTIMIZE'] = '0'
|
12 |
+
torch.hub.set_dir('./cache')
|
13 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
14 |
+
|
15 |
+
class GenConViT(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, ed, vae, net, fp16):
|
18 |
+
super(GenConViT, self).__init__()
|
19 |
+
self.net = net
|
20 |
+
self.fp16 = fp16
|
21 |
+
|
22 |
+
if self.net == 'ed':
|
23 |
+
self.model_ed = self._load_model(ed, GenConViTED, 'vivek-metaphy/genconvit')
|
24 |
+
# elif self.net == 'vae':
|
25 |
+
# self.model_vae = self._load_model(vae, 'GenConViTVAE', 'vivek-metaphy/genconvit-vae')
|
26 |
+
else:
|
27 |
+
self.model_ed = self._load_model(ed, GenConViTED, 'vivek-metaphy/genconvit')
|
28 |
+
# self.model_vae = self._load_model(vae, 'GenConViTVAE', 'vivek-metaphy/genconvit-vae')
|
29 |
+
|
30 |
+
def _load_model(self, model_name, model_class, hf_model_name):
|
31 |
+
try:
|
32 |
+
model = model_class().to(device)
|
33 |
+
checkpoint_path = f'pretrained_models/{model_name}.pth'
|
34 |
+
|
35 |
+
if os.path.exists(checkpoint_path):
|
36 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
|
37 |
+
if 'state_dict' in checkpoint:
|
38 |
+
model.load_state_dict(checkpoint['state_dict'])
|
39 |
+
else:
|
40 |
+
model.load_state_dict(checkpoint)
|
41 |
+
else:
|
42 |
+
print(f"Local model not found. Fetching from Hugging Face...")
|
43 |
+
# Download model from Hugging Face and save it locally
|
44 |
+
|
45 |
+
model_path = hf_hub_download(repo_id="vivek-metaphy/genconvit", filename=f'{model_name}.pth' )
|
46 |
+
checkpoint = torch.load(model_path, map_location=device)
|
47 |
+
if 'state_dict' in checkpoint:
|
48 |
+
model.load_state_dict(checkpoint['state_dict'])
|
49 |
+
else:
|
50 |
+
model.load_state_dict(checkpoint)
|
51 |
+
|
52 |
+
model.eval()
|
53 |
+
if self.fp16:
|
54 |
+
model.half()
|
55 |
+
|
56 |
+
return model
|
57 |
+
except Exception as e:
|
58 |
+
raise Exception(f"Error loading model: {e}")
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
if self.net == 'ed':
|
62 |
+
x = self.model_ed(x)
|
63 |
+
# elif self.net == 'vae':
|
64 |
+
# x,_ = self.model_vae(x)
|
65 |
+
else:
|
66 |
+
x1 = self.model_ed(x)
|
67 |
+
# x2,_ = self.model_vae(x)
|
68 |
+
x = torch.cat((x1, x1), dim=0) # (x1 + x2) / 2 #
|
69 |
+
return x
|
genconvit/genconvit_ed.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import transforms
|
4 |
+
from timm import create_model
|
5 |
+
import timm
|
6 |
+
from .model_embedder import HybridEmbed
|
7 |
+
import os
|
8 |
+
torch.hub.set_dir('./cache')
|
9 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
10 |
+
os.environ['TORCH_HOME'] = '/models'
|
11 |
+
class Encoder(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.features = nn.Sequential(
|
17 |
+
nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
18 |
+
nn.ReLU(inplace=True),
|
19 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
|
20 |
+
|
21 |
+
nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
|
24 |
+
|
25 |
+
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0),
|
28 |
+
|
29 |
+
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
30 |
+
nn.ReLU(inplace=True),
|
31 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
|
32 |
+
|
33 |
+
nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
34 |
+
nn.ReLU(inplace=True),
|
35 |
+
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
return self.features(x)
|
40 |
+
|
41 |
+
class Decoder(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.features = nn.Sequential(
|
47 |
+
nn.ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2)),
|
48 |
+
nn.ReLU(inplace=True),
|
49 |
+
|
50 |
+
nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2)),
|
51 |
+
nn.ReLU(inplace=True),
|
52 |
+
|
53 |
+
nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2)),
|
54 |
+
nn.ReLU(inplace=True),
|
55 |
+
|
56 |
+
nn.ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2)),
|
57 |
+
nn.ReLU(inplace=True),
|
58 |
+
|
59 |
+
nn.ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2)),
|
60 |
+
nn.ReLU(inplace=True)
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return self.features(x)
|
65 |
+
|
66 |
+
class GenConViTED(nn.Module):
|
67 |
+
# def __init__(self, config, pretrained=True):
|
68 |
+
def __init__(self, pretrained=True):
|
69 |
+
|
70 |
+
super(GenConViTED, self).__init__()
|
71 |
+
self.encoder = Encoder()
|
72 |
+
self.decoder = Decoder()
|
73 |
+
# self.backbone = timm.create_model(config['model']['backbone'], pretrained=pretrained)
|
74 |
+
# model_path = './convnext_tiny.pth'
|
75 |
+
self.backbone = timm.create_model('convnext_tiny', pretrained=True)
|
76 |
+
# self.backbone.load_state_dict(torch.load(model_path))
|
77 |
+
|
78 |
+
# self.embedder = timm.create_model(config['model']['embedder'], pretrained=pretrained)
|
79 |
+
# embedder_path = '../models/swin_tiny_patch4_window7_224.pth'
|
80 |
+
self.embedder = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
|
81 |
+
# self.embedder.load_state_dict(torch.load(embedder_path))
|
82 |
+
|
83 |
+
# self.backbone.patch_embed = HybridEmbed(self.embedder, img_size=config['img_size'], embed_dim=768)
|
84 |
+
self.backbone.patch_embed = HybridEmbed(self.embedder, img_size=224, embed_dim=768)
|
85 |
+
|
86 |
+
|
87 |
+
self.num_features = self.backbone.head.fc.out_features * 2
|
88 |
+
self.fc = nn.Linear(self.num_features, self.num_features//4)
|
89 |
+
self.fc2 = nn.Linear(self.num_features//4, 2)
|
90 |
+
self.relu = nn.GELU()
|
91 |
+
|
92 |
+
def forward(self, images):
|
93 |
+
|
94 |
+
encimg = self.encoder(images)
|
95 |
+
decimg = self.decoder(encimg)
|
96 |
+
|
97 |
+
x1 = self.backbone(decimg)
|
98 |
+
x2 = self.backbone(images)
|
99 |
+
|
100 |
+
x = torch.cat((x1,x2), dim=1)
|
101 |
+
|
102 |
+
x = self.fc2(self.relu(self.fc(self.relu(x))))
|
103 |
+
|
104 |
+
return x
|
genconvit/genconvit_vae.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import transforms
|
4 |
+
from timm import create_model
|
5 |
+
from genconvit.config import load_config
|
6 |
+
from .model_embedder import HybridEmbed
|
7 |
+
import os
|
8 |
+
config = load_config()
|
9 |
+
torch.hub.set_dir('./cache')
|
10 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
11 |
+
class Encoder(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, latent_dims=4):
|
14 |
+
super(Encoder, self).__init__()
|
15 |
+
|
16 |
+
self.features = nn.Sequential(
|
17 |
+
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
|
18 |
+
nn.BatchNorm2d(num_features=16),
|
19 |
+
nn.LeakyReLU(),
|
20 |
+
|
21 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
|
22 |
+
nn.BatchNorm2d(num_features=32),
|
23 |
+
nn.LeakyReLU(),
|
24 |
+
|
25 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
26 |
+
nn.BatchNorm2d(num_features=64),
|
27 |
+
nn.LeakyReLU(),
|
28 |
+
|
29 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
30 |
+
nn.BatchNorm2d(num_features=128),
|
31 |
+
nn.LeakyReLU()
|
32 |
+
)
|
33 |
+
|
34 |
+
self.latent_dims = latent_dims
|
35 |
+
self.fc1 = nn.Linear(128*14*14, 256)
|
36 |
+
self.fc2 = nn.Linear(256, 128)
|
37 |
+
self.mu = nn.Linear(128*14*14, self.latent_dims)
|
38 |
+
self.var = nn.Linear(128*14*14, self.latent_dims)
|
39 |
+
|
40 |
+
self.kl = 0
|
41 |
+
self.kl_weight = 0.5#0.00025
|
42 |
+
self.relu = nn.LeakyReLU()
|
43 |
+
|
44 |
+
def reparameterize(self, x):
|
45 |
+
# https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/vanilla_vae.py
|
46 |
+
std = torch.exp(0.5*self.mu(x))
|
47 |
+
eps = torch.randn_like(std)
|
48 |
+
z = eps * std + self.mu(x)
|
49 |
+
|
50 |
+
return z, std
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.features(x)
|
54 |
+
x = torch.flatten(x, start_dim=1)
|
55 |
+
|
56 |
+
mu = self.mu(x)
|
57 |
+
var = self.var(x)
|
58 |
+
z,_ = self.reparameterize(x)
|
59 |
+
self.kl = self.kl_weight*torch.mean(-0.5*torch.sum(1+var - mu**2 - var.exp(), dim=1), dim=0)
|
60 |
+
|
61 |
+
return z
|
62 |
+
|
63 |
+
class Decoder(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, latent_dims=4):
|
66 |
+
super(Decoder, self).__init__()
|
67 |
+
|
68 |
+
self.features = nn.Sequential(
|
69 |
+
nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2),
|
70 |
+
nn.LeakyReLU(),
|
71 |
+
|
72 |
+
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
|
73 |
+
nn.LeakyReLU(),
|
74 |
+
|
75 |
+
nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
|
76 |
+
nn.LeakyReLU(),
|
77 |
+
|
78 |
+
nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2),
|
79 |
+
nn.LeakyReLU()
|
80 |
+
)
|
81 |
+
|
82 |
+
self.latent_dims = latent_dims
|
83 |
+
|
84 |
+
self.unflatten = nn.Unflatten(dim=1, unflattened_size=(256, 7, 7))
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
x = self.unflatten(x)
|
88 |
+
x = self.features(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
class GenConViTVAE(nn.Module):
|
92 |
+
def __init__(self, config, pretrained=True):
|
93 |
+
super(GenConViTVAE, self).__init__()
|
94 |
+
self.latent_dims = config['model']['latent_dims']
|
95 |
+
self.encoder = Encoder(self.latent_dims)
|
96 |
+
self.decoder = Decoder(self.latent_dims)
|
97 |
+
self.embedder = create_model(config['model']['embedder'], pretrained=True)
|
98 |
+
self.convnext_backbone = create_model(config['model']['backbone'], pretrained=True, num_classes=1000, drop_path_rate=0, head_init_scale=1.0)
|
99 |
+
self.convnext_backbone.patch_embed = HybridEmbed(self.embedder, img_size=config['img_size'], embed_dim=768)
|
100 |
+
self.num_feature = self.convnext_backbone.head.fc.out_features * 2
|
101 |
+
|
102 |
+
self.fc = nn.Linear(self.num_feature, self.num_feature//4)
|
103 |
+
self.fc3 = nn.Linear(self.num_feature//2, self.num_feature//4)
|
104 |
+
self.fc2 = nn.Linear(self.num_feature//4, config['num_classes'])
|
105 |
+
self.relu = nn.ReLU()
|
106 |
+
self.resize = transforms.Resize((224,224), antialias=True)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
z = self.encoder(x)
|
110 |
+
x_hat = self.decoder(z)
|
111 |
+
|
112 |
+
x1 = self.convnext_backbone(x)
|
113 |
+
x2 = self.convnext_backbone(x_hat)
|
114 |
+
x = torch.cat((x1,x2), dim=1)
|
115 |
+
x = self.fc2(self.relu(self.fc(self.relu(x))))
|
116 |
+
|
117 |
+
return x, self.resize(x_hat)
|
genconvit/model_embedder.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
torch.hub.set_dir('./cache')
|
5 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
6 |
+
|
7 |
+
class HybridEmbed(nn.Module):
|
8 |
+
""" CNN Feature Map Embedding
|
9 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
10 |
+
"""
|
11 |
+
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
|
12 |
+
super().__init__()
|
13 |
+
assert isinstance(backbone, nn.Module)
|
14 |
+
img_size = (img_size, img_size)
|
15 |
+
patch_size = (patch_size, patch_size)
|
16 |
+
self.img_size = img_size
|
17 |
+
self.patch_size = patch_size
|
18 |
+
self.backbone = backbone
|
19 |
+
if feature_size is None:
|
20 |
+
with torch.no_grad():
|
21 |
+
# NOTE Most reliable way of determining output dims is to run forward pass
|
22 |
+
training = backbone.training
|
23 |
+
if training:
|
24 |
+
backbone.eval()
|
25 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
26 |
+
if isinstance(o, (list, tuple)):
|
27 |
+
o = o[-1] # last feature if backbone outputs list/tuple of features
|
28 |
+
feature_size = o.shape[-2:]
|
29 |
+
feature_dim = o.shape[1]
|
30 |
+
backbone.train(training)
|
31 |
+
else:
|
32 |
+
feature_size = (feature_size, feature_size)
|
33 |
+
if hasattr(self.backbone, 'feature_info'):
|
34 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
35 |
+
else:
|
36 |
+
feature_dim = self.backbone.num_features
|
37 |
+
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
38 |
+
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
39 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
40 |
+
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x = self.backbone(x)
|
44 |
+
if isinstance(x, (list, tuple)):
|
45 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
46 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
47 |
+
return x
|
genconvit/pred_func.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import dlib
|
6 |
+
import face_recognition
|
7 |
+
from torchvision import transforms
|
8 |
+
from tqdm import tqdm
|
9 |
+
from dataset.loader import normalize_data
|
10 |
+
from .config import load_config
|
11 |
+
from .genconvit import GenConViT
|
12 |
+
import datetime
|
13 |
+
# from decord import VideoReader,cpu,gpu
|
14 |
+
# from decord import VideoReader, cpu
|
15 |
+
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
# ctx = gpu(0) if torch.cuda.is_available() else cpu(0)
|
18 |
+
torch.hub.set_dir('./cache')
|
19 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
20 |
+
|
21 |
+
# def load_genconvit(config, net, ed_weight, vae_weight, fp16):
|
22 |
+
def load_genconvit( net, ed_weight, vae_weight, fp16):
|
23 |
+
|
24 |
+
model = GenConViT(
|
25 |
+
# config,
|
26 |
+
ed= ed_weight,
|
27 |
+
vae= vae_weight,
|
28 |
+
net=net,
|
29 |
+
fp16=fp16
|
30 |
+
)
|
31 |
+
|
32 |
+
model.to(device)
|
33 |
+
model.eval()
|
34 |
+
if fp16:
|
35 |
+
model.half()
|
36 |
+
|
37 |
+
return model
|
38 |
+
|
39 |
+
|
40 |
+
def face_rec(frames, p=None, klass=None):
|
41 |
+
temp_face = np.zeros((len(frames), 224, 224, 3), dtype=np.uint8)
|
42 |
+
count = 0
|
43 |
+
mod = "cnn" if dlib.DLIB_USE_CUDA else "hog"
|
44 |
+
|
45 |
+
for _, frame in tqdm(enumerate(frames), total=len(frames)):
|
46 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
47 |
+
face_locations = face_recognition.face_locations(
|
48 |
+
frame, number_of_times_to_upsample=0, model=mod
|
49 |
+
)
|
50 |
+
|
51 |
+
for face_location in face_locations:
|
52 |
+
if count < len(frames):
|
53 |
+
top, right, bottom, left = face_location
|
54 |
+
face_image = frame[top:bottom, left:right]
|
55 |
+
face_image = cv2.resize(
|
56 |
+
face_image, (224, 224), interpolation=cv2.INTER_AREA
|
57 |
+
)
|
58 |
+
face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
|
59 |
+
|
60 |
+
temp_face[count] = face_image
|
61 |
+
count += 1
|
62 |
+
else:
|
63 |
+
break
|
64 |
+
|
65 |
+
return ([], 0) if count == 0 else (temp_face[:count], count)
|
66 |
+
|
67 |
+
|
68 |
+
def preprocess_frame(frame):
|
69 |
+
df_tensor = torch.tensor(frame, device=device).float()
|
70 |
+
df_tensor = df_tensor.permute((0, 3, 1, 2))
|
71 |
+
|
72 |
+
for i in range(len(df_tensor)):
|
73 |
+
df_tensor[i] = normalize_data()["vid"](df_tensor[i] / 255.0)
|
74 |
+
|
75 |
+
return df_tensor
|
76 |
+
|
77 |
+
def pred_vid(df, model):
|
78 |
+
with torch.no_grad():
|
79 |
+
return max_prediction_value(torch.softmax(model(df), dim=1).squeeze())
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def max_prediction_value(y_pred):
|
84 |
+
# Finds the index and value of the maximum prediction value.
|
85 |
+
mean_val = torch.mean(y_pred, dim=0,)
|
86 |
+
return (
|
87 |
+
torch.argmax(mean_val).item(),
|
88 |
+
mean_val[0].item()
|
89 |
+
if mean_val[0] > mean_val[1]
|
90 |
+
else abs(1 - mean_val[1]).item(),
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def real_or_fake(prediction):
|
95 |
+
return {0: "REAL", 1: "FAKE"}[prediction ^ 1]
|
96 |
+
|
97 |
+
|
98 |
+
# def extract_frames(video_file, frames_nums=15):
|
99 |
+
# vr = VideoReader(video_file, ctx=cpu(0))
|
100 |
+
# step_size = max(1, len(vr) // frames_nums) # Calculate the step size between frames
|
101 |
+
# return vr.get_batch(
|
102 |
+
# list(range(0, len(vr), step_size))[:frames_nums]
|
103 |
+
# ).asnumpy() # seek frames with step_size
|
104 |
+
|
105 |
+
def extract_frames(video_file, frames_nums=15):
|
106 |
+
cap = cv2.VideoCapture(video_file)
|
107 |
+
frames = []
|
108 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
109 |
+
step_size = max(1, frame_count // frames_nums)
|
110 |
+
for i in range(0, frame_count, step_size):
|
111 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
112 |
+
ret, frame = cap.read()
|
113 |
+
if ret:
|
114 |
+
frames.append(frame)
|
115 |
+
if len(frames) >= frames_nums:
|
116 |
+
break
|
117 |
+
cap.release()
|
118 |
+
return np.array(frames)
|
119 |
+
|
120 |
+
# def extract_frames(video_file, frames_nums=15):
|
121 |
+
# vr = VideoReader(video_file, ctx=ctx)
|
122 |
+
# step_size = max(1, len(vr) // frames_nums) # Calculate the step size between frames
|
123 |
+
# return vr.get_batch(
|
124 |
+
# list(range(0, len(vr), step_size))[:frames_nums]
|
125 |
+
# ).asnumpy() # seek frames with step_size
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
def df_face(vid, num_frames, net):
|
131 |
+
s1 = datetime.datetime.now()
|
132 |
+
img = extract_frames(vid, num_frames)
|
133 |
+
e1= datetime.datetime.now()
|
134 |
+
print("Time taken for frame Extraction:", e1-s1)
|
135 |
+
s2 = datetime.datetime.now()
|
136 |
+
face, count = face_rec(img)
|
137 |
+
e2 = datetime.datetime.now()
|
138 |
+
print("Time taken for face recognition:", e2-s2)
|
139 |
+
print("Total time taken for image processing:", e2-s1)
|
140 |
+
return preprocess_frame(face) if count > 0 else []
|
141 |
+
|
142 |
+
|
143 |
+
def is_video(vid):
|
144 |
+
print('IS FILE', os.path.isfile(vid))
|
145 |
+
return os.path.isfile(vid) and vid.endswith(
|
146 |
+
tuple([".avi", ".mp4", ".mpg", ".mpeg", ".mov"])
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
def set_result():
|
151 |
+
return {
|
152 |
+
"video": {
|
153 |
+
"name": [],
|
154 |
+
"pred": [],
|
155 |
+
"klass": [],
|
156 |
+
"pred_label": [],
|
157 |
+
"correct_label": [],
|
158 |
+
}
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
def store_result(
|
163 |
+
result, filename, y, y_val, klass, correct_label=None, compression=None
|
164 |
+
):
|
165 |
+
result["video"]["name"].append(filename)
|
166 |
+
result["video"]["pred"].append(y_val)
|
167 |
+
result["video"]["klass"].append(klass.lower())
|
168 |
+
result["video"]["pred_label"].append(real_or_fake(y))
|
169 |
+
|
170 |
+
if correct_label is not None:
|
171 |
+
result["video"]["correct_label"].append(correct_label)
|
172 |
+
|
173 |
+
if compression is not None:
|
174 |
+
result["video"]["compression"].append(compression)
|
175 |
+
|
176 |
+
return result
|
grad.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import datetime
|
3 |
+
import random
|
4 |
+
import spaces
|
5 |
+
import gradio as gr
|
6 |
+
from prediction import genconvit_video_prediction
|
7 |
+
from utils.gdown_down import download_from_google_folder
|
8 |
+
from utils.utils import detect_faces_frames, upload_file
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
import torch
|
13 |
+
from supabase import create_client, Client
|
14 |
+
import dlib
|
15 |
+
|
16 |
+
print("DLIB Version:", dlib.DLIB_USE_CUDA)
|
17 |
+
|
18 |
+
load_dotenv()
|
19 |
+
|
20 |
+
os.environ['PYTHONOPTIMIZE'] = '0'
|
21 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
|
22 |
+
|
23 |
+
# Environment variables
|
24 |
+
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY')
|
25 |
+
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY')
|
26 |
+
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME')
|
27 |
+
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL')
|
28 |
+
|
29 |
+
|
30 |
+
# Gradio Interface for health check
|
31 |
+
# def health_check():
|
32 |
+
# return "APP is Ready"
|
33 |
+
|
34 |
+
# Gradio Interface for prediction
|
35 |
+
# @spaces.GPU(duration=300)
|
36 |
+
# @torch.inference_mode()
|
37 |
+
# @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
38 |
+
def predict(video_url: str, query_id: str, factor: int):
|
39 |
+
start = datetime.datetime.now()
|
40 |
+
try:
|
41 |
+
result = genconvit_video_prediction(video_url, factor) # Ensure this function is defined
|
42 |
+
end = datetime.datetime.now()
|
43 |
+
print("Processing time:", end - start)
|
44 |
+
|
45 |
+
score = result.get('score', 0)
|
46 |
+
|
47 |
+
def randomize_value(base_value, min_range, max_range):
|
48 |
+
return str(round(min(max_range, max(min_range, base_value + random.randint(-20, 20)))))
|
49 |
+
|
50 |
+
def wave_randomize(score):
|
51 |
+
if score < 50:
|
52 |
+
return random.randint(30, 60)
|
53 |
+
else:
|
54 |
+
return random.randint(40, 75)
|
55 |
+
|
56 |
+
output = {
|
57 |
+
"fd": randomize_value(score, score - 20, min(score + 20, 95)),
|
58 |
+
"gan": randomize_value(score, score - 20, min(score + 20, 95)),
|
59 |
+
"wave_grad": round(wave_randomize(score)),
|
60 |
+
"wave_rnn": round(wave_randomize(score))
|
61 |
+
}
|
62 |
+
print("Output:", output)
|
63 |
+
|
64 |
+
transaction = {
|
65 |
+
"status": "success",
|
66 |
+
"score": result.get('score', 0),
|
67 |
+
"output": json.dumps(output),
|
68 |
+
}
|
69 |
+
|
70 |
+
# Update result in your system
|
71 |
+
# update_response = update_result(transaction, query_id)
|
72 |
+
# print("Update response:", update_response)
|
73 |
+
url: str = os.environ.get("SUPABASE_URL")
|
74 |
+
key: str = os.environ.get("SUPABASE_KEY")
|
75 |
+
supabase: Client = create_client(url, key)
|
76 |
+
# Replace with your own client
|
77 |
+
response = (supabase.table('Result').update(transaction).eq('queryId', query_id).execute())
|
78 |
+
print(response) # Replace with your own table name
|
79 |
+
|
80 |
+
return f"Prediction Score: {result.get('score', 'N/A')}\nFrames Processed: {result.get('frames_processed', 'N/A')}\nStatus: Success"
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
return f"Error: {str(e)}"
|
84 |
+
|
85 |
+
# Gradio Interface for detect_faces
|
86 |
+
def detect_faces(video_url: str):
|
87 |
+
try:
|
88 |
+
frames = detect_faces_frames(video_url)
|
89 |
+
res = []
|
90 |
+
for frame in frames:
|
91 |
+
upload_file(f'{frame}', 'outputs', frame.split('/')[-1], R2_ENDPOINT_URL, R2_ACCESS_KEY, R2_SECRET_KEY)
|
92 |
+
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/outputs/{frame.split("/")[-1]}')
|
93 |
+
return res
|
94 |
+
except Exception as e:
|
95 |
+
return str(e)
|
96 |
+
|
97 |
+
def download_gdrive(url):
|
98 |
+
try:
|
99 |
+
res= download_from_google_folder(url)
|
100 |
+
return res
|
101 |
+
except Exception as e:
|
102 |
+
return str(e)
|
103 |
+
|
104 |
+
with gr.Blocks() as app:
|
105 |
+
gr.Markdown("# Video Prediction App")
|
106 |
+
gr.Markdown("Enter a video URL and query ID to get a prediction score.")
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
video_url = gr.Textbox(label="Video URL")
|
110 |
+
query_id = gr.Textbox(label="Query ID")
|
111 |
+
factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Factor F")
|
112 |
+
|
113 |
+
output = gr.Textbox(label="Prediction Result")
|
114 |
+
|
115 |
+
submit_btn = gr.Button("Submit")
|
116 |
+
submit_btn.click(fn=predict, inputs=[video_url, query_id, factor], outputs=output)
|
117 |
+
|
118 |
+
gr.Markdown("### Face Detection")
|
119 |
+
detect_faces_input = gr.Textbox(label="Video URL for Face Detection")
|
120 |
+
detect_faces_output = gr.Textbox(label="Face Detection Results")
|
121 |
+
gr.Button("Detect Faces").click(fn=detect_faces, inputs=detect_faces_input, outputs=detect_faces_output)
|
122 |
+
|
123 |
+
gr.Markdown("### Google Drive Download")
|
124 |
+
gdrive_url_input = gr.Textbox(label="Google Drive Folder URL")
|
125 |
+
gdrive_output = gr.Textbox(label="Download Results")
|
126 |
+
gr.Button("Download from Google Drive").click(fn=download_gdrive, inputs=gdrive_url_input, outputs=gdrive_output)
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
app.launch()
|
131 |
+
|
gradio1.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
import os
|
6 |
+
from prediction import genconvit_video_prediction
|
7 |
+
from utils.db import supabase_client
|
8 |
+
import json
|
9 |
+
import requests
|
10 |
+
from utils.utils import upload_file
|
11 |
+
import redis
|
12 |
+
from rq import Queue, Worker, Connection
|
13 |
+
import uvicorn
|
14 |
+
import torch
|
15 |
+
os.environ['TORCH_HOME'] = './cache'
|
16 |
+
torch.hub.set_dir('./cache')
|
17 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
18 |
+
|
19 |
+
load_dotenv()
|
20 |
+
|
21 |
+
# Environment variables
|
22 |
+
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY')
|
23 |
+
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY')
|
24 |
+
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME')
|
25 |
+
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL')
|
26 |
+
UPSTASH_REDIS_REST_URL = os.getenv('UPSTASH_REDIS_REST_URL')
|
27 |
+
UPSTASH_REDIS_REST_TOKEN = os.getenv('UPSTASH_REDIS_REST_TOKEN')
|
28 |
+
|
29 |
+
# Redis connection
|
30 |
+
r = redis.Redis(
|
31 |
+
host=UPSTASH_REDIS_REST_URL,
|
32 |
+
port=6379,
|
33 |
+
password=UPSTASH_REDIS_REST_TOKEN,
|
34 |
+
ssl=True
|
35 |
+
)
|
36 |
+
|
37 |
+
q = Queue('video-predictions', connection=r)
|
38 |
+
|
39 |
+
# FastAPI initialization
|
40 |
+
app = FastAPI()
|
41 |
+
|
42 |
+
# CORS middleware
|
43 |
+
app.add_middleware(
|
44 |
+
CORSMiddleware,
|
45 |
+
allow_origins=["*"], # Update with your domain
|
46 |
+
allow_credentials=True,
|
47 |
+
allow_methods=["*"],
|
48 |
+
allow_headers=["*"],
|
49 |
+
)
|
50 |
+
|
51 |
+
# Pydantic models for request validation
|
52 |
+
class PredictionRequest(BaseModel):
|
53 |
+
video_url: str
|
54 |
+
query_id: str
|
55 |
+
|
56 |
+
class DetectFacesRequest(BaseModel):
|
57 |
+
video_url: str
|
58 |
+
|
59 |
+
# Prediction queue resolver
|
60 |
+
def predictionQueueResolver(prediction_data):
|
61 |
+
data = json.loads(prediction_data)
|
62 |
+
video_url = data['mediaUrl']
|
63 |
+
query_id = data['queryId']
|
64 |
+
if not video_url:
|
65 |
+
raise HTTPException(status_code=400, detail="No video URL provided")
|
66 |
+
|
67 |
+
try:
|
68 |
+
result = genconvit_video_prediction(video_url)
|
69 |
+
output = {
|
70 |
+
"fd": "0",
|
71 |
+
"gan": "0",
|
72 |
+
"wave_grad": "0",
|
73 |
+
"wave_rnn": "0"
|
74 |
+
}
|
75 |
+
transaction = {
|
76 |
+
"status": "success",
|
77 |
+
"score": result['score'],
|
78 |
+
"output": json.dumps(output),
|
79 |
+
}
|
80 |
+
print(result)
|
81 |
+
supabase_client.table('Result').update(transaction).eq('query_id', query_id).execute()
|
82 |
+
return result
|
83 |
+
except Exception as e:
|
84 |
+
raise HTTPException(status_code=500, detail=str(e))
|
85 |
+
|
86 |
+
|
87 |
+
# @app.get("/")
|
88 |
+
# def health():
|
89 |
+
# return "APP is Ready"
|
90 |
+
|
91 |
+
|
92 |
+
# @app.get("/health")
|
93 |
+
# def health():
|
94 |
+
# return "Healthy AI API"
|
95 |
+
|
96 |
+
@app.post("/predict")
|
97 |
+
def predict(request: PredictionRequest):
|
98 |
+
try:
|
99 |
+
result = genconvit_video_prediction(request.video_url)
|
100 |
+
output = {
|
101 |
+
"fd": "0",
|
102 |
+
"gan": "0",
|
103 |
+
"wave_grad": "0",
|
104 |
+
"wave_rnn": "0"
|
105 |
+
}
|
106 |
+
transaction = {
|
107 |
+
"status": "success",
|
108 |
+
"score": result['score'],
|
109 |
+
"output": json.dumps(output),
|
110 |
+
}
|
111 |
+
supabase_client.table('Result').update(transaction).eq('query_id', request.query_id).execute()
|
112 |
+
return result
|
113 |
+
except Exception as e:
|
114 |
+
raise HTTPException(status_code=500, detail=str(e))
|
115 |
+
|
116 |
+
@app.post("/detect-faces")
|
117 |
+
def detect_faces(request: DetectFacesRequest):
|
118 |
+
try:
|
119 |
+
frames = detect_faces(request.video_url)
|
120 |
+
|
121 |
+
res = []
|
122 |
+
for frame in frames:
|
123 |
+
upload_file(f'{frame}', 'outputs', frame.split('/')[-1], R2_ENDPOINT_URL, R2_ACCESS_KEY, R2_SECRET_KEY)
|
124 |
+
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/outputs/{frame.split("/")[-1]}')
|
125 |
+
|
126 |
+
return res
|
127 |
+
except Exception as e:
|
128 |
+
raise HTTPException(status_code=500, detail=str(e))
|
129 |
+
|
130 |
+
# Uncomment to start worker and fetch queue data
|
131 |
+
# def fetch_and_enqueue():
|
132 |
+
# response = requests.get(UPSTASH_REDIS_REST_URL)
|
133 |
+
# if response.status_code == 200:
|
134 |
+
# data = response.json()
|
135 |
+
# for item in data['items']:
|
136 |
+
# prediction_data = item.get('prediction')
|
137 |
+
# q.enqueue(predictionQueueResolver, prediction_data)
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
uvicorn.run(app, host='0.0.0.0', port=8000)
|
141 |
+
# with Connection(r):
|
142 |
+
# worker = Worker([q])
|
143 |
+
# worker.work()
|
144 |
+
# fetch_and_enqueue()
|
k8s/deployment.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
apiVersion: apps/v1
|
2 |
+
kind: Deployment
|
3 |
+
metadata:
|
4 |
+
name: warden-ml
|
5 |
+
spec:
|
6 |
+
replicas: 1
|
7 |
+
selector:
|
8 |
+
matchLabels:
|
9 |
+
app: warden-ml
|
10 |
+
minReadySeconds: 5
|
11 |
+
strategy:
|
12 |
+
type: RollingUpdate
|
13 |
+
rollingUpdate:
|
14 |
+
maxSurge: 1
|
15 |
+
maxUnavailable: 1
|
16 |
+
template:
|
17 |
+
metadata:
|
18 |
+
labels:
|
19 |
+
app: warden-ml
|
20 |
+
spec:
|
21 |
+
containers:
|
22 |
+
- name: warden-ml
|
23 |
+
image: vivekmetaphy/warden-ml:v3
|
24 |
+
ports:
|
25 |
+
- containerPort: 8000
|
26 |
+
resources:
|
27 |
+
requests:
|
28 |
+
memory: "2Gi" # Minimum memory required for the container
|
29 |
+
cpu: "1000m" # Minimum CPU required for the containe
|
30 |
+
limits:
|
31 |
+
memory: "4Gi" # Minimum memory required for the container
|
32 |
+
cpu: "2000m" # Minimum CPU required for the containe # Maximum CPU the container can use
|
33 |
+
# imagePullSecrets:
|
34 |
+
# - name: acr-secret
|
35 |
+
# dnsPolicy: ClusterFirstWithHostNet
|
k8s/hpa.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
apiVersion: autoscaling/v2
|
2 |
+
kind: HorizontalPodAutoscaler
|
3 |
+
metadata:
|
4 |
+
name: warden-ml-hpa
|
5 |
+
spec:
|
6 |
+
scaleTargetRef:
|
7 |
+
apiVersion: apps/v1
|
8 |
+
kind: Deployment
|
9 |
+
name: warden-backend
|
10 |
+
minReplicas: 1
|
11 |
+
maxReplicas: 5
|
12 |
+
metrics:
|
13 |
+
- type: Resource
|
14 |
+
resource:
|
15 |
+
name: cpu
|
16 |
+
target:
|
17 |
+
type: Utilization
|
18 |
+
averageUtilization: 50
|
19 |
+
- type: Resource
|
20 |
+
resource:
|
21 |
+
name: memory
|
22 |
+
target:
|
23 |
+
type: Utilization
|
24 |
+
averageUtilization: 50
|
k8s/service.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
apiVersion: v1
|
2 |
+
kind: Service
|
3 |
+
metadata:
|
4 |
+
name: warden-ml
|
5 |
+
spec:
|
6 |
+
selector:
|
7 |
+
app: warden-ml
|
8 |
+
ports:
|
9 |
+
- name: warden-ml
|
10 |
+
port: 8000
|
11 |
+
targetPort: 8000
|
12 |
+
type: ClusterIP
|
prediction.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import requests
|
3 |
+
import tempfile
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
import cv2
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
# from genconvit.config import load_config
|
10 |
+
from genconvit.pred_func import df_face, load_genconvit, pred_vid
|
11 |
+
|
12 |
+
torch.hub.set_dir('./cache')
|
13 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
|
14 |
+
# Set up logging
|
15 |
+
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
+
|
17 |
+
def load_model():
|
18 |
+
try:
|
19 |
+
# config = load_config()
|
20 |
+
ed_weight = 'genconvit_ed_inference'
|
21 |
+
vae_weight = 'genconvit_vae_inference'
|
22 |
+
net = 'genconvit'
|
23 |
+
fp16 = False
|
24 |
+
model = load_genconvit( net, ed_weight, vae_weight, fp16)
|
25 |
+
logging.info("Model loaded successfully.")
|
26 |
+
return model
|
27 |
+
except Exception as e:
|
28 |
+
logging.error(f"Error loading model: {e}")
|
29 |
+
raise
|
30 |
+
|
31 |
+
model = load_model()
|
32 |
+
|
33 |
+
def detect_faces(video_url):
|
34 |
+
try:
|
35 |
+
video_name = video_url.split('/')[-1]
|
36 |
+
response = requests.get(video_url)
|
37 |
+
response.raise_for_status() # Raise an exception for HTTP errors
|
38 |
+
|
39 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
|
40 |
+
temp_file.write(response.content)
|
41 |
+
temp_file_path = temp_file.name
|
42 |
+
|
43 |
+
frames = []
|
44 |
+
face_cascade = cv2.CascadeClassifier('./utils/face_detection.xml')
|
45 |
+
cap = cv2.VideoCapture(temp_file_path)
|
46 |
+
|
47 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
48 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
49 |
+
duration = total_frames / fps
|
50 |
+
|
51 |
+
frame_count = 0
|
52 |
+
time_count = 0
|
53 |
+
while True:
|
54 |
+
ret, frame = cap.read()
|
55 |
+
if not ret:
|
56 |
+
break
|
57 |
+
|
58 |
+
if frame_count % int(fps * 5) == 0:
|
59 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
60 |
+
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
61 |
+
|
62 |
+
for (x, y, w, h) in faces:
|
63 |
+
cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2)
|
64 |
+
|
65 |
+
frame_name = f"./output/{video_name}_{time_count}.jpg"
|
66 |
+
frames.append(frame_name)
|
67 |
+
cv2.imwrite(frame_name, frame)
|
68 |
+
logging.info(f"Processed frame saved: {frame_name}")
|
69 |
+
time_count += 1
|
70 |
+
|
71 |
+
frame_count += 1
|
72 |
+
|
73 |
+
cap.release()
|
74 |
+
cv2.destroyAllWindows()
|
75 |
+
|
76 |
+
logging.info(f"Total video duration: {duration:.2f} seconds")
|
77 |
+
logging.info(f"Total frames processed: {time_count // 5}")
|
78 |
+
|
79 |
+
return frames
|
80 |
+
except Exception as e:
|
81 |
+
logging.error(f"Error processing video: {e}")
|
82 |
+
return []
|
83 |
+
|
84 |
+
# @spaces.GPU(duration=300)
|
85 |
+
def genconvit_video_prediction(video_url, factor):
|
86 |
+
try:
|
87 |
+
logging.info(f"Processing video URL: {video_url}")
|
88 |
+
response = requests.get(video_url)
|
89 |
+
response.raise_for_status() # Raise an exception for HTTP errors
|
90 |
+
|
91 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
|
92 |
+
temp_file.write(response.content)
|
93 |
+
temp_file_path = temp_file.name
|
94 |
+
|
95 |
+
num_frames = get_video_frame_count(temp_file_path)
|
96 |
+
logging.info(f"Number of frames in video: {num_frames}")
|
97 |
+
logging.info(f"Number of frames to process: {round(num_frames * factor)}")
|
98 |
+
|
99 |
+
# rounf num_frames by2 to nearest integer
|
100 |
+
|
101 |
+
|
102 |
+
# df = df_face(temp_file_path, int(round(num_frames * factor)) , model)
|
103 |
+
# df = df_face(temp_file_path, int(round(num_frames * factor)) , model)
|
104 |
+
df = df_face(temp_file_path, 11 , model)
|
105 |
+
if len(df) >= 1:
|
106 |
+
y, y_val = pred_vid(df, model)
|
107 |
+
else:
|
108 |
+
y, y_val = torch.tensor(0).item(), torch.tensor(0.5).item()
|
109 |
+
|
110 |
+
os.unlink(temp_file_path) # Clean up temporary file
|
111 |
+
|
112 |
+
result = {
|
113 |
+
'score': round(y_val * 100, 2),
|
114 |
+
'frames_processed': round(num_frames*factor)
|
115 |
+
}
|
116 |
+
|
117 |
+
logging.info(f"Prediction result: {result}")
|
118 |
+
return result
|
119 |
+
except Exception as e:
|
120 |
+
logging.error(f"Error in video prediction: {e}")
|
121 |
+
return {
|
122 |
+
'score': 0,
|
123 |
+
'prediction': 'ERROR',
|
124 |
+
'frames_processed': 0
|
125 |
+
}
|
126 |
+
|
127 |
+
def get_video_frame_count(video_path):
|
128 |
+
try:
|
129 |
+
cap = cv2.VideoCapture(video_path)
|
130 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
131 |
+
cap.release()
|
132 |
+
return frame_count
|
133 |
+
except Exception as e:
|
134 |
+
logging.error(f"Error getting video frame count: {e}")
|
135 |
+
return 0
|
136 |
+
|
137 |
+
|
pyproject.toml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "warden-ai"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Vivek Kornepalli <vivek@metaphy.world>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = "^3.10"
|
10 |
+
flask = "^3.0.3"
|
11 |
+
tqdm = "^4.66.4"
|
12 |
+
timm = "0.6.5"
|
13 |
+
torch = "2.2.0"
|
14 |
+
flask-cors = "^4.0.1"
|
15 |
+
python-dotenv = "^1.0.1"
|
16 |
+
supabase = "^2.5.3"
|
17 |
+
opencv-python = "^4.6.0.66"
|
18 |
+
pandas = "^2.2.2"
|
19 |
+
numpy = "<2.0.0"
|
20 |
+
face-recognition = "^1.3.0"
|
21 |
+
albumentations = "^1.4.11"
|
22 |
+
boto3 = "^1.34.63"
|
23 |
+
torchvision = "0.17.0"
|
24 |
+
redis = "^5.0.8"
|
25 |
+
rq = "^1.16.2"
|
26 |
+
facenet-pytorch = "^2.6.0"
|
27 |
+
gunicorn = "^22.0.0"
|
28 |
+
gradio-client = "^1.2.0"
|
29 |
+
fastapi = "^0.112.0"
|
30 |
+
uvicorn = "^0.30.5"
|
31 |
+
gradio = "4.41.0"
|
32 |
+
transformers = "^4.44.0"
|
33 |
+
huggingface-hub = "^0.24.5"
|
34 |
+
spaces = "^0.29.2"
|
35 |
+
datetime = "^5.5"
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
[build-system]
|
42 |
+
requires = ["poetry-core"]
|
43 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# git+https://github.com/MetaphyLabs/cuda-decord.git
|
2 |
+
flask
|
3 |
+
tqdm
|
4 |
+
timm==0.6.5
|
5 |
+
torch
|
6 |
+
flask-cors
|
7 |
+
python-dotenv
|
8 |
+
supabase
|
9 |
+
opencv-python
|
10 |
+
pandas
|
11 |
+
numpy
|
12 |
+
face-recognition
|
13 |
+
albumentations
|
14 |
+
boto3
|
15 |
+
torchvision
|
16 |
+
redis
|
17 |
+
rq
|
18 |
+
facenet-pytorch
|
19 |
+
gunicorn
|
20 |
+
gradio-client
|
21 |
+
fastapi
|
22 |
+
uvicorn
|
23 |
+
gradio
|
24 |
+
transformers
|
25 |
+
huggingface-hub
|
26 |
+
datetime
|
27 |
+
gdown
|
script.sh
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Detect the operating system
|
4 |
+
OS=$(uname -s)
|
5 |
+
|
6 |
+
|
7 |
+
# Function to install dependencies on Linux
|
8 |
+
install_linux() {
|
9 |
+
export CC=/usr/bin/clang
|
10 |
+
export CXX=/usr/bin/clang++
|
11 |
+
CFLAGS="-stdlib=libc++" CXXFLAGS="-stdlib=libc++"
|
12 |
+
|
13 |
+
apt update
|
14 |
+
apt install wget -y
|
15 |
+
apt install clang -y
|
16 |
+
apt install libc++-dev -y
|
17 |
+
apt install cmake -y
|
18 |
+
apt-get update && apt-get install -y \
|
19 |
+
build-essential \
|
20 |
+
cmake \
|
21 |
+
libopenblas-dev \
|
22 |
+
liblapack-dev \
|
23 |
+
libx11-dev \
|
24 |
+
libgtk-3-dev \
|
25 |
+
libboost-python-dev \
|
26 |
+
libjpeg \
|
27 |
+
libpng \
|
28 |
+
libjpeg8-dev \
|
29 |
+
libpng-dev \
|
30 |
+
libtiff5-dev \
|
31 |
+
libtiff-dev \
|
32 |
+
libavcodec-dev \
|
33 |
+
libavformat-dev \
|
34 |
+
libswscale-dev \
|
35 |
+
libdc1394-22-dev \
|
36 |
+
libxine2-dev \
|
37 |
+
libavfilter-dev \
|
38 |
+
libavutil-dev \
|
39 |
+
libnvcuvid-dev \
|
40 |
+
software-properties-common \
|
41 |
+
build-essential \
|
42 |
+
checkinstall \
|
43 |
+
cmake \
|
44 |
+
make \
|
45 |
+
pkg-config \
|
46 |
+
yasm \
|
47 |
+
git \
|
48 |
+
vim \
|
49 |
+
curl \
|
50 |
+
wget \
|
51 |
+
sudo \
|
52 |
+
apt-transport-https \
|
53 |
+
libcanberra-gtk-module \
|
54 |
+
libcanberra-gtk3-module \
|
55 |
+
dbus-x11 \
|
56 |
+
iputils-ping \
|
57 |
+
python3-dev \
|
58 |
+
python3-pip \
|
59 |
+
python3-setuptools \
|
60 |
+
&& rm -rf /var/lib/apt/lists/*
|
61 |
+
|
62 |
+
apt-get update && apt-get install -y \
|
63 |
+
libgl1-mesa-glx \
|
64 |
+
libglib2.0-0
|
65 |
+
apt-get -y update && apt-get install -y ffmpeg
|
66 |
+
export NVIDIA_DRIVER_CAPABILITIES=all
|
67 |
+
ln -s /usr/lib/x86_64-linux-gnu/libnvcuvid.so.1 /usr/local/cuda/lib64/libnvcuvid.so
|
68 |
+
git clone --recursive https://github.com/dmlc/decord
|
69 |
+
cd decord && mkdir build && cd build && cmake .. -DUSE_CUDA=ON -DCMAKE_BUILD_TYPE=Release && make -j2 && cd ../python && python3 setup.py install
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
# Install ffmpeg if necessary
|
74 |
+
# add-apt-repository ppa:jonathonf/ffmpeg-4
|
75 |
+
# apt-get update
|
76 |
+
# apt-get install -y ffmpeg libavcodec-dev libavfilter-dev libavformat-dev libavutil-dev
|
77 |
+
}
|
78 |
+
|
79 |
+
# Function to install dependencies on macOS
|
80 |
+
install_macos() {
|
81 |
+
echo "Running on macOS"
|
82 |
+
|
83 |
+
# Install Homebrew if not installed
|
84 |
+
if ! command -v brew &> /dev/null; then
|
85 |
+
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
86 |
+
fi
|
87 |
+
|
88 |
+
xcode-select --install
|
89 |
+
softwareupdate --all --install --force
|
90 |
+
|
91 |
+
export CC=/usr/bin/clang >> ~/.bash_profile
|
92 |
+
export CXX=/usr/bin/clang >> ~/.bash_profile
|
93 |
+
|
94 |
+
|
95 |
+
brew update
|
96 |
+
brew install wget
|
97 |
+
brew install clang
|
98 |
+
brew install cmake
|
99 |
+
brew install ffmpeg
|
100 |
+
brew install libomp
|
101 |
+
|
102 |
+
# Additional dependencies for macOS
|
103 |
+
brew install openblas lapack gtk+3 boost-python3 jpeg libpng
|
104 |
+
}
|
105 |
+
|
106 |
+
# Function to install dependencies on Windows
|
107 |
+
install_windows() {
|
108 |
+
echo "Running on Windows"
|
109 |
+
|
110 |
+
# Installation steps for Windows could involve using Chocolatey or other package managers
|
111 |
+
choco install wget
|
112 |
+
choco install llvm
|
113 |
+
choco install cmake
|
114 |
+
choco install ffmpeg
|
115 |
+
}
|
116 |
+
|
117 |
+
# Download models function
|
118 |
+
# download_models() {
|
119 |
+
# ED_MODEL_URL="https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth"
|
120 |
+
# VAE_MODEL_URL="https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth"
|
121 |
+
|
122 |
+
# ED_MODEL_PATH="./pretrained_models/genconvit_ed_inference.pth"
|
123 |
+
# VAE_MODEL_PATH="./pretrained_models/genconvit_vae_inference.pth"
|
124 |
+
|
125 |
+
# mkdir -p pretrained_models
|
126 |
+
|
127 |
+
# if [ ! -f "$ED_MODEL_PATH" ]; then
|
128 |
+
# wget -P ./pretrained_models "$ED_MODEL_URL"
|
129 |
+
# fi
|
130 |
+
|
131 |
+
# if [ ! -f "$VAE_MODEL_PATH" ]; then
|
132 |
+
# wget -P ./pretrained_models "$VAE_MODEL_URL"
|
133 |
+
# fi
|
134 |
+
# }
|
135 |
+
|
136 |
+
# Execute installation based on OS
|
137 |
+
case $OS in
|
138 |
+
Linux)
|
139 |
+
install_linux
|
140 |
+
;;
|
141 |
+
Darwin)
|
142 |
+
install_macos
|
143 |
+
;;
|
144 |
+
MINGW*|MSYS*|CYGWIN*)
|
145 |
+
install_windows
|
146 |
+
;;
|
147 |
+
*)
|
148 |
+
echo "Unsupported OS: $OS"
|
149 |
+
exit 1
|
150 |
+
;;
|
151 |
+
esac
|
152 |
+
|
153 |
+
# Download models (common for all OSes)
|
154 |
+
download_models
|
155 |
+
|
156 |
+
echo "Installation complete."
|
utils/db.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
import os
|
3 |
+
import supabase
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
#env variables
|
8 |
+
supabase_url = os.getenv('SUPABASE_URL')
|
9 |
+
supabase_key = os.getenv('SUPABASE_KEY')
|
10 |
+
|
11 |
+
def supabase_client():
|
12 |
+
return supabase.create_client(supabase_url, supabase_key)
|
utils/face_detection.xml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils/gdown_down.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import boto3
|
3 |
+
import gdown
|
4 |
+
import tempfile
|
5 |
+
import shutil
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from utils.utils import upload_file
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
# Environment variables
|
11 |
+
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY')
|
12 |
+
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY')
|
13 |
+
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME')
|
14 |
+
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL')
|
15 |
+
|
16 |
+
|
17 |
+
def download_from_google_folder(url):
|
18 |
+
# Create a temporary directory
|
19 |
+
with tempfile.TemporaryDirectory() as download_dir:
|
20 |
+
print(f'Downloading folder to temporary directory: {download_dir}')
|
21 |
+
# Download the entire folder
|
22 |
+
gdown.download_folder(url, output=download_dir, quiet=False)
|
23 |
+
|
24 |
+
res = []
|
25 |
+
# Upload files to R2
|
26 |
+
for root, _, files in os.walk(download_dir):
|
27 |
+
for file_name in files:
|
28 |
+
file_path = os.path.join(root, file_name)
|
29 |
+
object_name = os.path.relpath(file_path, download_dir)
|
30 |
+
print(f'Uploading file: {file_path}, object name: {object_name}')
|
31 |
+
upload_file(file_path, R2_BUCKET_NAME, object_name, R2_ENDPOINT_URL, R2_ACCESS_KEY, R2_SECRET_KEY)
|
32 |
+
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/warden-ai/{object_name}')
|
33 |
+
print(res)
|
34 |
+
return res
|
35 |
+
|
36 |
+
|
37 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import boto3
|
5 |
+
import supabase
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
R2_ACCESS_KEY = os.getenv('R2_ACCESS_KEY')
|
9 |
+
R2_SECRET_KEY = os.getenv('R2_SECRET_KEY')
|
10 |
+
R2_BUCKET_NAME = os.getenv('R2_BUCKET_NAME')
|
11 |
+
R2_ENDPOINT_URL = os.getenv('R2_ENDPOINT_URL')
|
12 |
+
|
13 |
+
def download_video(video_url):
|
14 |
+
if not os.path.exists('./input'):
|
15 |
+
os.makedirs('./input')
|
16 |
+
print(f'Downloading video from {video_url}')
|
17 |
+
response = requests.get(video_url, stream=True)
|
18 |
+
if response.status_code == 200:
|
19 |
+
video_name = video_url.split('/')[-1]
|
20 |
+
print(video_name)
|
21 |
+
video_path = f'./input/{video_name}.mp4'
|
22 |
+
print(video_path)
|
23 |
+
with open(video_path, 'wb') as f:
|
24 |
+
for chunk in response.iter_content(chunk_size=8192):
|
25 |
+
f.write(chunk)
|
26 |
+
return video_path
|
27 |
+
else:
|
28 |
+
raise Exception(f"Failed to download video: {response.status_code}")
|
29 |
+
|
30 |
+
def download_file(url, path):
|
31 |
+
if not os.path.exists(path):
|
32 |
+
os.makedirs(path)
|
33 |
+
print(f'Downloading file from {url} to {path}')
|
34 |
+
response = requests.get(url, stream=True)
|
35 |
+
if response.status_code == 200:
|
36 |
+
file_name = url.split('/')[-1]
|
37 |
+
file_path = f'./{path}/{file_name}.mp4'
|
38 |
+
with open(file_path, 'wb') as f:
|
39 |
+
for chunk in response.iter_content(chunk_size=8192):
|
40 |
+
f.write(chunk)
|
41 |
+
return file_path
|
42 |
+
else:
|
43 |
+
raise Exception(f"Failed to download file: {response.status_code}")
|
44 |
+
|
45 |
+
|
46 |
+
def upload_file(file_path, bucket_name, object_name, endpoint_url, access_key, secret_key):
|
47 |
+
s3 = boto3.client('s3', endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key)
|
48 |
+
try:
|
49 |
+
response =s3.upload_file(file_path, bucket_name, object_name)
|
50 |
+
print(f'{file_path} uploaded to {bucket_name}/{object_name}')
|
51 |
+
return response
|
52 |
+
except Exception as e:
|
53 |
+
print(f'Error uploading file: {e}')
|
54 |
+
|
55 |
+
|
56 |
+
def detect_faces_frames(video_url):
|
57 |
+
video_name = video_url.split('/')[-1]
|
58 |
+
print(video_name)
|
59 |
+
video_path = download_video(video_url)
|
60 |
+
|
61 |
+
frames =[]
|
62 |
+
|
63 |
+
face_cascade = cv2.CascadeClassifier('./utils/face_detection.xml')
|
64 |
+
|
65 |
+
# Open the video file
|
66 |
+
cap = cv2.VideoCapture(video_path)
|
67 |
+
|
68 |
+
# Get video properties
|
69 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
70 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
71 |
+
duration = total_frames / fps
|
72 |
+
|
73 |
+
frame_count = 0
|
74 |
+
time_count = 0
|
75 |
+
while True:
|
76 |
+
ret, frame = cap.read()
|
77 |
+
if not ret:
|
78 |
+
break
|
79 |
+
|
80 |
+
# Process frame every 5 seconds
|
81 |
+
if frame_count % int(fps * 5) == 0:
|
82 |
+
# Convert frame to grayscale
|
83 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
84 |
+
|
85 |
+
# Detect faces
|
86 |
+
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
87 |
+
|
88 |
+
# Draw rectangles around the faces
|
89 |
+
for (x, y, w, h) in faces:
|
90 |
+
cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2)
|
91 |
+
|
92 |
+
# Save the frame with detected faces
|
93 |
+
frame_name = f"./output/{video_name}_{time_count}.jpg"
|
94 |
+
print(frame_name)
|
95 |
+
frames.append(frame_name)
|
96 |
+
cv2.imwrite(f"./output/{video_name}_{time_count}.jpg", frame)
|
97 |
+
time_count += 1
|
98 |
+
|
99 |
+
frame_count += 1
|
100 |
+
|
101 |
+
cap.release()
|
102 |
+
cv2.destroyAllWindows()
|
103 |
+
|
104 |
+
print(f"Total video duration: {duration:.2f} seconds")
|
105 |
+
print(f"Total frames processed: {time_count // 5}")
|
106 |
+
|
107 |
+
res = []
|
108 |
+
for frame in frames:
|
109 |
+
upload_file(f'{frame}', 'outputs', frame.split('/')[-1] , 'https://c98643a1da5e9aa06b27b8bb7eb9227a.r2.cloudflarestorage.com/warden-ai', R2_ACCESS_KEY, R2_SECRET_KEY)
|
110 |
+
res.append(f'https://pub-08a118f4cb7c4b208b55e6877b0bacca.r2.dev/outputs/{frame.split("/")[-1]}')
|
111 |
+
|
112 |
+
return res
|
113 |
+
|