diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3b72af06f63e14ea65f6aac0929f38aa57590d94 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.env +__pycache__ diff --git a/main.py b/main.py deleted file mode 100644 index 41f526c049eaf635388c4fda034d0d9b5de8c0da..0000000000000000000000000000000000000000 --- a/main.py +++ /dev/null @@ -1,48 +0,0 @@ -from gradio_client import Client, handle_file - -client = Client("JeffreyXiang/TRELLIS") - -result = client.predict( - api_name="/start_session" -) -print(result) - -preprocessed_image = client.predict( - image=handle_file('https://i.pinimg.com/564x/79/b3/6d/79b36deeb891cf65e41c743f5350d6cf.jpg'), - api_name="/preprocess_image" -) - -rand_seed = client.predict( - randomize_seed=True, - seed=0, - api_name="/get_seed" -) - -result = client.predict( - image=handle_file(preprocessed_image), - multiimages=[], - seed=rand_seed, - ss_guidance_strength=7.5, - ss_sampling_steps=12, - slat_guidance_strength=3, - slat_sampling_steps=12, - multiimage_algo="stochastic", - api_name="/image_to_3d" -) -print(result) - -print("image_to_3d ok") - -client.predict( - api_name="/lambda_2" -) - -client.predict( - mesh_simplify=None, - texture_size=0.95, - api_name="/extract_glb" -) - -client.predict( - api_name="/lambda_4" -) diff --git a/requirements.txt b/requirements.txt index 4a3b36faacb4eba1c04e994d244ee7becd4986ee..f42a0bc3e4188a2833f591eafcf5fca71695be62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ +Flask>=2.0.3 +dotenv gradio_client>=1.7.0 +pydantic>=1.8.2 +requests>=2.26.0 diff --git a/src/controllers/glb_controller.py b/src/controllers/glb_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..2289ca676f6227a8c4ab03cbdb83ac9c5582c315 --- /dev/null +++ b/src/controllers/glb_controller.py @@ -0,0 +1,47 @@ +import os +from gradio_client import Client, handle_file + +class GlbController: + def get_glb(self, hf_token='', base_image_url='', ss_guidance_strength=7.5, + ss_sampling_steps=12, slat_guidance_strength=3, + slat_sampling_steps=12, multiimage_algo='stochastic', + mesh_simplify=0.95, texture_size=1024) -> str: + if hf_token == '': + hf_token = os.getenv("hf_token") + + client = Client("JeffreyXiang/TRELLIS", hf_token) + + client.predict( + api_name="/start_session" + ) + + preprocess_image_result = client.predict( + image=handle_file(base_image_url), + api_name="/preprocess_image" + ) + + get_seed_result = client.predict( + randomize_seed=True, + seed=0, + api_name="/get_seed" + ) + + client.predict( + image=handle_file(preprocess_image_result), + multiimages=[], + seed=get_seed_result, + ss_guidance_strength=ss_guidance_strength, + ss_sampling_steps=ss_sampling_steps, + slat_guidance_strength=slat_guidance_strength, + slat_sampling_steps=slat_sampling_steps, + multiimage_algo=multiimage_algo, + api_name="/image_to_3d" + ) + + extract_glb = client.predict( + mesh_simplify=mesh_simplify, + texture_size=texture_size, + api_name="/extract_glb" + ) + + return extract_glb diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ce2978cb2e4b024b2b598be5d6bf47e68cbac6 --- /dev/null +++ b/src/main.py @@ -0,0 +1,12 @@ +from dotenv import load_dotenv +from flask import Flask +from routes.glb_route import setup_routes + +load_dotenv() + +app = Flask(__name__) + +setup_routes(app) + +if __name__ == '__main__': + app.run(debug=True) diff --git a/src/routes/glb_route.py b/src/routes/glb_route.py new file mode 100644 index 0000000000000000000000000000000000000000..8c0639896c5a36602571bc48172a6202594c5e9c --- /dev/null +++ b/src/routes/glb_route.py @@ -0,0 +1,23 @@ +from flask import Blueprint, request +from controllers.glb_controller import GlbController + +glb_bp = Blueprint('glb', __name__) +controller = GlbController() + +def setup_routes(app): + @app.route('/glb', methods=['GET']) + def get_glb(): + base_image_url = request.args.get('base_image_url', type=str) + hf_token = request.args.get('hf_token', '', type=str) + ss_guidance_strength = request.args.get('ss_guidance_strength', 7.5, type=float) + ss_sampling_steps = request.args.get('ss_sampling_steps', 12, type=int) + slat_guidance_strength = request.args.get('slat_guidance_strength', 3, type=float) + slat_sampling_steps = request.args.get('slat_sampling_steps', 12, type=int) + multiimage_algo = request.args.get('multiimage_algo', 'stochastic', type=str) + mesh_simplify = request.args.get('mesh_simplify', 0.95, type=bool) + texture_size = request.args.get('texture_size', 1024, type=int) + + return controller.get_glb(hf_token, base_image_url, ss_guidance_strength, + ss_sampling_steps, slat_guidance_strength, + slat_sampling_steps, multiimage_algo, + mesh_simplify, texture_size)