Skip to content
Snippets Groups Projects
suggest_controller.py 987 B
import os
from flask import jsonify, Response
from gradio_client import Client

class SuggestController:
    def get_suggest(self, hf_token='', message='', system_prompt='', temperature=.0,
                    max_new_tokens=0, top_k_sampling=0, repetition_penalty=.0,
                    top_p_sampling=.0) -> Response:
        if hf_token == '':
            hf_token = os.getenv("hf_token")

        client = Client("mmcgovern574/Mistral-Small-24B-Instruct-2501", hf_token)

        result = client.predict(
            message=f'Enhance this prompt. The prompt will be used by an Image2Image model (only give one answer, don\'t add anything else) : {message}',
            param_2=system_prompt,
            param_3=temperature,
            param_4=max_new_tokens,
            param_5=top_k_sampling,
            param_6=repetition_penalty,
            param_7=top_p_sampling,
            api_name="/chat"
        )

        print(result)

        return jsonify({"message": result})