Qwen2.5-vl源码部署 #
使用vLLM部署Qwen2.5-VL-7B-Instruct模型的详细指南
阿里最新开源模型Qwen2.5-VL本地部署教程:视觉理解超越GPT-4o!
cudn加速库 #
cudnn-windows-x86_64-8.9.7.29_cuda12-archive
仓库地址 #
https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file
模型地址 #
https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct-AWQ
环境对比 #
(base) PS C:\Users\admin\python\qwen\Qwen2.5-VL> pip list
Package Version
--------------------------------- ------------------
aiobotocore 2.12.3
aiohappyeyeballs 2.4.0
aiohttp 3.10.5
aioitertools 0.7.1
aiosignal 1.2.0
alabaster 0.7.16
altair 5.0.1
anaconda-anon-usage 0.4.4
anaconda-catalogs 0.2.0
anaconda-client 1.12.3
anaconda-cloud-auth 0.5.1
anaconda-navigator 2.6.3
anaconda-project 0.11.1
annotated-types 0.6.0
anyio 4.2.0
appdirs 1.4.4
archspec 0.2.3
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
arrow 1.2.3
astroid 2.14.2
astropy 6.1.3
astropy-iers-data 0.2024.9.2.0.33.23
asttokens 2.0.5
async-lru 2.0.4
atomicwrites 1.4.0
attrs 23.1.0
Automat 20.2.0
autopep8 2.0.4
Babel 2.11.0
bcrypt 3.2.0
beautifulsoup4 4.12.3
binaryornot 0.4.4
black 24.8.0
bleach 4.1.0
blinker 1.6.2
bokeh 3.6.0
boltons 23.0.0
botocore 1.34.69
Bottleneck 1.3.7
Brotli 1.0.9
cachetools 5.3.3
certifi 2024.8.30
cffi 1.17.1
chardet 4.0.0
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
colorama 0.4.6
colorcet 3.1.0
comm 0.2.1
conda 24.9.2
conda-build 24.9.0
conda-content-trust 0.2.0
conda_index 0.5.0
conda-libmamba-solver 24.9.0
conda-pack 0.7.1
conda-package-handling 2.3.0
conda_package_streaming 0.10.0
conda-repo-cli 1.0.114
conda-token 0.5.0+1.g2209e04
constantly 23.10.4
contourpy 1.2.0
cookiecutter 2.6.0
cryptography 43.0.0
cssselect 1.2.0
cycler 0.11.0
cytoolz 0.12.2
dask 2024.8.2
dask-expr 1.1.13
datashader 0.16.3
debugpy 1.6.7
decorator 5.1.1
defusedxml 0.7.1
diff-match-patch 20200713
dill 0.3.8
distributed 2024.8.2
distro 1.9.0
docstring-to-markdown 0.11
docutils 0.18.1
et-xmlfile 1.1.0
executing 0.8.3
fastapi 0.115.5
fastjsonschema 2.16.2
filelock 3.13.1
flake8 7.0.0
Flask 3.0.3
fonttools 4.51.0
frozendict 2.4.2
frozenlist 1.4.0
fsspec 2024.6.1
gensim 4.3.3
gitdb 4.0.7
GitPython 3.1.43
greenlet 3.0.1
h11 0.14.0
h5py 3.11.0
HeapDict 1.0.1
holoviews 1.19.1
httpcore 1.0.2
httpx 0.27.0
huggingface-hub 0.30.1
hvplot 0.11.0
hyperlink 21.0.0
idna 3.7
imagecodecs 2023.1.23
imageio 2.33.1
imagesize 1.4.1
imbalanced-learn 0.12.3
importlib-metadata 7.0.1
incremental 22.10.0
inflection 0.5.1
iniconfig 1.1.1
intake 2.0.7
intervaltree 3.1.0
ipykernel 6.28.0
ipython 8.27.0
ipython-genutils 0.2.0
ipywidgets 7.8.1
isort 5.13.2
itemadapter 0.3.0
itemloaders 1.1.0
itsdangerous 2.2.0
jaraco.classes 3.2.1
jedi 0.19.1
jellyfish 1.0.1
Jinja2 3.1.4
jmespath 1.0.1
joblib 1.4.2
json5 0.9.6
jsonpatch 1.33
jsonpointer 2.1
jsonschema 4.23.0
jsonschema-specifications 2023.7.1
jupyter 1.0.0
jupyter_client 8.6.0
jupyter-console 6.6.3
jupyter_core 5.7.2
jupyter-events 0.10.0
jupyter-lsp 2.2.0
jupyter_server 2.14.1
jupyter_server_terminals 0.4.4
jupyterlab 4.2.5
jupyterlab-pygments 0.1.2
jupyterlab_server 2.27.3
jupyterlab-widgets 1.0.0
keyring 24.3.1
kiwisolver 1.4.4
lazy_loader 0.4
lazy-object-proxy 1.10.0
lckr_jupyterlab_variableinspector 3.1.0
libarchive-c 5.1
libmambapy 1.5.8
linkify-it-py 2.0.0
llvmlite 0.43.0
lmdb 1.4.1
locket 1.0.0
lxml 5.2.1
lz4 4.3.2
Markdown 3.4.1
markdown-it-py 2.2.0
MarkupSafe 2.1.3
matplotlib 3.9.2
matplotlib-inline 0.1.6
mccabe 0.7.0
mdit-py-plugins 0.3.0
mdurl 0.1.0
menuinst 2.1.2
mistune 2.0.4
mkl_fft 1.3.10
mkl_random 1.2.7
mkl-service 2.4.0
more-itertools 10.3.0
mpmath 1.3.0
msgpack 1.0.3
multidict 6.0.4
multipledispatch 0.6.0
mypy 1.11.2
mypy-extensions 1.0.0
navigator-updater 0.5.1
nbclient 0.8.0
nbconvert 7.16.4
nbformat 5.10.4
nest-asyncio 1.6.0
networkx 3.3
nltk 3.9.1
notebook 7.2.2
notebook_shim 0.2.3
numba 0.60.0
numexpr 2.8.7
numpy 1.26.4
numpydoc 1.7.0
openpyxl 3.1.5
overrides 7.4.0
packaging 24.1
pandas 2.2.2
pandocfilters 1.5.0
panel 1.5.2
param 2.1.1
paramiko 2.8.1
parsel 1.8.1
parso 0.8.3
partd 1.4.1
pathspec 0.10.3
patsy 0.5.6
pexpect 4.8.0
pickleshare 0.7.5
pillow 10.4.0
pip 24.2
pkce 1.0.3
pkginfo 1.10.0
platformdirs 3.10.0
plotly 5.24.1
pluggy 1.0.0
ply 3.11
prometheus-client 0.14.1
prompt-toolkit 3.0.43
Protego 0.1.16
protobuf 4.25.3
psutil 5.9.0
ptyprocess 0.7.0
pure-eval 0.2.2
py-cpuinfo 9.0.0
pyarrow 16.1.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.11.1
pycosat 0.6.6
pycparser 2.21
pyct 0.5.0
pycurl 7.45.3
pydantic 2.8.2
pydantic_core 2.20.1
pydeck 0.8.0
PyDispatcher 2.0.5
pydocstyle 6.3.0
pyerfa 2.0.1.4
pyflakes 3.2.0
Pygments 2.15.1
PyJWT 2.8.0
pylint 2.16.2
pylint-venv 3.0.3
pyls-spyder 0.4.0
PyNaCl 1.5.0
pyodbc 5.1.0
pyOpenSSL 24.2.1
pyparsing 3.1.2
PyQt5 5.15.10
PyQt5-sip 12.13.0
PyQtWebEngine 5.15.6
PySocks 1.7.1
pytest 7.4.4
python-dateutil 2.9.0.post0
python-dotenv 0.21.0
python-json-logger 2.0.7
python-lsp-black 2.0.0
python-lsp-jsonrpc 1.1.2
python-lsp-server 1.10.0
python-slugify 5.0.2
pytoolconfig 1.2.6
pytz 2024.1
pyviz_comms 3.0.2
PyWavelets 1.7.0
pywin32 305.1
pywin32-ctypes 0.2.2
pywinpty 2.0.10
PyYAML 6.0.1
pyzmq 25.1.2
QDarkStyle 3.2.3
qstylizer 0.2.2
QtAwesome 1.3.1
qtconsole 5.5.1
QtPy 2.4.1
queuelib 1.6.2
referencing 0.30.2
regex 2024.9.11
requests 2.32.3
requests-file 1.5.1
requests-toolbelt 1.0.0
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.7.1
rope 1.12.0
rpds-py 0.10.6
Rtree 1.0.1
ruamel.yaml 0.18.6
ruamel.yaml.clib 0.2.8
ruamel-yaml-conda 0.17.21
s3fs 2024.6.1
safetensors 0.5.3
scikit-image 0.24.0
scikit-learn 1.5.1
scipy 1.13.1
Scrapy 2.11.1
seaborn 0.13.2
semver 3.0.2
Send2Trash 1.8.2
service-identity 18.1.0
setuptools 75.1.0
sip 6.7.12
six 1.16.0
smart-open 5.2.1
smmap 4.0.0
sniffio 1.3.0
snowballstemmer 2.2.0
sortedcontainers 2.4.0
soupsieve 2.5
Sphinx 7.3.7
sphinxcontrib-applehelp 1.0.2
sphinxcontrib-devhelp 1.0.2
sphinxcontrib-htmlhelp 2.0.0
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 1.0.3
sphinxcontrib-serializinghtml 1.1.10
spyder 5.5.1
spyder-kernels 2.5.0
SQLAlchemy 2.0.34
stack-data 0.2.0
starlette 0.41.3
statsmodels 0.14.2
streamlit 1.37.1
sympy 1.13.1
tables 3.10.1
tabulate 0.9.0
tblib 1.7.0
tenacity 8.2.3
terminado 0.17.1
text-unidecode 1.3
textdistance 4.2.1
threadpoolctl 3.5.0
three-merge 0.1.1
tifffile 2023.4.12
timm 1.0.15
tinycss2 1.2.1
tldextract 5.1.2
toml 0.10.2
tomli 2.0.1
tomlkit 0.11.1
toolz 0.12.0
torch 2.6.0+cu118
torchaudio 2.6.0+cu118
torchvision 0.21.0+cu118
tornado 6.4.1
tqdm 4.66.5
traitlets 5.14.3
truststore 0.8.0
Twisted 23.10.0
twisted-iocpsupport 1.0.2
typing_extensions 4.11.0
tzdata 2023.3
uc-micro-py 1.0.1
ujson 5.10.0
unicodedata2 15.1.0
Unidecode 1.3.8
urllib3 2.2.3
uvicorn 0.32.0
w3lib 2.1.2
watchdog 4.0.1
wcwidth 0.2.5
webencodings 0.5.1
websocket-client 1.8.0
Werkzeug 3.0.3
whatthepatch 1.0.2
wheel 0.44.0
widgetsnbextension 3.6.6
win-inet-pton 1.1.0
wrapt 1.14.1
xarray 2023.6.0
xlwings 0.32.1
xyzservices 2022.9.0
yapf 0.40.2
yarl 1.11.0
zict 3.0.0
zipp 3.17.0
zope.interface 5.4.0
zstandard 0.23.0
https://github.com/QwenLM/Qwen2.5-VL/blob/main/cookbooks/video_understanding.ipynb
实操 #
基础环境 #
CUDA环境需要用v12.4
transformers和qwen-vl-utils
如果要使用AWQ的量化模型,需要指定transformers的版本
pip install git+https://github.com/huggingface/transformers@v4.49.0
!pip install git+https://github.com/huggingface/transformers
!pip install qwen-vl-utils
!pip install openai
安装flash_attention2 #
不启用flash_attention2的话出现了cuda爆显存的错误
因为有两种选择 #
-
【未实践】到官方仓库编译wheel包
-
直接安装第三方已经编译号的wheel包(pip install xxx.whl)
-
相关仓库:https://github.com/kingbri1/flash-attention 【win】需要cuda、 torch和python版本都与环境–对应
链接:https://github.com/QwenLM/Qwen2.5-VL/blob/main/cookbooks/video_understanding.ipynb
-
安装triton #
和flash_attention2的选择一样要么从源码编译安装wheel,要么去找别人已经编译好的wheel 包
来源:在windows上安装triton
相关仓库:https://github.com/woct0rdho/triton-windows
pip install -U triton-windows
报错处理 #
File c:\Users\hl\anaconda3\envs\ollama\Lib\site-packages\transformers\generatio
2278 input_ids, model_kwargs = self._expand_inputs_for_generation(
2279 input_ids=input_ids,...
RuntimeError: CUDA error: device-side assert triggeredCUDA kernel errors might be asynchronously reported at some other API call, so For debugging consider passing CUDA_LAUNCH_BLOCKING=1Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
需要修改如下几个参数
- torch_dtype=torch.bfloat16改成 torch_dtype=torch.float16
- 【可选,如果多设备选择报错则修改这个】 device_map=“auto”改成 →device_map=“cuda:0”
参考 #
Flash-attention 2.3.2 Windows下编译安装
代码 #
test_image.py文件
import ast
from PIL import Image, ImageDraw, ImageFont,ImageColor
import torch
from PIL import Image
import time
from qwen_vl_utils import process_vision_info
# 定义全局变量
model = None
processor = None
additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
def init_model():
global model,processor
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
model_path = "Qwen/Qwen2.5-VL-7B-Instruct-AWQ"
local_model_path = "C:/Users/admin/python/qwen/Qwen2.5-VL/Qwen2.5-VL-7B-Instruct-AWQ"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(local_model_path, torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
device_map="cuda:0")
processor = AutoProcessor.from_pretrained(local_model_path)
def inference(img_url, prompt, system_prompt="You are a helpful assistant", max_new_tokens=1024):
image = Image.open("./assets/spatial_understanding/cakes.png")
messages = [
# {
# "role": "system",
# "content": system_prompt
# },
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
# {
# "type": "image",
# "image": "./assets/spatial_understanding/cakes.png"
# },
# {
# "type": "image",
# "image": "./assets/spatial_understanding/cakes1.png"
# },
# {
# "type": "image",
# "image": "./assets/spatial_understanding/cakes2.png"
# },
# {
# "type": "image",
# "image": "./assets/spatial_understanding/cakes3.png"
# },
# {
# "type": "image",
# "image": "./assets/spatial_understanding/cartoon_brave_person.jpeg"
# },
# {
# "type": "image",
# "image": "./assets/spatial_understanding/multiple_items.png"
# },
# {
# "type": "image",
# "image": "./assets/spatial_understanding/Origamis.jpg"
# }
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print("input:\n",text)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=10240)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for t in output_text:
print(t)
# print("output:\n",output_text[0])
# input_height = inputs['image_grid_thw'][0][1]*14
# input_width = inputs['image_grid_thw'][0][2]*14
#
return output_text[0]
def test():
image_path = "./assets/spatial_understanding/cakes.png"
## Use a local HuggingFace model to inference.
# prompt in chinese
# prompt = "框出每一个小蛋糕的位置,以json格式输出所有的坐标,每张图区分开"
# prompt = "请问总共有几张图片?框出每一个见义勇为的人的位置,以json格式输出所有的坐标,每张图区分开,没有见义勇为的人的图要空json数据"
# prompt in english
prompt = "如何吃饭?"
response = inference(image_path, prompt)
# image = Image.open(image_path)
# print(image.size)
# image.thumbnail([640, 640], Image.Resampling.LANCZOS)
# plot_bounding_boxes(image, response, input_width, input_height)
torch.cuda.empty_cache()
def plot_bounding_boxes(im, bounding_boxes, input_width, input_height):
"""
Plots bounding boxes on an image with markers for each a name, using PIL, normalized coordinates, and different colors.
Args:
img_path: The path to the image file.
bounding_boxes: A list of bounding boxes containing the name of the object
and their positions in normalized [y1 x1 y2 x2] format.
"""
# Load the image
img = im
width, height = img.size
print(img.size)
# Create a drawing object
draw = ImageDraw.Draw(img)
# Define a list of colors
colors = [
'red',
'green',
'blue',
'yellow',
'orange',
'pink',
'purple',
'brown',
'gray',
'beige',
'turquoise',
'cyan',
'magenta',
'lime',
'navy',
'maroon',
'teal',
'olive',
'coral',
'lavender',
'violet',
'gold',
'silver',
] + additional_colors
# Parsing out the markdown fencing
bounding_boxes = parse_json(bounding_boxes)
font = ImageFont.truetype("NotoSansCJK-Regular.ttc", size=14)
try:
json_output = ast.literal_eval(bounding_boxes)
except Exception as e:
end_idx = bounding_boxes.rfind('"}') + len('"}')
truncated_text = bounding_boxes[:end_idx] + "]"
json_output = ast.literal_eval(truncated_text)
# Iterate over the bounding boxes
for i, bounding_box in enumerate(json_output):
# Select a color from the list
color = colors[i % len(colors)]
# Convert normalized coordinates to absolute coordinates
abs_y1 = int(bounding_box["bbox_2d"][1]/input_height * height)
abs_x1 = int(bounding_box["bbox_2d"][0]/input_width * width)
abs_y2 = int(bounding_box["bbox_2d"][3]/input_height * height)
abs_x2 = int(bounding_box["bbox_2d"][2]/input_width * width)
if abs_x1 > abs_x2:
abs_x1, abs_x2 = abs_x2, abs_x1
if abs_y1 > abs_y2:
abs_y1, abs_y2 = abs_y2, abs_y1
# Draw the bounding box
draw.rectangle(
((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4
)
# Draw the text
if "label" in bounding_box:
draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
# Display the image
img.show()
# @title Parsing JSON output
def parse_json(json_output):
# Parsing out the markdown fencing
lines = json_output.splitlines()
for i, line in enumerate(lines):
if line == "```json":
json_output = "\n".join(lines[i+1:]) # Remove everything before "```json"
json_output = json_output.split("```")[0] # Remove everything after the closing "```"
break # Exit the loop once "```json" is found
return json_output
if __name__ == "__main__":
init_model()
t1=time.time()
test()
print(time.time()-t1)
test_video.py文件
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import os
import hashlib
import requests
from IPython.display import Markdown, display
import numpy as np
from PIL import Image
import decord
from decord import VideoReader, cpu
# 定义全局变量
model = None
processor = None
def init_model():
global model, processor
model_path = "Qwen/Qwen2.5-VL-7B-Instruct-AWQ"
local_model_path = "C:/Users/admin/python/qwen/Qwen2.5-VL/Qwen2.5-VL-7B-Instruct-AWQ"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
local_model_path,
torch_dtype=torch.float16, #使用半精度浮点数(FP16)来减少内存占用和加速计算
attn_implementation="flash_attention_2", #使用Flash Attention 2来加速自注意力计算
device_map="cuda:0",#将模型加载到第一个 CUDA 设备(GPU)上
local_files_only=True #只从本地文件加载模型
)
processor = AutoProcessor.from_pretrained(#从本地路径加载与模型配套的处理器(processor),该处理器用于将原始输入(如视频帧、文本提示等)转换为模型可以理解的格式。同样设置了 local_files_only=True 参数以仅使用本地文件。
local_model_path,
local_files_only=True
)
def download_video(url, dest_path):
response = requests.get(url, stream=True)
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8096):
f.write(chunk)
print(f"Video downloaded to {dest_path}")
def get_video_frames(video_path, num_frames=128, cache_dir='.cache'):
os.makedirs(cache_dir, exist_ok=True)
video_hash = hashlib.md5(video_path.encode('utf-8')).hexdigest()
if video_path.startswith('http://') or video_path.startswith('https://'):
video_file_path = os.path.join(cache_dir, f'{video_hash}.mp4')
if not os.path.exists(video_file_path):
download_video(video_path, video_file_path)
else:
video_file_path = video_path
frames_cache_file = os.path.join(cache_dir, f'{video_hash}_{num_frames}_frames.npy')
timestamps_cache_file = os.path.join(cache_dir, f'{video_hash}_{num_frames}_timestamps.npy')
if os.path.exists(frames_cache_file) and os.path.exists(timestamps_cache_file):
frames = np.load(frames_cache_file)
timestamps = np.load(timestamps_cache_file)
return video_file_path, frames, timestamps
vr = VideoReader(video_file_path, ctx=cpu(0))
total_frames = len(vr)
indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
frames = vr.get_batch(indices).asnumpy()
timestamps = np.array([vr.get_frame_timestamp(idx) for idx in indices])
np.save(frames_cache_file, frames)
np.save(timestamps_cache_file, timestamps)
return video_file_path, frames, timestamps
def create_image_grid(images, num_columns=8):
pil_images = [Image.fromarray(image) for image in images]
num_rows = (len(images) + num_columns - 1) // num_columns
img_width, img_height = pil_images[0].size
grid_width = num_columns * img_width
grid_height = num_rows * img_height
grid_image = Image.new('RGB', (grid_width, grid_height))
for idx, image in enumerate(pil_images):
row_idx = idx // num_columns
col_idx = idx % num_columns
position = (col_idx * img_width, row_idx * img_height)
grid_image.paste(image, position)
return grid_image
def inference(video_path, prompt, max_new_tokens=2048, total_pixels=20480 * 28 * 28, min_pixels=16 * 28 * 28):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"video": video_path, "total_pixels": total_pixels, "min_pixels": min_pixels},
]
},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)#将消息列表转换为模型可理解的文本格式
image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True)#处理视觉信息,返回图像输入、视频输入和视频参数
fps_inputs = video_kwargs['fps']#获取视频帧率
print("video input:", video_inputs[0].shape)
num_frames, _, resized_height, resized_width = video_inputs[0].shape#获取视频的帧数、高度和宽度
print("num of video tokens:", int(num_frames / 2 * resized_height / 28 * resized_width / 28))#计算视频的token数量
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, fps=fps_inputs, padding=True, return_tensors="pt")#将文本、图像和视频输入转换为模型可接受的格式
inputs = inputs.to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)#将生成的token序列解码为文本
return output_text[0]
def test():
video_url = ".cache\\d830ebfba70612b83155fadc76b213c8.mp4"
prompt = "请用表格总结一下视频中的商品特点"
## Use a local HuggingFace model to inference.
video_path, frames, timestamps = get_video_frames(video_url, num_frames=64)
# image_grid = create_image_grid(frames, num_columns=8)
# display(image_grid.resize((640, 640)))
response = inference(video_path, prompt)
print(response)
# display(Markdown(response))
torch.cuda.empty_cache()
if __name__ == "__main__":
init_model()
test()