Spaces:
Sleeping
Sleeping
zhangfeng144
commited on
Commit
·
66e44c7
1
Parent(s):
670f33a
add batch upload
Browse files- app/database.py +62 -1
- app/image_utils.py +51 -1
- app/services.py +30 -42
- app/ui.py +2 -5
- main.py +2 -2
app/database.py
CHANGED
|
@@ -129,7 +129,6 @@ class Database:
|
|
| 129 |
)
|
| 130 |
|
| 131 |
exists = len(results) > 0
|
| 132 |
-
logger.info(f"Check file exists - hash: {image_hash}, exists: {exists}, results: {results}")
|
| 133 |
return exists
|
| 134 |
except Exception as e:
|
| 135 |
logger.error(f"Failed to check file exists: {str(e)}")
|
|
@@ -149,6 +148,68 @@ class Database:
|
|
| 149 |
except Exception as e:
|
| 150 |
logger.error(f"Failed to delete sticker: {str(e)}")
|
| 151 |
return f"Failed to delete sticker: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
# 初始化 Milvus 数据库
|
| 153 |
|
| 154 |
# 创建数据库实例
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
exists = len(results) > 0
|
|
|
|
| 132 |
return exists
|
| 133 |
except Exception as e:
|
| 134 |
logger.error(f"Failed to check file exists: {str(e)}")
|
|
|
|
| 148 |
except Exception as e:
|
| 149 |
logger.error(f"Failed to delete sticker: {str(e)}")
|
| 150 |
return f"Failed to delete sticker: {str(e)}"
|
| 151 |
+
|
| 152 |
+
def batch_store_stickers(self, stickers: List[Dict[str, Any]], batch_size: int = 100) -> bool:
|
| 153 |
+
"""批量存储贴纸数据到Milvus
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
stickers (List[Dict[str, Any]]): 贴纸数据列表,每个元素包含以下字段:
|
| 157 |
+
- title: str
|
| 158 |
+
- description: str
|
| 159 |
+
- tags: Union[str, List[str]]
|
| 160 |
+
- file_path: str
|
| 161 |
+
- image_hash: str (可选)
|
| 162 |
+
batch_size (int, optional): 每批处理的数量. Defaults to 100.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
bool: 是否全部插入成功
|
| 166 |
+
"""
|
| 167 |
+
try:
|
| 168 |
+
total_stickers = len(stickers)
|
| 169 |
+
if total_stickers == 0:
|
| 170 |
+
logger.warning("No stickers to store")
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
logger.info(f"Starting batch store of {total_stickers} stickers")
|
| 174 |
+
|
| 175 |
+
# 分批处理
|
| 176 |
+
for i in range(0, total_stickers, batch_size):
|
| 177 |
+
batch = stickers[i:i + batch_size]
|
| 178 |
+
batch_data = []
|
| 179 |
+
|
| 180 |
+
for sticker in batch:
|
| 181 |
+
# 处理标签格式
|
| 182 |
+
tags = sticker.get("tags", [])
|
| 183 |
+
if isinstance(tags, str):
|
| 184 |
+
tags = tags.split(",")
|
| 185 |
+
|
| 186 |
+
# 编码描述文本
|
| 187 |
+
vector = self.encode_text(sticker.get("description", ""))
|
| 188 |
+
|
| 189 |
+
batch_data.append({
|
| 190 |
+
"vector": vector,
|
| 191 |
+
"title": sticker.get("title", ""),
|
| 192 |
+
"description": sticker.get("description", ""),
|
| 193 |
+
"tags": tags,
|
| 194 |
+
"file_name": sticker.get("file_path", ""),
|
| 195 |
+
"image_hash": sticker.get("image_hash")
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
# 批量插入
|
| 199 |
+
if batch_data:
|
| 200 |
+
self.client.insert(
|
| 201 |
+
collection_name=self.collection_name,
|
| 202 |
+
data=batch_data
|
| 203 |
+
)
|
| 204 |
+
logger.info(f"Batch {i//batch_size + 1} stored successfully - {len(batch_data)} stickers")
|
| 205 |
+
|
| 206 |
+
logger.info("All stickers stored successfully ✅")
|
| 207 |
+
return True
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"Failed to batch store stickers: {str(e)}")
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
# 初始化 Milvus 数据库
|
| 214 |
|
| 215 |
# 创建数据库实例
|
app/image_utils.py
CHANGED
|
@@ -71,6 +71,19 @@ def upload_to_huggingface(temp_file_path: str) -> tuple:
|
|
| 71 |
logger.info(f"Image uploaded successfully: {file_path}")
|
| 72 |
return file_path, image_filename
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
def get_image_cdn_url(file_path: str) -> str:
|
| 76 |
"""获取图片CDN URL"""
|
|
@@ -104,4 +117,41 @@ def format_image_url(file_path: str) -> str:
|
|
| 104 |
return os.path.abspath(file_path)
|
| 105 |
|
| 106 |
# 如果是HuggingFace路径,返回完整的URL
|
| 107 |
-
return f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{file_path}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
logger.info(f"Image uploaded successfully: {file_path}")
|
| 72 |
return file_path, image_filename
|
| 73 |
|
| 74 |
+
def upload_folder_to_huggingface(folder_path: str) -> None:
|
| 75 |
+
"""上传目录到 HuggingFace"""
|
| 76 |
+
logger.info(f"Uploading folder to HuggingFace: {folder_path}")
|
| 77 |
+
api.upload_folder(
|
| 78 |
+
folder_path=folder_path,
|
| 79 |
+
path_in_repo="images/",
|
| 80 |
+
repo_id=DATASET_ID,
|
| 81 |
+
token=HUGGING_FACE_TOKEN,
|
| 82 |
+
repo_type="dataset"
|
| 83 |
+
)
|
| 84 |
+
logger.info(f"Image uploaded successfully: {folder_path}")
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
|
| 88 |
def get_image_cdn_url(file_path: str) -> str:
|
| 89 |
"""获取图片CDN URL"""
|
|
|
|
| 117 |
return os.path.abspath(file_path)
|
| 118 |
|
| 119 |
# 如果是HuggingFace路径,返回完整的URL
|
| 120 |
+
return f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{file_path}"
|
| 121 |
+
|
| 122 |
+
def generate_temp_image(temp_dir: str, image: Image.Image, image_filename: str) -> str:
|
| 123 |
+
"""根据图片和文件名生成临时图片文件
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
temp_dir (str): 临时目录路径
|
| 127 |
+
image (Image.Image): PIL图片对象
|
| 128 |
+
image_filename (str): 图片文件名,格式为 image_XXXXXX.png
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
str: 临时图片文件的路径
|
| 132 |
+
|
| 133 |
+
Raises:
|
| 134 |
+
ValueError: 如果临时目录不存在或无法访问
|
| 135 |
+
IOError: 如果保存图片失败
|
| 136 |
+
"""
|
| 137 |
+
try:
|
| 138 |
+
# 确保临时目录存在
|
| 139 |
+
if not os.path.exists(temp_dir):
|
| 140 |
+
os.makedirs(temp_dir)
|
| 141 |
+
logger.info(f"Created temporary directory: {temp_dir}")
|
| 142 |
+
|
| 143 |
+
# 构建临时文件路径
|
| 144 |
+
temp_file_path = os.path.join(temp_dir, image_filename)
|
| 145 |
+
|
| 146 |
+
# 保存图片到临时文件
|
| 147 |
+
image.save(temp_file_path, "PNG")
|
| 148 |
+
logger.info(f"Generated temporary image: {temp_file_path}")
|
| 149 |
+
|
| 150 |
+
return temp_file_path
|
| 151 |
+
|
| 152 |
+
except OSError as e:
|
| 153 |
+
logger.error(f"Failed to create temporary directory: {str(e)}")
|
| 154 |
+
raise ValueError(f"Failed to create temporary directory: {str(e)}")
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"Failed to generate temporary image: {str(e)}")
|
| 157 |
+
raise IOError(f"Failed to generate temporary image: {str(e)}")
|
app/services.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import logging
|
|
|
|
| 3 |
import zipfile
|
| 4 |
import tempfile
|
| 5 |
import shutil
|
|
@@ -16,6 +17,8 @@ from app.config import (
|
|
| 16 |
from app.database import db
|
| 17 |
from app.image_utils import (
|
| 18 |
save_image_temp,
|
|
|
|
|
|
|
| 19 |
upload_to_huggingface,
|
| 20 |
get_image_cdn_url,
|
| 21 |
get_image_description,
|
|
@@ -78,35 +81,31 @@ class StickerService:
|
|
| 78 |
def import_stickers(
|
| 79 |
sticker_dataset: str,
|
| 80 |
upload: bool = False,
|
| 81 |
-
gen_description: bool = False,
|
| 82 |
progress_callback: callable = None,
|
| 83 |
-
total_files: int = 0
|
| 84 |
) -> List[str]:
|
| 85 |
"""导入表情包数据集
|
| 86 |
|
| 87 |
Args:
|
| 88 |
sticker_dataset (str): 表情包数据集路径
|
| 89 |
upload (bool, optional): 是否上传到HuggingFace. Defaults to False.
|
| 90 |
-
gen_description (bool, optional): 是否生成AI描述. Defaults to False.
|
| 91 |
progress_callback (callable, optional): 进度回调函数. Defaults to None.
|
| 92 |
-
total_files (int, optional): 总文件数. Defaults to 0.
|
| 93 |
"""
|
| 94 |
results = []
|
| 95 |
descriptions = {}
|
| 96 |
|
| 97 |
try:
|
| 98 |
# 创建临时目录
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
# 解压数据集
|
| 103 |
with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref:
|
| 104 |
-
zip_ref.extractall(
|
| 105 |
-
|
| 106 |
-
logger.info(f"Extracted dataset to: {temp_dir}")
|
| 107 |
|
| 108 |
# 尝试读取data.json文件
|
| 109 |
-
data_json_path = os.path.join(temp_dir, 'data.json')
|
| 110 |
if os.path.exists(data_json_path):
|
| 111 |
with open(data_json_path, 'r', encoding='utf-8') as f:
|
| 112 |
data = json.load(f)
|
|
@@ -114,7 +113,7 @@ class StickerService:
|
|
| 114 |
logger.info(f"Loaded descriptions from data.json")
|
| 115 |
|
| 116 |
# 遍历解压后的目录
|
| 117 |
-
for root, dirs, files in os.walk(
|
| 118 |
for file in files:
|
| 119 |
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')):
|
| 120 |
image_path = os.path.join(root, file)
|
|
@@ -122,8 +121,6 @@ class StickerService:
|
|
| 122 |
try:
|
| 123 |
# 打开图片
|
| 124 |
image = Image.open(image_path)
|
| 125 |
-
|
| 126 |
-
# 检查文件名是否已存在
|
| 127 |
image_hash = calculate_image_hash(image)
|
| 128 |
if db.check_image_exists(image_hash):
|
| 129 |
results.append(f"跳过已存在的图片: {file}")
|
|
@@ -131,40 +128,23 @@ class StickerService:
|
|
| 131 |
progress_callback(file, "Skipped (exists)")
|
| 132 |
continue
|
| 133 |
|
| 134 |
-
# 保存到临时文件
|
| 135 |
-
temp_file_path = save_image_temp(image)
|
| 136 |
-
|
| 137 |
-
# 上传到 HuggingFace
|
| 138 |
-
if upload:
|
| 139 |
-
file_path, image_filename = upload_to_huggingface(temp_file_path)
|
| 140 |
-
else:
|
| 141 |
-
file_path = temp_file_path
|
| 142 |
-
image_url = file_path
|
| 143 |
-
|
| 144 |
# 获取图片描述
|
| 145 |
description = None
|
| 146 |
if file in descriptions:
|
| 147 |
description = descriptions[file]
|
| 148 |
-
|
| 149 |
-
image_cdn_url = ''
|
| 150 |
-
if (PUBLIC_URL):
|
| 151 |
-
image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={temp_file_path}'
|
| 152 |
-
else:
|
| 153 |
-
image_cdn_url = get_image_cdn_url(file_path)
|
| 154 |
-
description = get_image_description(image_cdn_url)
|
| 155 |
-
|
| 156 |
if not description:
|
| 157 |
-
|
| 158 |
if progress_callback:
|
| 159 |
progress_callback(file, "Skipped (no description)")
|
| 160 |
continue
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
db.store_sticker("", description, "", file_path, image_hash)
|
| 167 |
-
results.append(f"成功导入: {
|
| 168 |
|
| 169 |
if progress_callback:
|
| 170 |
progress_callback(file, "Imported")
|
|
@@ -174,7 +154,11 @@ class StickerService:
|
|
| 174 |
results.append(f"处理失败 {file}: {str(e)}")
|
| 175 |
if progress_callback:
|
| 176 |
progress_callback(file, f"Failed: {str(e)}")
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
return results
|
| 179 |
|
| 180 |
except Exception as e:
|
|
@@ -184,9 +168,13 @@ class StickerService:
|
|
| 184 |
|
| 185 |
finally:
|
| 186 |
# 清理临时目录
|
| 187 |
-
if
|
| 188 |
-
shutil.rmtree(
|
| 189 |
-
logger.info(f"Cleaned up temporary directory: {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
@staticmethod
|
| 192 |
def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]:
|
|
|
|
| 1 |
import os
|
| 2 |
import logging
|
| 3 |
+
import random
|
| 4 |
import zipfile
|
| 5 |
import tempfile
|
| 6 |
import shutil
|
|
|
|
| 17 |
from app.database import db
|
| 18 |
from app.image_utils import (
|
| 19 |
save_image_temp,
|
| 20 |
+
generate_temp_image,
|
| 21 |
+
upload_folder_to_huggingface,
|
| 22 |
upload_to_huggingface,
|
| 23 |
get_image_cdn_url,
|
| 24 |
get_image_description,
|
|
|
|
| 81 |
def import_stickers(
|
| 82 |
sticker_dataset: str,
|
| 83 |
upload: bool = False,
|
|
|
|
| 84 |
progress_callback: callable = None,
|
|
|
|
| 85 |
) -> List[str]:
|
| 86 |
"""导入表情包数据集
|
| 87 |
|
| 88 |
Args:
|
| 89 |
sticker_dataset (str): 表情包数据集路径
|
| 90 |
upload (bool, optional): 是否上传到HuggingFace. Defaults to False.
|
|
|
|
| 91 |
progress_callback (callable, optional): 进度回调函数. Defaults to None.
|
|
|
|
| 92 |
"""
|
| 93 |
results = []
|
| 94 |
descriptions = {}
|
| 95 |
|
| 96 |
try:
|
| 97 |
# 创建临时目录
|
| 98 |
+
cache_folder = os.path.join(TEMP_DIR, 'cache/')
|
| 99 |
+
img_folder = os.path.join(TEMP_DIR, 'images/')
|
| 100 |
+
data_json_path = os.path.join(cache_folder, 'data.json')
|
| 101 |
+
|
| 102 |
+
logger.info(f"start import dataset")
|
| 103 |
# 解压数据集
|
| 104 |
with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref:
|
| 105 |
+
zip_ref.extractall(cache_folder)
|
| 106 |
+
logger.info(f"Extracted dataset to: {cache_folder}")
|
|
|
|
| 107 |
|
| 108 |
# 尝试读取data.json文件
|
|
|
|
| 109 |
if os.path.exists(data_json_path):
|
| 110 |
with open(data_json_path, 'r', encoding='utf-8') as f:
|
| 111 |
data = json.load(f)
|
|
|
|
| 113 |
logger.info(f"Loaded descriptions from data.json")
|
| 114 |
|
| 115 |
# 遍历解压后的目录
|
| 116 |
+
for root, dirs, files in os.walk(cache_folder):
|
| 117 |
for file in files:
|
| 118 |
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')):
|
| 119 |
image_path = os.path.join(root, file)
|
|
|
|
| 121 |
try:
|
| 122 |
# 打开图片
|
| 123 |
image = Image.open(image_path)
|
|
|
|
|
|
|
| 124 |
image_hash = calculate_image_hash(image)
|
| 125 |
if db.check_image_exists(image_hash):
|
| 126 |
results.append(f"跳过已存在的图片: {file}")
|
|
|
|
| 128 |
progress_callback(file, "Skipped (exists)")
|
| 129 |
continue
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
# 获取图片描述
|
| 132 |
description = None
|
| 133 |
if file in descriptions:
|
| 134 |
description = descriptions[file]
|
| 135 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
if not description:
|
| 137 |
+
results.append(f"跳过无描述的图片: {file}")
|
| 138 |
if progress_callback:
|
| 139 |
progress_callback(file, "Skipped (no description)")
|
| 140 |
continue
|
| 141 |
|
| 142 |
+
image_filename = f"image_{random.randint(100000, 999999)}.png"
|
| 143 |
+
file_path = f"images/{image_filename}"
|
| 144 |
+
generate_temp_image(img_folder, image, image_filename)
|
| 145 |
+
|
| 146 |
db.store_sticker("", description, "", file_path, image_hash)
|
| 147 |
+
results.append(f"成功导入: {image_filename}")
|
| 148 |
|
| 149 |
if progress_callback:
|
| 150 |
progress_callback(file, "Imported")
|
|
|
|
| 154 |
results.append(f"处理失败 {file}: {str(e)}")
|
| 155 |
if progress_callback:
|
| 156 |
progress_callback(file, f"Failed: {str(e)}")
|
| 157 |
+
|
| 158 |
+
# 上传到 HuggingFace
|
| 159 |
+
if upload:
|
| 160 |
+
upload_folder_to_huggingface(img_folder)
|
| 161 |
+
|
| 162 |
return results
|
| 163 |
|
| 164 |
except Exception as e:
|
|
|
|
| 168 |
|
| 169 |
finally:
|
| 170 |
# 清理临时目录
|
| 171 |
+
if cache_folder and os.path.exists(cache_folder):
|
| 172 |
+
shutil.rmtree(cache_folder)
|
| 173 |
+
logger.info(f"Cleaned up temporary directory: {cache_folder}")
|
| 174 |
+
|
| 175 |
+
if img_folder and os.path.exists(img_folder):
|
| 176 |
+
shutil.rmtree(img_folder)
|
| 177 |
+
logger.info(f"Cleaned up temporary directory: {img_folder}")
|
| 178 |
|
| 179 |
@staticmethod
|
| 180 |
def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]:
|
app/ui.py
CHANGED
|
@@ -144,7 +144,6 @@ class StickerUI:
|
|
| 144 |
|
| 145 |
with gr.Row():
|
| 146 |
self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False)
|
| 147 |
-
self.gen_desc_checkbox = gr.Checkbox(label="Generate AI Descriptions", value=False)
|
| 148 |
|
| 149 |
with gr.Row():
|
| 150 |
self.import_button.render()
|
|
@@ -155,13 +154,12 @@ class StickerUI:
|
|
| 155 |
fn=self._import_stickers_with_progress,
|
| 156 |
inputs=[
|
| 157 |
self.dataset_input,
|
| 158 |
-
self.upload_checkbox
|
| 159 |
-
self.gen_desc_checkbox
|
| 160 |
],
|
| 161 |
outputs=self.import_output
|
| 162 |
)
|
| 163 |
|
| 164 |
-
def _import_stickers_with_progress(self, dataset_path, upload,
|
| 165 |
"""Import stickers with progress tracking."""
|
| 166 |
try:
|
| 167 |
# Count total files first
|
|
@@ -189,7 +187,6 @@ class StickerUI:
|
|
| 189 |
results = sticker_service.import_stickers(
|
| 190 |
dataset_path,
|
| 191 |
upload=upload,
|
| 192 |
-
gen_description=gen_description,
|
| 193 |
progress_callback=update_progress,
|
| 194 |
total_files=total_files
|
| 195 |
)
|
|
|
|
| 144 |
|
| 145 |
with gr.Row():
|
| 146 |
self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False)
|
|
|
|
| 147 |
|
| 148 |
with gr.Row():
|
| 149 |
self.import_button.render()
|
|
|
|
| 154 |
fn=self._import_stickers_with_progress,
|
| 155 |
inputs=[
|
| 156 |
self.dataset_input,
|
| 157 |
+
self.upload_checkbox
|
|
|
|
| 158 |
],
|
| 159 |
outputs=self.import_output
|
| 160 |
)
|
| 161 |
|
| 162 |
+
def _import_stickers_with_progress(self, dataset_path, upload, progress=gr.Progress()):
|
| 163 |
"""Import stickers with progress tracking."""
|
| 164 |
try:
|
| 165 |
# Count total files first
|
|
|
|
| 187 |
results = sticker_service.import_stickers(
|
| 188 |
dataset_path,
|
| 189 |
upload=upload,
|
|
|
|
| 190 |
progress_callback=update_progress,
|
| 191 |
total_files=total_files
|
| 192 |
)
|
main.py
CHANGED
|
@@ -72,7 +72,7 @@ async def api_delete_stickers(request: dict):
|
|
| 72 |
|
| 73 |
|
| 74 |
@app.post("/api/import_dataset")
|
| 75 |
-
async def api_import_dataset(file: UploadFile = File(...), upload: bool = False
|
| 76 |
"""Import sticker dataset from ZIP file"""
|
| 77 |
try:
|
| 78 |
# Create a temporary file to store the uploaded ZIP
|
|
@@ -82,7 +82,7 @@ async def api_import_dataset(file: UploadFile = File(...), upload: bool = False,
|
|
| 82 |
temp_file_path = temp_file.name
|
| 83 |
|
| 84 |
# Import the dataset
|
| 85 |
-
results = sticker_service.import_stickers(temp_file_path, upload
|
| 86 |
|
| 87 |
# Clean up the temporary file
|
| 88 |
os.unlink(temp_file_path)
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
@app.post("/api/import_dataset")
|
| 75 |
+
async def api_import_dataset(file: UploadFile = File(...), upload: bool = False):
|
| 76 |
"""Import sticker dataset from ZIP file"""
|
| 77 |
try:
|
| 78 |
# Create a temporary file to store the uploaded ZIP
|
|
|
|
| 82 |
temp_file_path = temp_file.name
|
| 83 |
|
| 84 |
# Import the dataset
|
| 85 |
+
results = sticker_service.import_stickers(temp_file_path, upload)
|
| 86 |
|
| 87 |
# Clean up the temporary file
|
| 88 |
os.unlink(temp_file_path)
|