diff --git a/projects/videochat/.gitignore b/projects/videochat/.gitignore new file mode 100644 index 0000000000..c296bd9895 --- /dev/null +++ b/projects/videochat/.gitignore @@ -0,0 +1,3 @@ +/pretrained_models/ +/images/ +/extract/ diff --git a/projects/videochat/README.md b/projects/videochat/README.md new file mode 100644 index 0000000000..3cd4caaa43 --- /dev/null +++ b/projects/videochat/README.md @@ -0,0 +1,183 @@ +# VideoChat + +对视频进行批量QA测试。 + +# 运行 + +```shell +# We recommend using conda to manage the environment and use python3.8.16 +conda create -n videochat python=3.8.16 +conda activate videochat + +# Clone the repository: +cd videochat + +# Install dependencies: +pip install -r requirements.txt +pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz +python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' +pip install pydantic==1.10.12 +apt get ffmpeg + +# Download the checkpoints +mkdir pretrained_models +wget -P ./pretrained_models https://huggingface.co/spaces/xinyu1205/Tag2Text/resolve/main/tag2text_swin_14m.pth +wget -P ./pretrained_models https://datarelease.blob.core.windows.net/grit/models/grit_b_densecap_objectdet.pth +wget -P ./pretrained_models https://huggingface.co/Sn4kehead/TransNetV2/resolve/main/transnetv2-pytorch-weights.pth +git clone https://huggingface.co/mrm8488/flan-t5-large-finetuned-openai-summarize_from_feedback ./pretrained_models/flan-t5-large-finetuned-openai-summarize_from_feedback +cd ./pretrained_models/flan-t5-large-finetuned-openai-summarize_from_feedback +git lfs pull +cd ../.. + +# Configure the necessary ChatGPT APIs +export OPENAI_API_KEY={Your_Private_Openai_Key} + +# Run the VideoChat test. +python test.py +``` + +# 文件结构 + +- test.py: 代码入口。初始化各个视觉模型与推理的调用,进行问答的测试。 +- chatbot.py: 设置ChatGPT的prompt与发送请求。 +- configs.ini: 进行各个配置项的设置。 +- pretrained_models: 需要下载各个模型的.pth文件,放在这个目录下。 +- models: 存放了各个基础模型的class +- transforms.py: 一些数据增强的class +- util.py: 加载视频的函数 + +# 详细介绍 + +## configs.ini + +- device 使用的gpu +- videos_path 测试视频的目录,其中应该存放若干子目录,子目录中需要包含视频video.mp4和问题data.json +- output_path 问答测试结果存放的目录 +- images_path 存放grit模型预测的检测框的可视化结果 +- evaluate_path 指定该路径可以对某次问答结果重新评估分数 +- appid 调用讯飞api时需要去官网注册appid +- secret_key 调用讯飞api时需要去官网获取secret_key +- segment_length 用于指定internvideo动作预测间隔几秒 +- remarks 用于备注本次测试的一些信息 +- llm 指定shot summary和问答使用的ChatGPT模型 +- predict 是否要进行问答(设置为False时,可以结合evaluate_path评估特定路径下的结果) +- evaluate 是否要对问答结果评估分数 +- mode 如果是normal,则不进行分段;如果是shot,则进行分段 + +## pretrained_models + +需要包含以下目录和文件 + +- flan-t5-large-finetuned-openai-summarize_from_feedback 目录 +- uniformerv2 目录 +- grit_b_densecap_objectdet.pth +- tag2text_swin_14m.pth +- transnetv2-pytorch-weights.pth + +## test.py + +包含3个class + +### InputVideo + +1. init:从videos_path下读取相应的文件,如果features文件不存在,会调用self.extract_features进行特征提取,并保存到视频所在目录 +2. self.extract_features:首先调用self.video_chat.inference_second进行特征的提取,随后会调用ProcessSubtitle将whisper和讯飞的结果进行合并,调用ShotProcessor处理shot分段的结果(帧->秒),它们处理完的结果都会保存到features.json中 +3. start_test:首先从features.json中读取特征,然后根据mode来进行不分段的推理或shot的推理。shot推理依赖于features\['time_intervals'\]的结果(参考models/transnetv2.py),将features的内容按照shot的时间段进行划分,每一段shot对应的特征进入ChatGPT进行summary。 + +### VideoChat + +1. load_model:加载基础模型 +2. inference_second:将视频每秒抽取一帧,提取特征,所有特征都是一个list,list中包含{'begin': int, 'text': str}类型的对象。这里提取的特征包括whisper(没有说话人的字幕)、讯飞subtitle(有说话人的字幕)、dense(画面中的物体描述)、dense_with_pos(dense的检测框)、frame caption(画面的描述)、shot(每个shot的起始帧和结束帧)、synth_caption(根据frame形成的总结性描述)、tag(视频中出现的物体类别) + +| 特征 | 模型 | 在哪里调用 | +| :-------------: | :------------------: | :--------------: | +| shot | models/transnetv2.py | inference_second | +| subtitle | models/subtitle.py | inference_second | +| whisper | whisper | inference_second | +| dense | models/grit_model.py | inference_second | +| dense_with_pos | models/grit_model.py | inference_second | +| frame | models/tag2text.py | inference_second | +| synth_caption | simplet5 | inference_second | +| time_intervals | models/transnetv2.py | extract_features | +| merged_subtitle | models/subtitle.py | extract_features | +| ocr | models/ocr.py | inference_second | +| dense_with_ocr | models/ocr.py | extract_features | +| final_subtitle | models/ocr.py | extract_features | + +## chatbot.py + +用于构造prompt和向chatgpt发送请求 + +### ConversationBot + +1. init_agent:用于在非shot情况下的prompt构造,输入是全部features +2. init_agent_shot:用于在shot情况下的prompt构造,输入是shot features +3. init_agent_with_summary:在所有shot转summary后,根据summary构造prompt,输入是summary list +4. run_text:用于发送请求进行问答 + +### ChainOfThought + +使用LangChain来构造顺序链,以及规范ChatGPT输出 + +## models/subtitle.py + +包含了两个class,用于调用讯飞api,以及将讯飞api的结果和whisper合并 + +### RequestApi + +- 在test.py的extract_features中调用,提取features\['subtitle'\] + +1. upload:通过该函数对注册好的语音转写服务发送请求,在响应体中能获得一个orderId,通过这个orderId可以请求到转写的结果 +2. get_result:根据orderId获取转写的结果 +3. result2text:由于转写结果是以分词的形式返回的,这里将其处理成{'begin': int, 'end': int, 'speaker': int, 'text': str}形式的对象 + +### ProcessSubtitle + +- 在test.py的features提取完成后调用,提取features\['merged_subtitle'\] + +1. merge_whisper_and_xunfei:由于讯飞对中文的识别能力较强且有说话人识别,whisper对外语的识别能力较强,这里的逻辑是如果whisper判断为非中文,就采用whisper的结果,并根据这一句的begin时间查找时间最近的讯飞中的句子的speaker。如果whisper判断为中文,直接查找时间最近的讯飞中的对象。 +2. find_match_subtitle:查找begin时间最近的对象 +3. remove_duplicates_by_text_key:由于whisper和讯飞的分句不同,上述结果可能会将一句话append两次,这里删去重复的语句 + +## models/transnetv2.py + +包含了计算视频分段,以及将分段帧转时间的全部方法 + +### Shot + +- 在test.py的extract_features中调用,提取features\['shot'\] + +1. init:加载模型 +2. inference: 获取每一段shot的起始帧和结束帧 + +### ShotProcessor + +- 在test.py的features提取完成后调用,提取features\['time_intervals'\] + +1. shot:根据视频帧率,将shot的起始帧和结束帧都转化为时间 + +## models/ocr.py + +使用paddleocr进行文字提取,暂时无法兼容 + +### ProcessOCR + +1. inference: 使用paddleocr获取画面中文字的内容和位置,保存到features\['ocr'\] +2. merge:将ocr结果和features\['merged_subtitle'\]合并。 + 根据ocr出现的位置,如果位置在画面下方20%,则认为是字幕,由于一般的规则无法处理ocr字幕和语音字幕的合并,这里调用GPT4来实现。 +3. find_text_in_dense:如果判断ocr文字不是字幕,则将其和features\['dense'\]的结果进行比对,查找dense中bbox包含该ocr的bbox的最小物体,将文字内容和找到的物体进行合并。 + +# 引用 + +The project is based on + +- [Ask-Anything](https://github.com/OpenGVLab/Ask-Anything/) +- [InternVideo](https://github.com/OpenGVLab/InternVideo) +- [Tag2Text](https://github.com/xinyu1205/Tag2Text) +- [GRiT](https://github.com/JialianW/GRiT) +- [mrm8488](https://huggingface.co/mrm8488/flan-t5-large-finetuned-openai-summarize_from_feedback) +- [ChatGPT](https://openai.com/blog/chatgpt) +- [TransNetV2](https://github.com/soCzech/TransNetV2) +- [LangChain](https://github.com/langchain-ai/langchain) + +Thanks for the authors for their efforts. diff --git a/projects/videochat/__init__.py b/projects/videochat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/chatbot.py b/projects/videochat/chatbot.py new file mode 100644 index 0000000000..7125eaad65 --- /dev/null +++ b/projects/videochat/chatbot.py @@ -0,0 +1,443 @@ +""" +Description: +Version: 1.0 +Author: ZhuYichen +Date: 2023-07-03 10:56:48 +LastEditors: ZhuYichen +LastEditTime: 2023-07-09 19:22:10 +""" +import time + +import openai +import tiktoken +from langchain.chains import LLMChain +from langchain.llms import OpenAI +from langchain.output_parsers import ResponseSchema, StructuredOutputParser +from langchain.prompts import PromptTemplate + + +def remove_duplicates(input_string): + words = input_string.split(' | ') + unique_words = [] + seen_words = set() + + for word in words: + if word not in seen_words: + unique_words.append(word) + seen_words.add(word) + + result_string = ' | '.join(unique_words) + return result_string + + +class ConversationBot: + + def __init__(self): + self.system_prompt = None + self.openai_api_key = None + + def run_text(self, question, llm, change_prompt=None, t=0): + openai.api_key = self.openai_api_key + system_prompt = change_prompt if change_prompt else self.system_prompt + try: + response = openai.ChatCompletion.create( + model=llm, + messages=[{ + 'role': 'system', + 'content': system_prompt + }, { + 'role': 'user', + 'content': question + }], + temperature=t, + max_tokens=2048, + top_p=1, + frequency_penalty=0, + presence_penalty=0) + answer = response.choices[0].message.content + except Exception as e: + print(e) + time.sleep(60) # Wait for 1 minute before retrying + return self.run_text(question, llm, change_prompt) + print(f'\nQuestion: {question}\nAnswer: {answer}\n') + return answer + + def init_agent(self, openai_api_key, features, dense_intervals=5): + self.openai_api_key = openai_api_key + prompt_dict = dict() + for feature_type, feature_content in features.items(): + prompt_dict[feature_type] = '' + for item in feature_content: + if feature_type == 'subtitle' \ + or feature_type == 'merged_subtitle': + if item.get('speaker', None): + prompt_dict[ + feature_type] += \ + 'Second{}: Speaker{}: {}\n'.format( + item['begin'], item['speaker'], item['text']) + else: + prompt_dict[feature_type] += 'Second{}: {}\n'.format( + item['begin'], item['text']) + elif feature_type == 'whisper': + prompt_dict[feature_type] += 'Second{}: {}\n'.format( + item['begin'], item['text']) + elif feature_type == 'dense': + if int(item['begin']) % dense_intervals != 0: + continue + prompt_dict[feature_type] += 'Second{}: {}\n'.format( + int(item['begin']), item['text']) + elif feature_type == 'frame': + prompt_dict[feature_type] += 'Second{}: {}'.format( + int(item['begin']), item['text']) + # for ocr in features['dense_with_ocr']: + # if ocr['begin'] == item['begin']: + # prompt_dict[feature_type] += \ + # ', {}'.format(ocr['text']) + prompt_dict[feature_type] += '\n' + elif feature_type == 'tag': + prompt_dict[feature_type] += item + ' | ' + elif feature_type == 'synth_caption': + prompt_dict[feature_type] = item + elif feature_type == 'dense_with_ocr': + prompt_dict[feature_type] += 'Second{}: {}\n'.format( + int(item['begin']), item['text']) + prompt_dict['tag'] = remove_duplicates(prompt_dict['tag']) + system_prompt = \ + """You are a chatbot that conducts conversations based on video + descriptions. You mainly answer based on the given description, + and you can also modify the content according to the tag + information, and you can also answer the relevant knowledge of + the person or object contained in the video. The second + description is a description for one second, so that you can + convert it into time. When describing, please mainly refer to + the second description. Dense caption is to give content every + five seconds, you can disambiguate them in timing. But you don't + create a video plot out of nothing. + + Begin! + + The second description of the video is: {} + + The dense caption of the video is: {} + + The general description of the video is: {} + + The subtitle of the video is: {} + + """.format( + prompt_dict['frame'], + prompt_dict['dense'], + prompt_dict['synth_caption'], + prompt_dict['whisper'] + ) + if not openai_api_key.startswith('sk-'): + print('OPEN_API_KEY ERROR') + self.system_prompt = system_prompt + token_count = num_tokens_from_string(system_prompt) + print(token_count) + if token_count > 14000: + return self.init_agent(openai_api_key, features, dense_intervals=5) + return self.system_prompt + + def init_agent_with_summary(self, openai_api_key, output): + self.openai_api_key = openai_api_key + summary = '' + for item in output['summary']: + summary += 'Second{} to Second{}: {}\n'.format( + item['begin'], item['end'], item['text']) + subtitle = '' + for item in output['subtitle']: + subtitle += 'Second{} to Second{}: Speaker{} said: {}\n'.format( + item['begin'], item['end'], item['speaker'], item['text']) + system_prompt = \ + f""" + 下面是一个视频的若干片段的文字描述以及字幕,描述和字幕都注明了开始时间和结束时间。你需要根据这些信息来回答用户的问题,对于用户的问题,你必须根据信息猜测一个答案。你可以按照以下流程思考: + 1. 找出用户提问的关键词 + 2. 在描述中搜索和关键词相关的信息 + 3. 根据片段中的内容猜测答案 + 下面是文字描述: + 1.描述: + {summary} + 2.字幕: + {subtitle} + """ + if not openai_api_key.startswith('sk-'): + print('OPEN_API_KEY ERROR') + self.system_prompt = system_prompt + print(self.system_prompt) + return self.system_prompt + + def init_agent_shot(self, openai_api_key, features): + self.openai_api_key = openai_api_key + prompt_dict = dict() + for feature_type, feature_content in features.items(): + prompt_dict[feature_type] = '' + for item in feature_content: + if feature_type == 'subtitle': + prompt_dict[ + feature_type] += 'Second{}: Speaker{}: {}\n'.format( + item['begin'], item['speaker'], item['text']) + elif feature_type == 'dense' or feature_type == 'ocr': + continue + elif feature_type == 'frame': + prompt_dict[feature_type] += 'Second{}: {}'.format( + item['begin'], item['text']) + # for ocr in features['dense_with_ocr']: + # if ocr['begin'] == item['begin']: + # prompt_dict[feature_type] += \ + # ', {}'.format(ocr['text']) + prompt_dict[feature_type] += '\n' + + system_prompt = \ + """ + 你需要根据用户提供的视频片段的描述来总结视频,描述分为三部分 + 1.帧描述:帧描述提供了每一秒视频中的主体和事件 + 2.字幕:字幕提供了某个时间点,视频中的人说的话,你可以根据它来推断视频中发生了什么。 + 你需要注意视频片段中出现的对话和行为。下面是一个参考案例: + ''' + 用户提供的描述: + 1. + 帧描述: + Second1: a man walking down a street next to parked cars + Second2: a man riding a scooter on a street with cars and + pedestrians + Second3: a busy street with people riding mopeds, motorcycles + and cars + Second4: a man riding a scooter down a busy street with cars + and people on motorcycles + Second5: a man riding a black moped down a busy street with + cars + Second6: a man riding a scooter down a street next to cars + Second7: a man riding a black moped down a street next to cars + on the sidewalk + Second8: a man riding a moped down a busy street with parked + cars on the sidewalk + Second9: a man riding a motor scooter down a street next to a + sidewalk + Second10: a man riding a moped down a street with watermelons + 2. + 字幕: + Second0: Speaker1: 有一个人前来买瓜 + ''' + 参考答案: + 一个男人骑着摩托车沿着街道行驶,周围停着一些汽车。街道上有许多人骑着摩托车、 + 机动车和自行车行驶。然后这个骑着黑色摩托车的男人到了一个西瓜摊边上。 + ''' + 下面你将收到用户提供的描述: + """ + question = \ + """ + 1. + 帧描述: + {} + 2. + 字幕: + {} + 你的总结: + """.format(prompt_dict['frame'], prompt_dict['subtitle']) + if not openai_api_key.startswith('sk-'): + print('OPEN_API_KEY ERROR') + self.system_prompt = system_prompt + return self.system_prompt, question + + def evaluate_qa(self, openai_api_key, qa, llm): + openai.api_key = openai_api_key + system_prompt = \ + """ + 对比标准答案,评估以下视频问答模型的预测准确性,按照0-5的评分标准: + + ‘’’ + + 0:预测答案与标准答案完全不一致或无关。 + + 1:预测答案与标准答案部分一致,但主要的信息都没有涵盖。 + + 2:预测答案中包含了部分标准答案,但并没有完全回答问题。 + + 3:预测答案包含了标准答案的大部分内容,但缺少一些关键信息。 + + 4:预测答案和标准答案几乎完全一致,只是在细节上有一些小的遗漏。 + + 5:预测答案和标准答案完全一致。 + + ‘’’ + 你需要给出评分的理由,再进行评分。 + + 你的回答必须以字典格式写出: + + {'reason': 评分理由,'score': 分数,} + + 问题、标准答案以及预测答案分别如下: + """ + response = openai.ChatCompletion.create( + model=llm, + messages=[{ + 'role': 'system', + 'content': system_prompt + }, { + 'role': 'user', + 'content': qa + }], + temperature=0, + max_tokens=2048, + top_p=1, + frequency_penalty=0, + presence_penalty=0) + answer = response.choices[0].message.content + print(f'\nAnswer: {answer}\n') + answer = answer.strip('{}').replace("'", '') + try: + score_part = answer.split(', ')[-1] + reason_part = answer[:-len(score_part)].rstrip(', ') + dict_answer = dict( + item.split(': ') for item in [reason_part, score_part]) + except Exception as e: + dict_answer = { + 'reason': '返回评分格式错误', + 'score': 0, + } + print(e) + return dict_answer + + +class ChainOfThought: + + def __init__(self, description): + self.llm = OpenAI(temperature=0, model_name='gpt-3.5-turbo-16k') + self.description = description + + def get_question_type(self, question): + is_summary_schema = ResponseSchema( + name='is_summary', + description='Is this question a requirement to summarize this ' + 'paragraph? \ + Answer true if yes,\ + false if not or unknown.') + + question_type_schema = ResponseSchema( + name='question_type', + description='What is the type of the question? \ + Answer 1 if it is a type of reasoning, such as ' + 'reasoning about character relationships, ' + 'storylines, etc.\ + Answer 2 if it is a visual type, such as counting, ' + 'object recognition, etc.\ + Answer 0 if it does not belong to the above two ' + 'types or cannot be determined.') + + response_schemas = [is_summary_schema, question_type_schema] + + output_parser = StructuredOutputParser.from_response_schemas( + response_schemas) + + format_instructions = output_parser.get_format_instructions() + + prompt = PromptTemplate( + template='The user is asking a question about a video clip, ' + 'and you need to determine the type of question.\n{' + 'format_instructions}\n{question}', + input_variables=['question'], + partial_variables={'format_instructions': format_instructions}) + model = OpenAI(temperature=0, model_name='gpt-3.5-turbo-16k') + _input = prompt.format_prompt(question=question) + output = model(_input.to_string()) + res = output_parser.parse(output) + return res + + def summarize(self): + # summary_chain = + # load_summarize_chain(self.llm, chain_type="map_reduce") + # summarize_document_chain = + # AnalyzeDocumentChain(combine_docs_chain=summary_chain) + # res = summarize_document_chain.run(self.description) + prompt = PromptTemplate( + input_variables=['description'], + template='The following is a textual description of a video. ' + 'Please summarize what happened in the video into one ' + 'paragraph in Chinese.:\n {description}?', + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + res = chain.run(self.description) + print(res) + return res + + def think(self, question): + response_schemas = [ + ResponseSchema( + name='step1', + description='Identify the keywords that users ask questions ' + 'about.'), + ResponseSchema( + name='step2', + description='Search for descriptions related to keywords in ' + 'the description. '), + ResponseSchema( + name='step3', + description='Reasoning answers based on the content in the ' + 'description found. '), + ResponseSchema( + name='answer', description='Final answer to this question.') + ] + output_parser = StructuredOutputParser.from_response_schemas( + response_schemas) + format_instructions = output_parser.get_format_instructions() + prompt = PromptTemplate( + template=''' + {description}\n + Guess the possible answer to the question instead of answering + unknown.You can think through the following process: 1. Identify + the keywords that users ask questions about; 2. Search for + descriptions related to keywords in the description; 3. Reasoning + answers based on the content in the description. \n + Please provide your thinking process and answer.\n + {format_instructions}\n + The question is:{question} + ''', + input_variables=['description', 'question'], + partial_variables={'format_instructions': format_instructions}) + model = OpenAI(temperature=1, model_name='gpt-3.5-turbo-16k') + _input = prompt.format_prompt( + question=question, description=self.description) + output = model(_input.to_string()) + res = output_parser.parse(output) + print('Revised answer:', res['answer']) + return res + + def get_answer_type(self, question, answer): + is_solved_schema = ResponseSchema( + name='is_solved', + description='Did the provided answer address the question? \ + Answer true if yes,\ + false if not or unknown, for example, the answer ' + "is' 没有提到 'or' 无法确定 '.") + response_schemas = [is_solved_schema] + output_parser = StructuredOutputParser.from_response_schemas( + response_schemas) + format_instructions = output_parser.get_format_instructions() + prompt = PromptTemplate( + template=''' + You need to determine if the answer below answers the question.\n + {format_instructions}\n + The question is: {question}\n + The answer is: {answer}\n + ''', + input_variables=['question', 'answer'], + partial_variables={'format_instructions': format_instructions}) + model = OpenAI(temperature=0, model_name='gpt-3.5-turbo-16k') + _input = prompt.format_prompt(question=question, answer=answer) + output = model(_input.to_string()) + res = output_parser.parse(output) + return res + + +def num_tokens_from_string(string: str) -> int: + """Returns the number of tokens in a text string.""" + encoding = tiktoken.encoding_for_model('gpt-3.5-turbo-16k') + num_tokens = len(encoding.encode(string)) + return num_tokens + + +if __name__ == '__main__': + import pdb + + pdb.set_trace() diff --git a/projects/videochat/configs.ini b/projects/videochat/configs.ini new file mode 100644 index 0000000000..6749892e33 --- /dev/null +++ b/projects/videochat/configs.ini @@ -0,0 +1,15 @@ +[Arguments] +device = 0 +videos_path = /mnt/data/test_video +output_path = /mnt/videochat/output +images_path = /mnt/videochat/images +evaluate_path = /mnt/videochat/output/20230915014800 +appid = None +secret_key = None +segment_length = 5 +remarks = remarks +llm = gpt-3.5-turbo-16k +predict = True +evaluate = True +mode = normal +qa_mode = normal diff --git a/projects/videochat/configs/med_config.json b/projects/videochat/configs/med_config.json new file mode 100644 index 0000000000..391d5ca7e9 --- /dev/null +++ b/projects/videochat/configs/med_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30524, + "encoder_width": 768, + "add_cross_attention": true +} diff --git a/projects/videochat/configs/q2l_config.json b/projects/videochat/configs/q2l_config.json new file mode 100644 index 0000000000..1b7443c8f5 --- /dev/null +++ b/projects/videochat/configs/q2l_config.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 4, + "num_hidden_layers": 2, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30522, + "encoder_width": 768, + "add_cross_attention": true, + "add_tag_cross_attention": false + } diff --git a/projects/videochat/configs/swin/config_swinB_224.json b/projects/videochat/configs/swin/config_swinB_224.json new file mode 100644 index 0000000000..f85acd9451 --- /dev/null +++ b/projects/videochat/configs/swin/config_swinB_224.json @@ -0,0 +1,9 @@ +{ + "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", + "vision_width": 1024, + "image_res": 224, + "window_size": 7, + "embed_dim": 128, + "depths": [ 2, 2, 18, 2 ], + "num_heads": [ 4, 8, 16, 32 ] + } diff --git a/projects/videochat/configs/swin/config_swinB_384.json b/projects/videochat/configs/swin/config_swinB_384.json new file mode 100644 index 0000000000..82a68889cb --- /dev/null +++ b/projects/videochat/configs/swin/config_swinB_384.json @@ -0,0 +1,9 @@ +{ + "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", + "vision_width": 1024, + "image_res": 384, + "window_size": 12, + "embed_dim": 128, + "depths": [ 2, 2, 18, 2 ], + "num_heads": [ 4, 8, 16, 32 ] + } diff --git a/projects/videochat/configs/swin/config_swinB_480.json b/projects/videochat/configs/swin/config_swinB_480.json new file mode 100644 index 0000000000..f1adb58950 --- /dev/null +++ b/projects/videochat/configs/swin/config_swinB_480.json @@ -0,0 +1,9 @@ +{ + "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", + "vision_width": 1024, + "image_res": 480, + "window_size": 15, + "embed_dim": 128, + "depths": [ 2, 2, 18, 2 ], + "num_heads": [ 4, 8, 16, 32 ] +} diff --git a/projects/videochat/configs/swin/config_swinB_576.json b/projects/videochat/configs/swin/config_swinB_576.json new file mode 100644 index 0000000000..f6220dac7e --- /dev/null +++ b/projects/videochat/configs/swin/config_swinB_576.json @@ -0,0 +1,9 @@ +{ + "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", + "vision_width": 1024, + "image_res": 576, + "window_size": 18, + "embed_dim": 128, + "depths": [ 2, 2, 18, 2 ], + "num_heads": [ 4, 8, 16, 32 ] +} diff --git a/projects/videochat/configs/swin/config_swinB_608.json b/projects/videochat/configs/swin/config_swinB_608.json new file mode 100644 index 0000000000..453e704338 --- /dev/null +++ b/projects/videochat/configs/swin/config_swinB_608.json @@ -0,0 +1,9 @@ +{ + "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", + "vision_width": 1024, + "image_res": 608, + "window_size": 19, + "embed_dim": 128, + "depths": [ 2, 2, 18, 2 ], + "num_heads": [ 4, 8, 16, 32 ] +} diff --git a/projects/videochat/configs/tag2text_caption.yaml b/projects/videochat/configs/tag2text_caption.yaml new file mode 100644 index 0000000000..51dc658127 --- /dev/null +++ b/projects/videochat/configs/tag2text_caption.yaml @@ -0,0 +1,32 @@ +image_root: '/home/notebook/data/group/projects/tagging/caption/datasets/public/coco/' + +ann_root: 'dataset/caption_dataset' +coco_gt_root: 'dataset/caption_dataset' + +pretrained: '/home/notebook/code/personal/S9049611/BLIP/output/pretrain_caption_tagtotext_v2_bert_asl' + +# size of vit model; base or large +vit: 'swin_b' +vit_grad_ckpt: False +vit_ckpt_layer: 0 + +batch_size: 35 +init_lr: 5e-6 + +image_size: 384 + +# generation configs +max_length: 20 +min_length: 5 +num_beams: 3 +prompt: 'a picture of ' + +# optimizer +weight_decay: 0.05 +min_lr: 0 +max_epoch: 10 + +text_pretrain: 'bert' + +class_num: 3429 +threshold: 0.7 diff --git a/projects/videochat/eval.py b/projects/videochat/eval.py new file mode 100644 index 0000000000..9d56bbb7e9 --- /dev/null +++ b/projects/videochat/eval.py @@ -0,0 +1,62 @@ +import json +import os +import time + +import openai + + +def get_answer(question, answer_list): + try: + answer = '\n'.join(answer_list) + query = f'''问题:"{question}"\n回答列表:\n"{answer}"''' + response = openai.ChatCompletion.create( + model='gpt-4', + messages=[{ + 'role': + 'system', + 'content': + '假设有一个视频,用户从其中提取了若干帧图像,并对每一帧图像都提问了相同的问' + '题,并有一个视觉模型做出了相应的回答。但实际上该问题只与其中的部分图像相关' + ',因此你需要从其中筛选并整合出比较合理的回答。以如下json格式回答' + "{'answer':string, 'reason':string}" + }, { + 'role': 'user', + 'content': query, + }], + temperature=1, + max_tokens=1024, + top_p=1, + frequency_penalty=0, + presence_penalty=0) + answer = response.choices[0].message.content + answer = answer.strip('{}').replace("\"", "'") + result = dict() + result['answer'] = answer.split( + ', \'reason\':')[0][len('\'answer\': '):].strip('\'') + print(answer) + except openai.error.RateLimitError as e: + print(e) + time.sleep(60) # Wait for 1 minute before retrying + return get_answer(question, answer_list) + return result['answer'] + + +def main(): + main_path = '/mnt/data.coronaryct.1/ZhuYichen/videochat/output' \ + '/20230915014800 ' + for folder in os.listdir(main_path): + file_path = os.path.join( + os.path.join(main_path, folder), 'output.json') + new_file_path = os.path.join( + os.path.join(main_path, folder), 'output_final.json') + with open(file_path) as file: + data = json.load(file) + for i, qa in enumerate(data['qa']): + answer = get_answer(qa['q'], qa['predict']) + data['qa'][i]['predict'] = answer + with open(new_file_path, 'w') as json_file: + json.dump(data, json_file, indent=4, ensure_ascii=False) + + +if __name__ == '__main__': + main() diff --git a/projects/videochat/models/__init__.py b/projects/videochat/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/__init__.py b/projects/videochat/models/centernet/__init__.py new file mode 100644 index 0000000000..2c743bfb4b --- /dev/null +++ b/projects/videochat/models/centernet/__init__.py @@ -0,0 +1,8 @@ +from .modeling.dense_heads.centernet import CenterNet +from .modeling.meta_arch.centernet_detector import CenterNetDetector +from .modeling.roi_heads.custom_roi_heads import (CustomCascadeROIHeads, + CustomROIHeads) + + +def main(): + print(CenterNet, CenterNetDetector, CustomCascadeROIHeads, CustomROIHeads) diff --git a/projects/videochat/models/centernet/config.py b/projects/videochat/models/centernet/config.py new file mode 100644 index 0000000000..040cc40790 --- /dev/null +++ b/projects/videochat/models/centernet/config.py @@ -0,0 +1,89 @@ +from detectron2.config import CfgNode as CN + + +def add_centernet_config(cfg): + _C = cfg + + _C.MODEL.CENTERNET = CN() + _C.MODEL.CENTERNET.NUM_CLASSES = 80 + _C.MODEL.CENTERNET.IN_FEATURES = ['p3', 'p4', 'p5', 'p6', 'p7'] + _C.MODEL.CENTERNET.FPN_STRIDES = [8, 16, 32, 64, 128] + _C.MODEL.CENTERNET.PRIOR_PROB = 0.01 + _C.MODEL.CENTERNET.INFERENCE_TH = 0.05 + _C.MODEL.CENTERNET.CENTER_NMS = False + _C.MODEL.CENTERNET.NMS_TH_TRAIN = 0.6 + _C.MODEL.CENTERNET.NMS_TH_TEST = 0.6 + _C.MODEL.CENTERNET.PRE_NMS_TOPK_TRAIN = 1000 + _C.MODEL.CENTERNET.POST_NMS_TOPK_TRAIN = 100 + _C.MODEL.CENTERNET.PRE_NMS_TOPK_TEST = 1000 + _C.MODEL.CENTERNET.POST_NMS_TOPK_TEST = 100 + _C.MODEL.CENTERNET.NORM = 'GN' + _C.MODEL.CENTERNET.USE_DEFORMABLE = False + _C.MODEL.CENTERNET.NUM_CLS_CONVS = 4 + _C.MODEL.CENTERNET.NUM_BOX_CONVS = 4 + _C.MODEL.CENTERNET.NUM_SHARE_CONVS = 0 + _C.MODEL.CENTERNET.LOC_LOSS_TYPE = 'giou' + _C.MODEL.CENTERNET.SIGMOID_CLAMP = 1e-4 + _C.MODEL.CENTERNET.HM_MIN_OVERLAP = 0.8 + _C.MODEL.CENTERNET.MIN_RADIUS = 4 + _C.MODEL.CENTERNET.SOI = [[0, 80], [64, 160], [128, 320], [256, 640], + [512, 10000000]] + _C.MODEL.CENTERNET.POS_WEIGHT = 1. + _C.MODEL.CENTERNET.NEG_WEIGHT = 1. + _C.MODEL.CENTERNET.REG_WEIGHT = 2. + _C.MODEL.CENTERNET.HM_FOCAL_BETA = 4 + _C.MODEL.CENTERNET.HM_FOCAL_ALPHA = 0.25 + _C.MODEL.CENTERNET.LOSS_GAMMA = 2.0 + _C.MODEL.CENTERNET.WITH_AGN_HM = False + _C.MODEL.CENTERNET.ONLY_PROPOSAL = False + _C.MODEL.CENTERNET.AS_PROPOSAL = False + _C.MODEL.CENTERNET.IGNORE_HIGH_FP = -1. + _C.MODEL.CENTERNET.MORE_POS = False + _C.MODEL.CENTERNET.MORE_POS_THRESH = 0.2 + _C.MODEL.CENTERNET.MORE_POS_TOPK = 9 + _C.MODEL.CENTERNET.NOT_NORM_REG = True + _C.MODEL.CENTERNET.NOT_NMS = False + _C.MODEL.CENTERNET.NO_REDUCE = False + + _C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False + _C.MODEL.ROI_BOX_HEAD.PRIOR_PROB = 0.01 + _C.MODEL.ROI_BOX_HEAD.USE_EQL_LOSS = False + _C.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = \ + 'datasets/lvis/lvis_v1_train_cat_info.json' + _C.MODEL.ROI_BOX_HEAD.EQL_FREQ_CAT = 200 + _C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False + _C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT = 50 + _C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT = 0.5 + _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False + + _C.MODEL.BIFPN = CN() + _C.MODEL.BIFPN.NUM_LEVELS = 5 + _C.MODEL.BIFPN.NUM_BIFPN = 6 + _C.MODEL.BIFPN.NORM = 'GN' + _C.MODEL.BIFPN.OUT_CHANNELS = 160 + _C.MODEL.BIFPN.SEPARABLE_CONV = False + + _C.MODEL.DLA = CN() + _C.MODEL.DLA.OUT_FEATURES = ['dla2'] + _C.MODEL.DLA.USE_DLA_UP = True + _C.MODEL.DLA.NUM_LAYERS = 34 + _C.MODEL.DLA.MS_OUTPUT = False + _C.MODEL.DLA.NORM = 'BN' + _C.MODEL.DLA.DLAUP_IN_FEATURES = ['dla3', 'dla4', 'dla5'] + _C.MODEL.DLA.DLAUP_NODE = 'conv' + + _C.SOLVER.RESET_ITER = False + _C.SOLVER.TRAIN_ITER = -1 + + _C.INPUT.CUSTOM_AUG = '' + _C.INPUT.TRAIN_SIZE = 640 + _C.INPUT.TEST_SIZE = 640 + _C.INPUT.SCALE_RANGE = (0.1, 2.) + # 'default' for fixed short/ long edge, 'square' for max size=INPUT.SIZE + _C.INPUT.TEST_INPUT_TYPE = 'default' + + _C.DEBUG = False + _C.SAVE_DEBUG = False + _C.SAVE_PTH = False + _C.VIS_THRESH = 0.3 + _C.DEBUG_SHOW_NAME = False diff --git a/projects/videochat/models/centernet/data/__init__.py b/projects/videochat/models/centernet/data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/data/custom_build_augmentation.py b/projects/videochat/models/centernet/data/custom_build_augmentation.py new file mode 100644 index 0000000000..309474b9c8 --- /dev/null +++ b/projects/videochat/models/centernet/data/custom_build_augmentation.py @@ -0,0 +1,42 @@ +from detectron2.data import transforms as T + +from .transforms.custom_augmentation_impl import EfficientDetResizeCrop + + +def build_custom_augmentation(cfg, is_train): + """Create a list of default :class:`Augmentation` from config. Now it + includes resizing and flipping. + + Returns: + list[Augmentation] + """ + if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge': + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = 'choice' + augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] + elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop': + if is_train: + scale = cfg.INPUT.SCALE_RANGE + size = cfg.INPUT.TRAIN_SIZE + else: + scale = (1, 1) + size = cfg.INPUT.TEST_SIZE + augmentation = [EfficientDetResizeCrop(size, scale)] + else: + assert 0, cfg.INPUT.CUSTOM_AUG + + if is_train: + augmentation.append(T.RandomFlip()) + return augmentation + + +build_custom_transform_gen = build_custom_augmentation +""" +Alias for backward-compatibility. +""" diff --git a/projects/videochat/models/centernet/data/custom_dataset_dataloader.py b/projects/videochat/models/centernet/data/custom_dataset_dataloader.py new file mode 100644 index 0000000000..43d5549a40 --- /dev/null +++ b/projects/videochat/models/centernet/data/custom_dataset_dataloader.py @@ -0,0 +1,234 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import itertools +import logging +from collections import defaultdict +from typing import Optional + +import torch +import torch.utils.data +from detectron2.data.build import (build_batch_data_loader, + check_metadata_consistency, + filter_images_with_few_keypoints, + filter_images_with_only_crowd_annotations, + get_detection_dataset_dicts, + print_instances_class_histogram) +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog +from detectron2.data.common import DatasetFromList, MapDataset +from detectron2.data.samplers import (RepeatFactorTrainingSampler, + TrainingSampler) +from detectron2.utils import comm +from torch.utils.data.sampler import Sampler + +# from .custom_build_augmentation import build_custom_augmentation + + +def build_custom_train_loader(cfg, mapper=None): + """Modified from detectron2.data.build.build_custom_train_loader, but + supports different samplers.""" + source_aware = cfg.DATALOADER.SOURCE_AWARE + if source_aware: + dataset_dicts = get_detection_dataset_dicts_with_source( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN + if cfg.MODEL.LOAD_PROPOSALS else None, + ) + sizes = [0 for _ in range(len(cfg.DATASETS.TRAIN))] + for d in dataset_dicts: + sizes[d['dataset_source']] += 1 + print('dataset sizes', sizes) + else: + dataset_dicts = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN + if cfg.MODEL.LOAD_PROPOSALS else None, + ) + dataset = DatasetFromList(dataset_dicts, copy=False) + + if mapper is None: + assert 0 + # mapper = DatasetMapper(cfg, True) + dataset = MapDataset(dataset, mapper) + + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + logger = logging.getLogger(__name__) + logger.info('Using training sampler {}'.format(sampler_name)) + if sampler_name == 'TrainingSampler': + sampler = TrainingSampler(len(dataset)) + elif sampler_name == 'MultiDatasetSampler': + assert source_aware + sampler = MultiDatasetSampler(cfg, sizes, dataset_dicts) + elif sampler_name == 'RepeatFactorTrainingSampler': + repeat_factors = \ + RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD) + sampler = RepeatFactorTrainingSampler(repeat_factors) + elif sampler_name == 'ClassAwareSampler': + sampler = ClassAwareSampler(dataset_dicts) + else: + raise ValueError('Unknown training sampler: {}'.format(sampler_name)) + + return build_batch_data_loader( + dataset, + sampler, + cfg.SOLVER.IMS_PER_BATCH, + aspect_ratio_grouping=cfg.DATALOADER.ASPECT_RATIO_GROUPING, + num_workers=cfg.DATALOADER.NUM_WORKERS, + ) + + +class ClassAwareSampler(Sampler): + + def __init__(self, dataset_dicts, seed: Optional[int] = None): + """ + Args: size (int): the total number of data of the underlying dataset + to sample from seed (int): the initial seed of the shuffle. Must be + the same across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + self._size = len(dataset_dicts) + assert self._size > 0 + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + self.weights = self._get_class_balance_factor(dataset_dicts) + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, + self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + ids = torch.multinomial( + self.weights, self._size, generator=g, replacement=True) + yield from ids + + def _get_class_balance_factor(self, dataset_dicts, ll=1.): + # 1. For each category c, compute the fraction of images that + # contain it: f(c) + ret = [] + category_freq = defaultdict(int) + for dataset_dict in dataset_dicts: # For each image (without repeats) + cat_ids = { + ann['category_id'] + for ann in dataset_dict['annotations'] + } + for cat_id in cat_ids: + category_freq[cat_id] += 1 + for i, dataset_dict in enumerate(dataset_dicts): + cat_ids = { + ann['category_id'] + for ann in dataset_dict['annotations'] + } + ret.append( + sum([1. / (category_freq[cat_id]**ll) for cat_id in cat_ids])) + return torch.tensor(ret).float() + + +def get_detection_dataset_dicts_with_source(dataset_names, + filter_empty=True, + min_keypoints=0, + proposal_files=None): + assert len(dataset_names) + dataset_dicts = [ + DatasetCatalog.get(dataset_name) for dataset_name in dataset_names + ] + for dataset_name, dicts in zip(dataset_names, dataset_dicts): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + + for source_id, (dataset_name, + dicts) in enumerate(zip(dataset_names, dataset_dicts)): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + for d in dicts: + d['dataset_source'] = source_id + + if 'annotations' in dicts[0]: + try: + class_names = MetadataCatalog.get(dataset_name).thing_classes + check_metadata_consistency('thing_classes', dataset_name) + print_instances_class_histogram(dicts, class_names) + except AttributeError: + # class names are not available for this dataset + pass + + assert proposal_files is None + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = 'annotations' in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations( + dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints( + dataset_dicts, min_keypoints) + + return dataset_dicts + + +class MultiDatasetSampler(Sampler): + + def __init__(self, cfg, sizes, dataset_dicts, seed: Optional[int] = None): + """ + Args: size (int): the total number of data of the underlying dataset + to sample from seed (int): the initial seed of the shuffle. Must be + the same across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + self.sizes = sizes + dataset_ratio = cfg.DATALOADER.DATASET_RATIO + self._batch_size = cfg.SOLVER.IMS_PER_BATCH + assert len(dataset_ratio) == len(sizes), \ + 'length of dataset ratio ' \ + '{} should be equal to number if dataset {}'.format( + len(dataset_ratio), len(sizes) + ) + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + self._ims_per_gpu = self._batch_size // self._world_size + self.dataset_ids = torch.tensor( + [d['dataset_source'] for d in dataset_dicts], dtype=torch.long) + + dataset_weight = \ + [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) + for i, (r, s) in enumerate(zip(dataset_ratio, sizes))] + dataset_weight = torch.cat(dataset_weight) + self.weights = dataset_weight + self.sample_epoch_size = len(self.weights) + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, + self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + ids = torch.multinomial( + self.weights, + self.sample_epoch_size, + generator=g, + replacement=True) + nums = [(self.dataset_ids[ids] == i).sum().int().item() + for i in range(len(self.sizes))] + print('_rank, len, nums', self._rank, len(ids), nums, flush=True) + # print('_rank, len, nums, self.dataset_ids[ids[:10]], ', + # self._rank, len(ids), nums, self.dataset_ids[ids[:10]], + # flush=True) + yield from ids diff --git a/projects/videochat/models/centernet/data/datasets/__init__.py b/projects/videochat/models/centernet/data/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/data/datasets/coco.py b/projects/videochat/models/centernet/data/datasets/coco.py new file mode 100644 index 0000000000..98a556baad --- /dev/null +++ b/projects/videochat/models/centernet/data/datasets/coco.py @@ -0,0 +1,55 @@ +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.builtin_meta import _get_builtin_metadata +from detectron2.data.datasets.coco import load_coco_json +from detectron2.data.datasets.register_coco import register_coco_instances + + +def register_distill_coco_instances(name, metadata, json_file, image_root): + """add extra_annotation_keys.""" + assert isinstance(name, str), name + assert isinstance(json_file, (str, os.PathLike)), json_file + assert isinstance(image_root, (str, os.PathLike)), image_root + # 1. register a function which returns dicts + DatasetCatalog.register( + name, lambda: load_coco_json( + json_file, image_root, name, extra_annotation_keys=['score'])) + + # 2. Optionally, add metadata about this dataset, + # since they might be useful in evaluation, visualization or logging + MetadataCatalog.get(name).set( + json_file=json_file, + image_root=image_root, + evaluator_type='coco', + **metadata) + + +_PREDEFINED_SPLITS_COCO = { + 'coco_2017_unlabeled': + ('coco/unlabeled2017', 'coco/annotations/image_info_unlabeled2017.json'), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_COCO.items(): + register_coco_instances( + key, + _get_builtin_metadata('coco'), + os.path.join('datasets', json_file) + if '://' not in json_file else json_file, + os.path.join('datasets', image_root), + ) + +_PREDEFINED_SPLITS_DISTILL_COCO = { + 'coco_un_yolov4_55_0.5': + ('coco/unlabeled2017', + 'coco/annotations/yolov4_cocounlabeled_55_ann0.5.json'), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_DISTILL_COCO.items(): + register_distill_coco_instances( + key, + _get_builtin_metadata('coco'), + os.path.join('datasets', json_file) + if '://' not in json_file else json_file, + os.path.join('datasets', image_root), + ) diff --git a/projects/videochat/models/centernet/data/datasets/nuimages.py b/projects/videochat/models/centernet/data/datasets/nuimages.py new file mode 100644 index 0000000000..1d40aa8821 --- /dev/null +++ b/projects/videochat/models/centernet/data/datasets/nuimages.py @@ -0,0 +1,75 @@ +import os + +from detectron2.data.datasets.register_coco import register_coco_instances + +categories = [ + { + 'id': 0, + 'name': 'car' + }, + { + 'id': 1, + 'name': 'truck' + }, + { + 'id': 2, + 'name': 'trailer' + }, + { + 'id': 3, + 'name': 'bus' + }, + { + 'id': 4, + 'name': 'construction_vehicle' + }, + { + 'id': 5, + 'name': 'bicycle' + }, + { + 'id': 6, + 'name': 'motorcycle' + }, + { + 'id': 7, + 'name': 'pedestrian' + }, + { + 'id': 8, + 'name': 'traffic_cone' + }, + { + 'id': 9, + 'name': 'barrier' + }, +] + + +def _get_builtin_metadata(): + id_to_name = {x['id']: x['name'] for x in categories} + thing_dataset_id_to_contiguous_id = {i: i for i in range(len(categories))} + thing_classes = [id_to_name[k] for k in sorted(id_to_name)] + return { + 'thing_dataset_id_to_contiguous_id': thing_dataset_id_to_contiguous_id, + 'thing_classes': thing_classes + } + + +_PREDEFINED_SPLITS = { + 'nuimages_train': + ('nuimages', 'nuimages/annotations/nuimages_v1.0-train.json'), + 'nuimages_val': + ('nuimages', 'nuimages/annotations/nuimages_v1.0-val.json'), + 'nuimages_mini': + ('nuimages', 'nuimages/annotations/nuimages_v1.0-mini.json'), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS.items(): + register_coco_instances( + key, + _get_builtin_metadata(), + os.path.join('datasets', json_file) + if '://' not in json_file else json_file, + os.path.join('datasets', image_root), + ) diff --git a/projects/videochat/models/centernet/data/datasets/objects365.py b/projects/videochat/models/centernet/data/datasets/objects365.py new file mode 100644 index 0000000000..d248157652 --- /dev/null +++ b/projects/videochat/models/centernet/data/datasets/objects365.py @@ -0,0 +1,1496 @@ +import os + +from detectron2.data.datasets.register_coco import register_coco_instances + +categories_v1 = [ + { + 'id': 164, + 'name': 'cutting/chopping board' + }, + { + 'id': 49, + 'name': 'tie' + }, + { + 'id': 306, + 'name': 'crosswalk sign' + }, + { + 'id': 145, + 'name': 'gun' + }, + { + 'id': 14, + 'name': 'street lights' + }, + { + 'id': 223, + 'name': 'bar soap' + }, + { + 'id': 74, + 'name': 'wild bird' + }, + { + 'id': 219, + 'name': 'ice cream' + }, + { + 'id': 37, + 'name': 'stool' + }, + { + 'id': 25, + 'name': 'storage box' + }, + { + 'id': 153, + 'name': 'giraffe' + }, + { + 'id': 52, + 'name': 'pen/pencil' + }, + { + 'id': 61, + 'name': 'high heels' + }, + { + 'id': 340, + 'name': 'mangosteen' + }, + { + 'id': 22, + 'name': 'bracelet' + }, + { + 'id': 155, + 'name': 'piano' + }, + { + 'id': 162, + 'name': 'vent' + }, + { + 'id': 75, + 'name': 'laptop' + }, + { + 'id': 236, + 'name': 'toaster' + }, + { + 'id': 231, + 'name': 'fire truck' + }, + { + 'id': 42, + 'name': 'basket' + }, + { + 'id': 150, + 'name': 'zebra' + }, + { + 'id': 124, + 'name': 'head phone' + }, + { + 'id': 90, + 'name': 'sheep' + }, + { + 'id': 322, + 'name': 'steak' + }, + { + 'id': 39, + 'name': 'couch' + }, + { + 'id': 209, + 'name': 'toothbrush' + }, + { + 'id': 59, + 'name': 'bicycle' + }, + { + 'id': 336, + 'name': 'red cabbage' + }, + { + 'id': 228, + 'name': 'golf ball' + }, + { + 'id': 120, + 'name': 'tomato' + }, + { + 'id': 132, + 'name': 'computer box' + }, + { + 'id': 8, + 'name': 'cup' + }, + { + 'id': 183, + 'name': 'basketball' + }, + { + 'id': 298, + 'name': 'butterfly' + }, + { + 'id': 250, + 'name': 'garlic' + }, + { + 'id': 12, + 'name': 'desk' + }, + { + 'id': 141, + 'name': 'microwave' + }, + { + 'id': 171, + 'name': 'strawberry' + }, + { + 'id': 200, + 'name': 'kettle' + }, + { + 'id': 63, + 'name': 'van' + }, + { + 'id': 300, + 'name': 'cheese' + }, + { + 'id': 215, + 'name': 'marker' + }, + { + 'id': 100, + 'name': 'blackboard/whiteboard' + }, + { + 'id': 186, + 'name': 'printer' + }, + { + 'id': 333, + 'name': 'bread/bun' + }, + { + 'id': 243, + 'name': 'penguin' + }, + { + 'id': 364, + 'name': 'iron' + }, + { + 'id': 180, + 'name': 'ladder' + }, + { + 'id': 34, + 'name': 'flag' + }, + { + 'id': 78, + 'name': 'cell phone' + }, + { + 'id': 97, + 'name': 'fan' + }, + { + 'id': 224, + 'name': 'scale' + }, + { + 'id': 151, + 'name': 'duck' + }, + { + 'id': 319, + 'name': 'flute' + }, + { + 'id': 156, + 'name': 'stop sign' + }, + { + 'id': 290, + 'name': 'rickshaw' + }, + { + 'id': 128, + 'name': 'sailboat' + }, + { + 'id': 165, + 'name': 'tennis racket' + }, + { + 'id': 241, + 'name': 'cigar' + }, + { + 'id': 101, + 'name': 'balloon' + }, + { + 'id': 308, + 'name': 'hair drier' + }, + { + 'id': 167, + 'name': 'skating and skiing shoes' + }, + { + 'id': 237, + 'name': 'helicopter' + }, + { + 'id': 65, + 'name': 'sink' + }, + { + 'id': 129, + 'name': 'tangerine' + }, + { + 'id': 330, + 'name': 'crab' + }, + { + 'id': 320, + 'name': 'measuring cup' + }, + { + 'id': 260, + 'name': 'fishing rod' + }, + { + 'id': 346, + 'name': 'saw' + }, + { + 'id': 216, + 'name': 'ship' + }, + { + 'id': 46, + 'name': 'coffee table' + }, + { + 'id': 194, + 'name': 'facial mask' + }, + { + 'id': 281, + 'name': 'stapler' + }, + { + 'id': 118, + 'name': 'refrigerator' + }, + { + 'id': 40, + 'name': 'belt' + }, + { + 'id': 349, + 'name': 'starfish' + }, + { + 'id': 87, + 'name': 'hanger' + }, + { + 'id': 116, + 'name': 'baseball glove' + }, + { + 'id': 261, + 'name': 'cherry' + }, + { + 'id': 334, + 'name': 'baozi' + }, + { + 'id': 267, + 'name': 'screwdriver' + }, + { + 'id': 158, + 'name': 'converter' + }, + { + 'id': 335, + 'name': 'lion' + }, + { + 'id': 170, + 'name': 'baseball' + }, + { + 'id': 111, + 'name': 'skis' + }, + { + 'id': 136, + 'name': 'broccoli' + }, + { + 'id': 342, + 'name': 'eraser' + }, + { + 'id': 337, + 'name': 'polar bear' + }, + { + 'id': 139, + 'name': 'shovel' + }, + { + 'id': 193, + 'name': 'extension cord' + }, + { + 'id': 284, + 'name': 'goldfish' + }, + { + 'id': 174, + 'name': 'pepper' + }, + { + 'id': 138, + 'name': 'stroller' + }, + { + 'id': 328, + 'name': 'yak' + }, + { + 'id': 83, + 'name': 'clock' + }, + { + 'id': 235, + 'name': 'tricycle' + }, + { + 'id': 248, + 'name': 'parking meter' + }, + { + 'id': 274, + 'name': 'trophy' + }, + { + 'id': 324, + 'name': 'binoculars' + }, + { + 'id': 51, + 'name': 'traffic light' + }, + { + 'id': 314, + 'name': 'donkey' + }, + { + 'id': 45, + 'name': 'barrel/bucket' + }, + { + 'id': 292, + 'name': 'pomegranate' + }, + { + 'id': 13, + 'name': 'handbag' + }, + { + 'id': 262, + 'name': 'tablet' + }, + { + 'id': 68, + 'name': 'apple' + }, + { + 'id': 226, + 'name': 'cabbage' + }, + { + 'id': 23, + 'name': 'flower' + }, + { + 'id': 58, + 'name': 'faucet' + }, + { + 'id': 206, + 'name': 'tong' + }, + { + 'id': 291, + 'name': 'trombone' + }, + { + 'id': 160, + 'name': 'carrot' + }, + { + 'id': 172, + 'name': 'bow tie' + }, + { + 'id': 122, + 'name': 'tent' + }, + { + 'id': 163, + 'name': 'cookies' + }, + { + 'id': 115, + 'name': 'remote' + }, + { + 'id': 175, + 'name': 'coffee machine' + }, + { + 'id': 238, + 'name': 'green beans' + }, + { + 'id': 233, + 'name': 'cello' + }, + { + 'id': 28, + 'name': 'wine glass' + }, + { + 'id': 295, + 'name': 'mushroom' + }, + { + 'id': 344, + 'name': 'scallop' + }, + { + 'id': 125, + 'name': 'lantern' + }, + { + 'id': 123, + 'name': 'shampoo/shower gel' + }, + { + 'id': 285, + 'name': 'meat balls' + }, + { + 'id': 266, + 'name': 'key' + }, + { + 'id': 296, + 'name': 'calculator' + }, + { + 'id': 168, + 'name': 'scissors' + }, + { + 'id': 103, + 'name': 'cymbal' + }, + { + 'id': 6, + 'name': 'bottle' + }, + { + 'id': 264, + 'name': 'nuts' + }, + { + 'id': 234, + 'name': 'notepaper' + }, + { + 'id': 211, + 'name': 'mango' + }, + { + 'id': 287, + 'name': 'toothpaste' + }, + { + 'id': 196, + 'name': 'chopsticks' + }, + { + 'id': 140, + 'name': 'baseball bat' + }, + { + 'id': 244, + 'name': 'hurdle' + }, + { + 'id': 195, + 'name': 'tennis ball' + }, + { + 'id': 144, + 'name': 'surveillance camera' + }, + { + 'id': 271, + 'name': 'volleyball' + }, + { + 'id': 94, + 'name': 'keyboard' + }, + { + 'id': 339, + 'name': 'seal' + }, + { + 'id': 11, + 'name': 'picture/frame' + }, + { + 'id': 348, + 'name': 'okra' + }, + { + 'id': 191, + 'name': 'sausage' + }, + { + 'id': 166, + 'name': 'candy' + }, + { + 'id': 62, + 'name': 'ring' + }, + { + 'id': 311, + 'name': 'dolphin' + }, + { + 'id': 273, + 'name': 'eggplant' + }, + { + 'id': 84, + 'name': 'drum' + }, + { + 'id': 143, + 'name': 'surfboard' + }, + { + 'id': 288, + 'name': 'antelope' + }, + { + 'id': 204, + 'name': 'clutch' + }, + { + 'id': 207, + 'name': 'slide' + }, + { + 'id': 43, + 'name': 'towel/napkin' + }, + { + 'id': 352, + 'name': 'durian' + }, + { + 'id': 276, + 'name': 'board eraser' + }, + { + 'id': 315, + 'name': 'electric drill' + }, + { + 'id': 312, + 'name': 'sushi' + }, + { + 'id': 198, + 'name': 'pie' + }, + { + 'id': 106, + 'name': 'pickup truck' + }, + { + 'id': 176, + 'name': 'bathtub' + }, + { + 'id': 26, + 'name': 'vase' + }, + { + 'id': 133, + 'name': 'elephant' + }, + { + 'id': 256, + 'name': 'sandwich' + }, + { + 'id': 327, + 'name': 'noodles' + }, + { + 'id': 10, + 'name': 'glasses' + }, + { + 'id': 109, + 'name': 'airplane' + }, + { + 'id': 95, + 'name': 'tripod' + }, + { + 'id': 247, + 'name': 'CD' + }, + { + 'id': 121, + 'name': 'machinery vehicle' + }, + { + 'id': 365, + 'name': 'flashlight' + }, + { + 'id': 53, + 'name': 'microphone' + }, + { + 'id': 270, + 'name': 'pliers' + }, + { + 'id': 362, + 'name': 'chainsaw' + }, + { + 'id': 259, + 'name': 'bear' + }, + { + 'id': 197, + 'name': 'electronic stove and gas stove' + }, + { + 'id': 89, + 'name': 'pot/pan' + }, + { + 'id': 220, + 'name': 'tape' + }, + { + 'id': 338, + 'name': 'lighter' + }, + { + 'id': 177, + 'name': 'snowboard' + }, + { + 'id': 214, + 'name': 'violin' + }, + { + 'id': 217, + 'name': 'chicken' + }, + { + 'id': 2, + 'name': 'sneakers' + }, + { + 'id': 161, + 'name': 'washing machine' + }, + { + 'id': 131, + 'name': 'kite' + }, + { + 'id': 354, + 'name': 'rabbit' + }, + { + 'id': 86, + 'name': 'bus' + }, + { + 'id': 275, + 'name': 'dates' + }, + { + 'id': 282, + 'name': 'camel' + }, + { + 'id': 88, + 'name': 'nightstand' + }, + { + 'id': 179, + 'name': 'grapes' + }, + { + 'id': 229, + 'name': 'pine apple' + }, + { + 'id': 56, + 'name': 'necklace' + }, + { + 'id': 18, + 'name': 'leather shoes' + }, + { + 'id': 358, + 'name': 'hoverboard' + }, + { + 'id': 345, + 'name': 'pencil case' + }, + { + 'id': 359, + 'name': 'pasta' + }, + { + 'id': 157, + 'name': 'radiator' + }, + { + 'id': 201, + 'name': 'hamburger' + }, + { + 'id': 268, + 'name': 'globe' + }, + { + 'id': 332, + 'name': 'barbell' + }, + { + 'id': 329, + 'name': 'mop' + }, + { + 'id': 252, + 'name': 'horn' + }, + { + 'id': 350, + 'name': 'eagle' + }, + { + 'id': 169, + 'name': 'folder' + }, + { + 'id': 137, + 'name': 'toilet' + }, + { + 'id': 5, + 'name': 'lamp' + }, + { + 'id': 27, + 'name': 'bench' + }, + { + 'id': 249, + 'name': 'swan' + }, + { + 'id': 76, + 'name': 'knife' + }, + { + 'id': 341, + 'name': 'comb' + }, + { + 'id': 64, + 'name': 'watch' + }, + { + 'id': 105, + 'name': 'telephone' + }, + { + 'id': 3, + 'name': 'chair' + }, + { + 'id': 33, + 'name': 'boat' + }, + { + 'id': 107, + 'name': 'orange' + }, + { + 'id': 60, + 'name': 'bread' + }, + { + 'id': 147, + 'name': 'cat' + }, + { + 'id': 135, + 'name': 'gas stove' + }, + { + 'id': 307, + 'name': 'papaya' + }, + { + 'id': 227, + 'name': 'router/modem' + }, + { + 'id': 357, + 'name': 'asparagus' + }, + { + 'id': 73, + 'name': 'motorcycle' + }, + { + 'id': 77, + 'name': 'traffic sign' + }, + { + 'id': 67, + 'name': 'fish' + }, + { + 'id': 326, + 'name': 'radish' + }, + { + 'id': 213, + 'name': 'egg' + }, + { + 'id': 203, + 'name': 'cucumber' + }, + { + 'id': 17, + 'name': 'helmet' + }, + { + 'id': 110, + 'name': 'luggage' + }, + { + 'id': 80, + 'name': 'truck' + }, + { + 'id': 199, + 'name': 'frisbee' + }, + { + 'id': 232, + 'name': 'peach' + }, + { + 'id': 1, + 'name': 'person' + }, + { + 'id': 29, + 'name': 'boots' + }, + { + 'id': 310, + 'name': 'chips' + }, + { + 'id': 142, + 'name': 'skateboard' + }, + { + 'id': 44, + 'name': 'slippers' + }, + { + 'id': 4, + 'name': 'hat' + }, + { + 'id': 178, + 'name': 'suitcase' + }, + { + 'id': 24, + 'name': 'tv' + }, + { + 'id': 119, + 'name': 'train' + }, + { + 'id': 82, + 'name': 'power outlet' + }, + { + 'id': 245, + 'name': 'swing' + }, + { + 'id': 15, + 'name': 'book' + }, + { + 'id': 294, + 'name': 'jellyfish' + }, + { + 'id': 192, + 'name': 'fire extinguisher' + }, + { + 'id': 212, + 'name': 'deer' + }, + { + 'id': 181, + 'name': 'pear' + }, + { + 'id': 347, + 'name': 'table tennis paddle' + }, + { + 'id': 113, + 'name': 'trolley' + }, + { + 'id': 91, + 'name': 'guitar' + }, + { + 'id': 202, + 'name': 'golf club' + }, + { + 'id': 221, + 'name': 'wheelchair' + }, + { + 'id': 254, + 'name': 'saxophone' + }, + { + 'id': 117, + 'name': 'paper towel' + }, + { + 'id': 303, + 'name': 'race car' + }, + { + 'id': 240, + 'name': 'carriage' + }, + { + 'id': 246, + 'name': 'radio' + }, + { + 'id': 318, + 'name': 'parrot' + }, + { + 'id': 251, + 'name': 'french fries' + }, + { + 'id': 98, + 'name': 'dog' + }, + { + 'id': 112, + 'name': 'soccer' + }, + { + 'id': 355, + 'name': 'french horn' + }, + { + 'id': 79, + 'name': 'paddle' + }, + { + 'id': 283, + 'name': 'lettuce' + }, + { + 'id': 9, + 'name': 'car' + }, + { + 'id': 258, + 'name': 'kiwi fruit' + }, + { + 'id': 325, + 'name': 'llama' + }, + { + 'id': 187, + 'name': 'billiards' + }, + { + 'id': 210, + 'name': 'facial cleanser' + }, + { + 'id': 81, + 'name': 'cow' + }, + { + 'id': 331, + 'name': 'microscope' + }, + { + 'id': 148, + 'name': 'lemon' + }, + { + 'id': 302, + 'name': 'pomelo' + }, + { + 'id': 85, + 'name': 'fork' + }, + { + 'id': 154, + 'name': 'pumpkin' + }, + { + 'id': 289, + 'name': 'shrimp' + }, + { + 'id': 71, + 'name': 'teddy bear' + }, + { + 'id': 184, + 'name': 'potato' + }, + { + 'id': 102, + 'name': 'air conditioner' + }, + { + 'id': 208, + 'name': 'hot dog' + }, + { + 'id': 222, + 'name': 'plum' + }, + { + 'id': 316, + 'name': 'spring rolls' + }, + { + 'id': 230, + 'name': 'crane' + }, + { + 'id': 149, + 'name': 'liquid soap' + }, + { + 'id': 55, + 'name': 'canned' + }, + { + 'id': 35, + 'name': 'speaker' + }, + { + 'id': 108, + 'name': 'banana' + }, + { + 'id': 297, + 'name': 'treadmill' + }, + { + 'id': 99, + 'name': 'spoon' + }, + { + 'id': 104, + 'name': 'mouse' + }, + { + 'id': 182, + 'name': 'american football' + }, + { + 'id': 299, + 'name': 'egg tart' + }, + { + 'id': 127, + 'name': 'cleaning products' + }, + { + 'id': 313, + 'name': 'urinal' + }, + { + 'id': 286, + 'name': 'medal' + }, + { + 'id': 239, + 'name': 'brush' + }, + { + 'id': 96, + 'name': 'hockey' + }, + { + 'id': 279, + 'name': 'dumbbell' + }, + { + 'id': 32, + 'name': 'umbrella' + }, + { + 'id': 272, + 'name': 'hammer' + }, + { + 'id': 16, + 'name': 'plate' + }, + { + 'id': 21, + 'name': 'potted plant' + }, + { + 'id': 242, + 'name': 'earphone' + }, + { + 'id': 70, + 'name': 'candle' + }, + { + 'id': 185, + 'name': 'paint brush' + }, + { + 'id': 48, + 'name': 'toy' + }, + { + 'id': 130, + 'name': 'pizza' + }, + { + 'id': 255, + 'name': 'trumpet' + }, + { + 'id': 361, + 'name': 'hotair balloon' + }, + { + 'id': 188, + 'name': 'fire hydrant' + }, + { + 'id': 50, + 'name': 'bed' + }, + { + 'id': 253, + 'name': 'avocado' + }, + { + 'id': 293, + 'name': 'coconut' + }, + { + 'id': 257, + 'name': 'cue' + }, + { + 'id': 280, + 'name': 'hamimelon' + }, + { + 'id': 66, + 'name': 'horse' + }, + { + 'id': 173, + 'name': 'pigeon' + }, + { + 'id': 190, + 'name': 'projector' + }, + { + 'id': 69, + 'name': 'camera' + }, + { + 'id': 30, + 'name': 'bowl' + }, + { + 'id': 269, + 'name': 'broom' + }, + { + 'id': 343, + 'name': 'pitaya' + }, + { + 'id': 305, + 'name': 'tuba' + }, + { + 'id': 309, + 'name': 'green onion' + }, + { + 'id': 363, + 'name': 'lobster' + }, + { + 'id': 225, + 'name': 'watermelon' + }, + { + 'id': 47, + 'name': 'suv' + }, + { + 'id': 31, + 'name': 'dining table' + }, + { + 'id': 54, + 'name': 'sandals' + }, + { + 'id': 351, + 'name': 'monkey' + }, + { + 'id': 218, + 'name': 'onion' + }, + { + 'id': 36, + 'name': 'trash bin/can' + }, + { + 'id': 20, + 'name': 'glove' + }, + { + 'id': 277, + 'name': 'rice' + }, + { + 'id': 152, + 'name': 'sports car' + }, + { + 'id': 360, + 'name': 'target' + }, + { + 'id': 205, + 'name': 'blender' + }, + { + 'id': 19, + 'name': 'pillow' + }, + { + 'id': 72, + 'name': 'cake' + }, + { + 'id': 93, + 'name': 'tea pot' + }, + { + 'id': 353, + 'name': 'game board' + }, + { + 'id': 38, + 'name': 'backpack' + }, + { + 'id': 356, + 'name': 'ambulance' + }, + { + 'id': 146, + 'name': 'life saver' + }, + { + 'id': 189, + 'name': 'goose' + }, + { + 'id': 278, + 'name': 'tape measure/ruler' + }, + { + 'id': 92, + 'name': 'traffic cone' + }, + { + 'id': 134, + 'name': 'toiletries' + }, + { + 'id': 114, + 'name': 'oven' + }, + { + 'id': 317, + 'name': 'tortoise/turtle' + }, + { + 'id': 265, + 'name': 'corn' + }, + { + 'id': 126, + 'name': 'donut' + }, + { + 'id': 57, + 'name': 'mirror' + }, + { + 'id': 7, + 'name': 'cabinet/shelf' + }, + { + 'id': 263, + 'name': 'green vegetables' + }, + { + 'id': 159, + 'name': 'tissue ' + }, + { + 'id': 321, + 'name': 'shark' + }, + { + 'id': 301, + 'name': 'pig' + }, + { + 'id': 41, + 'name': 'carpet' + }, + { + 'id': 304, + 'name': 'rice cooker' + }, + { + 'id': 323, + 'name': 'poker card' + }, +] + + +def _get_builtin_metadata(version): + if version == 'v1': + id_to_name = {x['id']: x['name'] for x in categories_v1} + else: + assert 0, version + thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(365)} + thing_classes = [id_to_name[k] for k in sorted(id_to_name)] + return { + 'thing_dataset_id_to_contiguous_id': thing_dataset_id_to_contiguous_id, + 'thing_classes': thing_classes + } + + +_PREDEFINED_SPLITS_OBJECTS365 = { + 'objects365_train': + ('objects365/train', 'objects365/annotations/objects365_train.json'), + 'objects365_val': + ('objects365/val', 'objects365/annotations/objects365_val.json'), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_OBJECTS365.items(): + register_coco_instances( + key, + _get_builtin_metadata('v1'), + os.path.join('datasets', json_file) + if '://' not in json_file else json_file, + os.path.join('datasets', image_root), + ) diff --git a/projects/videochat/models/centernet/data/transforms/__init__.py b/projects/videochat/models/centernet/data/transforms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/data/transforms/custom_augmentation_impl.py b/projects/videochat/models/centernet/data/transforms/custom_augmentation_impl.py new file mode 100644 index 0000000000..fd7bd29a63 --- /dev/null +++ b/projects/videochat/models/centernet/data/transforms/custom_augmentation_impl.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified by Xingyi Zhou +"""Implement many useful :class:`Augmentation`.""" +import numpy as np +from detectron2.data.transforms.augmentation import Augmentation +from PIL import Image + +from .custom_transform import EfficientDetResizeCropTransform + +__all__ = [ + 'EfficientDetResizeCrop', +] + + +class EfficientDetResizeCrop(Augmentation): + """Scale the shorter edge to the given size, with a limit of `max_size` on + the longer edge. + + If `max_size` is reached, then downscale so that the longer edge does not + exceed max_size. + """ + + def __init__(self, size, scale, interp=Image.BILINEAR): + """ + Args: + """ + super().__init__() + self.target_size = (size, size) + self.scale = scale + self.interp = interp + + def get_transform(self, img): + # Select a random scale factor. + scale_factor = np.random.uniform(*self.scale) + scaled_target_height = scale_factor * self.target_size[0] + scaled_target_width = scale_factor * self.target_size[1] + # Recompute the accurate scale_factor using rounded scaled image size. + width, height = img.shape[1], img.shape[0] + img_scale_y = scaled_target_height / height + img_scale_x = scaled_target_width / width + img_scale = min(img_scale_y, img_scale_x) + + # Select non-zero random offset (x, y) if scaled image is larger + # than target size + scaled_h = int(height * img_scale) + scaled_w = int(width * img_scale) + offset_y = scaled_h - self.target_size[0] + offset_x = scaled_w - self.target_size[1] + offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1)) + offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1)) + return EfficientDetResizeCropTransform(scaled_h, scaled_w, offset_y, + offset_x, img_scale, + self.target_size, self.interp) diff --git a/projects/videochat/models/centernet/data/transforms/custom_transform.py b/projects/videochat/models/centernet/data/transforms/custom_transform.py new file mode 100644 index 0000000000..fc7dad9aea --- /dev/null +++ b/projects/videochat/models/centernet/data/transforms/custom_transform.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified by Xingyi Zhou +# File: transform.py + +import numpy as np +import torch +import torch.nn.functional as F +from fvcore.transforms.transform import Transform +from PIL import Image + +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + pass + +__all__ = [ + 'EfficientDetResizeCropTransform', +] + + +class EfficientDetResizeCropTransform(Transform): + """""" + + def __init__(self, + scaled_h, + scaled_w, + offset_y, + offset_x, + img_scale, + target_size, + interp=None): + """ + Args: + h, w (int): original image size + new_h, new_w (int): new image size + interp: PIL interpolation methods, defaults to bilinear. + """ + super().__init__() + if interp is None: + interp = Image.BILINEAR + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + # assert img.shape[:2] == (self.h, self.w) + assert len(img.shape) <= 4 + + if img.dtype == np.uint8: + pil_image = Image.fromarray(img) + interp_method = interp if interp is not None else self.interp + pil_image = pil_image.resize((self.scaled_w, self.scaled_h), + interp_method) + ret = np.asarray(pil_image) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + # img = img.crop((self.offset_x, self.offset_y, right, lower)) + if len(ret.shape) <= 3: + ret = ret[self.offset_y:lower, self.offset_x:right] + else: + ret = ret[..., self.offset_y:lower, self.offset_x:right, :] + else: + # PIL only supports uint8 + img = torch.from_numpy(img) + shape = list(img.shape) + shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] + img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw + _PIL_RESIZE_TO_INTERPOLATE_MODE = { + Image.BILINEAR: 'bilinear', + Image.BICUBIC: 'bicubic' + } + mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp] + img = F.interpolate( + img, (self.scaled_h, self.scaled_w), + mode=mode, + align_corners=False) + shape[:2] = (self.scaled_h, self.scaled_w) + ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + if len(ret.shape) <= 3: + ret = ret[self.offset_y:lower, self.offset_x:right] + else: + ret = ret[..., self.offset_y:lower, self.offset_x:right, :] + return ret + + def apply_coords(self, coords): + coords[:, 0] = coords[:, 0] * self.img_scale + coords[:, 1] = coords[:, 1] * self.img_scale + coords[:, 0] -= self.offset_x + coords[:, 1] -= self.offset_y + return coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + def inverse(self): + raise NotImplementedError + # return ResizeTransform(self.new_h, self.new_w, self.h, self.w, + # self.interp) diff --git a/projects/videochat/models/centernet/modeling/__init__.py b/projects/videochat/models/centernet/modeling/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/modeling/backbone/__init__.py b/projects/videochat/models/centernet/modeling/backbone/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/modeling/backbone/bifpn.py b/projects/videochat/models/centernet/modeling/backbone/bifpn.py new file mode 100644 index 0000000000..057e7003a6 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/backbone/bifpn.py @@ -0,0 +1,554 @@ +# Modified from https://github.com/rwightman/efficientdet-pytorch/blob +# /master/effdet/efficientdet.py The original file is under Apache-2.0 License +import math +from collections import OrderedDict + +import torch +from detectron2.layers import Conv2d, ShapeSpec +from detectron2.layers.batch_norm import get_norm +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.resnet import build_resnet_backbone +from torch import nn + +from .dlafpn import dla34 + + +def get_fpn_config(base_reduction=8): + """BiFPN config with sum.""" + p = { + 'nodes': [ + { + 'reduction': base_reduction << 3, + 'inputs_offsets': [3, 4] + }, + { + 'reduction': base_reduction << 2, + 'inputs_offsets': [2, 5] + }, + { + 'reduction': base_reduction << 1, + 'inputs_offsets': [1, 6] + }, + { + 'reduction': base_reduction, + 'inputs_offsets': [0, 7] + }, + { + 'reduction': base_reduction << 1, + 'inputs_offsets': [1, 7, 8] + }, + { + 'reduction': base_reduction << 2, + 'inputs_offsets': [2, 6, 9] + }, + { + 'reduction': base_reduction << 3, + 'inputs_offsets': [3, 5, 10] + }, + { + 'reduction': base_reduction << 4, + 'inputs_offsets': [4, 11] + }, + ], + 'weight_method': + 'fastattn', + } + return p + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +class SequentialAppend(nn.Sequential): + + def __init__(self, *args): + super(SequentialAppend, self).__init__(*args) + + def forward(self, x): + for module in self: + x.append(module(x)) + return x + + +class SequentialAppendLast(nn.Sequential): + + def __init__(self, *args): + super(SequentialAppendLast, self).__init__(*args) + + # def forward(self, x: List[torch.Tensor]): + def forward(self, x): + for module in self: + x.append(module(x[-1])) + return x + + +class ConvBnAct2d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + padding='', + bias=False, + norm='', + act_layer=Swish): + super(ConvBnAct2d, self).__init__() + # self.conv = create_conv2d( in_channels, out_channels, kernel_size, + # stride=stride, dilation=dilation, padding=padding, bias=bias) + self.conv = Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + bias=(norm == '')) + self.bn = get_norm(norm, out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +class SeparableConv2d(nn.Module): + """Separable Conv.""" + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + padding='', + bias=False, + channel_multiplier=1.0, + pw_kernel_size=1, + act_layer=Swish, + norm=''): + super(SeparableConv2d, self).__init__() + + # self.conv_dw = create_conv2d( in_channels, int(in_channels * + # channel_multiplier), kernel_size, stride=stride, + # dilation=dilation, padding=padding, depthwise=True) + + self.conv_dw = Conv2d( + in_channels, + int(in_channels * channel_multiplier), + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + bias=bias, + groups=out_channels) + # print('conv_dw', kernel_size, stride) self.conv_pw = + # create_conv2d( int(in_channels * channel_multiplier), + # out_channels, pw_kernel_size, padding=padding, bias=bias) + + self.conv_pw = Conv2d( + int(in_channels * channel_multiplier), + out_channels, + kernel_size=pw_kernel_size, + padding=pw_kernel_size // 2, + bias=(norm == '')) + # print('conv_pw', pw_kernel_size) + + self.bn = get_norm(norm, out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +class ResampleFeatureMap(nn.Sequential): + + def __init__(self, + in_channels, + out_channels, + reduction_ratio=1., + pad_type='', + pooling_type='max', + norm='', + apply_bn=False, + conv_after_downsample=False, + redundant_bias=False): + super(ResampleFeatureMap, self).__init__() + pooling_type = pooling_type or 'max' + self.in_channels = in_channels + self.out_channels = out_channels + self.reduction_ratio = reduction_ratio + self.conv_after_downsample = conv_after_downsample + + conv = None + if in_channels != out_channels: + conv = ConvBnAct2d( + in_channels, + out_channels, + kernel_size=1, + padding=pad_type, + norm=norm if apply_bn else '', + bias=not apply_bn or redundant_bias, + act_layer=None) + + if reduction_ratio > 1: + stride_size = int(reduction_ratio) + if conv is not None and not self.conv_after_downsample: + self.add_module('conv', conv) + self.add_module( + 'downsample', + # create_pool2d( pooling_type, kernel_size=stride_size + 1, + # stride=stride_size, padding=pad_type) nn.MaxPool2d( + # kernel_size=stride_size + 1, stride=stride_size, + # padding=pad_type) + nn.MaxPool2d(kernel_size=stride_size, stride=stride_size)) + if conv is not None and self.conv_after_downsample: + self.add_module('conv', conv) + else: + if conv is not None: + self.add_module('conv', conv) + if reduction_ratio < 1: + scale = int(1 // reduction_ratio) + self.add_module('upsample', + nn.UpsamplingNearest2d(scale_factor=scale)) + + +class FpnCombine(nn.Module): + + def __init__(self, + feature_info, + fpn_config, + fpn_channels, + inputs_offsets, + target_reduction, + pad_type='', + pooling_type='max', + norm='', + apply_bn_for_resampling=False, + conv_after_downsample=False, + redundant_bias=False, + weight_method='attn'): + super(FpnCombine, self).__init__() + self.inputs_offsets = inputs_offsets + self.weight_method = weight_method + + self.resample = nn.ModuleDict() + for idx, offset in enumerate(inputs_offsets): + in_channels = fpn_channels + if offset < len(feature_info): + in_channels = feature_info[offset]['num_chs'] + input_reduction = feature_info[offset]['reduction'] + else: + node_idx = offset - len(feature_info) + # print('node_idx, len', node_idx, len(fpn_config['nodes'])) + input_reduction = fpn_config['nodes'][node_idx]['reduction'] + reduction_ratio = target_reduction / input_reduction + self.resample[str(offset)] = ResampleFeatureMap( + in_channels, + fpn_channels, + reduction_ratio=reduction_ratio, + pad_type=pad_type, + pooling_type=pooling_type, + norm=norm, + apply_bn=apply_bn_for_resampling, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias) + + if weight_method == 'attn' or weight_method == 'fastattn': + # WSM + self.edge_weights = nn.Parameter( + torch.ones(len(inputs_offsets)), requires_grad=True) + else: + self.edge_weights = None + + def forward(self, x): + dtype = x[0].dtype + nodes = [] + for offset in self.inputs_offsets: + input_node = x[offset] + input_node = self.resample[str(offset)](input_node) + nodes.append(input_node) + + if self.weight_method == 'attn': + normalized_weights = torch.softmax( + self.edge_weights.type(dtype), dim=0) + x = torch.stack(nodes, dim=-1) * normalized_weights + elif self.weight_method == 'fastattn': + edge_weights = nn.functional.relu(self.edge_weights.type(dtype)) + weights_sum = torch.sum(edge_weights) + x = torch.stack([(nodes[i] * edge_weights[i]) / + (weights_sum + 0.0001) + for i in range(len(nodes))], + dim=-1) + elif self.weight_method == 'sum': + x = torch.stack(nodes, dim=-1) + else: + raise ValueError('unknown weight_method {}'.format( + self.weight_method)) + x = torch.sum(x, dim=-1) + return x + + +class BiFpnLayer(nn.Module): + + def __init__(self, + feature_info, + fpn_config, + fpn_channels, + num_levels=5, + pad_type='', + pooling_type='max', + norm='', + act_layer=Swish, + apply_bn_for_resampling=False, + conv_after_downsample=True, + conv_bn_relu_pattern=False, + separable_conv=True, + redundant_bias=False): + super(BiFpnLayer, self).__init__() + self.fpn_config = fpn_config + self.num_levels = num_levels + self.conv_bn_relu_pattern = False + + self.feature_info = [] + self.fnode = SequentialAppend() + for i, fnode_cfg in enumerate(fpn_config['nodes']): + # logging.debug('fnode {} : {}'.format(i, fnode_cfg)) + # print('fnode {} : {}'.format(i, fnode_cfg)) + fnode_layers = OrderedDict() + + # combine features + reduction = fnode_cfg['reduction'] + fnode_layers['combine'] = FpnCombine( + feature_info, + fpn_config, + fpn_channels, + fnode_cfg['inputs_offsets'], + target_reduction=reduction, + pad_type=pad_type, + pooling_type=pooling_type, + norm=norm, + apply_bn_for_resampling=apply_bn_for_resampling, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias, + weight_method=fpn_config['weight_method']) + self.feature_info.append( + dict(num_chs=fpn_channels, reduction=reduction)) + + # after combine ops + after_combine = OrderedDict() + if not conv_bn_relu_pattern: + after_combine['act'] = act_layer(inplace=True) + conv_bias = redundant_bias + conv_act = None + else: + conv_bias = False + conv_act = act_layer + conv_kwargs = dict( + in_channels=fpn_channels, + out_channels=fpn_channels, + kernel_size=3, + padding=pad_type, + bias=conv_bias, + norm=norm, + act_layer=conv_act) + after_combine['conv'] = SeparableConv2d( + **conv_kwargs) if separable_conv else ConvBnAct2d( + **conv_kwargs) + fnode_layers['after_combine'] = nn.Sequential(after_combine) + + self.fnode.add_module(str(i), nn.Sequential(fnode_layers)) + + self.feature_info = self.feature_info[-num_levels::] + + def forward(self, x): + x = self.fnode(x) + return x[-self.num_levels::] + + +class BiFPN(Backbone): + + def __init__( + self, + cfg, + bottom_up, + in_features, + out_channels, + norm='', + num_levels=5, + num_bifpn=4, + separable_conv=False, + ): + super(BiFPN, self).__init__() + assert isinstance(bottom_up, Backbone) + + # Feature map strides and channels from the bottom up network (e.g. + # ResNet) + input_shapes = bottom_up.output_shape() + in_strides = [input_shapes[f].stride for f in in_features] + in_channels = [input_shapes[f].channels for f in in_features] + + self.num_levels = num_levels + self.num_bifpn = num_bifpn + self.bottom_up = bottom_up + self.in_features = in_features + self._size_divisibility = 128 + levels = [int(math.log2(s)) for s in in_strides] + self._out_feature_strides = { + 'p{}'.format(int(math.log2(s))): s + for s in in_strides + } + if len(in_features) < num_levels: + for ll in range(num_levels - len(in_features)): + s = ll + levels[-1] + self._out_feature_strides['p{}'.format(s + 1)] = 2**(s + 1) + self._out_features = list(sorted(self._out_feature_strides.keys())) + self._out_feature_channels = { + k: out_channels + for k in self._out_features + } + + # print('self._out_feature_strides', self._out_feature_strides) + # print('self._out_feature_channels', self._out_feature_channels) + + feature_info = [{ + 'num_chs': in_channels[level], + 'reduction': in_strides[level] + } for level in range(len(self.in_features))] + # self.config = config + fpn_config = get_fpn_config() + self.resample = SequentialAppendLast() + for level in range(num_levels): + if level < len(feature_info): + in_chs = in_channels[level] # feature_info[level]['num_chs'] + reduction = in_strides[ + level] # feature_info[level]['reduction'] + else: + # Adds a coarser level by downsampling the last feature map + reduction_ratio = 2 + self.resample.add_module( + str(level), + ResampleFeatureMap( + in_channels=in_chs, + out_channels=out_channels, + pad_type='same', + pooling_type=None, + norm=norm, + reduction_ratio=reduction_ratio, + apply_bn=True, + conv_after_downsample=False, + redundant_bias=False, + )) + in_chs = out_channels + reduction = int(reduction * reduction_ratio) + feature_info.append(dict(num_chs=in_chs, reduction=reduction)) + + self.cell = nn.Sequential() + for rep in range(self.num_bifpn): + # logging.debug('building cell {}'.format(rep)) + # print('building cell {}'.format(rep)) + fpn_layer = BiFpnLayer( + feature_info=feature_info, + fpn_config=fpn_config, + fpn_channels=out_channels, + num_levels=self.num_levels, + pad_type='same', + pooling_type=None, + norm=norm, + act_layer=Swish, + separable_conv=separable_conv, + apply_bn_for_resampling=True, + conv_after_downsample=False, + conv_bn_relu_pattern=False, + redundant_bias=False, + ) + self.cell.add_module(str(rep), fpn_layer) + feature_info = fpn_layer.feature_info + # import pdb; pdb.set_trace() + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + # print('input shapes', x.shape) + bottom_up_features = self.bottom_up(x) + x = [bottom_up_features[f] for f in self.in_features] + assert len(self.resample) == self.num_levels - len(x) + x = self.resample(x) + # shapes = [xx.shape for xx in x] + # print('resample shapes', shapes) + x = self.cell(x) + out = {f: xx for f, xx in zip(self._out_features, x)} + # import pdb; pdb.set_trace() + return out + + +@BACKBONE_REGISTRY.register() +def build_resnet_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: backbone (Backbone): backbone module, must be a subclass of + :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + backbone = BiFPN( + cfg=cfg, + bottom_up=bottom_up, + in_features=in_features, + out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS, + norm=cfg.MODEL.BIFPN.NORM, + num_levels=cfg.MODEL.BIFPN.NUM_LEVELS, + num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN, + separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p37_dla_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + bottom_up = dla34(cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + assert cfg.MODEL.BIFPN.NUM_LEVELS == 5 + + backbone = BiFPN( + cfg=cfg, + bottom_up=bottom_up, + in_features=in_features, + out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS, + norm=cfg.MODEL.BIFPN.NORM, + num_levels=cfg.MODEL.BIFPN.NUM_LEVELS, + num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN, + separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV, + ) + return backbone diff --git a/projects/videochat/models/centernet/modeling/backbone/bifpn_fcos.py b/projects/videochat/models/centernet/modeling/backbone/bifpn_fcos.py new file mode 100644 index 0000000000..0c58884f97 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/backbone/bifpn_fcos.py @@ -0,0 +1,501 @@ +# This file is modified from +# https://github.com/aim-uofa/AdelaiDet/blob/master/adet/modeling/backbone +# /bifpn.py The original file is under 2-clause BSD License for academic +# use, and *non-commercial use*. +import torch +import torch.nn.functional as F +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import BACKBONE_REGISTRY +from detectron2.modeling.backbone import Backbone, build_resnet_backbone +from torch import nn + +from .dlafpn import dla34 + +__all__ = [] + + +def swish(x): + return x * x.sigmoid() + + +def split_name(name): + for i, c in enumerate(name): + if not c.isalpha(): + return name[:i], int(name[i:]) + raise ValueError() + + +class FeatureMapResampler(nn.Module): + + def __init__(self, in_channels, out_channels, stride, norm=''): + super(FeatureMapResampler, self).__init__() + if in_channels != out_channels: + self.reduction = Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=(norm == ''), + norm=get_norm(norm, out_channels), + activation=None) + else: + self.reduction = None + + assert stride <= 2 + self.stride = stride + + def forward(self, x): + if self.reduction is not None: + x = self.reduction(x) + + if self.stride == 2: + x = F.max_pool2d( + x, kernel_size=self.stride + 1, stride=self.stride, padding=1) + elif self.stride == 1: + pass + else: + raise NotImplementedError() + return x + + +class BackboneWithTopLevels(Backbone): + + def __init__(self, backbone, out_channels, num_top_levels, norm=''): + super(BackboneWithTopLevels, self).__init__() + self.backbone = backbone + backbone_output_shape = backbone.output_shape() + + self._out_feature_channels = { + name: shape.channels + for name, shape in backbone_output_shape.items() + } + self._out_feature_strides = { + name: shape.stride + for name, shape in backbone_output_shape.items() + } + self._out_features = list(self._out_feature_strides.keys()) + + last_feature_name = max( + self._out_feature_strides.keys(), key=lambda x: split_name(x)[1]) + self.last_feature_name = last_feature_name + self.num_top_levels = num_top_levels + + last_channels = self._out_feature_channels[last_feature_name] + last_stride = self._out_feature_strides[last_feature_name] + + prefix, suffix = split_name(last_feature_name) + prev_channels = last_channels + for i in range(num_top_levels): + name = prefix + str(suffix + i + 1) + self.add_module( + name, FeatureMapResampler(prev_channels, out_channels, 2, + norm)) + prev_channels = out_channels + + self._out_feature_channels[name] = out_channels + self._out_feature_strides[name] = last_stride * 2**(i + 1) + self._out_features.append(name) + + def forward(self, x): + outputs = self.backbone(x) + last_features = outputs[self.last_feature_name] + prefix, suffix = split_name(self.last_feature_name) + + x = last_features + for i in range(self.num_top_levels): + name = prefix + str(suffix + i + 1) + x = self.__getattr__(name)(x) + outputs[name] = x + + return outputs + + +class SingleBiFPN(Backbone): + """This module implements Feature Pyramid Network. + + It creates pyramid features built on top of some input feature maps. + """ + + def __init__(self, in_channels_list, out_channels, norm=''): + """ + Args: bottom_up (Backbone): module representing the bottom up + subnetwork. Must be a subclass of :class:`Backbone`. The multi-scale + feature maps generated by the bottom up network, and listed in + `in_features`, are used to generate FPN levels. in_features (list[ + str]): names of the input feature maps coming from the backbone to + which FPN is attached. For example, if the backbone produces [ + "res2", "res3", "res4"], any *contiguous* sublist of these may be + used; order must be from high to low resolution. out_channels (int): + number of channels in the output feature maps. norm (str): the + normalization to use. + """ + super(SingleBiFPN, self).__init__() + + self.out_channels = out_channels + # build 5-levels bifpn + if len(in_channels_list) == 5: + self.nodes = [ + { + 'feat_level': 3, + 'inputs_offsets': [3, 4] + }, + { + 'feat_level': 2, + 'inputs_offsets': [2, 5] + }, + { + 'feat_level': 1, + 'inputs_offsets': [1, 6] + }, + { + 'feat_level': 0, + 'inputs_offsets': [0, 7] + }, + { + 'feat_level': 1, + 'inputs_offsets': [1, 7, 8] + }, + { + 'feat_level': 2, + 'inputs_offsets': [2, 6, 9] + }, + { + 'feat_level': 3, + 'inputs_offsets': [3, 5, 10] + }, + { + 'feat_level': 4, + 'inputs_offsets': [4, 11] + }, + ] + elif len(in_channels_list) == 3: + self.nodes = [ + { + 'feat_level': 1, + 'inputs_offsets': [1, 2] + }, + { + 'feat_level': 0, + 'inputs_offsets': [0, 3] + }, + { + 'feat_level': 1, + 'inputs_offsets': [1, 3, 4] + }, + { + 'feat_level': 2, + 'inputs_offsets': [2, 5] + }, + ] + else: + raise NotImplementedError + + node_info = [_ for _ in in_channels_list] + + num_output_connections = [0 for _ in in_channels_list] + for fnode in self.nodes: + feat_level = fnode['feat_level'] + inputs_offsets = fnode['inputs_offsets'] + inputs_offsets_str = '_'.join(map(str, inputs_offsets)) + for input_offset in inputs_offsets: + num_output_connections[input_offset] += 1 + + in_channels = node_info[input_offset] + if in_channels != out_channels: + lateral_conv = Conv2d( + in_channels, + out_channels, + kernel_size=1, + norm=get_norm(norm, out_channels)) + self.add_module( + 'lateral_{}_f{}'.format(input_offset, feat_level), + lateral_conv) + node_info.append(out_channels) + num_output_connections.append(0) + + # generate attention weights + name = 'weights_f{}_{}'.format(feat_level, inputs_offsets_str) + self.__setattr__( + name, + nn.Parameter( + torch.ones(len(inputs_offsets), dtype=torch.float32), + requires_grad=True)) + + # generate convolutions after combination + name = 'outputs_f{}_{}'.format(feat_level, inputs_offsets_str) + self.add_module( + name, + Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm=get_norm(norm, out_channels), + bias=(norm == ''))) + + def forward(self, feats): + """ + Args: input (dict[str->Tensor]): mapping feature map name (e.g., + "p5") to feature map tensor for each feature level in high to low + resolution order. Returns: dict[str->Tensor]: mapping from feature + map name to FPN feature map tensor in high to low resolution order. + Returned feature names follow the FPN paper convention: "p", + where stage has stride = 2 ** stage e.g., ["n2", "n3", ..., "n6"]. + """ + feats = [_ for _ in feats] + num_levels = len(feats) + num_output_connections = [0 for _ in feats] + for fnode in self.nodes: + feat_level = fnode['feat_level'] + inputs_offsets = fnode['inputs_offsets'] + inputs_offsets_str = '_'.join(map(str, inputs_offsets)) + input_nodes = [] + _, _, target_h, target_w = feats[feat_level].size() + for input_offset in inputs_offsets: + num_output_connections[input_offset] += 1 + input_node = feats[input_offset] + + # reduction + if input_node.size(1) != self.out_channels: + name = 'lateral_{}_f{}'.format(input_offset, feat_level) + input_node = self.__getattr__(name)(input_node) + + # maybe downsample + _, _, h, w = input_node.size() + if h > target_h and w > target_w: + height_stride_size = int((h - 1) // target_h + 1) + width_stride_size = int((w - 1) // target_w + 1) + assert height_stride_size == width_stride_size == 2 + input_node = F.max_pool2d( + input_node, + kernel_size=(height_stride_size + 1, + width_stride_size + 1), + stride=(height_stride_size, width_stride_size), + padding=1) + elif h <= target_h and w <= target_w: + if h < target_h or w < target_w: + input_node = F.interpolate( + input_node, + size=(target_h, target_w), + mode='nearest') + else: + raise NotImplementedError() + input_nodes.append(input_node) + + # attention + name = 'weights_f{}_{}'.format(feat_level, inputs_offsets_str) + weights = F.relu(self.__getattr__(name)) + norm_weights = weights / (weights.sum() + 0.0001) + + new_node = torch.stack(input_nodes, dim=-1) + new_node = (norm_weights * new_node).sum(dim=-1) + new_node = swish(new_node) + + name = 'outputs_f{}_{}'.format(feat_level, inputs_offsets_str) + feats.append(self.__getattr__(name)(new_node)) + + num_output_connections.append(0) + + output_feats = [] + for idx in range(num_levels): + for i, fnode in enumerate(reversed(self.nodes)): + if fnode['feat_level'] == idx: + output_feats.append(feats[-1 - i]) + break + else: + raise ValueError() + return output_feats + + +class BiFPN(Backbone): + """This module implements Feature Pyramid Network. + + It creates pyramid features built on top of some input feature maps. + """ + + def __init__(self, + bottom_up, + in_features, + out_channels, + num_top_levels, + num_repeats, + norm=''): + """ + Args: bottom_up (Backbone): module representing the bottom up + subnetwork. Must be a subclass of :class:`Backbone`. The multi-scale + feature maps generated by the bottom up network, and listed in + `in_features`, are used to generate FPN levels. in_features (list[ + str]): names of the input feature maps coming from the backbone to + which FPN is attached. For example, if the backbone produces [ + "res2", "res3", "res4"], any *contiguous* sublist of these may be + used; order must be from high to low resolution. out_channels (int): + number of channels in the output feature maps. num_top_levels (int): + the number of the top levels (p6 or p7). num_repeats (int): the + number of repeats of BiFPN. norm (str): the normalization to use. + """ + super(BiFPN, self).__init__() + assert isinstance(bottom_up, Backbone) + + # add extra feature levels (i.e., 6 and 7) + self.bottom_up = BackboneWithTopLevels(bottom_up, out_channels, + num_top_levels, norm) + bottom_up_output_shapes = self.bottom_up.output_shape() + + in_features = sorted(in_features, key=lambda x: split_name(x)[1]) + self._size_divisibility = 128 + # bottom_up_output_shapes[in_features[-1]].stride + self.out_channels = out_channels + self.min_level = split_name(in_features[0])[1] + + # add the names for top blocks + prefix, last_suffix = split_name(in_features[-1]) + for i in range(num_top_levels): + in_features.append(prefix + str(last_suffix + i + 1)) + self.in_features = in_features + + # generate output features + self._out_features = [ + 'p{}'.format(split_name(name)[1]) for name in in_features + ] + self._out_feature_strides = { + out_name: bottom_up_output_shapes[in_name].stride + for out_name, in_name in zip(self._out_features, in_features) + } + self._out_feature_channels = { + k: out_channels + for k in self._out_features + } + + # build bifpn + self.repeated_bifpn = nn.ModuleList() + for i in range(num_repeats): + if i == 0: + in_channels_list = [ + bottom_up_output_shapes[name].channels + for name in in_features + ] + else: + in_channels_list = [ + self._out_feature_channels[name] + for name in self._out_features + ] + self.repeated_bifpn.append( + SingleBiFPN(in_channels_list, out_channels, norm)) + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + """ + Args: input (dict[str->Tensor]): mapping feature map name (e.g., + "p5") to feature map tensor for each feature level in high to low + resolution order. Returns: dict[str->Tensor]: mapping from feature + map name to FPN feature map tensor in high to low resolution order. + Returned feature names follow the FPN paper convention: "p", + where stage has stride = 2 ** stage e.g., ["n2", "n3", ..., "n6"]. + """ + bottom_up_features = self.bottom_up(x) + feats = [bottom_up_features[f] for f in self.in_features] + + for bifpn in self.repeated_bifpn: + feats = bifpn(feats) + + return dict(zip(self._out_features, feats)) + + +def _assert_strides_are_log2_contiguous(strides): + """Assert that each stride is 2x times its preceding stride, i.e. + "contiguous in log2".""" + for i, stride in enumerate(strides[1:], 1): + assert stride == 2 * strides[ + i - 1], 'Strides {} {} are not log2 contiguous'.format( + stride, strides[i - 1]) + + +@BACKBONE_REGISTRY.register() +def build_fcos_resnet_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + top_levels = 2 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p35_fcos_resnet_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + top_levels = 0 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p35_fcos_dla_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + bottom_up = dla34(cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + top_levels = 0 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p37_fcos_dla_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + bottom_up = dla34(cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + assert cfg.MODEL.BIFPN.NUM_LEVELS == 5 + top_levels = 2 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM) + return backbone diff --git a/projects/videochat/models/centernet/modeling/backbone/dla.py b/projects/videochat/models/centernet/modeling/backbone/dla.py new file mode 100644 index 0000000000..6b613c9bb6 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/backbone/dla.py @@ -0,0 +1,604 @@ +import math +from os.path import join + +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from detectron2.layers import (Conv2d, DeformConv, ModulatedDeformConv, + ShapeSpec, get_norm) +from detectron2.modeling.backbone.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN +from detectron2.modeling.backbone.resnet import (BasicStem, BottleneckBlock, + DeformBottleneckBlock) +from torch import nn + +__all__ = [ + 'BottleneckBlock', + 'DeformBottleneckBlock', + 'BasicStem', +] + +DCNV1 = False + +HASH = { + 34: 'ba72cf86', + 60: '24839fc4', +} + + +def get_model_url(data, name, hash): + return join('http://dl.yf.io/dla/models', data, + '{}-{}.pth'.format(name, hash)) + + +class BasicBlock(nn.Module): + + def __init__(self, inplanes, planes, stride=1, dilation=1, norm='BN'): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation) + self.bn1 = get_norm(norm, planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=1, + padding=dilation, + bias=False, + dilation=dilation) + self.bn2 = get_norm(norm, planes) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, dilation=1, norm='BN'): + super(Bottleneck, self).__init__() + expansion = Bottleneck.expansion + bottle_planes = planes // expansion + self.conv1 = nn.Conv2d( + inplanes, bottle_planes, kernel_size=1, bias=False) + self.bn1 = get_norm(norm, bottle_planes) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation) + self.bn2 = get_norm(norm, bottle_planes) + self.conv3 = nn.Conv2d( + bottle_planes, planes, kernel_size=1, bias=False) + self.bn3 = get_norm(norm, planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class Root(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + residual, + norm='BN'): + super(Root, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=1, + bias=False, + padding=(kernel_size - 1) // 2) + self.bn = get_norm(norm, out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class Tree(nn.Module): + + def __init__(self, + levels, + block, + in_channels, + out_channels, + stride=1, + level_root=False, + root_dim=0, + root_kernel_size=1, + dilation=1, + root_residual=False, + norm='BN'): + super(Tree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block( + in_channels, + out_channels, + stride, + dilation=dilation, + norm=norm) + self.tree2 = block( + out_channels, out_channels, 1, dilation=dilation, norm=norm) + else: + self.tree1 = Tree( + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + norm=norm) + self.tree2 = Tree( + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + norm=norm) + if levels == 1: + self.root = Root( + root_dim, + out_channels, + root_kernel_size, + root_residual, + norm=norm) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride) + if in_channels != out_channels: + self.project = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False), get_norm(norm, out_channels)) + + def forward(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + + def __init__(self, + num_layers, + levels, + channels, + block=BasicBlock, + residual_root=False, + norm='BN'): + """ + Args: + """ + super(DLA, self).__init__() + self.norm = norm + self.channels = channels + self.base_layer = nn.Sequential( + nn.Conv2d( + 3, channels[0], kernel_size=7, stride=1, padding=3, + bias=False), get_norm(self.norm, channels[0]), + nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level(channels[0], channels[0], + levels[0]) + self.level1 = self._make_conv_level( + channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree( + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root, + norm=norm) + self.level3 = Tree( + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root, + norm=norm) + self.level4 = Tree( + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root, + norm=norm) + self.level5 = Tree( + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root, + norm=norm) + self.load_pretrained_model( + data='imagenet', + name='dla{}'.format(num_layers), + hash=HASH[num_layers]) + + def load_pretrained_model(self, data, name, hash): + model_url = get_model_url(data, name, hash) + model_weights = model_zoo.load_url(model_url) + num_classes = len(model_weights[list(model_weights.keys())[-1]]) + self.fc = nn.Conv2d( + self.channels[-1], + num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True) + print('Loading pretrained') + self.load_state_dict(model_weights, strict=False) + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias=False, + dilation=dilation), + get_norm(self.norm, planes), + nn.ReLU(inplace=True) + ]) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, x): + y = [] + x = self.base_layer(x) + for i in range(6): + x = getattr(self, 'level{}'.format(i))(x) + y.append(x) + return y + + +def fill_up_weights(up): + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + +class _DeformConv(nn.Module): + + def __init__(self, chi, cho, norm='BN'): + super(_DeformConv, self).__init__() + self.actf = nn.Sequential(get_norm(norm, cho), nn.ReLU(inplace=True)) + if DCNV1: + self.offset = Conv2d( + chi, 18, kernel_size=3, stride=1, padding=1, dilation=1) + self.conv = DeformConv( + chi, + cho, + kernel_size=(3, 3), + stride=1, + padding=1, + dilation=1, + deformable_groups=1) + else: + self.offset = Conv2d( + chi, 27, kernel_size=3, stride=1, padding=1, dilation=1) + self.conv = ModulatedDeformConv( + chi, + cho, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + deformable_groups=1) + nn.init.constant_(self.offset.weight, 0) + nn.init.constant_(self.offset.bias, 0) + + def forward(self, x): + if DCNV1: + offset = self.offset(x) + x = self.conv(x, offset) + else: + offset_mask = self.offset(x) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + x = self.conv(x, offset, mask) + x = self.actf(x) + return x + + +class IDAUp(nn.Module): + + def __init__(self, o, channels, up_f, norm='BN'): + super(IDAUp, self).__init__() + for i in range(1, len(channels)): + c = channels[i] + f = int(up_f[i]) + proj = _DeformConv(c, o, norm=norm) + node = _DeformConv(o, o, norm=norm) + + up = nn.ConvTranspose2d( + o, + o, + f * 2, + stride=f, + padding=f // 2, + output_padding=0, + groups=o, + bias=False) + fill_up_weights(up) + + setattr(self, 'proj_' + str(i), proj) + setattr(self, 'up_' + str(i), up) + setattr(self, 'node_' + str(i), node) + + def forward(self, layers, startp, endp): + for i in range(startp + 1, endp): + upsample = getattr(self, 'up_' + str(i - startp)) + project = getattr(self, 'proj_' + str(i - startp)) + layers[i] = upsample(project(layers[i])) + node = getattr(self, 'node_' + str(i - startp)) + layers[i] = node(layers[i] + layers[i - 1]) + + +class DLAUp(nn.Module): + + def __init__(self, startp, channels, scales, in_channels=None, norm='BN'): + super(DLAUp, self).__init__() + self.startp = startp + if in_channels is None: + in_channels = channels + self.channels = channels + channels = list(channels) + scales = np.array(scales, dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr( + self, 'ida_{}'.format(i), + IDAUp( + channels[j], + in_channels[j:], + scales[j:] // scales[j], + norm=norm)) + scales[j + 1:] = scales[j] + in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] + + def forward(self, layers): + out = [layers[-1]] # start with 32 + for i in range(len(layers) - self.startp - 1): + ida = getattr(self, 'ida_{}'.format(i)) + ida(layers, len(layers) - i - 2, len(layers)) + out.insert(0, layers[-1]) + return out + + +DLA_CONFIGS = { + 34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], BasicBlock), + 60: ([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], Bottleneck) +} + + +class DLASeg(Backbone): + + def __init__(self, + num_layers, + out_features, + use_dla_up=True, + ms_output=False, + norm='BN'): + super(DLASeg, self).__init__() + # depth = 34 + levels, channels, Block = DLA_CONFIGS[num_layers] + self.base = DLA( + num_layers=num_layers, + levels=levels, + channels=channels, + block=Block, + norm=norm) + down_ratio = 4 + self.first_level = int(np.log2(down_ratio)) + self.ms_output = ms_output + self.last_level = 5 if not self.ms_output else 6 + channels = self.base.channels + scales = [2**i for i in range(len(channels[self.first_level:]))] + self.use_dla_up = use_dla_up + if self.use_dla_up: + self.dla_up = DLAUp( + self.first_level, + channels[self.first_level:], + scales, + norm=norm) + out_channel = channels[self.first_level] + if not self.ms_output: # stride 4 DLA + self.ida_up = IDAUp( + out_channel, + channels[self.first_level:self.last_level], + [2**i for i in range(self.last_level - self.first_level)], + norm=norm) + self._out_features = out_features + self._out_feature_channels = { + 'dla{}'.format(i): channels[i] + for i in range(6) + } + self._out_feature_strides = {'dla{}'.format(i): 2**i for i in range(6)} + self._size_divisibility = 32 + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + x = self.base(x) + if self.use_dla_up: + x = self.dla_up(x) + if not self.ms_output: # stride 4 dla + y = [] + for i in range(self.last_level - self.first_level): + y.append(x[i].clone()) + self.ida_up(y, 0, len(y)) + ret = {} + for i in range(self.last_level - self.first_level): + out_feature = 'dla{}'.format(i) + if out_feature in self._out_features: + ret[out_feature] = y[i] + else: + ret = {} + st = self.first_level if self.use_dla_up else 0 + for i in range(self.last_level - st): + out_feature = 'dla{}'.format(i + st) + if out_feature in self._out_features: + ret[out_feature] = x[i] + + return ret + + +@BACKBONE_REGISTRY.register() +def build_dla_backbone(cfg, input_shape): + """Create a ResNet instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + return DLASeg( + out_features=cfg.MODEL.DLA.OUT_FEATURES, + num_layers=cfg.MODEL.DLA.NUM_LAYERS, + use_dla_up=cfg.MODEL.DLA.USE_DLA_UP, + ms_output=cfg.MODEL.DLA.MS_OUTPUT, + norm=cfg.MODEL.DLA.NORM) + + +class LastLevelP6P7(nn.Module): + """This module is used in RetinaNet to generate extra layers, P6 and P7 + from C5 feature.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.num_levels = 2 + self.in_feature = 'dla5' + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_retinanet_dla_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_dla_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + in_channels_p6p7 = bottom_up.output_shape()['dla5'].channels + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7(in_channels_p6p7, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/projects/videochat/models/centernet/modeling/backbone/dlafpn.py b/projects/videochat/models/centernet/modeling/backbone/dlafpn.py new file mode 100644 index 0000000000..0031706b1b --- /dev/null +++ b/projects/videochat/models/centernet/modeling/backbone/dlafpn.py @@ -0,0 +1,597 @@ +#!/usr/bin/env python + +# this file is from https://github.com/ucbdrive/dla/blob/master/dla.py. + +import math +from os.path import join + +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from detectron2.layers import Conv2d, ModulatedDeformConv, ShapeSpec +from detectron2.layers.batch_norm import get_norm +from detectron2.modeling.backbone import FPN, Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from torch import nn + +WEB_ROOT = 'http://dl.yf.io/dla/models' + + +def get_model_url(data, name, hash): + return join('http://dl.yf.io/dla/models', data, + '{}-{}.pth'.format(name, hash)) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, cfg, inplanes, planes, stride=1, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation) + self.bn1 = get_norm(cfg.MODEL.DLA.NORM, planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=1, + padding=dilation, + bias=False, + dilation=dilation) + self.bn2 = get_norm(cfg.MODEL.DLA.NORM, planes) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, cfg, inplanes, planes, stride=1, dilation=1): + super(Bottleneck, self).__init__() + expansion = Bottleneck.expansion + bottle_planes = planes // expansion + self.conv1 = nn.Conv2d( + inplanes, bottle_planes, kernel_size=1, bias=False) + self.bn1 = get_norm(cfg.MODEL.DLA.NORM, bottle_planes) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation) + self.bn2 = get_norm(cfg.MODEL.DLA.NORM, bottle_planes) + self.conv3 = nn.Conv2d( + bottle_planes, planes, kernel_size=1, bias=False) + self.bn3 = get_norm(cfg.MODEL.DLA.NORM, planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class Root(nn.Module): + + def __init__(self, cfg, in_channels, out_channels, kernel_size, residual): + super(Root, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + bias=False, + padding=(kernel_size - 1) // 2) + self.bn = get_norm(cfg.MODEL.DLA.NORM, out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class Tree(nn.Module): + + def __init__(self, + cfg, + levels, + block, + in_channels, + out_channels, + stride=1, + level_root=False, + root_dim=0, + root_kernel_size=1, + dilation=1, + root_residual=False): + super(Tree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block( + cfg, in_channels, out_channels, stride, dilation=dilation) + self.tree2 = block( + cfg, out_channels, out_channels, 1, dilation=dilation) + else: + self.tree1 = Tree( + cfg, + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual) + self.tree2 = Tree( + cfg, + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual) + if levels == 1: + self.root = Root(cfg, root_dim, out_channels, root_kernel_size, + root_residual) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride) + if in_channels != out_channels: + self.project = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False), get_norm(cfg.MODEL.DLA.NORM, out_channels)) + + def forward(self, x, residual=None, children=None): + if self.training and residual is not None: + x = x + residual.sum() * 0.0 + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(Backbone): + + def __init__(self, + cfg, + levels, + channels, + block=BasicBlock, + residual_root=False): + super(DLA, self).__init__() + self.cfg = cfg + self.channels = channels + + self._out_features = ['dla{}'.format(i) for i in range(6)] + self._out_feature_channels = { + k: channels[i] + for i, k in enumerate(self._out_features) + } + self._out_feature_strides = { + k: 2**i + for i, k in enumerate(self._out_features) + } + + self.base_layer = nn.Sequential( + nn.Conv2d( + 3, channels[0], kernel_size=7, stride=1, padding=3, + bias=False), get_norm(cfg.MODEL.DLA.NORM, channels[0]), + nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level(channels[0], channels[0], + levels[0]) + self.level1 = self._make_conv_level( + channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree( + cfg, + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root) + self.level3 = Tree( + cfg, + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root) + self.level4 = Tree( + cfg, + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root) + self.level5 = Tree( + cfg, + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + + self.load_pretrained_model( + data='imagenet', name='dla34', hash='ba72cf86') + + def load_pretrained_model(self, data, name, hash): + model_url = get_model_url(data, name, hash) + model_weights = model_zoo.load_url(model_url) + del model_weights['fc.weight'] + del model_weights['fc.bias'] + print('Loading pretrained DLA!') + self.load_state_dict(model_weights, strict=True) + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias=False, + dilation=dilation), + get_norm(self.cfg.MODEL.DLA.NORM, planes), + nn.ReLU(inplace=True) + ]) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, x): + y = {} + x = self.base_layer(x) + for i in range(6): + name = 'level{}'.format(i) + x = getattr(self, name)(x) + y['dla{}'.format(i)] = x + return y + + +def fill_up_weights(up): + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + +class Conv(nn.Module): + + def __init__(self, chi, cho, norm): + super(Conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(chi, cho, kernel_size=1, stride=1, bias=False), + get_norm(norm, cho), nn.ReLU(inplace=True)) + + def forward(self, x): + return self.conv(x) + + +class DeformConv(nn.Module): + + def __init__(self, chi, cho, norm): + super(DeformConv, self).__init__() + self.actf = nn.Sequential(get_norm(norm, cho), nn.ReLU(inplace=True)) + self.offset = Conv2d( + chi, 27, kernel_size=3, stride=1, padding=1, dilation=1) + self.conv = ModulatedDeformConv( + chi, + cho, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + deformable_groups=1) + nn.init.constant_(self.offset.weight, 0) + nn.init.constant_(self.offset.bias, 0) + + def forward(self, x): + offset_mask = self.offset(x) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + x = self.conv(x, offset, mask) + x = self.actf(x) + return x + + +class IDAUp(nn.Module): + + def __init__(self, o, channels, up_f, norm='FrozenBN', node_type=Conv): + super(IDAUp, self).__init__() + for i in range(1, len(channels)): + c = channels[i] + f = int(up_f[i]) + proj = node_type(c, o, norm) + node = node_type(o, o, norm) + + up = nn.ConvTranspose2d( + o, + o, + f * 2, + stride=f, + padding=f // 2, + output_padding=0, + groups=o, + bias=False) + fill_up_weights(up) + + setattr(self, 'proj_' + str(i), proj) + setattr(self, 'up_' + str(i), up) + setattr(self, 'node_' + str(i), node) + + def forward(self, layers, startp, endp): + for i in range(startp + 1, endp): + upsample = getattr(self, 'up_' + str(i - startp)) + project = getattr(self, 'proj_' + str(i - startp)) + layers[i] = upsample(project(layers[i])) + node = getattr(self, 'node_' + str(i - startp)) + layers[i] = node(layers[i] + layers[i - 1]) + + +DLAUP_NODE_MAP = { + 'conv': Conv, + 'dcn': DeformConv, +} + + +class DLAUP(Backbone): + + def __init__(self, bottom_up, in_features, norm, dlaup_node='conv'): + super(DLAUP, self).__init__() + assert isinstance(bottom_up, Backbone) + self.bottom_up = bottom_up + input_shapes = bottom_up.output_shape() + in_strides = [input_shapes[f].stride for f in in_features] + in_channels = [input_shapes[f].channels for f in in_features] + in_levels = [ + int(math.log2(input_shapes[f].stride)) for f in in_features + ] + self.in_features = in_features + out_features = ['dlaup{}'.format(ll) for ll in in_levels] + self._out_features = out_features + self._out_feature_channels = { + 'dlaup{}'.format(ll): in_channels[i] + for i, ll in enumerate(in_levels) + } + self._out_feature_strides = { + 'dlaup{}'.format(ll): 2**ll + for ll in in_levels + } + + print('self._out_features', self._out_features) + print('self._out_feature_channels', self._out_feature_channels) + print('self._out_feature_strides', self._out_feature_strides) + self._size_divisibility = 32 + + node_type = DLAUP_NODE_MAP[dlaup_node] + + self.startp = int(math.log2(in_strides[0])) + self.channels = in_channels + channels = list(in_channels) + scales = np.array([2**i for i in range(len(out_features))], dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr( + self, 'ida_{}'.format(i), + IDAUp( + channels[j], + in_channels[j:], + scales[j:] // scales[j], + norm=norm, + node_type=node_type)) + scales[j + 1:] = scales[j] + in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + bottom_up_features = self.bottom_up(x) + layers = [bottom_up_features[f] for f in self.in_features] + out = [layers[-1]] # start with 32 + for i in range(len(layers) - 1): + ida = getattr(self, 'ida_{}'.format(i)) + ida(layers, len(layers) - i - 2, len(layers)) + out.insert(0, layers[-1]) + ret = {} + for k, v in zip(self._out_features, out): + ret[k] = v + # import pdb; pdb.set_trace() + return ret + + +def dla34(cfg, pretrained=None): # DLA-34 + model = DLA( + cfg, [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock) + return model + + +class LastLevelP6P7(nn.Module): + """This module is used in RetinaNet to generate extra layers, P6 and P7 + from C5 feature.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.num_levels = 2 + self.in_feature = 'dla5' + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_dla_fpn3_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + + depth_to_creator = {'dla34': dla34} + bottom_up = depth_to_creator['dla{}'.format(cfg.MODEL.DLA.NUM_LAYERS)](cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=None, + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + + return backbone + + +@BACKBONE_REGISTRY.register() +def build_dla_fpn5_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + + depth_to_creator = {'dla34': dla34} + bottom_up = depth_to_creator['dla{}'.format(cfg.MODEL.DLA.NUM_LAYERS)](cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + in_channels_top = bottom_up.output_shape()['dla5'].channels + + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7(in_channels_top, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + + return backbone + + +@BACKBONE_REGISTRY.register() +def build_dlaup_backbone(cfg, input_shape: ShapeSpec): + """ + Args: cfg: a detectron2 CfgNode Returns: backbone (Backbone): backbone + module, must be a subclass of :class:`Backbone`. + """ + + depth_to_creator = {'dla34': dla34} + bottom_up = depth_to_creator['dla{}'.format(cfg.MODEL.DLA.NUM_LAYERS)](cfg) + + backbone = DLAUP( + bottom_up=bottom_up, + in_features=cfg.MODEL.DLA.DLAUP_IN_FEATURES, + norm=cfg.MODEL.DLA.NORM, + dlaup_node=cfg.MODEL.DLA.DLAUP_NODE, + ) + + return backbone diff --git a/projects/videochat/models/centernet/modeling/backbone/fpn_p5.py b/projects/videochat/models/centernet/modeling/backbone/fpn_p5.py new file mode 100644 index 0000000000..70d86bc917 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/backbone/fpn_p5.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import fvcore.nn.weight_init as weight_init +import torch.nn.functional as F +from detectron2.layers import ShapeSpec +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN +from detectron2.modeling.backbone.resnet import build_resnet_backbone +from torch import nn + + +class LastLevelP6P7_P5(nn.Module): + """This module is used in RetinaNet to generate extra layers, P6 and P7 + from C5 feature.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.num_levels = 2 + self.in_feature = 'p5' + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_p67_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: backbone (Backbone): backbone module, must be a subclass of + :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p35_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: backbone (Backbone): backbone module, must be a subclass of + :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=None, + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/projects/videochat/models/centernet/modeling/backbone/res2net.py b/projects/videochat/models/centernet/modeling/backbone/res2net.py new file mode 100644 index 0000000000..5908db0aef --- /dev/null +++ b/projects/videochat/models/centernet/modeling/backbone/res2net.py @@ -0,0 +1,811 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved This +# file is modified from https://github.com/Res2Net/Res2Net-detectron2/blob +# /master/detectron2/modeling/backbone/resnet.py The original file is under +# Apache-2.0 License +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +import torch.nn.functional as F +from detectron2.layers import (CNNBlockBase, Conv2d, DeformConv, + ModulatedDeformConv, ShapeSpec, get_norm) +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN +from torch import nn + +from .bifpn import BiFPN +from .fpn_p5 import LastLevelP6P7_P5 + +__all__ = [ + 'ResNetBlockBase', + 'BasicBlock', + 'BottleneckBlock', + 'DeformBottleneckBlock', + 'BasicStem', + 'ResNet', + 'make_stage', + 'build_res2net_backbone', +] + +ResNetBlockBase = CNNBlockBase +""" +Alias for backward compatibiltiy. +""" + + +class BasicBlock(CNNBlockBase): + """The basic residual block for ResNet-18 and ResNet-34, with two 3x3 conv + layers and a projection shortcut if needed.""" + + def __init__(self, in_channels, out_channels, *, stride=1, norm='BN'): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the first conv. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + self.conv2 = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + out = self.conv2(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BottleneckBlock(CNNBlockBase): + """The standard bottle2neck residual block used by Res2Net-50, 101 and + 152.""" + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm='BN', + stride_in_1x1=False, + dilation=1, + basewidth=26, + scale=4, + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False, + norm=get_norm(norm, out_channels), + )) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t + # implementations have stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + width = bottleneck_channels // scale + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + if scale == 1: + self.nums = 1 + else: + self.nums = scale - 1 + if self.in_channels != self.out_channels and stride_3x3 != 2: + self.pool = nn.AvgPool2d( + kernel_size=3, stride=stride_3x3, padding=1) + + convs = [] + bns = [] + for i in range(self.nums): + convs.append( + nn.Conv2d( + width, + width, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + )) + bns.append(get_norm(norm, width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + self.scale = scale + self.width = width + self.in_channels = in_channels + self.out_channels = out_channels + self.stride_3x3 = stride_3x3 + for layer in [self.conv1, self.conv3]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + if self.shortcut is not None: + for layer in self.shortcut.modules(): + if isinstance(layer, Conv2d): + weight_init.c2_msra_fill(layer) + + for layer in self.convs: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. See Sec 5.1 in + # "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": "For + # BN layers, the learnable scaling coefficient γ is initialized to + # be 1, except for each residual block's last BN where γ is + # initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) Add it as an option + # when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0 or self.in_channels != self.out_channels: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = F.relu_(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + if self.scale != 1 and self.stride_3x3 == 1: + out = torch.cat((out, spx[self.nums]), 1) + elif self.scale != 1 and self.stride_3x3 == 2: + out = torch.cat((out, self.pool(spx[self.nums])), 1) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class DeformBottleneckBlock(ResNetBlockBase): + """Not implemented for res2net yet. + + Similar to :class:`BottleneckBlock`, but with deformable conv in the 3x3 + convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm='BN', + stride_in_1x1=False, + dilation=1, + deform_modulated=False, + deform_num_groups=1, + basewidth=26, + scale=4, + ): + super().__init__(in_channels, out_channels, stride) + self.deform_modulated = deform_modulated + + if in_channels != out_channels: + # self.shortcut = Conv2d( + # in_channels, + # out_channels, + # kernel_size=1, + # stride=stride, + # bias=False, + # norm=get_norm(norm, out_channels), + # ) + self.shortcut = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False), + Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False, + norm=get_norm(norm, out_channels), + )) + else: + self.shortcut = None + + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + width = bottleneck_channels // scale + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + if scale == 1: + self.nums = 1 + else: + self.nums = scale - 1 + if self.in_channels != self.out_channels and stride_3x3 != 2: + self.pool = nn.AvgPool2d( + kernel_size=3, stride=stride_3x3, padding=1) + + if deform_modulated: + deform_conv_op = ModulatedDeformConv + # offset channels are 2 or 3 (if with modulated) * kernel_size * + # kernel_size + offset_channels = 27 + else: + deform_conv_op = DeformConv + offset_channels = 18 + + # self.conv2_offset = Conv2d( + # bottleneck_channels, + # offset_channels * deform_num_groups, + # kernel_size=3, + # stride=stride_3x3, + # padding=1 * dilation, + # dilation=dilation, + # ) + # self.conv2 = deform_conv_op( + # bottleneck_channels, + # bottleneck_channels, + # kernel_size=3, + # stride=stride_3x3, + # padding=1 * dilation, + # bias=False, + # groups=num_groups, + # dilation=dilation, + # deformable_groups=deform_num_groups, + # norm=get_norm(norm, bottleneck_channels), + # ) + + conv2_offsets = [] + convs = [] + bns = [] + for i in range(self.nums): + conv2_offsets.append( + Conv2d( + width, + offset_channels * deform_num_groups, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + )) + convs.append( + deform_conv_op( + width, + width, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + deformable_groups=deform_num_groups, + )) + bns.append(get_norm(norm, width)) + self.conv2_offsets = nn.ModuleList(conv2_offsets) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + self.scale = scale + self.width = width + self.in_channels = in_channels + self.out_channels = out_channels + self.stride_3x3 = stride_3x3 + # for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + # if layer is not None: # shortcut can be None + # weight_init.c2_msra_fill(layer) + + # nn.init.constant_(self.conv2_offset.weight, 0) + # nn.init.constant_(self.conv2_offset.bias, 0) + for layer in [self.conv1, self.conv3]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + if self.shortcut is not None: + for layer in self.shortcut.modules(): + if isinstance(layer, Conv2d): + weight_init.c2_msra_fill(layer) + + for layer in self.convs: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + for layer in self.conv2_offsets: + if layer.weight is not None: + nn.init.constant_(layer.weight, 0) + if layer.bias is not None: + nn.init.constant_(layer.bias, 0) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + # if self.deform_modulated: + # offset_mask = self.conv2_offset(out) + # offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + # offset = torch.cat((offset_x, offset_y), dim=1) + # mask = mask.sigmoid() + # out = self.conv2(out, offset, mask) + # else: + # offset = self.conv2_offset(out) + # out = self.conv2(out, offset) + # out = F.relu_(out) + + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0 or self.in_channels != self.out_channels: + sp = spx[i].contiguous() + else: + sp = sp + spx[i].contiguous() + + # sp = self.convs[i](sp) + if self.deform_modulated: + offset_mask = self.conv2_offsets[i](sp) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + sp = self.convs[i](sp, offset, mask) + else: + offset = self.conv2_offsets[i](sp) + sp = self.convs[i](sp, offset) + sp = F.relu_(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + if self.scale != 1 and self.stride_3x3 == 1: + out = torch.cat((out, spx[self.nums]), 1) + elif self.scale != 1 and self.stride_3x3 == 2: + out = torch.cat((out, self.pool(spx[self.nums])), 1) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +def make_stage(block_class, num_blocks, first_stride, *, in_channels, + out_channels, **kwargs): + """Create a list of blocks just like those in a ResNet stage. + + Args: block_class (type): a subclass of ResNetBlockBase num_blocks ( + int): first_stride (int): the stride of the first block. The other + blocks will have stride=1. in_channels (int): input channels of the + entire stage. out_channels (int): output channels of **every block** in + the stage. kwargs: other arguments passed to the constructor of every + block. Returns: list[nn.Module]: a list of block module. + """ + assert 'stride' not in kwargs, 'Stride of blocks in make_stage ' \ + 'cannot be changed. ' + blocks = [] + for i in range(num_blocks): + blocks.append( + block_class( + in_channels=in_channels, + out_channels=out_channels, + stride=first_stride if i == 0 else 1, + **kwargs, + )) + in_channels = out_channels + return blocks + + +class BasicStem(CNNBlockBase): + """The standard ResNet stem (layers before the first residual block).""" + + def __init__(self, in_channels=3, out_channels=64, norm='BN'): + """ + Args: + norm (str or callable): norm after the first conv layer. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, 4) + self.in_channels = in_channels + self.conv1 = nn.Sequential( + Conv2d( + in_channels, + 32, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + get_norm(norm, 32), + nn.ReLU(inplace=True), + Conv2d( + 32, + 32, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + get_norm(norm, 32), + nn.ReLU(inplace=True), + Conv2d( + 32, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ) + self.bn1 = get_norm(norm, out_channels) + + for layer in self.conv1: + if isinstance(layer, Conv2d): + weight_init.c2_msra_fill(layer) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +class ResNet(Backbone): + + def __init__(self, stem, stages, num_classes=None, out_features=None): + """ + Args: stem (nn.Module): a stem module stages (list[list[ + CNNBlockBase]]): several (typically 4) stages, each contains + multiple :class:`CNNBlockBase`. num_classes (None or int): if None, + will not perform classification. Otherwise, will create a linear + layer. out_features (list[str]): name of the layers whose outputs + should be returned in forward. Can be anything in "stem", "linear", + or "res2" ... If None, will return the output of the last layer. + """ + super(ResNet, self).__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {'stem': current_stride} + self._out_feature_channels = {'stem': self.stem.out_channels} + + self.stages_and_names = [] + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = 'res' + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stages_and_names.append((stage, name)) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks])) + self._out_feature_channels[name] = curr_channels = blocks[ + -1].out_channels + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet + # in 1 Hour": "The 1000-way fully-connected layer is initialized + # by drawing weights from a zero-mean Gaussian with standard + # deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = 'linear' + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, 'Available children: {}'.format( + ', '.join(children)) + + def forward(self, x): + outputs = {} + x = self.stem(x) + if 'stem' in self._out_features: + outputs['stem'] = x + for stage, name in self.stages_and_names: + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if 'linear' in self._out_features: + outputs['linear'] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name]) + for name in self._out_features + } + + def freeze(self, freeze_at=0): + """Freeze the first several stages of the ResNet. + + Commonly used in + fine-tuning. + Args: + freeze_at (int): number of stem and stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + the first stage, etc. + Returns: + nn.Module: this ResNet itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, (stage, _) in enumerate(self.stages_and_names, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + +@BACKBONE_REGISTRY.register() +def build_res2net_backbone(cfg, input_shape): + """Create a Res2Net instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + + # fmt: off + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + scale = 4 + bottleneck_channels = num_groups * width_per_group * scale + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + # fmt: on + assert res5_dilation in { + 1, 2 + }, 'res5_dilation cannot be {}.'.format(res5_dilation) + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + + if depth in [18, 34]: + assert out_channels == 64, \ + 'Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34' + assert not any( + deform_on_per_stage + ), 'MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34' + assert res5_dilation == 1, \ + 'Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34' + assert num_groups == 1, \ + 'Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34' + + stages = [] + + # Avoid creating variables without gradients + # It consumes extra memory and may cause allreduce to fail + out_stage_idx = [{ + 'res2': 2, + 'res3': 3, + 'res4': 4, + 'res5': 5 + }[f] for f in out_features] + max_stage_idx = max(out_stage_idx) + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 + and dilation == 2) else 2 + stage_kargs = { + 'num_blocks': num_blocks_per_stage[idx], + 'first_stride': first_stride, + 'in_channels': in_channels, + 'out_channels': out_channels, + 'norm': norm, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs['block_class'] = BasicBlock + else: + stage_kargs['bottleneck_channels'] = bottleneck_channels + stage_kargs['stride_in_1x1'] = stride_in_1x1 + stage_kargs['dilation'] = dilation + stage_kargs['num_groups'] = num_groups + stage_kargs['scale'] = scale + + if deform_on_per_stage[idx]: + stage_kargs['block_class'] = DeformBottleneckBlock + stage_kargs['deform_modulated'] = deform_modulated + stage_kargs['deform_num_groups'] = deform_num_groups + else: + stage_kargs['block_class'] = BottleneckBlock + blocks = make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features).freeze(freeze_at) + + +@BACKBONE_REGISTRY.register() +def build_p67_res2net_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: backbone (Backbone): backbone module, must be a subclass of + :class:`Backbone`. + """ + bottom_up = build_res2net_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_res2net_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: backbone (Backbone): backbone module, must be a subclass of + :class:`Backbone`. + """ + bottom_up = build_res2net_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + backbone = BiFPN( + cfg=cfg, + bottom_up=bottom_up, + in_features=in_features, + out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS, + norm=cfg.MODEL.BIFPN.NORM, + num_levels=cfg.MODEL.BIFPN.NUM_LEVELS, + num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN, + separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV, + ) + return backbone diff --git a/projects/videochat/models/centernet/modeling/debug.py b/projects/videochat/models/centernet/modeling/debug.py new file mode 100644 index 0000000000..1a5b34d031 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/debug.py @@ -0,0 +1,302 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * + 255).astype(np.uint8).reshape(1300, 1, 1, 3) + + +def _get_color_image(heatmap): + heatmap = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], + heatmap.shape[2], 1) + if heatmap.shape[0] == 1: + color_map = (heatmap * np.ones((1, 1, 1, 3), np.uint8) * + 255).max(axis=0).astype(np.uint8) # H, W, 3 + else: + color_map = (heatmap * COLORS[:heatmap.shape[0]]).max( + axis=0).astype(np.uint8) # H, W, 3 + + return color_map + + +def _blend_image(image, color_map, a=0.7): + color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) + ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8) + return ret + + +def _blend_image_heatmaps(image, color_maps, a=0.7): + merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32) + for color_map in color_maps: + color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) + merges = np.maximum(merges, color_map) + ret = np.clip(image * (1 - a) + merges * a, 0, 255).astype(np.uint8) + return ret + + +def _decompose_level(x, shapes_per_level, N): + ''' + x: LNHiWi x C + ''' + x = x.view(x.shape[0], -1) + ret = [] + st = 0 + for ll in range(len(shapes_per_level)): + ret.append([]) + h = shapes_per_level[ll][0].int().item() + w = shapes_per_level[ll][1].int().item() + for i in range(N): + ret[ll].append(x[st + h * w * i:st + h * w * + (i + 1)].view(h, w, -1).permute(2, 0, 1)) + st += h * w * N + return ret + + +def _imagelist_to_tensor(images): + images = [x for x in images] + image_sizes = [x.shape[-2:] for x in images] + h = max([size[0] for size in image_sizes]) + w = max([size[1] for size in image_sizes]) + S = 32 + h, w = ((h - 1) // S + 1) * S, ((w - 1) // S + 1) * S + images = [ + F.pad(x, (0, w - x.shape[2], 0, h - x.shape[1], 0, 0)) for x in images + ] + images = torch.stack(images) + return images + + +def _ind2il(ind, shapes_per_level, N): + r = ind + ll = 0 + S = 0 + while r - S >= N * shapes_per_level[ll][0] * shapes_per_level[ll][1]: + S += N * shapes_per_level[ll][0] * shapes_per_level[ll][1] + ll += 1 + i = (r - S) // (shapes_per_level[ll][0] * shapes_per_level[ll][1]) + return i, ll + + +def debug_train(images, gt_instances, flattened_hms, reg_targets, labels, + pos_inds, shapes_per_level, locations, strides): + ''' + images: N x 3 x H x W + flattened_hms: LNHiWi x C + shapes_per_level: L x 2 [(H_i, W_i)] + locations: LNHiWi x 2 + ''' + reg_inds = torch.nonzero(reg_targets.max(dim=1)[0] > 0).squeeze(1) + N = len(images) + images = _imagelist_to_tensor(images) + repeated_locations = [torch.cat([loc] * N, dim=0) for loc in locations] + locations = torch.cat(repeated_locations, dim=0) + gt_hms = _decompose_level(flattened_hms, shapes_per_level, N) + masks = flattened_hms.new_zeros((flattened_hms.shape[0], 1)) + masks[pos_inds] = 1 + masks = _decompose_level(masks, shapes_per_level, N) + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0) + color_maps = [] + for ll in range(len(gt_hms)): + color_map = _get_color_image(gt_hms[ll][i].detach().cpu().numpy()) + color_maps.append(color_map) + cv2.imshow('gthm_{}'.format(ll), color_map) + blend = _blend_image_heatmaps(image.copy(), color_maps) + if gt_instances is not None: + bboxes = gt_instances[i].gt_boxes.tensor + for j in range(len(bboxes)): + bbox = bboxes[j] + cv2.rectangle(blend, (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), (0, 0, 255), 3, + cv2.LINE_AA) + + for j in range(len(pos_inds)): + image_id, ll = _ind2il(pos_inds[j], shapes_per_level, N) + if image_id != i: + continue + loc = locations[pos_inds[j]] + cv2.drawMarker( + blend, (int(loc[0]), int(loc[1])), (0, 255, 255), + markerSize=(ll + 1) * 16) + + for j in range(len(reg_inds)): + image_id, ll = _ind2il(reg_inds[j], shapes_per_level, N) + if image_id != i: + continue + ltrb = reg_targets[reg_inds[j]] + ltrb *= strides[ll] + loc = locations[reg_inds[j]] + bbox = [(loc[0] - ltrb[0]), (loc[1] - ltrb[1]), (loc[0] + ltrb[2]), + (loc[1] + ltrb[3])] + cv2.rectangle(blend, (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), (255, 0, 0), 1, + cv2.LINE_AA) + cv2.circle(blend, (int(loc[0]), int(loc[1])), 2, (255, 0, 0), -1) + + cv2.imshow('blend', blend) + cv2.waitKey() + + +def debug_test(images, + logits_pred, + reg_pred, + agn_hm_pred=[], + preds=[], + vis_thresh=0.3, + debug_show_name=False, + mult_agn=False): + ''' + images: N x 3 x H x W + class_target: LNHiWi x C + cat_agn_heatmap: LNHiWi + shapes_per_level: L x 2 [(H_i, W_i)] + ''' + # N = len(images) + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0) + # result = image.copy().astype(np.uint8) + pred_image = image.copy().astype(np.uint8) + color_maps = [] + L = len(logits_pred) + for ll in range(L): + if logits_pred[0] is not None: + stride = min(image.shape[0], image.shape[1]) / min( + logits_pred[ll][i].shape[1], logits_pred[ll][i].shape[2]) + else: + stride = min(image.shape[0], image.shape[1]) / min( + agn_hm_pred[ll][i].shape[1], agn_hm_pred[ll][i].shape[2]) + stride = stride if stride < 60 else 64 if stride < 100 else 128 + if logits_pred[0] is not None: + if mult_agn: + logits_pred[ll][i] = \ + logits_pred[ll][i] * agn_hm_pred[ll][i] + color_map = _get_color_image( + logits_pred[ll][i].detach().cpu().numpy()) + color_maps.append(color_map) + cv2.imshow('predhm_{}'.format(ll), color_map) + + if debug_show_name: + from detectron2.data.datasets.lvis_v1_categories import \ + LVIS_CATEGORIES + cat2name = [x['name'] for x in LVIS_CATEGORIES] + for j in range(len(preds[i].scores) if preds is not None else 0): + if preds[i].scores[j] > vis_thresh: + bbox = preds[i].proposal_boxes[j] \ + if preds[i].has('proposal_boxes') else \ + preds[i].pred_boxes[j] + bbox = bbox.tensor[0].detach().cpu().numpy().astype( + np.int32) + cat = int(preds[i].pred_classes[j]) \ + if preds[i].has('pred_classes') else 0 + cl = COLORS[cat, 0, 0] + cv2.rectangle(pred_image, (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (int(cl[0]), int(cl[1]), int(cl[2])), 2, + cv2.LINE_AA) + if debug_show_name: + txt = '{}{:.1f}'.format( + cat2name[cat] if cat > 0 else '', + preds[i].scores[j]) + font = cv2.FONT_HERSHEY_SIMPLEX + cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] + cv2.rectangle( + pred_image, + (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), + (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), + (int(cl[0]), int(cl[1]), int(cl[2])), -1) + cv2.putText( + pred_image, + txt, (int(bbox[0]), int(bbox[1] - 2)), + font, + 0.5, (0, 0, 0), + thickness=1, + lineType=cv2.LINE_AA) + + if agn_hm_pred[ll] is not None: + agn_hm_ = \ + agn_hm_pred[ll][i, 0, :, :, None].detach().cpu().numpy() + agn_hm_ = (agn_hm_ * + np.array([255, 255, 255]).reshape(1, 1, 3)).astype( + np.uint8) + cv2.imshow('agn_hm_{}'.format(ll), agn_hm_) + blend = _blend_image_heatmaps(image.copy(), color_maps) + cv2.imshow('blend', blend) + cv2.imshow('preds', pred_image) + cv2.waitKey() + + +global cnt +cnt = 0 + + +def debug_second_stage(images, + instances, + proposals=None, + vis_thresh=0.3, + save_debug=False, + debug_show_name=False): + images = _imagelist_to_tensor(images) + if debug_show_name: + from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES + cat2name = [x['name'] for x in LVIS_CATEGORIES] + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype( + np.uint8).copy() + if instances[i].has('gt_boxes'): + bboxes = instances[i].gt_boxes.tensor.cpu().numpy() + scores = np.ones(bboxes.shape[0]) + cats = instances[i].gt_classes.cpu().numpy() + else: + bboxes = instances[i].pred_boxes.tensor.cpu().numpy() + scores = instances[i].scores.cpu().numpy() + cats = instances[i].pred_classes.cpu().numpy() + for j in range(len(bboxes)): + if scores[j] > vis_thresh: + bbox = bboxes[j] + cl = COLORS[cats[j], 0, 0] + cl = (int(cl[0]), int(cl[1]), int(cl[2])) + cv2.rectangle(image, (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), cl, 2, cv2.LINE_AA) + if debug_show_name: + cat = cats[j] + txt = '{}{:.1f}'.format(cat2name[cat] if cat > 0 else '', + scores[j]) + font = cv2.FONT_HERSHEY_SIMPLEX + cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] + cv2.rectangle( + image, (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), + (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), + (int(cl[0]), int(cl[1]), int(cl[2])), -1) + cv2.putText( + image, + txt, (int(bbox[0]), int(bbox[1] - 2)), + font, + 0.5, (0, 0, 0), + thickness=1, + lineType=cv2.LINE_AA) + if proposals is not None: + proposal_image = images[i].detach().cpu().numpy().transpose( + 1, 2, 0).astype(np.uint8).copy() + bboxes = proposals[i].proposal_boxes.tensor.cpu().numpy() + if proposals[i].has('scores'): + scores = proposals[i].scores.cpu().numpy() + else: + scores = proposals[i].objectness_logits.sigmoid().cpu().numpy() + for j in range(len(bboxes)): + if scores[j] > vis_thresh: + bbox = bboxes[j] + cl = (209, 159, 83) + cv2.rectangle(proposal_image, (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), cl, 2, + cv2.LINE_AA) + + cv2.imshow('image', image) + if proposals is not None: + cv2.imshow('proposals', proposal_image) + if save_debug: + global cnt + cnt += 1 + cv2.imwrite('output/save_debug/{}.jpg'.format(cnt), + proposal_image) + cv2.waitKey() diff --git a/projects/videochat/models/centernet/modeling/dense_heads/__init__.py b/projects/videochat/models/centernet/modeling/dense_heads/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/modeling/dense_heads/centernet.py b/projects/videochat/models/centernet/modeling/dense_heads/centernet.py new file mode 100644 index 0000000000..c1726ad195 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/dense_heads/centernet.py @@ -0,0 +1,937 @@ +import torch +from detectron2.config import configurable +from detectron2.layers import cat +from detectron2.modeling.proposal_generator.build import \ + PROPOSAL_GENERATOR_REGISTRY +from detectron2.structures import Boxes, Instances +from detectron2.utils.comm import get_world_size +from torch import nn + +from ..debug import debug_test, debug_train +from ..layers.heatmap_focal_loss import (binary_heatmap_focal_loss, + heatmap_focal_loss_jit) +from ..layers.iou_loss import IOULoss +from ..layers.ml_nms import ml_nms +from .centernet_head import CenterNetHead +from .utils import _transpose, reduce_sum + +__all__ = ['CenterNet'] + +INF = 100000000 + + +@PROPOSAL_GENERATOR_REGISTRY.register() +class CenterNet(nn.Module): + + @configurable + def __init__( + self, + # input_shape: Dict[str, ShapeSpec], + in_channels=256, + *, + num_classes=80, + in_features=('p3', 'p4', 'p5', 'p6', 'p7'), + strides=(8, 16, 32, 64, 128), + score_thresh=0.05, + hm_min_overlap=0.8, + loc_loss_type='giou', + min_radius=4, + hm_focal_alpha=0.25, + hm_focal_beta=4, + loss_gamma=2.0, + reg_weight=2.0, + not_norm_reg=True, + with_agn_hm=False, + only_proposal=False, + as_proposal=False, + not_nms=False, + pos_weight=1., + neg_weight=1., + sigmoid_clamp=1e-4, + ignore_high_fp=-1., + center_nms=False, + sizes_of_interest=[[0, 80], [64, 160], [128, 320], [256, 640], + [512, 10000000]], + more_pos=False, + more_pos_thresh=0.2, + more_pos_topk=9, + pre_nms_topk_train=1000, + pre_nms_topk_test=1000, + post_nms_topk_train=100, + post_nms_topk_test=100, + nms_thresh_train=0.6, + nms_thresh_test=0.6, + no_reduce=False, + debug=False, + vis_thresh=0.5, + pixel_mean=[103.530, 116.280, 123.675], + pixel_std=[1.0, 1.0, 1.0], + device='cuda', + centernet_head=None, + ): + super().__init__() + self.num_classes = num_classes + self.in_features = in_features + self.strides = strides + self.score_thresh = score_thresh + self.min_radius = min_radius + self.hm_focal_alpha = hm_focal_alpha + self.hm_focal_beta = hm_focal_beta + self.loss_gamma = loss_gamma + self.reg_weight = reg_weight + self.not_norm_reg = not_norm_reg + self.with_agn_hm = with_agn_hm + self.only_proposal = only_proposal + self.as_proposal = as_proposal + self.not_nms = not_nms + self.pos_weight = pos_weight + self.neg_weight = neg_weight + self.sigmoid_clamp = sigmoid_clamp + self.ignore_high_fp = ignore_high_fp + self.center_nms = center_nms + self.sizes_of_interest = sizes_of_interest + self.more_pos = more_pos + self.more_pos_thresh = more_pos_thresh + self.more_pos_topk = more_pos_topk + self.pre_nms_topk_train = pre_nms_topk_train + self.pre_nms_topk_test = pre_nms_topk_test + self.post_nms_topk_train = post_nms_topk_train + self.post_nms_topk_test = post_nms_topk_test + self.nms_thresh_train = nms_thresh_train + self.nms_thresh_test = nms_thresh_test + self.no_reduce = no_reduce + self.debug = debug + self.vis_thresh = vis_thresh + if self.center_nms: + self.not_nms = True + self.iou_loss = IOULoss(loc_loss_type) + assert (not self.only_proposal) or self.with_agn_hm + # delta for rendering heatmap + self.delta = (1 - hm_min_overlap) / (1 + hm_min_overlap) + if centernet_head is None: + self.centernet_head = CenterNetHead( + in_channels=in_channels, + num_levels=len(in_features), + with_agn_hm=with_agn_hm, + only_proposal=only_proposal) + else: + self.centernet_head = centernet_head + if self.debug: + pixel_mean = torch.Tensor(pixel_mean).to( + torch.device(device)).view(3, 1, 1) + pixel_std = torch.Tensor(pixel_std).to(torch.device(device)).view( + 3, 1, 1) + self.denormalizer = lambda x: x * pixel_std + pixel_mean + + @classmethod + def from_config(cls, cfg, input_shape): + ret = { + # 'input_shape': input_shape, + 'in_channels': + input_shape[cfg.MODEL.CENTERNET.IN_FEATURES[0]].channels, + 'num_classes': + cfg.MODEL.CENTERNET.NUM_CLASSES, + 'in_features': + cfg.MODEL.CENTERNET.IN_FEATURES, + 'strides': + cfg.MODEL.CENTERNET.FPN_STRIDES, + 'score_thresh': + cfg.MODEL.CENTERNET.INFERENCE_TH, + 'loc_loss_type': + cfg.MODEL.CENTERNET.LOC_LOSS_TYPE, + 'hm_min_overlap': + cfg.MODEL.CENTERNET.HM_MIN_OVERLAP, + 'min_radius': + cfg.MODEL.CENTERNET.MIN_RADIUS, + 'hm_focal_alpha': + cfg.MODEL.CENTERNET.HM_FOCAL_ALPHA, + 'hm_focal_beta': + cfg.MODEL.CENTERNET.HM_FOCAL_BETA, + 'loss_gamma': + cfg.MODEL.CENTERNET.LOSS_GAMMA, + 'reg_weight': + cfg.MODEL.CENTERNET.REG_WEIGHT, + 'not_norm_reg': + cfg.MODEL.CENTERNET.NOT_NORM_REG, + 'with_agn_hm': + cfg.MODEL.CENTERNET.WITH_AGN_HM, + 'only_proposal': + cfg.MODEL.CENTERNET.ONLY_PROPOSAL, + 'as_proposal': + cfg.MODEL.CENTERNET.AS_PROPOSAL, + 'not_nms': + cfg.MODEL.CENTERNET.NOT_NMS, + 'pos_weight': + cfg.MODEL.CENTERNET.POS_WEIGHT, + 'neg_weight': + cfg.MODEL.CENTERNET.NEG_WEIGHT, + 'sigmoid_clamp': + cfg.MODEL.CENTERNET.SIGMOID_CLAMP, + 'ignore_high_fp': + cfg.MODEL.CENTERNET.IGNORE_HIGH_FP, + 'center_nms': + cfg.MODEL.CENTERNET.CENTER_NMS, + 'sizes_of_interest': + cfg.MODEL.CENTERNET.SOI, + 'more_pos': + cfg.MODEL.CENTERNET.MORE_POS, + 'more_pos_thresh': + cfg.MODEL.CENTERNET.MORE_POS_THRESH, + 'more_pos_topk': + cfg.MODEL.CENTERNET.MORE_POS_TOPK, + 'pre_nms_topk_train': + cfg.MODEL.CENTERNET.PRE_NMS_TOPK_TRAIN, + 'pre_nms_topk_test': + cfg.MODEL.CENTERNET.PRE_NMS_TOPK_TEST, + 'post_nms_topk_train': + cfg.MODEL.CENTERNET.POST_NMS_TOPK_TRAIN, + 'post_nms_topk_test': + cfg.MODEL.CENTERNET.POST_NMS_TOPK_TEST, + 'nms_thresh_train': + cfg.MODEL.CENTERNET.NMS_TH_TRAIN, + 'nms_thresh_test': + cfg.MODEL.CENTERNET.NMS_TH_TEST, + 'no_reduce': + cfg.MODEL.CENTERNET.NO_REDUCE, + 'debug': + cfg.DEBUG, + 'vis_thresh': + cfg.VIS_THRESH, + 'pixel_mean': + cfg.MODEL.PIXEL_MEAN, + 'pixel_std': + cfg.MODEL.PIXEL_STD, + 'device': + cfg.MODEL.DEVICE, + 'centernet_head': + CenterNetHead( + cfg, + [input_shape[f] for f in cfg.MODEL.CENTERNET.IN_FEATURES]), + } + return ret + + def forward(self, images, features_dict, gt_instances): + features = [features_dict[f] for f in self.in_features] + clss_per_level, reg_pred_per_level, agn_hm_pred_per_level = \ + self.centernet_head(features) + grids = self.compute_grids(features) + shapes_per_level = grids[0].new_tensor([(x.shape[2], x.shape[3]) + for x in reg_pred_per_level]) + + if not self.training: + return self.inference(images, clss_per_level, reg_pred_per_level, + agn_hm_pred_per_level, grids) + else: + pos_inds, labels, reg_targets, flattened_hms = \ + self._get_ground_truth( + grids, shapes_per_level, gt_instances) + # logits_pred: M x F, reg_pred: M x 4, agn_hm_pred: M + logits_pred, reg_pred, agn_hm_pred = self._flatten_outputs( + clss_per_level, reg_pred_per_level, agn_hm_pred_per_level) + + if self.more_pos: + # add more pixels as positive if \ 1. they are within the + # center3x3 region of an object 2. their regression losses + # are small (= 0).squeeze(1) + reg_pred = reg_pred[reg_inds] + reg_targets_pos = reg_targets[reg_inds] + reg_weight_map = flattened_hms.max(dim=1)[0] + reg_weight_map = reg_weight_map[reg_inds] + reg_weight_map = reg_weight_map * 0 + 1 \ + if self.not_norm_reg else reg_weight_map + if self.no_reduce: + reg_norm = max(reg_weight_map.sum(), 1) + else: + reg_norm = max( + reduce_sum(reg_weight_map.sum()).item() / num_gpus, 1) + + reg_loss = self.reg_weight * self.iou_loss( + reg_pred, reg_targets_pos, reg_weight_map, + reduction='sum') / reg_norm + losses['loss_centernet_loc'] = reg_loss + + if self.with_agn_hm: + cat_agn_heatmap = flattened_hms.max(dim=1)[0] # M + agn_pos_loss, agn_neg_loss = binary_heatmap_focal_loss( + agn_hm_pred, + cat_agn_heatmap, + pos_inds, + alpha=self.hm_focal_alpha, + beta=self.hm_focal_beta, + gamma=self.loss_gamma, + sigmoid_clamp=self.sigmoid_clamp, + ignore_high_fp=self.ignore_high_fp, + ) + agn_pos_loss = self.pos_weight * agn_pos_loss / num_pos_avg + agn_neg_loss = self.neg_weight * agn_neg_loss / num_pos_avg + losses['loss_centernet_agn_pos'] = agn_pos_loss + losses['loss_centernet_agn_neg'] = agn_neg_loss + + if self.debug: + print('losses', losses) + print('total_num_pos', total_num_pos) + return losses + + def compute_grids(self, features): + grids = [] + for level, feature in enumerate(features): + h, w = feature.size()[-2:] + shifts_x = torch.arange( + 0, + w * self.strides[level], + step=self.strides[level], + dtype=torch.float32, + device=feature.device) + shifts_y = torch.arange( + 0, + h * self.strides[level], + step=self.strides[level], + dtype=torch.float32, + device=feature.device) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \ + self.strides[level] // 2 + grids.append(grids_per_level) + return grids + + def _get_ground_truth(self, grids, shapes_per_level, gt_instances): + """ + Input: + grids: list of tensors [(hl x wl, 2)]_l + shapes_per_level: list of tuples L x 2: + gt_instances: gt instances + Retuen: + pos_inds: N + labels: N + reg_targets: M x 4 + flattened_hms: M x C or M x 1 + N: number of objects in all images + M: number of pixels from all FPN levels + """ + + # get positive pixel index + if not self.more_pos: + pos_inds, labels = self._get_label_inds(gt_instances, + shapes_per_level) + else: + pos_inds, labels = None, None + heatmap_channels = self.num_classes + L = len(grids) + num_loc_list = [len(loc) for loc in grids] + strides = torch.cat([ + shapes_per_level.new_ones(num_loc_list[ll]) * self.strides[ll] + for ll in range(L) + ]).float() # M + reg_size_ranges = torch.cat([ + shapes_per_level.new_tensor( + self.sizes_of_interest[ll]).float().view(1, 2).expand( + num_loc_list[ll], 2) for ll in range(L) + ]) # M x 2 + grids = torch.cat(grids, dim=0) # M x 2 + M = grids.shape[0] + + reg_targets = [] + flattened_hms = [] + for i in range(len(gt_instances)): # images + boxes = gt_instances[i].gt_boxes.tensor # N x 4 + area = gt_instances[i].gt_boxes.area() # N + gt_classes = gt_instances[ + i].gt_classes # N in [0, self.num_classes] + + N = boxes.shape[0] + if N == 0: + reg_targets.append(grids.new_zeros((M, 4)) - INF) + flattened_hms.append( + grids.new_zeros( + (M, 1 if self.only_proposal else heatmap_channels))) + continue + + ll = grids[:, 0].view(M, 1) - boxes[:, 0].view(1, N) # M x N + t = grids[:, 1].view(M, 1) - boxes[:, 1].view(1, N) # M x N + r = boxes[:, 2].view(1, N) - grids[:, 0].view(M, 1) # M x N + b = boxes[:, 3].view(1, N) - grids[:, 1].view(M, 1) # M x N + reg_target = torch.stack([ll, t, r, b], dim=2) # M x N x 4 + + centers = ((boxes[:, [0, 1]] + boxes[:, [2, 3]]) / 2) # N x 2 + centers_expanded = centers.view(1, N, 2).expand(M, N, + 2) # M x N x 2 + strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) + centers_discret = ((centers_expanded / strides_expanded).int() * + strides_expanded).float() + strides_expanded / 2 + # M x N x 2 + + is_peak = (((grids.view(M, 1, 2).expand(M, N, 2) - + centers_discret)**2).sum(dim=2) == 0) # M x N + is_in_boxes = reg_target.min(dim=2)[0] > 0 # M x N + is_center3x3 = self.get_center3x3(grids, centers, + strides) & is_in_boxes # M x N + is_cared_in_the_level = self.assign_reg_fpn( + reg_target, reg_size_ranges) # M x N + reg_mask = is_center3x3 & is_cared_in_the_level # M x N + + dist2 = ((grids.view(M, 1, 2).expand(M, N, 2) - + centers_expanded)**2).sum(dim=2) # M x N + dist2[is_peak] = 0 + radius2 = self.delta**2 * 2 * area # N + radius2 = torch.clamp(radius2, min=self.min_radius**2) + weighted_dist2 = dist2 / radius2.view(1, N).expand(M, N) # M x N + reg_target = self._get_reg_targets(reg_target, + weighted_dist2.clone(), + reg_mask, area) # M x 4 + + if self.only_proposal: + flattened_hm = self._create_agn_heatmaps_from_dist( + weighted_dist2.clone()) # M x 1 + else: + flattened_hm = self._create_heatmaps_from_dist( + weighted_dist2.clone(), + gt_classes, + channels=heatmap_channels) # M x C + + reg_targets.append(reg_target) + flattened_hms.append(flattened_hm) + + # transpose im first training_targets to level first ones + reg_targets = _transpose(reg_targets, num_loc_list) + flattened_hms = _transpose(flattened_hms, num_loc_list) + for ll in range(len(reg_targets)): + reg_targets[ll] = reg_targets[ll] / float(self.strides[ll]) + reg_targets = cat([x for x in reg_targets], dim=0) # MB x 4 + flattened_hms = cat([x for x in flattened_hms], dim=0) # MB x C + + return pos_inds, labels, reg_targets, flattened_hms + + def _get_label_inds(self, gt_instances, shapes_per_level): + """ + Inputs: + gt_instances: [n_i], sum n_i = N + shapes_per_level: L x 2 [(h_l, w_l)]_L + Returns: + pos_inds: N' + labels: N' + """ + pos_inds = [] + labels = [] + L = len(self.strides) + B = len(gt_instances) + shapes_per_level = shapes_per_level.long() + loc_per_level = (shapes_per_level[:, 0] * + shapes_per_level[:, 1]).long() # L + level_bases = [] + s = 0 + for ll in range(L): + level_bases.append(s) + s = s + B * loc_per_level[ll] + level_bases = shapes_per_level.new_tensor(level_bases).long() # L + strides_default = shapes_per_level.new_tensor( + self.strides).float() # L + for im_i in range(B): + targets_per_im = gt_instances[im_i] + bboxes = targets_per_im.gt_boxes.tensor # n x 4 + n = bboxes.shape[0] + centers = ((bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2) # n x 2 + centers = centers.view(n, 1, 2).expand(n, L, 2) + strides = strides_default.view(1, L, 1).expand(n, L, 2) + centers_inds = (centers / strides).long() # n x L x 2 + Ws = shapes_per_level[:, 1].view(1, L).expand(n, L) + pos_ind = \ + level_bases.view(1, L).expand(n, L) + \ + im_i * loc_per_level.view(1, L).expand(n, L) + \ + centers_inds[:, :, 1] * Ws + \ + centers_inds[:, :, 0] # n x L + is_cared_in_the_level = self.assign_fpn_level(bboxes) + pos_ind = pos_ind[is_cared_in_the_level].view(-1) + label = targets_per_im.gt_classes.view(n, 1).expand( + n, L)[is_cared_in_the_level].view(-1) + + pos_inds.append(pos_ind) # n' + labels.append(label) # n' + pos_inds = torch.cat(pos_inds, dim=0).long() + labels = torch.cat(labels, dim=0) + return pos_inds, labels # N, N + + def assign_fpn_level(self, boxes): + """ + Inputs: + boxes: n x 4 + size_ranges: L x 2 + Return: + is_cared_in_the_level: n x L + """ + size_ranges = boxes.new_tensor(self.sizes_of_interest).view( + len(self.sizes_of_interest), 2) # L x 2 + crit = ((boxes[:, 2:] - boxes[:, :2])**2).sum(dim=1)**0.5 / 2 # n + n, L = crit.shape[0], size_ranges.shape[0] + crit = crit.view(n, 1).expand(n, L) + size_ranges_expand = size_ranges.view(1, L, 2).expand(n, L, 2) + is_cared_in_the_level = (crit >= size_ranges_expand[:, :, 0]) & \ + (crit <= size_ranges_expand[:, :, 1]) + return is_cared_in_the_level + + def assign_reg_fpn(self, reg_targets_per_im, size_ranges): + """ + Inputs: + reg_targets_per_im: M x N x 4 + size_ranges: M x 2 + """ + crit = \ + ((reg_targets_per_im[:, :, :2] + reg_targets_per_im[:, :, 2:]) + ** 2).sum(dim=2) ** 0.5 / 2 + # M x N + is_cared_in_the_level = (crit >= size_ranges[:, [0]]) & \ + (crit <= size_ranges[:, [1]]) + return is_cared_in_the_level + + def _get_reg_targets(self, reg_targets, dist, mask, area): + """ + reg_targets (M x N x 4): long tensor + dist (M x N) + is_*: M x N + """ + dist[mask == 0] = INF * 1.0 + min_dist, min_inds = dist.min(dim=1) # M + reg_targets_per_im = reg_targets[range(len(reg_targets)), + min_inds] # M x N x 4 --> M x 4 + reg_targets_per_im[min_dist == INF] = -INF + return reg_targets_per_im + + def _create_heatmaps_from_dist(self, dist, labels, channels): + ''' + dist: M x N + labels: N + return: + heatmaps: M x C + ''' + heatmaps = dist.new_zeros((dist.shape[0], channels)) + for c in range(channels): + inds = (labels == c) # N + if inds.int().sum() == 0: + continue + heatmaps[:, c] = torch.exp(-dist[:, inds].min(dim=1)[0]) + zeros = heatmaps[:, c] < 1e-4 + heatmaps[zeros, c] = 0 + return heatmaps + + def _create_agn_heatmaps_from_dist(self, dist): + ''' + dist: M x N + return: + heatmaps: M x 1 + ''' + heatmaps = dist.new_zeros((dist.shape[0], 1)) + heatmaps[:, 0] = torch.exp(-dist.min(dim=1)[0]) + zeros = heatmaps < 1e-4 + heatmaps[zeros] = 0 + return heatmaps + + def _flatten_outputs(self, clss, reg_pred, agn_hm_pred): + # Reshape: (N, F, Hl, Wl) -> (N, Hl, Wl, F) -> (sum_l N*Hl*Wl, F) + clss = cat( + [x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]) for x in clss], + dim=0) if clss[0] is not None else None + reg_pred = cat( + [x.permute(0, 2, 3, 1).reshape(-1, 4) for x in reg_pred], dim=0) + agn_hm_pred = cat( + [x.permute(0, 2, 3, 1).reshape(-1) + for x in agn_hm_pred], dim=0) if self.with_agn_hm else None + return clss, reg_pred, agn_hm_pred + + def get_center3x3(self, locations, centers, strides): + ''' + Inputs: + locations: M x 2 + centers: N x 2 + strides: M + ''' + M, N = locations.shape[0], centers.shape[0] + locations_expanded = locations.view(M, 1, 2).expand(M, N, + 2) # M x N x 2 + centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2 + strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N + centers_discret = ((centers_expanded / strides_expanded).int() * + strides_expanded).float() + strides_expanded / 2 + # M x N x 2 + dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs() + dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs() + return (dist_x <= strides_expanded[:, :, 0]) & \ + (dist_y <= strides_expanded[:, :, 0]) + + def inference(self, images, clss_per_level, reg_pred_per_level, + agn_hm_pred_per_level, grids): + logits_pred = [ + x.sigmoid() if x is not None else None for x in clss_per_level + ] + agn_hm_pred_per_level = [ + x.sigmoid() if x is not None else None + for x in agn_hm_pred_per_level + ] + + if self.only_proposal: + proposals = self.predict_instances( + grids, agn_hm_pred_per_level, reg_pred_per_level, + images.image_sizes, [None for _ in agn_hm_pred_per_level]) + else: + proposals = self.predict_instances(grids, logits_pred, + reg_pred_per_level, + images.image_sizes, + agn_hm_pred_per_level) + if self.as_proposal or self.only_proposal: + for p in range(len(proposals)): + proposals[p].proposal_boxes = proposals[p].get('pred_boxes') + proposals[p].objectness_logits = proposals[p].get('scores') + proposals[p].remove('pred_boxes') + + if self.debug: + debug_test([self.denormalizer(x) for x in images], + logits_pred, + reg_pred_per_level, + agn_hm_pred_per_level, + preds=proposals, + vis_thresh=self.vis_thresh, + debug_show_name=False) + return proposals, {} + + def predict_instances(self, + grids, + logits_pred, + reg_pred, + image_sizes, + agn_hm_pred, + is_proposal=False): + sampled_boxes = [] + for ll in range(len(grids)): + sampled_boxes.append( + self.predict_single_level( + grids[ll], + logits_pred[ll], + reg_pred[ll] * self.strides[ll], + image_sizes, + agn_hm_pred[ll], + ll, + is_proposal=is_proposal)) + boxlists = list(zip(*sampled_boxes)) + boxlists = [Instances.cat(boxlist) for boxlist in boxlists] + boxlists = self.nms_and_topK(boxlists, nms=not self.not_nms) + return boxlists + + def predict_single_level(self, + grids, + heatmap, + reg_pred, + image_sizes, + agn_hm, + level, + is_proposal=False): + N, C, H, W = heatmap.shape + # put in the same format as grids + if self.center_nms: + heatmap_nms = nn.functional.max_pool2d( + heatmap, (3, 3), stride=1, padding=1) + heatmap = heatmap * (heatmap_nms == heatmap).float() + heatmap = heatmap.permute(0, 2, 3, 1) # N x H x W x C + heatmap = heatmap.reshape(N, -1, C) # N x HW x C + box_regression = reg_pred.view(N, 4, H, W).permute(0, 2, 3, + 1) # N x H x W x 4 + box_regression = box_regression.reshape(N, -1, 4) + + candidate_inds = heatmap > self.score_thresh # 0.05 + pre_nms_top_n = candidate_inds.view(N, -1).sum(1) # N + pre_nms_topk = self.pre_nms_topk_train \ + if self.training else self.pre_nms_topk_test + pre_nms_top_n = pre_nms_top_n.clamp(max=pre_nms_topk) # N + + if agn_hm is not None: + agn_hm = agn_hm.view(N, 1, H, W).permute(0, 2, 3, 1) + agn_hm = agn_hm.reshape(N, -1) + heatmap = heatmap * agn_hm[:, :, None] + + results = [] + for i in range(N): + per_box_cls = heatmap[i] # HW x C + per_candidate_inds = candidate_inds[i] # n + per_box_cls = per_box_cls[per_candidate_inds] # n + + per_candidate_nonzeros = per_candidate_inds.nonzero() # n + per_box_loc = per_candidate_nonzeros[:, 0] # n + per_class = per_candidate_nonzeros[:, 1] # n + + per_box_regression = box_regression[i] # HW x 4 + per_box_regression = per_box_regression[per_box_loc] # n x 4 + per_grids = grids[per_box_loc] # n x 2 + + per_pre_nms_top_n = pre_nms_top_n[i] # 1 + + if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): + per_box_cls, top_k_indices = \ + per_box_cls.topk(per_pre_nms_top_n, sorted=False) + per_class = per_class[top_k_indices] + per_box_regression = per_box_regression[top_k_indices] + per_grids = per_grids[top_k_indices] + + detections = torch.stack([ + per_grids[:, 0] - per_box_regression[:, 0], + per_grids[:, 1] - per_box_regression[:, 1], + per_grids[:, 0] + per_box_regression[:, 2], + per_grids[:, 1] + per_box_regression[:, 3], + ], + dim=1) # n x 4 + + # avoid invalid boxes in RoI heads + detections[:, 2] = torch.max(detections[:, 2], + detections[:, 0] + 0.01) + detections[:, 3] = torch.max(detections[:, 3], + detections[:, 1] + 0.01) + boxlist = Instances(image_sizes[i]) + boxlist.scores = torch.sqrt(per_box_cls) \ + if self.with_agn_hm else per_box_cls # n + # import pdb; pdb.set_trace() + boxlist.pred_boxes = Boxes(detections) + boxlist.pred_classes = per_class + results.append(boxlist) + return results + + def nms_and_topK(self, boxlists, nms=True): + num_images = len(boxlists) + results = [] + for i in range(num_images): + nms_thresh = self.nms_thresh_train if self.training else \ + self.nms_thresh_test + result = ml_nms(boxlists[i], nms_thresh) if nms else boxlists[i] + if self.debug: + print('#proposals before nms', len(boxlists[i])) + print('#proposals after nms', len(result)) + num_dets = len(result) + post_nms_topk = self.post_nms_topk_train if self.training else \ + self.post_nms_topk_test + if num_dets > post_nms_topk: + cls_scores = result.scores + image_thresh, _ = torch.kthvalue(cls_scores.float().cpu(), + num_dets - post_nms_topk + 1) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + if self.debug: + print('#proposals after filter', len(result)) + results.append(result) + return results + + def _add_more_pos(self, reg_pred, gt_instances, shapes_per_level): + labels, level_masks, c33_inds, c33_masks, c33_regs = \ + self._get_c33_inds(gt_instances, shapes_per_level) + N, L, K = labels.shape[0], len(self.strides), 9 + c33_inds[c33_masks == 0] = 0 + reg_pred_c33 = reg_pred[c33_inds].detach() # N x L x K + invalid_reg = c33_masks == 0 + c33_regs_expand = c33_regs.view(N * L * K, 4).clamp(min=0) + if N > 0: + with torch.no_grad(): + c33_reg_loss = self.iou_loss( + reg_pred_c33.view(N * L * K, 4), + c33_regs_expand, + None, + reduction='none').view(N, L, K).detach() # N x L x K + else: + c33_reg_loss = reg_pred_c33.new_zeros((N, L, K)).detach() + c33_reg_loss[invalid_reg] = INF # N x L x K + c33_reg_loss.view(N * L, K)[level_masks.view(N * L), + 4] = 0 # real center + c33_reg_loss = c33_reg_loss.view(N, L * K) + if N == 0: + loss_thresh = c33_reg_loss.new_ones((N)).float() + else: + loss_thresh = torch.kthvalue( + c33_reg_loss, self.more_pos_topk, dim=1)[0] # N + loss_thresh[ + loss_thresh > self.more_pos_thresh] = self.more_pos_thresh # N + new_pos = c33_reg_loss.view(N, L, K) < \ + loss_thresh.view(N, 1, 1).expand(N, L, K) + pos_inds = c33_inds[new_pos].view(-1) # P + labels = labels.view(N, 1, 1).expand(N, L, K)[new_pos].view(-1) + return pos_inds, labels + + def _get_c33_inds(self, gt_instances, shapes_per_level): + ''' + Get the center (and the 3x3 region near center) locations of each + objects Inputs: gt_instances: [n_i], sum n_i = N shapes_per_level: L + x 2 [(h_l, w_l)]_L + ''' + labels = [] + level_masks = [] + c33_inds = [] + c33_masks = [] + c33_regs = [] + L = len(self.strides) + B = len(gt_instances) + shapes_per_level = shapes_per_level.long() + loc_per_level = (shapes_per_level[:, 0] * + shapes_per_level[:, 1]).long() # L + level_bases = [] + s = 0 + for ll in range(L): + level_bases.append(s) + s = s + B * loc_per_level[ll] + level_bases = shapes_per_level.new_tensor(level_bases).long() # L + strides_default = shapes_per_level.new_tensor( + self.strides).float() # L + K = 9 + dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0, 1]).long() + dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1, 1]).long() + for im_i in range(B): + targets_per_im = gt_instances[im_i] + bboxes = targets_per_im.gt_boxes.tensor # n x 4 + n = bboxes.shape[0] + if n == 0: + continue + centers = ((bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2) # n x 2 + centers = centers.view(n, 1, 2).expand(n, L, 2) + + strides = strides_default.view(1, L, 1).expand(n, L, 2) # + centers_inds = (centers / strides).long() # n x L x 2 + center_grids = centers_inds * strides + strides // 2 # n x L x 2 + ll = center_grids[:, :, 0] - bboxes[:, 0].view(n, 1).expand(n, L) + t = center_grids[:, :, 1] - bboxes[:, 1].view(n, 1).expand(n, L) + r = bboxes[:, 2].view(n, 1).expand(n, L) - center_grids[:, :, 0] + b = bboxes[:, 3].view(n, 1).expand(n, L) - center_grids[:, :, 1] + # n x L + reg = torch.stack([ll, t, r, b], dim=2) # n x L x 4 + reg = reg / strides_default.view(1, L, 1).expand(n, L, 4).float() + + Ws = shapes_per_level[:, 1].view(1, L).expand(n, L) + Hs = shapes_per_level[:, 0].view(1, L).expand(n, L) + expand_Ws = Ws.view(n, L, 1).expand(n, L, K) + expand_Hs = Hs.view(n, L, 1).expand(n, L, K) + label = targets_per_im.gt_classes.view(n).clone() + mask = reg.min(dim=2)[0] >= 0 # n x L + mask = mask & self.assign_fpn_level(bboxes) + labels.append(label) # n + level_masks.append(mask) # n x L + + Dy = dy.view(1, 1, K).expand(n, L, K) + Dx = dx.view(1, 1, K).expand(n, L, K) + c33_ind = \ + level_bases.view(1, L, 1).expand(n, L, K) + \ + im_i * loc_per_level.view(1, L, 1).expand(n, L, K) + \ + (centers_inds[:, :, 1:2].expand(n, L, K) + Dy) * expand_Ws + \ + (centers_inds[:, :, 0:1].expand(n, L, K) + Dx) # n x L x K + + c33_mask = \ + ((centers_inds[:, :, 1:2] + .expand(n, L, K) + dy) < expand_Hs) & \ + ((centers_inds[:, :, 1:2] + .expand(n, L, K) + dy) >= 0) & \ + ((centers_inds[:, :, 0:1] + .expand(n, L, K) + dx) < expand_Ws) & \ + ((centers_inds[:, :, 0:1] + .expand(n, L, K) + dx) >= 0) + # Currently it hard codes the 3x3 region + c33_reg = reg.view(n, L, 1, 4).expand(n, L, K, 4).clone() + c33_reg[:, :, [0, 3, 6], 0] -= 1 + c33_reg[:, :, [0, 3, 6], 2] += 1 + c33_reg[:, :, [2, 5, 8], 0] += 1 + c33_reg[:, :, [2, 5, 8], 2] -= 1 + c33_reg[:, :, [0, 1, 2], 1] -= 1 + c33_reg[:, :, [0, 1, 2], 3] += 1 + c33_reg[:, :, [6, 7, 8], 1] += 1 + c33_reg[:, :, [6, 7, 8], 3] -= 1 + c33_mask = c33_mask & (c33_reg.min(dim=3)[0] >= 0) # n x L x K + c33_inds.append(c33_ind) + c33_masks.append(c33_mask) + c33_regs.append(c33_reg) + + if len(level_masks) > 0: + labels = torch.cat(labels, dim=0) + level_masks = torch.cat(level_masks, dim=0) + c33_inds = torch.cat(c33_inds, dim=0).long() + c33_regs = torch.cat(c33_regs, dim=0) + c33_masks = torch.cat(c33_masks, dim=0) + else: + labels = shapes_per_level.new_zeros((0)).long() + level_masks = shapes_per_level.new_zeros((0, L)).bool() + c33_inds = shapes_per_level.new_zeros((0, L, K)).long() + c33_regs = shapes_per_level.new_zeros((0, L, K, 4)).float() + c33_masks = shapes_per_level.new_zeros((0, L, K)).bool() + return labels, level_masks, c33_inds, c33_masks, c33_regs + # N x L, N x L x K diff --git a/projects/videochat/models/centernet/modeling/dense_heads/centernet_head.py b/projects/videochat/models/centernet/modeling/dense_heads/centernet_head.py new file mode 100644 index 0000000000..1759979eb2 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/dense_heads/centernet_head.py @@ -0,0 +1,175 @@ +import math + +import torch +from detectron2.config import configurable +from detectron2.layers import get_norm +from torch import nn +from torch.nn import functional as F + +from ..layers.deform_conv import DFConv2d + +__all__ = ['CenterNetHead'] + + +class Scale(nn.Module): + + def __init__(self, init_value=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class CenterNetHead(nn.Module): + + @configurable + def __init__( + self, + # input_shape: List[ShapeSpec], + in_channels, + num_levels, + *, + num_classes=80, + with_agn_hm=False, + only_proposal=False, + norm='GN', + num_cls_convs=4, + num_box_convs=4, + num_share_convs=0, + use_deformable=False, + prior_prob=0.01): + super().__init__() + self.num_classes = num_classes + self.with_agn_hm = with_agn_hm + self.only_proposal = only_proposal + self.out_kernel = 3 + + head_configs = { + 'cls': + (num_cls_convs if not self.only_proposal else 0, use_deformable), + 'bbox': (num_box_convs, use_deformable), + 'share': (num_share_convs, use_deformable) + } + + # in_channels = [s.channels for s in input_shape] + # assert len(set(in_channels)) == 1, \ + # "Each level must have the same channel!" + # in_channels = in_channels[0] + channels = { + 'cls': in_channels, + 'bbox': in_channels, + 'share': in_channels, + } + for head in head_configs: + tower = [] + num_convs, use_deformable = head_configs[head] + channel = channels[head] + for i in range(num_convs): + if use_deformable and i == num_convs - 1: + conv_func = DFConv2d + else: + conv_func = nn.Conv2d + tower.append( + conv_func( + in_channels if i == 0 else channel, + channel, + kernel_size=3, + stride=1, + padding=1, + bias=True)) + if norm == 'GN' and channel % 32 != 0: + tower.append(nn.GroupNorm(25, channel)) + elif norm != '': + tower.append(get_norm(norm, channel)) + tower.append(nn.ReLU()) + self.add_module('{}_tower'.format(head), nn.Sequential(*tower)) + + self.bbox_pred = nn.Conv2d( + in_channels, + 4, + kernel_size=self.out_kernel, + stride=1, + padding=self.out_kernel // 2) + + self.scales = nn.ModuleList( + [Scale(init_value=1.0) for _ in range(num_levels)]) + + for modules in [ + self.cls_tower, + self.bbox_tower, + self.share_tower, + self.bbox_pred, + ]: + for ll in modules.modules(): + if isinstance(ll, nn.Conv2d): + torch.nn.init.normal_(ll.weight, std=0.01) + torch.nn.init.constant_(ll.bias, 0) + + torch.nn.init.constant_(self.bbox_pred.bias, 8.) + prior_prob = prior_prob + bias_value = -math.log((1 - prior_prob) / prior_prob) + + if self.with_agn_hm: + self.agn_hm = nn.Conv2d( + in_channels, + 1, + kernel_size=self.out_kernel, + stride=1, + padding=self.out_kernel // 2) + torch.nn.init.constant_(self.agn_hm.bias, bias_value) + torch.nn.init.normal_(self.agn_hm.weight, std=0.01) + + if not self.only_proposal: + cls_kernel_size = self.out_kernel + self.cls_logits = nn.Conv2d( + in_channels, + self.num_classes, + kernel_size=cls_kernel_size, + stride=1, + padding=cls_kernel_size // 2, + ) + + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + + @classmethod + def from_config(cls, cfg, input_shape): + ret = { + # 'input_shape': input_shape, + 'in_channels': [s.channels for s in input_shape][0], + 'num_levels': len(input_shape), + 'num_classes': cfg.MODEL.CENTERNET.NUM_CLASSES, + 'with_agn_hm': cfg.MODEL.CENTERNET.WITH_AGN_HM, + 'only_proposal': cfg.MODEL.CENTERNET.ONLY_PROPOSAL, + 'norm': cfg.MODEL.CENTERNET.NORM, + 'num_cls_convs': cfg.MODEL.CENTERNET.NUM_CLS_CONVS, + 'num_box_convs': cfg.MODEL.CENTERNET.NUM_BOX_CONVS, + 'num_share_convs': cfg.MODEL.CENTERNET.NUM_SHARE_CONVS, + 'use_deformable': cfg.MODEL.CENTERNET.USE_DEFORMABLE, + 'prior_prob': cfg.MODEL.CENTERNET.PRIOR_PROB, + } + return ret + + def forward(self, x): + clss = [] + bbox_reg = [] + agn_hms = [] + for ll, feature in enumerate(x): + feature = self.share_tower(feature) + cls_tower = self.cls_tower(feature) + bbox_tower = self.bbox_tower(feature) + if not self.only_proposal: + clss.append(self.cls_logits(cls_tower)) + else: + clss.append(None) + + if self.with_agn_hm: + agn_hms.append(self.agn_hm(bbox_tower)) + else: + agn_hms.append(None) + reg = self.bbox_pred(bbox_tower) + reg = self.scales[ll](reg) + bbox_reg.append(F.relu(reg)) + + return clss, bbox_reg, agn_hms diff --git a/projects/videochat/models/centernet/modeling/dense_heads/utils.py b/projects/videochat/models/centernet/modeling/dense_heads/utils.py new file mode 100644 index 0000000000..9919f435d5 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/dense_heads/utils.py @@ -0,0 +1,32 @@ +import torch +# from .data import CenterNetCrop +from detectron2.utils.comm import get_world_size + +__all__ = ['reduce_sum', '_transpose'] + +INF = 1000000000 + + +def _transpose(training_targets, num_loc_list): + """ + This function is used to transpose image first training targets to + level first ones + :return: level first training targets + """ + for im_i in range(len(training_targets)): + training_targets[im_i] = torch.split( + training_targets[im_i], num_loc_list, dim=0) + + targets_level_first = [] + for targets_per_level in zip(*training_targets): + targets_level_first.append(torch.cat(targets_per_level, dim=0)) + return targets_level_first + + +def reduce_sum(tensor): + world_size = get_world_size() + if world_size < 2: + return tensor + tensor = tensor.clone() + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + return tensor diff --git a/projects/videochat/models/centernet/modeling/layers/__init__.py b/projects/videochat/models/centernet/modeling/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/modeling/layers/deform_conv.py b/projects/videochat/models/centernet/modeling/layers/deform_conv.py new file mode 100644 index 0000000000..602f1e2ee0 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/layers/deform_conv.py @@ -0,0 +1,106 @@ +import torch +from detectron2.layers import Conv2d +from torch import nn + + +class _NewEmptyTensorOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return _NewEmptyTensorOp.apply(grad, shape), None + + +class DFConv2d(nn.Module): + """Deformable convolutional layer.""" + + def __init__(self, + in_channels, + out_channels, + with_modulated_dcn=True, + kernel_size=3, + stride=1, + groups=1, + dilation=1, + deformable_groups=1, + bias=False, + padding=None): + super(DFConv2d, self).__init__() + if isinstance(kernel_size, (list, tuple)): + assert isinstance(stride, (list, tuple)) + assert isinstance(dilation, (list, tuple)) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(dilation) == 2 + padding = (dilation[0] * (kernel_size[0] - 1) // 2, + dilation[1] * (kernel_size[1] - 1) // 2) + offset_base_channels = kernel_size[0] * kernel_size[1] + else: + padding = dilation * (kernel_size - 1) // 2 + offset_base_channels = kernel_size * kernel_size + if with_modulated_dcn: + from detectron2.layers.deform_conv import ModulatedDeformConv + offset_channels = offset_base_channels * 3 # default: 27 + conv_block = ModulatedDeformConv + else: + from detectron2.layers.deform_conv import DeformConv + offset_channels = offset_base_channels * 2 # default: 18 + conv_block = DeformConv + self.offset = Conv2d( + in_channels, + deformable_groups * offset_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=1, + dilation=dilation) + nn.init.constant_(self.offset.weight, 0) + nn.init.constant_(self.offset.bias, 0) + ''' + for l in [self.offset, ]: + nn.init.kaiming_uniform_(l.weight, a=1) + torch.nn.init.constant_(l.bias, 0.) + ''' + self.conv = conv_block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + deformable_groups=deformable_groups, + bias=bias) + self.with_modulated_dcn = with_modulated_dcn + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.offset_split = offset_base_channels * deformable_groups * 2 + + def forward(self, x, return_offset=False): + if x.numel() > 0: + if not self.with_modulated_dcn: + offset_mask = self.offset(x) + x = self.conv(x, offset_mask) + else: + offset_mask = self.offset(x) + offset = offset_mask[:, :self.offset_split, :, :] + mask = offset_mask[:, self.offset_split:, :, :].sigmoid() + x = self.conv(x, offset, mask) + if return_offset: + return x, offset_mask + return x + # get output shape + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // d + 1 + for i, p, di, k, d in zip(x.shape[-2:], self.padding, self. + dilation, self.kernel_size, self.stride) + ] + output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) diff --git a/projects/videochat/models/centernet/modeling/layers/heatmap_focal_loss.py b/projects/videochat/models/centernet/modeling/layers/heatmap_focal_loss.py new file mode 100644 index 0000000000..e9bc0a14cc --- /dev/null +++ b/projects/videochat/models/centernet/modeling/layers/heatmap_focal_loss.py @@ -0,0 +1,97 @@ +import torch + + +def heatmap_focal_loss( + inputs, + targets, + pos_inds, + labels, + alpha: float = -1, + beta: float = 4, + gamma: float = 2, + reduction: str = 'sum', + sigmoid_clamp: float = 1e-4, + ignore_high_fp: float = -1., +): + """Loss used in RetinaNet for dense detection: + https://arxiv.org/abs/1708.02002. + + Args: + inputs: (sum_l N*Hl*Wl, C) + targets: (sum_l N*Hl*Wl, C) + pos_inds: N + labels: N + Returns: + Loss tensor with the reduction option applied. + """ + pred = torch.clamp( + inputs.sigmoid_(), min=sigmoid_clamp, max=1 - sigmoid_clamp) + neg_weights = torch.pow(1 - targets, beta) + pos_pred_pix = pred[pos_inds] # N x C + pos_pred = pos_pred_pix.gather(1, labels.unsqueeze(1)) + pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, gamma) + neg_loss = torch.log(1 - pred) * torch.pow(pred, gamma) * neg_weights + + if ignore_high_fp > 0: + not_high_fp = (pred < ignore_high_fp).float() + neg_loss = not_high_fp * neg_loss + + if reduction == 'sum': + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if alpha >= 0: + pos_loss = alpha * pos_loss + neg_loss = (1 - alpha) * neg_loss + + return -pos_loss, -neg_loss + + +heatmap_focal_loss_jit = torch.jit.script(heatmap_focal_loss) +# heatmap_focal_loss_jit = heatmap_focal_loss + + +def binary_heatmap_focal_loss( + inputs, + targets, + pos_inds, + alpha: float = -1, + beta: float = 4, + gamma: float = 2, + sigmoid_clamp: float = 1e-4, + ignore_high_fp: float = -1., +): + """ + Args: + inputs: (sum_l N*Hl*Wl,) + targets: (sum_l N*Hl*Wl,) + pos_inds: N + Returns: + Loss tensor with the reduction option applied. + """ + pred = torch.clamp( + inputs.sigmoid_(), min=sigmoid_clamp, max=1 - sigmoid_clamp) + neg_weights = torch.pow(1 - targets, beta) + for i, ind in enumerate(pos_inds): + if ind >= pred.shape[0]: + print('%' * 100) + print(pred.shape, ind, pos_inds) + pos_inds[i] = pred.shape[0] - 1 + pos_pred = pred[pos_inds] # N + pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, gamma) + neg_loss = torch.log(1 - pred) * torch.pow(pred, gamma) * neg_weights + if ignore_high_fp > 0: + not_high_fp = (pred < ignore_high_fp).float() + neg_loss = not_high_fp * neg_loss + + pos_loss = -pos_loss.sum() + neg_loss = -neg_loss.sum() + + if alpha >= 0: + pos_loss = alpha * pos_loss + neg_loss = (1 - alpha) * neg_loss + + return pos_loss, neg_loss + + +# binary_heatmap_focal_loss_jit = torch.jit.script(binary_heatmap_focal_loss) diff --git a/projects/videochat/models/centernet/modeling/layers/iou_loss.py b/projects/videochat/models/centernet/modeling/layers/iou_loss.py new file mode 100644 index 0000000000..d288736e11 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/layers/iou_loss.py @@ -0,0 +1,120 @@ +import torch +from torch import nn + + +class IOULoss(nn.Module): + + def __init__(self, loc_loss_type='iou'): + super(IOULoss, self).__init__() + self.loc_loss_type = loc_loss_type + + def forward(self, pred, target, weight=None, reduction='sum'): + pred_left = pred[:, 0] + pred_top = pred[:, 1] + pred_right = pred[:, 2] + pred_bottom = pred[:, 3] + + target_left = target[:, 0] + target_top = target[:, 1] + target_right = target[:, 2] + target_bottom = target[:, 3] + + target_aera = (target_left + target_right) * \ + (target_top + target_bottom) + pred_aera = (pred_left + pred_right) * \ + (pred_top + pred_bottom) + + w_intersect = torch.min(pred_left, target_left) + \ + torch.min(pred_right, target_right) + h_intersect = torch.min(pred_bottom, target_bottom) + \ + torch.min(pred_top, target_top) + + g_w_intersect = torch.max(pred_left, target_left) + \ + torch.max(pred_right, target_right) + g_h_intersect = torch.max(pred_bottom, target_bottom) + \ + torch.max(pred_top, target_top) + ac_uion = g_w_intersect * g_h_intersect + + area_intersect = w_intersect * h_intersect + area_union = target_aera + pred_aera - area_intersect + + ious = (area_intersect + 1.0) / (area_union + 1.0) + gious = ious - (ac_uion - area_union) / ac_uion + if self.loc_loss_type == 'iou': + losses = -torch.log(ious) + elif self.loc_loss_type == 'linear_iou': + losses = 1 - ious + elif self.loc_loss_type == 'giou': + losses = 1 - gious + else: + raise NotImplementedError + + if weight is not None: + losses = losses * weight + else: + losses = losses + + if reduction == 'sum': + return losses.sum() + elif reduction == 'batch': + return losses.sum(dim=[1]) + elif reduction == 'none': + return losses + else: + raise NotImplementedError + + +def giou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = 'none', + eps: float = 1e-7, +) -> torch.Tensor: + """Generalized Intersection over Union Loss (Hamid Rezatofighi et. + + al) https://arxiv.org/abs/1902.09630 Gradient-friendly IoU loss with an + additional penalty that is non-zero when the boxes do not overlap and + scales with the size of their smallest enclosing box. This loss is + symmetric, so the boxes1 and boxes2 arguments are interchangeable. Args: + boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or ( + 4,). reduction: 'none' | 'mean' | 'sum' 'none': No reduction will be + applied to the output. 'mean': The output will be averaged. 'sum': The + output will be summed. eps (float): small number to prevent division by + zero + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + assert (x2 >= x1).all(), 'bad box: x1 larger than x2' + assert (y2 >= y1).all(), 'bad box: y1 larger than y2' + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsctk = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + iouk = intsctk / (unionk + eps) + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + + area_c = (xc2 - xc1) * (yc2 - yc1) + miouk = iouk - ((area_c - unionk) / (area_c + eps)) + + loss = 1 - miouk + + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + + return loss diff --git a/projects/videochat/models/centernet/modeling/layers/ml_nms.py b/projects/videochat/models/centernet/modeling/layers/ml_nms.py new file mode 100644 index 0000000000..a035f11e45 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/layers/ml_nms.py @@ -0,0 +1,34 @@ +from detectron2.layers import batched_nms + + +def ml_nms(boxlist, + nms_thresh, + max_proposals=-1, + score_field='scores', + label_field='labels'): + """Performs non-maximum suppression on a boxlist, with scores specified in + a boxlist field via score_field. + + Arguments: + boxlist(BoxList) + nms_thresh (float) + max_proposals (int): if > 0, then only the top max_proposals are kept + after non-maximum suppression + score_field (str) + """ + if nms_thresh <= 0: + return boxlist + if boxlist.has('pred_boxes'): + boxes = boxlist.pred_boxes.tensor + labels = boxlist.pred_classes + else: + boxes = boxlist.proposal_boxes.tensor + labels = boxlist.proposal_boxes.tensor.new_zeros( + len(boxlist.proposal_boxes.tensor)) + scores = boxlist.scores + + keep = batched_nms(boxes, scores, labels, nms_thresh) + if max_proposals > 0: + keep = keep[:max_proposals] + boxlist = boxlist[keep] + return boxlist diff --git a/projects/videochat/models/centernet/modeling/meta_arch/__init__.py b/projects/videochat/models/centernet/modeling/meta_arch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/modeling/meta_arch/centernet_detector.py b/projects/videochat/models/centernet/modeling/meta_arch/centernet_detector.py new file mode 100644 index 0000000000..867ceab366 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/meta_arch/centernet_detector.py @@ -0,0 +1,65 @@ +import torch +from detectron2.modeling import (build_backbone, build_proposal_generator, + detector_postprocess) +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.structures import ImageList +from torch import nn + + +@META_ARCH_REGISTRY.register() +class CenterNetDetector(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.mean, self.std = cfg.MODEL.PIXEL_MEAN, cfg.MODEL.PIXEL_STD + self.register_buffer('pixel_mean', + torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)) + self.register_buffer('pixel_std', + torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1)) + + self.backbone = build_backbone(cfg) + self.proposal_generator = build_proposal_generator( + cfg, self.backbone.output_shape()) + + def forward(self, batched_inputs): + if not self.training: + return self.inference(batched_inputs) + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + gt_instances = [x['instances'].to(self.device) for x in batched_inputs] + + _, proposal_losses = self.proposal_generator(images, features, + gt_instances) + return proposal_losses + + @property + def device(self): + return self.pixel_mean.device + + @torch.no_grad() + def inference(self, batched_inputs, do_postprocess=True): + images = self.preprocess_image(batched_inputs) + inp = images.tensor + features = self.backbone(inp) + proposals, _ = self.proposal_generator(images, features, None) + + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + proposals, batched_inputs, images.image_sizes): + if do_postprocess: + height = input_per_image.get('height', image_size[0]) + width = input_per_image.get('width', image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({'instances': r}) + else: + r = results_per_image + processed_results.append(r) + return processed_results + + def preprocess_image(self, batched_inputs): + """Normalize, pad and batch the input images.""" + images = [x['image'].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, + self.backbone.size_divisibility) + return images diff --git a/projects/videochat/models/centernet/modeling/roi_heads/__init__.py b/projects/videochat/models/centernet/modeling/roi_heads/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/centernet/modeling/roi_heads/custom_fast_rcnn.py b/projects/videochat/models/centernet/modeling/roi_heads/custom_fast_rcnn.py new file mode 100644 index 0000000000..1b68b0bc87 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/roi_heads/custom_fast_rcnn.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved Part +# of the code is from https://github.com/tztztztztz/eql.detectron2/blob +# /master/projects/EQL/eql/fast_rcnn.py + +import torch +from detectron2.layers import ShapeSpec, cat +from detectron2.modeling.roi_heads.fast_rcnn import (FastRCNNOutputLayers, + _log_classification_stats, + fast_rcnn_inference) +from torch.nn import functional as F + +__all__ = ['CustomFastRCNNOutputLayers'] + + +class CustomFastRCNNOutputLayers(FastRCNNOutputLayers): + + def __init__(self, cfg, input_shape: ShapeSpec, **kwargs): + super().__init__(cfg, input_shape, **kwargs) + + self.cfg = cfg + + def losses(self, predictions, proposals): + """enable advanced loss.""" + scores, proposal_deltas = predictions + gt_classes = ( + cat([p.gt_classes for p in proposals], dim=0) + if len(proposals) else torch.empty(0)) + # num_classes = self.num_classes + _log_classification_stats(scores, gt_classes) + + if len(proposals): + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], + dim=0) # Nx4 + assert not proposal_boxes.requires_grad, \ + 'Proposals should not require gradients!' + gt_boxes = cat( + [(p.gt_boxes if p.has('gt_boxes') else p.proposal_boxes).tensor + for p in proposals], + dim=0, + ) + else: + proposal_boxes = gt_boxes = torch.empty( + (0, 4), device=proposal_deltas.device) + + loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes) + return { + 'loss_cls': + loss_cls, + 'loss_box_reg': + self.box_reg_loss(proposal_boxes, gt_boxes, proposal_deltas, + gt_classes) + } + + def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes): + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros( + [1])[0] # This is more robust than .sum() * 0. + + B = pred_class_logits.shape[0] + C = pred_class_logits.shape[1] - 1 + + target = pred_class_logits.new_zeros(B, C + 1) + target[range(len(gt_classes)), gt_classes] = 1 # B x (C + 1) + target = target[:, :C] # B x C + + weight = 1 + + cls_loss = F.binary_cross_entropy_with_logits( + pred_class_logits[:, :-1], target, reduction='none') # B x C + loss = torch.sum(cls_loss * weight) / B + return loss + + def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): + """change _no_instance handling.""" + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] + + loss = F.cross_entropy(pred_class_logits, gt_classes, reduction='mean') + return loss + + def inference(self, predictions, proposals): + """enable use proposal boxes.""" + boxes = self.predict_boxes(predictions, proposals) + scores = self.predict_probs(predictions, proposals) + if self.cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE: + proposal_scores = [p.get('objectness_logits') for p in proposals] + scores = [(s * ps[:, None])**0.5 + for s, ps in zip(scores, proposal_scores)] + image_shapes = [x.image_size for x in proposals] + return fast_rcnn_inference( + boxes, + scores, + image_shapes, + self.test_score_thresh, + self.test_nms_thresh, + self.test_topk_per_image, + ) + + def predict_probs(self, predictions, proposals): + """support sigmoid.""" + scores, _ = predictions + num_inst_per_image = [len(p) for p in proposals] + probs = F.softmax(scores, dim=-1) + return probs.split(num_inst_per_image, dim=0) diff --git a/projects/videochat/models/centernet/modeling/roi_heads/custom_roi_heads.py b/projects/videochat/models/centernet/modeling/roi_heads/custom_roi_heads.py new file mode 100644 index 0000000000..4ca3942a79 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/roi_heads/custom_roi_heads.py @@ -0,0 +1,190 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import torch +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads +from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference +from detectron2.modeling.roi_heads.roi_heads import (ROI_HEADS_REGISTRY, + StandardROIHeads) +from detectron2.utils.events import get_event_storage + +from .custom_fast_rcnn import CustomFastRCNNOutputLayers + + +@ROI_HEADS_REGISTRY.register() +class CustomROIHeads(StandardROIHeads): + + @classmethod + def _init_box_head(self, cfg, input_shape): + ret = super()._init_box_head(cfg, input_shape) + del ret['box_predictor'] + ret['box_predictor'] = CustomFastRCNNOutputLayers( + cfg, ret['box_head'].output_shape) + self.debug = cfg.DEBUG + if self.debug: + self.debug_show_name = cfg.DEBUG_SHOW_NAME + self.save_debug = cfg.SAVE_DEBUG + self.vis_thresh = cfg.VIS_THRESH + self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to( + torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to( + torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + return ret + + def forward(self, images, features, proposals, targets=None): + """enable debug.""" + if not self.debug: + del images + if self.training: + assert targets + proposals = self.label_and_sample_proposals(proposals, targets) + del targets + + if self.training: + losses = self._forward_box(features, proposals) + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + pred_instances = self._forward_box(features, proposals) + pred_instances = self.forward_with_given_boxes( + features, pred_instances) + if self.debug: + from ..debug import debug_second_stage + debug_second_stage([self.denormalizer(images[0].clone())], + pred_instances, + proposals=proposals, + debug_show_name=self.debug_show_name) + return pred_instances, {} + + def denormalizer(self, x): + return x * self.pixel_std + self.pixel_mean + + +@ROI_HEADS_REGISTRY.register() +class CustomCascadeROIHeads(CascadeROIHeads): + + @classmethod + def _init_box_head(self, cfg, input_shape): + self.mult_proposal_score = cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE + ret = super()._init_box_head(cfg, input_shape) + del ret['box_predictors'] + cascade_bbox_reg_weights = \ + cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS + box_predictors = [] + for box_head, bbox_reg_weights in zip(ret['box_heads'], + cascade_bbox_reg_weights): + box_predictors.append( + CustomFastRCNNOutputLayers( + cfg, + box_head.output_shape, + box2box_transform=Box2BoxTransform( + weights=bbox_reg_weights))) + ret['box_predictors'] = box_predictors + self.debug = cfg.DEBUG + if self.debug: + self.debug_show_name = cfg.DEBUG_SHOW_NAME + self.save_debug = cfg.SAVE_DEBUG + self.vis_thresh = cfg.VIS_THRESH + self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to( + torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to( + torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + return ret + + def _forward_box(self, features, proposals, targets=None): + """Add mult proposal scores at testing.""" + if (not self.training) and self.mult_proposal_score: + if len(proposals) > 0 and proposals[0].has('scores'): + proposal_scores = [p.get('scores') for p in proposals] + else: + proposal_scores = [ + p.get('objectness_logits') for p in proposals + ] + + features = [features[f] for f in self.box_in_features] + head_outputs = [] # (predictor, predictions, proposals) + prev_pred_boxes = None + image_sizes = [x.image_size for x in proposals] + for k in range(self.num_cascade_stages): + if k > 0: + proposals = self._create_proposals_from_boxes( + prev_pred_boxes, image_sizes) + if self.training: + proposals = self._match_and_label_boxes( + proposals, k, targets) + predictions = self._run_stage(features, proposals, k) + prev_pred_boxes = self.box_predictor[k].predict_boxes( + predictions, proposals) + head_outputs.append( + (self.box_predictor[k], predictions, proposals)) + + if self.training: + losses = {} + storage = get_event_storage() + for stage, (predictor, predictions, + proposals) in enumerate(head_outputs): + with storage.name_scope('stage{}'.format(stage)): + stage_losses = predictor.losses(predictions, proposals) + losses.update({ + k + '_stage{}'.format(stage): v + for k, v in stage_losses.items() + }) + return losses + else: + # Each is a list[Tensor] of length #image. Each tensor is Ri x ( + # K+1) + scores_per_stage = [ + h[0].predict_probs(h[1], h[2]) for h in head_outputs + ] + scores = [ + sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) + for scores_per_image in zip(*scores_per_stage) + ] + + if self.mult_proposal_score: + scores = [(s * ps[:, None])**0.5 + for s, ps in zip(scores, proposal_scores)] + + predictor, predictions, proposals = head_outputs[-1] + boxes = predictor.predict_boxes(predictions, proposals) + pred_instances, _ = fast_rcnn_inference( + boxes, + scores, + image_sizes, + predictor.test_score_thresh, + predictor.test_nms_thresh, + predictor.test_topk_per_image, + ) + + return pred_instances + + def forward(self, images, features, proposals, targets=None): + """enable debug.""" + if not self.debug: + del images + if self.training: + proposals = self.label_and_sample_proposals(proposals, targets) + + if self.training: + losses = self._forward_box(features, proposals, targets) + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + # import pdb; pdb.set_trace() + pred_instances = self._forward_box(features, proposals) + pred_instances = self.forward_with_given_boxes( + features, pred_instances) + if self.debug: + from ..debug import debug_second_stage + debug_second_stage( + [self.denormalizer(x.clone()) for x in images], + pred_instances, + proposals=proposals, + save_debug=self.save_debug, + debug_show_name=self.debug_show_name, + vis_thresh=self.vis_thresh) + return pred_instances, {} + + def denormalizer(self, x): + return x * self.pixel_std + self.pixel_mean diff --git a/projects/videochat/models/centernet/modeling/roi_heads/fed_loss.py b/projects/videochat/models/centernet/modeling/roi_heads/fed_loss.py new file mode 100644 index 0000000000..1ed6065cd7 --- /dev/null +++ b/projects/videochat/models/centernet/modeling/roi_heads/fed_loss.py @@ -0,0 +1,32 @@ +import json + +import torch + + +def load_class_freq(path='datasets/lvis/lvis_v1_train_cat_info.json', + freq_weight=0.5): + cat_info = json.load(open(path, 'r')) + cat_info = torch.tensor( + [c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])]) + freq_weight = cat_info.float()**freq_weight + return freq_weight + + +def get_fed_loss_inds(gt_classes, + num_sample_cats=50, + C=1203, + weight=None, + fed_cls_inds=-1): + appeared = torch.unique(gt_classes) # C' + prob = appeared.new_ones(C + 1).float() + prob[-1] = 0 + if len(appeared) < num_sample_cats: + if weight is not None: + prob[:C] = weight.float().clone() + prob[appeared] = 0 + if fed_cls_inds > 0: + prob[fed_cls_inds:] = 0 + more_appeared = torch.multinomial( + prob, num_sample_cats - len(appeared), replacement=False) + appeared = torch.cat([appeared, more_appeared]) + return appeared diff --git a/projects/videochat/models/grit_model.py b/projects/videochat/models/grit_model.py new file mode 100644 index 0000000000..851a3f3fba --- /dev/null +++ b/projects/videochat/models/grit_model.py @@ -0,0 +1,69 @@ +''' +Description: +Version: 1.0 +Author: ZhuYichen +Date: 2023-07-03 16:59:09 +LastEditors: ZhuYichen +LastEditTime: 2023-07-10 15:48:04 +''' +import os + +from detectron2.data.detection_utils import read_image + +from projects.videochat.models.grit_src.image_dense_captions import ( + dense_pred_to_caption, dense_pred_to_caption_only_name, image_caption_api, + init_demo) + + +class DenseCaptioning(): + + def __init__(self, device): + self.device = device + self.demo = None + + def initialize_model(self): + self.demo = init_demo(self.device) + + def image_dense_caption_debug(self, image_src): + dense_caption = """ + 1. the broccoli is green, [0, 0, 333, 325]; + 2. a piece of broccoli, [0, 147, 143, 324]; + 3. silver fork on plate, [4, 547, 252, 612]; + """ + return dense_caption + + def image_dense_caption(self, image_src): + dense_caption = image_caption_api(image_src, self.device) + print('\033[1;35m' + '*' * 100 + '\033[0m') + print('Step2, Dense Caption:\n') + print(dense_caption) + print('\033[1;35m' + '*' * 100 + '\033[0m') + return dense_caption + + def run_caption_api(self, image_src): + img = read_image(image_src, format='BGR') + print(img.shape) + predictions, visualized_output = self.demo.run_on_image(img) + new_caption = dense_pred_to_caption_only_name(predictions) + return new_caption + + def run_caption_tensor(self, + img, + video_path=None, + index=0, + images_path=None): + # img = read_image(image_src, format="BGR") + # print(img.shape) + predictions, visualized_output = self.demo.run_on_image(img) + # print('predictions:',predictions) + if video_path and images_path: + folder = os.path.basename(os.path.dirname(video_path)) + folder_path = os.path.join(images_path, folder) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + file_name = 'output_image_{}.jpg'.format(index) + visualized_output.save(os.path.join(folder_path, file_name)) + new_caption_only_name = dense_pred_to_caption_only_name(predictions) + new_caption = dense_pred_to_caption(predictions) + # print('new_caption:',new_caption) + return new_caption_only_name, new_caption diff --git a/projects/videochat/models/grit_src/__init__.py b/projects/videochat/models/grit_src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/videochat/models/grit_src/configs/Base.yaml b/projects/videochat/models/grit_src/configs/Base.yaml new file mode 100644 index 0000000000..620bee8eda --- /dev/null +++ b/projects/videochat/models/grit_src/configs/Base.yaml @@ -0,0 +1,77 @@ +MODEL: + META_ARCHITECTURE: "GRiT" + MASK_ON: True + PROPOSAL_GENERATOR: + NAME: "CenterNet" + FPN: + IN_FEATURES: ["layer3", "layer4", "layer5"] + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + ROI_HEADS: + NAME: GRiTROIHeadsAndTextDecoder + IN_FEATURES: ["p3", "p4", "p5"] + IOU_THRESHOLDS: [0.6] + NUM_CLASSES: 1 + SCORE_THRESH_TEST: 0.02 + NMS_THRESH_TEST: 0.5 + OBJECT_FEAT_POOLER_RES: 14 + ROI_BOX_CASCADE_HEAD: + IOUS: [0.6, 0.7, 0.8] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + CLS_AGNOSTIC_BBOX_REG: True + MULT_PROPOSAL_SCORE: True + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 + CLS_AGNOSTIC_MASK: True + CENTERNET: + NUM_CLASSES: 1 + REG_WEIGHT: 1. + NOT_NORM_REG: True + ONLY_PROPOSAL: True + WITH_AGN_HM: True + INFERENCE_TH: 0.0001 + PRE_NMS_TOPK_TRAIN: 4000 + POST_NMS_TOPK_TRAIN: 2000 + PRE_NMS_TOPK_TEST: 1000 + POST_NMS_TOPK_TEST: 256 + NMS_TH_TRAIN: 0.9 + NMS_TH_TEST: 0.9 + POS_WEIGHT: 0.5 + NEG_WEIGHT: 0.5 + IGNORE_HIGH_FP: 0.85 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1] + DATASET_INPUT_SIZE: [1024] + DATASET_INPUT_SCALE: [[0.1, 2.0]] + FILTER_EMPTY_ANNOTATIONS: False + NUM_WORKERS: 8 +TEST: + DETECTIONS_PER_IMAGE: 256 +SOLVER: + LR_SCHEDULER_NAME: "WarmupCosineLR" + CHECKPOINT_PERIOD: 10000 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 + USE_CUSTOM_SOLVER: True + OPTIMIZER: "ADAMW" + MAX_ITER: 180000 + IMS_PER_BATCH: 64 + BASE_LR: 0.00008 + VIT_LAYER_DECAY: True + CLIP_GRADIENTS: + ENABLED: True +INPUT: + FORMAT: RGB + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 640 +USE_ACT_CHECKPOINT: True +VERSION: 2 diff --git a/projects/videochat/models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml b/projects/videochat/models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml new file mode 100644 index 0000000000..03ba335674 --- /dev/null +++ b/projects/videochat/models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml @@ -0,0 +1,23 @@ +_BASE_: "Base.yaml" +MODEL: + TRAIN_TASK: ["ObjectDet", "DenseCap"] + TEST_TASK: "DenseCap" # DenseCap or ObjectDet: Choose one for testing + MASK_ON: True + ROI_HEADS: + SOFT_NMS_ENABLED: False + BEAM_SIZE: 1 + WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth" + BACKBONE: + NAME: build_vit_fpn_backbone + VIT_LAYERS: 12 +SOLVER: + VIT_LAYER_DECAY_RATE: 0.7 +DATASETS: + TRAIN: ("GRiT_coco2017_train", "vg_train") + TEST: ("coco_2017_test-dev",) +DATALOADER: + DATASET_RATIO: [1, 1] + DATASET_BS: 2 + DATASET_INPUT_SIZE: [1024, 1024] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.1, 2.0]] +OUTPUT_DIR: "./output/GRiT_B_DenseCap_ObjectDet" diff --git a/projects/videochat/models/grit_src/grit/__init__.py b/projects/videochat/models/grit_src/grit/__init__.py new file mode 100644 index 0000000000..75e4647ff8 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/__init__.py @@ -0,0 +1,9 @@ +from .data.datasets import grit_coco, object365, vg +from .modeling.backbone import vit +from .modeling.meta_arch import grit +from .modeling.roi_heads import grit_roi_heads + + +def main(): + # 这里是你的代码逻辑 + print(grit_coco, object365, vg, vit, grit, grit_roi_heads) diff --git a/projects/videochat/models/grit_src/grit/config.py b/projects/videochat/models/grit_src/grit/config.py new file mode 100644 index 0000000000..62c5c61dec --- /dev/null +++ b/projects/videochat/models/grit_src/grit/config.py @@ -0,0 +1,50 @@ +from detectron2.config import CfgNode as CN + + +def add_grit_config(cfg): + _C = cfg + + _C.MODEL.BEAM_SIZE = 1 + _C.MODEL.TRAIN_TASK = ['ObjectDet', 'DenseCap'] + _C.MODEL.TEST_TASK = 'DenseCap' + + _C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use + _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False + + _C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0 + _C.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES = 14 + _C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False + + # Backbones + _C.MODEL.VIT_LAYERS = 12 + + # Text Decoder + _C.TEXT_DECODER = CN() + _C.TEXT_DECODER.VOCAB_SIZE = 30522 + _C.TEXT_DECODER.HIDDEN_SIZE = 768 + _C.TEXT_DECODER.NUM_LAYERS = 6 + _C.TEXT_DECODER.ATTENTION_HEADS = 12 + _C.TEXT_DECODER.FEEDFORWARD_SIZE = 768 * 4 + + # Multi-dataset dataloader + _C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio + _C.DATALOADER.DATASET_BS = 1 + _C.DATALOADER.DATASET_INPUT_SIZE = [1024, 1024] + _C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.1, 2.0)] + _C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (640, 800)] + _C.DATALOADER.DATASET_MAX_SIZES = [1333, 1333] + + _C.SOLVER.USE_CUSTOM_SOLVER = True + _C.SOLVER.OPTIMIZER = 'ADAMW' + _C.SOLVER.VIT_LAYER_DECAY = True + _C.SOLVER.VIT_LAYER_DECAY_RATE = 0.7 + + _C.INPUT.CUSTOM_AUG = 'EfficientDetResizeCrop' + _C.INPUT.TRAIN_SIZE = 1024 + _C.INPUT.TEST_SIZE = 1024 + _C.INPUT.SCALE_RANGE = (0.1, 2.) + # 'default' for fixed short / long edge + _C.INPUT.TEST_INPUT_TYPE = 'default' + + _C.FIND_UNUSED_PARAM = True + _C.USE_ACT_CHECKPOINT = True diff --git a/projects/videochat/models/grit_src/grit/custom_solver.py b/projects/videochat/models/grit_src/grit/custom_solver.py new file mode 100644 index 0000000000..34f02b78bc --- /dev/null +++ b/projects/videochat/models/grit_src/grit/custom_solver.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob +# /main/detic/custom_solver.py +import itertools +from typing import Any, Dict, List, Set + +import torch +from detectron2.config import CfgNode +from detectron2.solver.build import maybe_add_gradient_clipping + + +def build_custom_optimizer(cfg: CfgNode, + model: torch.nn.Module) -> torch.optim.Optimizer: + params: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + optimizer_type = cfg.SOLVER.OPTIMIZER + + for key, value in model.named_parameters(recurse=True): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + + if cfg.SOLVER.VIT_LAYER_DECAY: + lr = lr * get_vit_lr_decay_rate( + key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS) + + param = {'params': [value], 'lr': lr} + if optimizer_type != 'ADAMW': + param['weight_decay'] = weight_decay + params += [param] + + def maybe_add_full_model_gradient_clipping( + optim): # optim: the optimizer class + # detectron2 doesn't have full model gradient clipping now + clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE + enable = ( + cfg.SOLVER.CLIP_GRADIENTS.ENABLED + and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == 'full_model' + and clip_norm_val > 0.0) + + class FullModelGradientClippingOptimizer(optim): + + def step(self, closure=None): + all_params = itertools.chain( + *[x['params'] for x in self.param_groups]) + torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) + super().step(closure=closure) + + return FullModelGradientClippingOptimizer if enable else optim + + if optimizer_type == 'SGD': + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( + params, + cfg.SOLVER.BASE_LR, + momentum=cfg.SOLVER.MOMENTUM, + nesterov=cfg.SOLVER.NESTEROV) + elif optimizer_type == 'ADAMW': + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( + params, cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) + else: + raise NotImplementedError(f'no optimizer type {optimizer_type}') + if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == 'full_model': + optimizer = maybe_add_gradient_clipping(cfg, optimizer) + return optimizer + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith('backbone'): + if '.pos_embed' in name or '.patch_embed' in name: + layer_id = 0 + elif '.blocks.' in name and '.residual.' not in name: + layer_id = int(name[name.find('.blocks.'):].split('.')[2]) + 1 + + return lr_decay_rate**(num_layers + 1 - layer_id) diff --git a/projects/videochat/models/grit_src/grit/data/custom_build_augmentation.py b/projects/videochat/models/grit_src/grit/data/custom_build_augmentation.py new file mode 100644 index 0000000000..e03d4c8539 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/custom_build_augmentation.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.data import transforms as T + +from .transforms.custom_augmentation_impl import EfficientDetResizeCrop + + +def build_custom_augmentation(cfg, + is_train, + scale=None, + size=None, + min_size=None, + max_size=None): + """Create a list of default :class:`Augmentation` from config. Now it + includes resizing and flipping. + + Returns: + list[Augmentation] + """ + if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge': + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN \ + if min_size is None else min_size + max_size = cfg.INPUT.MAX_SIZE_TRAIN \ + if max_size is None else max_size + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = 'choice' + augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] + elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop': + if is_train: + scale = cfg.INPUT.SCALE_RANGE if scale is None else scale + size = cfg.INPUT.TRAIN_SIZE if size is None else size + else: + scale = (1, 1) + size = cfg.INPUT.TEST_SIZE + augmentation = [EfficientDetResizeCrop(size, scale)] + else: + assert 0, cfg.INPUT.CUSTOM_AUG + + if is_train: + augmentation.append(T.RandomFlip()) + return augmentation + + +build_custom_transform_gen = build_custom_augmentation +""" +Alias for backward-compatibility. +""" diff --git a/projects/videochat/models/grit_src/grit/data/custom_dataset_dataloader.py b/projects/videochat/models/grit_src/grit/data/custom_dataset_dataloader.py new file mode 100644 index 0000000000..06306e655e --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/custom_dataset_dataloader.py @@ -0,0 +1,274 @@ +# Copyright (c) Facebook, Inc. and its affiliates. Modified by Jialian Wu +# from https://github.com/facebookresearch/Detic/blob/main/detic/data +# /custom_dataset_dataloader.py +import itertools +import operator +from typing import Optional + +import torch +import torch.utils.data +from detectron2.config import configurable +from detectron2.data.build import (check_metadata_consistency, + filter_images_with_few_keypoints, + filter_images_with_only_crowd_annotations, + get_detection_dataset_dicts, + print_instances_class_histogram, + worker_init_reset_seed) +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog +from detectron2.data.common import DatasetFromList, MapDataset +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.data.samplers import TrainingSampler +from detectron2.utils import comm +from detectron2.utils.comm import get_world_size +from torch.utils.data.sampler import Sampler + + +def _custom_train_loader_from_config(cfg, + mapper=None, + *, + dataset=None, + sampler=None): + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + if 'MultiDataset' in sampler_name: + dataset_dicts = get_detection_dataset_dicts_with_source( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN + if cfg.MODEL.LOAD_PROPOSALS else None, + ) + else: + dataset_dicts = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN + if cfg.MODEL.LOAD_PROPOSALS else None, + ) + + if mapper is None: + mapper = DatasetMapper(cfg, True) + + if sampler is not None: + pass + elif sampler_name == 'TrainingSampler': + sampler = TrainingSampler(len(dataset)) + elif sampler_name == 'MultiDatasetSampler': + sampler = MultiDatasetSampler( + dataset_dicts, + dataset_ratio=cfg.DATALOADER.DATASET_RATIO, + ) + else: + raise ValueError('Unknown training sampler: {}'.format(sampler_name)) + + return { + 'dataset': dataset_dicts, + 'sampler': sampler, + 'mapper': mapper, + 'total_batch_size': cfg.SOLVER.IMS_PER_BATCH, + 'num_workers': cfg.DATALOADER.NUM_WORKERS, + 'dataset_bs': cfg.DATALOADER.DATASET_BS, + 'num_datasets': len(cfg.DATASETS.TRAIN) + } + + +@configurable(from_config=_custom_train_loader_from_config) +def build_custom_train_loader(dataset, + *, + mapper, + sampler, + total_batch_size=16, + num_workers=0, + num_datasets=1, + dataset_bs=1): + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + if sampler is None: + sampler = TrainingSampler(len(dataset)) + assert isinstance(sampler, torch.utils.data.sampler.Sampler) + + return build_dataset_batch_data_loader( + dataset_bs, + dataset, + sampler, + total_batch_size, + num_datasets=num_datasets, + num_workers=num_workers, + ) + + +def build_dataset_batch_data_loader(dataset_bs, + dataset, + sampler, + total_batch_size, + num_datasets, + num_workers=0): + world_size = get_world_size() + assert ( + total_batch_size > 0 and total_batch_size % world_size == 0 + ), 'Total batch size ({}) must be divisible by the number of gpus ({}).' \ + .format(total_batch_size, world_size) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + num_workers=num_workers, + batch_sampler=None, + collate_fn=operator.itemgetter( + 0), # don't batch, but yield individual elements + worker_init_fn=worker_init_reset_seed, + ) + + if num_datasets > 1: + return MultiDatasets(data_loader, dataset_bs, num_datasets) + else: + return SingleDataset(data_loader, dataset_bs) + + +def get_detection_dataset_dicts_with_source(dataset_names, + filter_empty=True, + min_keypoints=0, + proposal_files=None): + assert len(dataset_names) + dataset_dicts = [ + DatasetCatalog.get(dataset_name) for dataset_name in dataset_names + ] + for dataset_name, dicts in zip(dataset_names, dataset_dicts): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + + for source_id, (dataset_name, dicts) in \ + enumerate(zip(dataset_names, dataset_dicts)): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + for d in dicts: + d['dataset_source'] = source_id + + if 'annotations' in dicts[0]: + try: + class_names = MetadataCatalog.get(dataset_name).thing_classes + check_metadata_consistency('thing_classes', dataset_name) + print_instances_class_histogram(dicts, class_names) + except AttributeError: + pass + + assert proposal_files is None + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = 'annotations' in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations( + dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints( + dataset_dicts, min_keypoints) + + return dataset_dicts + + +class MultiDatasetSampler(Sampler): + + def __init__( + self, + dataset_dicts, + dataset_ratio, + seed: Optional[int] = None, + ): + sizes = [0 for _ in range(len(dataset_ratio))] + for d in dataset_dicts: + sizes[d['dataset_source']] += 1 + print('dataset sizes', sizes) + self.sizes = sizes + assert len(dataset_ratio) == len(sizes), \ + 'length of dataset ' \ + 'ratio {} should be equal to number if dataset {}'.format( + len(dataset_ratio), len(sizes) + ) + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + self.dataset_ids = torch.tensor( + [d['dataset_source'] for d in dataset_dicts], dtype=torch.long) + self.dataset_ratio = dataset_ratio + + dataset_weight = [ + torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) + for i, (r, s) in enumerate(zip(dataset_ratio, sizes)) + ] + dataset_weight = torch.cat(dataset_weight) + + self.weights = dataset_weight + self.sample_epoch_size = len(self.weights) + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, + self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + if len(self.dataset_ratio) > 1: + # multiple datasets + ids = torch.multinomial( + self.weights, + self.sample_epoch_size, + generator=g, + replacement=True) + # nums = [(self.dataset_ids[ids] == i).sum().int().item() + # for i in range(len(self.sizes))] + yield from ids + else: + # single dataset + yield from torch.randperm(self.sizes[0], generator=g).tolist() + + +class SingleDataset(torch.utils.data.IterableDataset): + + def __init__(self, dataset, batch_sizes): + self.dataset = dataset + self.batch_sizes = batch_sizes + self._buckets = [[] for _ in range(2)] + + def __iter__(self): + for d in self.dataset: + w, h = d['width'], d['height'] + aspect_ratio_bucket_id = 0 if w > h else 1 + bucket_id = aspect_ratio_bucket_id + bucket = self._buckets[bucket_id] + bucket.append(d) + if len(bucket) == self.batch_sizes: + yield bucket[:] + del bucket[:] + + +class MultiDatasets(torch.utils.data.IterableDataset): + + def __init__(self, dataset, batch_sizes, num_datasets): + self.dataset = dataset + self.batch_sizes = batch_sizes + self._buckets = [[] for _ in range(2 * num_datasets)] + self.iter_idx = 0 + self.num_datasets = num_datasets + + def __iter__(self): + for d in self.dataset: + w, h = d['width'], d['height'] + aspect_ratio_bucket_id = 0 if w > h else 1 + bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id + bucket = self._buckets[bucket_id] + if len(bucket) < self.batch_sizes: + bucket.append(d) + selected_dataset = self.iter_idx % self.num_datasets + if len(bucket) == self.batch_sizes \ + and selected_dataset == d['dataset_source']: + self.iter_idx += 1 + yield bucket[:] + del bucket[:] diff --git a/projects/videochat/models/grit_src/grit/data/custom_dataset_mapper.py b/projects/videochat/models/grit_src/grit/data/custom_dataset_mapper.py new file mode 100644 index 0000000000..decce6f42b --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/custom_dataset_mapper.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob +# /main/detic/data/custom_dataset_mapper.py +import copy +import logging +from itertools import compress + +import numpy as np +import torch +from detectron2.config import configurable +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.data.dataset_mapper import DatasetMapper + +from .custom_build_augmentation import build_custom_augmentation + +__all__ = ['CustomDatasetMapper', 'ObjDescription'] +logger = logging.getLogger(__name__) + + +class CustomDatasetMapper(DatasetMapper): + + @configurable + def __init__(self, is_train: bool, dataset_augs=[], **kwargs): + if is_train: + self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs] + super().__init__(is_train, **kwargs) + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + ret = super().from_config(cfg, is_train) + if is_train: + if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop': + dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE + dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE + ret['dataset_augs'] = [ + build_custom_augmentation(cfg, True, scale, size) + for scale, size in zip(dataset_scales, dataset_sizes) + ] + else: + assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge' + min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES + max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES + ret['dataset_augs'] = [ + build_custom_augmentation( + cfg, True, min_size=mi, max_size=ma) + for mi, ma in zip(min_sizes, max_sizes) + ] + else: + ret['dataset_augs'] = [] + + return ret + + def __call__(self, dataset_dict): + dataset_dict_out = self.prepare_data(dataset_dict) + + # When augmented image is too small, do re-augmentation + retry = 0 + while (dataset_dict_out['image'].shape[1] < 32 + or dataset_dict_out['image'].shape[2] < 32): + retry += 1 + if retry == 100: + logger.info( + 'Retry 100 times for augmentation. Make sure the image ' + 'size is not too small. ') + logger.info('Find image information below') + logger.info(dataset_dict) + dataset_dict_out = self.prepare_data(dataset_dict) + + return dataset_dict_out + + def prepare_data(self, dataset_dict_in): + dataset_dict = copy.deepcopy(dataset_dict_in) + if 'file_name' in dataset_dict: + ori_image = utils.read_image( + dataset_dict['file_name'], format=self.image_format) + else: + ori_image, _, _ = self.tar_dataset[dataset_dict['tar_index']] + ori_image = utils._apply_exif_orientation(ori_image) + ori_image = utils.convert_PIL_to_numpy(ori_image, + self.image_format) + utils.check_image_size(dataset_dict, ori_image) + + aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=None) + if self.is_train: + transforms = \ + self.dataset_augs[dataset_dict['dataset_source']](aug_input) + else: + transforms = self.augmentations(aug_input) + image, _ = aug_input.image, aug_input.sem_seg + + image_shape = image.shape[:2] + dataset_dict['image'] = torch.as_tensor( + np.ascontiguousarray(image.transpose(2, 0, 1))) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop('annotations', None) + return dataset_dict + + if 'annotations' in dataset_dict: + if len(dataset_dict['annotations']) > 0: + object_descriptions = [ + an['object_description'] + for an in dataset_dict['annotations'] + ] + else: + object_descriptions = [] + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict['annotations']: + if not self.use_instance_mask: + anno.pop('segmentation', None) + if not self.use_keypoint: + anno.pop('keypoints', None) + + all_annos = [(utils.transform_instance_annotations( + obj, + transforms, + image_shape, + keypoint_hflip_indices=self.keypoint_hflip_indices, + ), obj.get('iscrowd', 0)) + for obj in dataset_dict.pop('annotations')] + annos = [ann[0] for ann in all_annos if ann[1] == 0] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.instance_mask_format) + + instances.gt_object_descriptions = ObjDescription( + object_descriptions) + + del all_annos + if self.recompute_boxes: + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict['instances'] = utils.filter_empty_instances(instances) + + return dataset_dict + + +class ObjDescription: + + def __init__(self, object_descriptions): + self.data = object_descriptions + + def __getitem__(self, item): + assert type(item) == torch.Tensor + assert item.dim() == 1 + if len(item) > 0: + assert item.dtype == torch.int64 or item.dtype == torch.bool + if item.dtype == torch.int64: + return ObjDescription([self.data[x.item()] for x in item]) + elif item.dtype == torch.bool: + return ObjDescription(list(compress(self.data, item))) + + return ObjDescription(list(compress(self.data, item))) + + def __len__(self): + return len(self.data) + + def __repr__(self): + return 'ObjDescription({})'.format(self.data) diff --git a/projects/videochat/models/grit_src/grit/data/datasets/grit_coco.py b/projects/videochat/models/grit_src/grit/data/datasets/grit_coco.py new file mode 100644 index 0000000000..a34555af86 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/datasets/grit_coco.py @@ -0,0 +1,115 @@ +import logging +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from fvcore.common.file_io import PathManager +from fvcore.common.timer import Timer +from lvis import LVIS + +logger = logging.getLogger(__name__) + +__all__ = ['load_GRiTcoco_json', 'register_GRiTcoco_instances'] + + +def register_GRiTcoco_instances(name, metadata, json_file, image_root): + """""" + DatasetCatalog.register( + name, lambda: load_GRiTcoco_json(json_file, image_root, name)) + MetadataCatalog.get(name).set( + json_file=json_file, + image_root=image_root, + evaluator_type='coco', + **metadata) + + +def get_GRiTcoco_meta(): + categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}] + categories = sorted(categories, key=lambda x: x['id']) + thing_classes = [k['name'] for k in categories] + meta = {'thing_classes': thing_classes} + return meta + + +def load_GRiTcoco_json(json_file, image_root, dataset_name=None): + """Load COCO class name text for object description for GRiT.""" + + json_file = PathManager.get_local_path(json_file) + + timer = Timer() + lvis_api = LVIS(json_file) + if timer.seconds() > 1: + logger.info('Loading {} takes {:.2f} seconds.'.format( + json_file, timer.seconds())) + + class_names = {} + sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id']) + for x in sort_cat: + class_names[x['id']] = x['name'] + + img_ids = sorted(lvis_api.imgs.keys()) + imgs = lvis_api.load_imgs(img_ids) + anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] + + ann_ids = [ann['id'] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), \ + "Annotation ids in '{}' are not unique".format(json_file) + + imgs_anns = list(zip(imgs, anns)) + logger.info('Loaded {} images in the LVIS v1 format from {}'.format( + len(imgs_anns), json_file)) + + dataset_dicts = [] + + for (img_dict, anno_dict_list) in imgs_anns: + record = {} + if 'file_name' in img_dict: + file_name = img_dict['file_name'] + record['file_name'] = os.path.join(image_root, file_name) + + record['height'] = int(img_dict['height']) + record['width'] = int(img_dict['width']) + image_id = record['image_id'] = img_dict['id'] + + objs = [] + for anno in anno_dict_list: + assert anno['image_id'] == image_id + if anno.get('iscrowd', 0) > 0: + continue + obj = {'bbox': anno['bbox'], 'bbox_mode': BoxMode.XYWH_ABS} + obj['category_id'] = 0 + obj['object_description'] = class_names[anno['category_id']] + if 'segmentation' in anno: + segm = anno['segmentation'] + valid_segm = [ + poly for poly in segm + if len(poly) % 2 == 0 and len(poly) >= 6 + ] + if not len(segm) == len(valid_segm): + print('Annotation contains an invalid polygon with < 3 ' + 'points ') + assert len(segm) > 0 + obj['segmentation'] = segm + objs.append(obj) + record['annotations'] = objs + if len(record['annotations']) == 0: + continue + record['task'] = 'ObjectDet' + dataset_dicts.append(record) + + return dataset_dicts + + +_CUSTOM_SPLITS_LVIS = { + 'GRiT_coco2017_train': + ('coco/train2017/', 'coco/annotations/instances_train2017.json'), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items(): + register_GRiTcoco_instances( + key, + get_GRiTcoco_meta(), + os.path.join('datasets', json_file) + if '://' not in json_file else json_file, + os.path.join('datasets', image_root), + ) diff --git a/projects/videochat/models/grit_src/grit/data/datasets/object365.py b/projects/videochat/models/grit_src/grit/data/datasets/object365.py new file mode 100644 index 0000000000..9eb8543a14 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/datasets/object365.py @@ -0,0 +1,112 @@ +import logging +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from fvcore.common.file_io import PathManager +from fvcore.common.timer import Timer +from lvis import LVIS + +logger = logging.getLogger(__name__) + +__all__ = ['load_o365_json', 'register_o365_instances'] + + +def register_o365_instances(name, metadata, json_file, image_root): + DatasetCatalog.register( + name, lambda: load_o365_json(json_file, image_root, name)) + MetadataCatalog.get(name).set( + json_file=json_file, + image_root=image_root, + evaluator_type='lvis', + **metadata) + + +def get_o365_meta(): + categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}] + o365_categories = sorted(categories, key=lambda x: x['id']) + thing_classes = [k['name'] for k in o365_categories] + meta = {'thing_classes': thing_classes} + return meta + + +def load_o365_json(json_file, image_root, dataset_name=None): + """Load Object365 class name text for object description for GRiT.""" + + json_file = PathManager.get_local_path(json_file) + + timer = Timer() + lvis_api = LVIS(json_file) + if timer.seconds() > 1: + logger.info('Loading {} takes {:.2f} seconds.'.format( + json_file, timer.seconds())) + + class_names = {} + sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id']) + for x in sort_cat: + if '/' in x['name']: + text = '' + for xx in x['name'].split('/'): + text += xx + text += ' ' + text = text[:-1] + else: + text = x['name'] + class_names[x['id']] = text + + img_ids = sorted(lvis_api.imgs.keys()) + imgs = lvis_api.load_imgs(img_ids) + anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] + + ann_ids = [ann['id'] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), \ + "Annotation ids in '{}' are not unique".format(json_file) + + imgs_anns = list(zip(imgs, anns)) + logger.info('Loaded {} images in the LVIS v1 format from {}'.format( + len(imgs_anns), json_file)) + + dataset_dicts = [] + + for (img_dict, anno_dict_list) in imgs_anns: + record = {} + if 'file_name' in img_dict: + file_name = img_dict['file_name'] + record['file_name'] = os.path.join(image_root, file_name) + + record['height'] = int(img_dict['height']) + record['width'] = int(img_dict['width']) + image_id = record['image_id'] = img_dict['id'] + + objs = [] + for anno in anno_dict_list: + assert anno['image_id'] == image_id + if anno.get('iscrowd', 0) > 0: + continue + obj = {'bbox': anno['bbox'], 'bbox_mode': BoxMode.XYWH_ABS} + obj['category_id'] = 0 + obj['object_description'] = class_names[anno['category_id']] + + objs.append(obj) + record['annotations'] = objs + if len(record['annotations']) == 0: + continue + record['task'] = 'ObjectDet' + dataset_dicts.append(record) + + return dataset_dicts + + +_CUSTOM_SPLITS_LVIS = { + 'object365_train': + ('object365/images/train/', 'object365/annotations/train_v1.json'), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items(): + register_o365_instances( + key, + get_o365_meta(), + os.path.join('datasets', json_file) + if '://' not in json_file else json_file, + os.path.join('datasets', image_root), + ) diff --git a/projects/videochat/models/grit_src/grit/data/datasets/vg.py b/projects/videochat/models/grit_src/grit/data/datasets/vg.py new file mode 100644 index 0000000000..2757c8663b --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/datasets/vg.py @@ -0,0 +1,99 @@ +import logging +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from fvcore.common.file_io import PathManager +from fvcore.common.timer import Timer +from lvis import LVIS + +logger = logging.getLogger(__name__) + +__all__ = ['load_vg_json', 'register_vg_instances'] + + +def register_vg_instances(name, metadata, json_file, image_root): + """""" + DatasetCatalog.register(name, + lambda: load_vg_json(json_file, image_root, name)) + MetadataCatalog.get(name).set( + json_file=json_file, + image_root=image_root, + evaluator_type='vg', + **metadata) + + +def get_vg_meta(): + categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}] + vg_categories = sorted(categories, key=lambda x: x['id']) + thing_classes = [k['name'] for k in vg_categories] + meta = {'thing_classes': thing_classes} + return meta + + +def load_vg_json(json_file, image_root, dataset_name=None): + + json_file = PathManager.get_local_path(json_file) + + timer = Timer() + lvis_api = LVIS(json_file) + if timer.seconds() > 1: + logger.info('Loading {} takes {:.2f} seconds.'.format( + json_file, timer.seconds())) + + img_ids = sorted(lvis_api.imgs.keys()) + imgs = lvis_api.load_imgs(img_ids) + anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] + + ann_ids = [ann['id'] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), \ + "Annotation ids in '{}' are not unique".format(json_file) + + imgs_anns = list(zip(imgs, anns)) + logger.info('Loaded {} images in the LVIS v1 format from {}'.format( + len(imgs_anns), json_file)) + + dataset_dicts = [] + + for (img_dict, anno_dict_list) in imgs_anns: + record = {} + if 'file_name' in img_dict: + file_name = img_dict['file_name'] + record['file_name'] = os.path.join(image_root, file_name) + + record['height'] = int(img_dict['height']) + record['width'] = int(img_dict['width']) + image_id = record['image_id'] = img_dict['id'] + + objs = [] + for anno in anno_dict_list: + assert anno['image_id'] == image_id + if anno.get('iscrowd', 0) > 0: + continue + obj = {'bbox': anno['bbox'], 'bbox_mode': BoxMode.XYWH_ABS} + obj['category_id'] = 0 + obj['object_description'] = anno['caption'] + + objs.append(obj) + record['annotations'] = objs + if len(record['annotations']) == 0: + continue + record['task'] = 'DenseCap' + dataset_dicts.append(record) + + return dataset_dicts + + +_CUSTOM_SPLITS_LVIS = { + 'vg_train': ('vg/images', 'vg/annotations/train.json'), + 'vg_test': ('vg/images', 'vg/annotations/test.json'), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items(): + register_vg_instances( + key, + get_vg_meta(), + os.path.join('datasets', json_file) + if '://' not in json_file else json_file, + os.path.join('datasets', image_root), + ) diff --git a/projects/videochat/models/grit_src/grit/data/transforms/custom_augmentation_impl.py b/projects/videochat/models/grit_src/grit/data/transforms/custom_augmentation_impl.py new file mode 100644 index 0000000000..edce8ec6f6 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/transforms/custom_augmentation_impl.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved Part +# of the code is from https://github.com/rwightman/efficientdet-pytorch/blob +# /master/effdet/data/transforms.py Modified by Xingyi Zhou The original +# code is under Apache-2.0 License +import numpy as np +from detectron2.data.transforms.augmentation import Augmentation +from PIL import Image + +from .custom_transform import EfficientDetResizeCropTransform + +__all__ = [ + 'EfficientDetResizeCrop', +] + + +class EfficientDetResizeCrop(Augmentation): + """Scale the shorter edge to the given size, with a limit of `max_size` on + the longer edge. + + If `max_size` is reached, then downscale so that the longer edge does not + exceed max_size. + """ + + def __init__(self, size, scale, interp=Image.BILINEAR): + """""" + super().__init__() + self.target_size = (size, size) + self.scale = scale + self.interp = interp + + def get_transform(self, img): + # Select a random scale factor. + scale_factor = np.random.uniform(*self.scale) + scaled_target_height = scale_factor * self.target_size[0] + scaled_target_width = scale_factor * self.target_size[1] + # Recompute the accurate scale_factor using rounded scaled image size. + width, height = img.shape[1], img.shape[0] + img_scale_y = scaled_target_height / height + img_scale_x = scaled_target_width / width + img_scale = min(img_scale_y, img_scale_x) + + # Select non-zero random offset (x, y) if scaled image is larger + # than target size + scaled_h = int(height * img_scale) + scaled_w = int(width * img_scale) + offset_y = scaled_h - self.target_size[0] + offset_x = scaled_w - self.target_size[1] + offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1)) + offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1)) + return EfficientDetResizeCropTransform(scaled_h, scaled_w, offset_y, + offset_x, img_scale, + self.target_size, self.interp) diff --git a/projects/videochat/models/grit_src/grit/data/transforms/custom_transform.py b/projects/videochat/models/grit_src/grit/data/transforms/custom_transform.py new file mode 100644 index 0000000000..c9840cac6b --- /dev/null +++ b/projects/videochat/models/grit_src/grit/data/transforms/custom_transform.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved Part +# of the code is from https://github.com/rwightman/efficientdet-pytorch/blob +# /master/effdet/data/transforms.py Modified by Xingyi Zhou The original +# code is under Apache-2.0 License +import numpy as np +import torch +import torch.nn.functional as F +from fvcore.transforms.transform import Transform +from PIL import Image + +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + pass + +__all__ = [ + 'EfficientDetResizeCropTransform', +] + + +class EfficientDetResizeCropTransform(Transform): + """""" + + def __init__(self, + scaled_h, + scaled_w, + offset_y, + offset_x, + img_scale, + target_size, + interp=None): + """ + Args: + h, w (int): original image size + new_h, new_w (int): new image size + interp: PIL interpolation methods, defaults to bilinear. + """ + super().__init__() + if interp is None: + interp = Image.BILINEAR + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + assert len(img.shape) <= 4 + + if img.dtype == np.uint8: + pil_image = Image.fromarray(img) + interp_method = interp if interp is not None else self.interp + pil_image = pil_image.resize((self.scaled_w, self.scaled_h), + interp_method) + ret = np.asarray(pil_image) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + if len(ret.shape) <= 3: + ret = ret[self.offset_y:lower, self.offset_x:right] + else: + ret = ret[..., self.offset_y:lower, self.offset_x:right, :] + else: + # PIL only supports uint8 + img = torch.from_numpy(img) + shape = list(img.shape) + shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] + img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw + _PIL_RESIZE_TO_INTERPOLATE_MODE = { + Image.BILINEAR: 'bilinear', + Image.BICUBIC: 'bicubic' + } + mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp] + img = F.interpolate( + img, (self.scaled_h, self.scaled_w), + mode=mode, + align_corners=False) + shape[:2] = (self.scaled_h, self.scaled_w) + ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + if len(ret.shape) <= 3: + ret = ret[self.offset_y:lower, self.offset_x:right] + else: + ret = ret[..., self.offset_y:lower, self.offset_x:right, :] + return ret + + def apply_coords(self, coords): + coords[:, 0] = coords[:, 0] * self.img_scale + coords[:, 1] = coords[:, 1] * self.img_scale + coords[:, 0] -= self.offset_x + coords[:, 1] -= self.offset_y + return coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + def inverse(self): + raise NotImplementedError + + def inverse_apply_coords(self, coords): + coords[:, 0] += self.offset_x + coords[:, 1] += self.offset_y + coords[:, 0] = coords[:, 0] / self.img_scale + coords[:, 1] = coords[:, 1] / self.img_scale + return coords + + def inverse_apply_box(self, box: np.ndarray) -> np.ndarray: + """""" + idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten() + coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2) + coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2)) + minxy = coords.min(axis=1) + maxxy = coords.max(axis=1) + trans_boxes = np.concatenate((minxy, maxxy), axis=1) + return trans_boxes diff --git a/projects/videochat/models/grit_src/grit/evaluation/eval.py b/projects/videochat/models/grit_src/grit/evaluation/eval.py new file mode 100644 index 0000000000..7a42bc017a --- /dev/null +++ b/projects/videochat/models/grit_src/grit/evaluation/eval.py @@ -0,0 +1,160 @@ +import itertools +import json +import os + +from detectron2.evaluation.coco_evaluation import ( + COCOEvaluator, _evaluate_predictions_on_coco) +from detectron2.structures import BoxMode +from detectron2.utils.file_io import PathManager + + +class GRiTCOCOEvaluator(COCOEvaluator): + + def process(self, inputs, outputs): + for input, output in zip(inputs, outputs): + prediction = {'image_id': input['image_id']} + + if 'instances' in output: + instances = output['instances'].to(self._cpu_device) + prediction['instances'] = instances_to_coco_json( + instances, input['image_id']) + + if len(prediction) > 1: + self._predictions.append(prediction) + + def _eval_predictions(self, predictions, img_ids=None): + self._logger.info('Preparing results for COCO format ...') + coco_results = list( + itertools.chain(*[x['instances'] for x in predictions])) + tasks = self._tasks or self._tasks_from_predictions(coco_results) + + if self._output_dir: + file_path = os.path.join(self._output_dir, + 'coco_instances_results.json') + self._logger.info('Saving results to {}'.format(file_path)) + with PathManager.open(file_path, 'w') as f: + f.write(json.dumps(coco_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info('Annotations are not available for evaluation.') + return + + self._logger.info('Evaluating predictions with {} COCO API...'.format( + 'unofficial' if self._use_fast_impl else 'official')) + + coco_results = self.convert_classname_to_id(coco_results) + + for task in sorted(tasks): + assert task in {'bbox', 'segm', + 'keypoints'}, f'Got unknown task: {task}!' + coco_eval = ( + _evaluate_predictions_on_coco( + self._coco_api, + coco_results, + task, + kpt_oks_sigmas=self._kpt_oks_sigmas, + use_fast_impl=self._use_fast_impl, + img_ids=img_ids, + max_dets_per_image=self._max_dets_per_image, + ) if len(coco_results) > 0 else + None # cocoapi does not handle empty results very well + ) + + res = self._derive_coco_results( + coco_eval, + task, + class_names=self._metadata.get('thing_classes')) + self._results[task] = res + + def convert_classname_to_id(self, results): + outputs = [] + class_name_to_id = {} + categories = sorted( + self._coco_api.dataset['categories'], key=lambda x: x['id']) + + for cat in categories: + class_name_to_id[cat['name']] = cat['id'] + + for pred in results: + if pred['object_descriptions'] in class_name_to_id: + pred['category_id'] = class_name_to_id[ + pred['object_descriptions']] + del pred['object_descriptions'] + outputs.append(pred) + + return outputs + + +class GRiTVGEvaluator(COCOEvaluator): + + def process(self, inputs, outputs): + for input, output in zip(inputs, outputs): + assert input['image_id'] == int( + input['file_name'].split('/')[-1].split('.')[0]) + prediction = {'image_id': input['image_id']} + + if 'instances' in output: + instances = output['instances'].to(self._cpu_device) + prediction['instances'] = instances_to_coco_json( + instances, input['image_id'], output_logits=True) + h = input['height'] + w = input['width'] + scale = 720.0 / max(h, w) + scaled_inst = [] + for inst in prediction['instances']: + inst['bbox'][0] = inst['bbox'][0] * scale + inst['bbox'][1] = inst['bbox'][1] * scale + inst['bbox'][2] = inst['bbox'][2] * scale + inst['bbox'][3] = inst['bbox'][3] * scale + scaled_inst.append(inst) + if len(scaled_inst) > 0: + prediction['instances'] = scaled_inst + if len(prediction) > 1: + self._predictions.append(prediction) + + def _eval_predictions(self, predictions, img_ids=None): + """This is only for saving the results to json file.""" + self._logger.info('Preparing results for COCO format ...') + coco_results = list( + itertools.chain(*[x['instances'] for x in predictions])) + + if self._output_dir: + file_path = os.path.join(self._output_dir, + 'vg_instances_results.json') + self._logger.info('Saving results to {}'.format(file_path)) + with PathManager.open(file_path, 'w') as f: + f.write(json.dumps(coco_results)) + f.flush() + + +def instances_to_coco_json(instances, img_id, output_logits=False): + """Add object_descriptions and logit (if applicable) to detectron2's + instances_to_coco_json.""" + num_instance = len(instances) + if num_instance == 0: + return [] + + boxes = instances.pred_boxes.tensor.numpy() + boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + boxes = boxes.tolist() + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + object_descriptions = instances.pred_object_descriptions.data + if output_logits: + logits = instances.logits.tolist() + + results = [] + for k in range(num_instance): + result = { + 'image_id': img_id, + 'category_id': classes[k], + 'bbox': boxes[k], + 'score': scores[k], + 'object_descriptions': object_descriptions[k], + } + if output_logits: + result['logit'] = logits[k] + + results.append(result) + return results diff --git a/projects/videochat/models/grit_src/grit/modeling/backbone/utils.py b/projects/videochat/models/grit_src/grit/modeling/backbone/utils.py new file mode 100644 index 0000000000..cea681fa2b --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/backbone/utils.py @@ -0,0 +1,195 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved This +# code is from https://github.com/facebookresearch/detectron2/blob/main +# /detectron2/modeling/backbone/utils.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + 'window_partition', + 'window_unpartition', + 'add_decomposed_rel_pos', + 'get_abs_pos', + 'PatchEmbed', +] + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: windows: windows after partition with [B * num_windows, + window_size, window_size, C]. (Hp, Wp): padded height and width before + partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """Window unpartition into original sequences and removing padding. Args: + x. + + (tensor): input tokens with [B * num_windows, window_size, window_size, + C]. window_size (int): window size. pad_hw (Tuple): padded height and + width (Hp, Wp). hw (Tuple): original height and width (H, W) before + padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode='linear', + ) + rel_pos_resized = rel_pos_resized.reshape(-1, + max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - + k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) + rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) + + return attn + + +def get_abs_pos(abs_pos, has_cls_token, hw): + """Calculate absolute positional embeddings. If needed, resize embeddings + and remove cls_token dimension for the original embeddings. Args: abs_pos + (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. hw + (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h, w = hw + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode='bicubic', + align_corners=False, + ) + + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding.""" + + def __init__(self, + kernel_size=(16, 16), + stride=(16, 16), + padding=(0, 0), + in_chans=3, + embed_dim=768): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + def forward(self, x): + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/projects/videochat/models/grit_src/grit/modeling/backbone/vit.py b/projects/videochat/models/grit_src/grit/modeling/backbone/vit.py new file mode 100644 index 0000000000..ad731a04d1 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/backbone/vit.py @@ -0,0 +1,608 @@ +# Modified by Jialian Wu from https://github.com/facebookresearch/detectron2 +# /blob/main/detectron2/modeling/backbone/vit.py +import logging +import math +from functools import partial + +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from detectron2.layers import CNNBlockBase, Conv2d, ShapeSpec, get_norm +from detectron2.modeling.backbone.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from timm.models.layers import DropPath, Mlp, trunc_normal_ + +from projects.videochat.models.centernet.modeling.backbone.fpn_p5 import \ + LastLevelP6P7_P5 +from .utils import (PatchEmbed, add_decomposed_rel_pos, get_abs_pos, + window_partition, window_unpartition) + +logger = logging.getLogger(__name__) + +__all__ = ['ViT'] + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: dim (int): Number of input channels. num_heads (int): Number + of attention heads. qkv_bias (bool: If True, add a learnable bias + to query, key, value. rel_pos (bool): If True, add relative + positional embeddings to the attention map. rel_pos_zero_init ( + bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the + relative positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter( + torch.zeros(2 * input_size[1] - 1, head_dim)) + + if not rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, + -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, + self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, + -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class ResBottleneckBlock(CNNBlockBase): + """The standard bottleneck residual block without the last activation + layer. + + It contains 3 conv layers with kernels 1x1, 3x3, 1x1. + """ + + def __init__( + self, + in_channels, + out_channels, + bottleneck_channels, + norm='LN', + act_layer=nn.GELU, + ): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + act_layer (callable): activation for all conv layers. + """ + super().__init__(in_channels, out_channels, 1) + + self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) + self.norm1 = get_norm(norm, bottleneck_channels) + self.act1 = act_layer() + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + 3, + padding=1, + bias=False, + ) + self.norm2 = get_norm(norm, bottleneck_channels) + self.act2 = act_layer() + + self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) + self.norm3 = get_norm(norm, out_channels) + + for layer in [self.conv1, self.conv2, self.conv3]: + weight_init.c2_msra_fill(layer) + for layer in [self.norm1, self.norm2]: + layer.weight.data.fill_(1.0) + layer.bias.data.zero_() + # zero init last norm layer. + self.norm3.weight.data.zero_() + self.norm3.bias.data.zero_() + + def forward(self, x): + out = x + for layer in self.children(): + out = layer(out) + + out = x + out + return out + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual + propagation blocks.""" + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + use_residual_block=False, + input_size=None, + ): + """ + Args: dim (int): Number of input channels. num_heads (int): Number + of attention heads in each ViT block. mlp_ratio (float): Ratio of + mlp hidden dim to embedding dim. qkv_bias (bool): If True, + add a learnable bias to query, key, value. drop_path (float): + Stochastic depth rate. norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. use_rel_pos (bool): If + True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative + positional parameters. window_size (int): Window size for window + attention blocks. If it equals 0, then not use window attention. + use_residual_block (bool): If True, use a residual block after the + MLP block. input_size (int or None): Input resolution for + calculating the relative positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else + (window_size, window_size), + ) + + self.drop_path = DropPath( + drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer) + + self.window_size = window_size + + self.use_residual_block = use_residual_block + if use_residual_block: + # Use a residual block with bottleneck channel as dim // 2 + self.residual = ResBottleneckBlock( + in_channels=dim, + out_channels=dim, + bottleneck_channels=dim // 2, + norm='LN', + act_layer=act_layer, + ) + + def forward(self, x): + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + if self.use_residual_block: + x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + return x + + +class ViT(Backbone): + """This module implements Vision Transformer (ViT) backbone in + :paper:`vitdet`. + + "Exploring Plain Vision Transformer Backbones for Object Detection", + https://arxiv.org/abs/2203.16527 + """ + + def __init__( + self, + img_size=1024, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_abs_pos=True, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + window_block_indexes=(), + residual_block_indexes=(), + use_act_checkpoint=True, + pretrain_img_size=224, + pretrain_use_cls_token=True, + out_feature='last_feat', + ): + """ + Args: img_size (int): Input image size. patch_size (int): Patch + size. in_chans (int): Number of input image channels. embed_dim ( + int): Patch embedding dimension. depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, + value. drop_path_rate (float): Stochastic depth rate. norm_layer ( + nn.Module): Normalization layer. act_layer (nn.Module): Activation + layer. use_abs_pos (bool): If True, use absolute positional + embeddings. use_rel_pos (bool): If True, add relative positional + embeddings to the attention map. rel_pos_zero_init (bool): If True, + zero initialize relative positional parameters. window_size (int): + Window size for window attention blocks. window_block_indexes ( + list): Indexes for blocks using window attention. + residual_block_indexes (list): Indexes for blocks using conv + propagation. use_act_checkpoint (bool): If True, use activation + checkpointing. pretrain_img_size (int): input image size for + pretraining models. pretrain_use_cls_token (bool): If True, + pretrainig models use class token. out_feature (str): name of the + feature from the last block. + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + self.use_act_checkpoint = use_act_checkpoint + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image + # size. + num_patches = (pretrain_img_size // patch_size) * ( + pretrain_img_size // patch_size) + num_positions = (num_patches + + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter( + torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i in window_block_indexes else 0, + use_residual_block=i in residual_block_indexes, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self._out_feature_channels = {out_feature: embed_dim} + self._out_feature_strides = {out_feature: patch_size} + self._out_features = [out_feature] + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token, + (x.shape[1], x.shape[2])) + + for blk in self.blocks: + if self.use_act_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + return x.permute(0, 3, 1, 2) + + +class ViT_FPN(Backbone): + + def __init__(self, + bottom_up=None, + top_block=None, + out_channels=None, + strides=None, + vit_out_dim=None): + super(ViT_FPN, self).__init__() + assert isinstance(bottom_up, Backbone) + self.bottom_up = bottom_up + self.top_block = top_block + + self._out_feature_strides = { + 'p{}'.format(int(math.log2(s))): s + for s in strides + } + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = { + k: out_channels + for k in self._out_features + } + self._size_divisibility = strides[2] + + self.maxpool = nn.MaxPool2d(2, stride=2) + self.fpn_stride_16_8 = nn.ConvTranspose2d( + vit_out_dim, vit_out_dim, 2, stride=2, bias=False) + self.fpn_stride8_conv1 = nn.Conv2d( + in_channels=vit_out_dim, + out_channels=out_channels, + kernel_size=1, + bias=False) + self.fpn_stride8_norm1 = nn.LayerNorm(out_channels) + self.fpn_stride8_conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.fpn_stride8_norm2 = nn.LayerNorm(out_channels) + + self.fpn_stride16_conv1 = nn.Conv2d( + in_channels=vit_out_dim, + out_channels=out_channels, + kernel_size=1, + bias=False) + self.fpn_stride16_norm1 = nn.LayerNorm(out_channels) + self.fpn_stride16_conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.fpn_stride16_norm2 = nn.LayerNorm(out_channels) + + self.fpn_stride32_conv1 = nn.Conv2d( + in_channels=vit_out_dim, + out_channels=out_channels, + kernel_size=1, + bias=False) + self.fpn_stride32_norm1 = nn.LayerNorm(out_channels) + self.fpn_stride32_conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.fpn_stride32_norm2 = nn.LayerNorm(out_channels) + + def forward(self, x): + vit_output_featuremap = self.bottom_up(x) + + stride8_feature = self.fpn_stride_16_8(vit_output_featuremap) + stride8_feature = self.fpn_stride8_norm1( + self.fpn_stride8_conv1(stride8_feature).permute(0, 2, 3, + 1)).permute( + 0, 3, 1, 2) + stride8_feature = self.fpn_stride8_norm2( + self.fpn_stride8_conv2(stride8_feature).permute(0, 2, 3, + 1)).permute( + 0, 3, 1, 2) + + stride32_feature = self.maxpool(vit_output_featuremap) + stride32_feature = self.fpn_stride32_norm1( + self.fpn_stride32_conv1(stride32_feature).permute(0, 2, 3, + 1)).permute( + 0, 3, 1, 2) + stride32_feature = self.fpn_stride32_norm2( + self.fpn_stride32_conv2(stride32_feature).permute(0, 2, 3, + 1)).permute( + 0, 3, 1, 2) + + stride16_feature = self.fpn_stride16_norm1( + self.fpn_stride16_conv1(vit_output_featuremap).permute( + 0, 2, 3, 1)).permute(0, 3, 1, 2) + stride16_feature = self.fpn_stride16_norm2( + self.fpn_stride16_conv2(stride16_feature).permute(0, 2, 3, + 1)).permute( + 0, 3, 1, 2) + + results = [stride8_feature, stride16_feature, stride32_feature] + + results.extend(self.top_block(stride32_feature)) + + assert len(self._out_features) == len(results) + fpn_out = {f: res for f, res in zip(self._out_features, results)} + + return fpn_out + + @property + def size_divisibility(self): + return self._size_divisibility + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name]) + for name in self._out_features + } + + +@BACKBONE_REGISTRY.register() +def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec): + embed_dim = 768 + vit_out_dim = embed_dim + bottom_up = ViT( # Single-scale ViT backbone + img_size=1024, + patch_size=16, + embed_dim=embed_dim, + depth=12, + num_heads=12, + drop_path_rate=0.1, + window_size=14, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_act_checkpoint=cfg.USE_ACT_CHECKPOINT, + use_rel_pos=True, + out_feature='last_feat', + ) + + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + assert out_channels == 256 or out_channels == 768 or out_channels == 1024 + backbone = ViT_FPN( + bottom_up=bottom_up, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + out_channels=out_channels, + strides=[8, 16, 32, 64, 128], + vit_out_dim=vit_out_dim) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_vit_fpn_backbone_large(cfg, input_shape: ShapeSpec): + window_block_indexes = ( + list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + + list(range(18, 23))) + embed_dim = 1024 + vit_out_dim = embed_dim + bottom_up = ViT( # Single-scale ViT backbone + img_size=1024, + patch_size=16, + embed_dim=embed_dim, + depth=24, + num_heads=16, + drop_path_rate=0.4, + window_size=14, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + window_block_indexes=window_block_indexes, + residual_block_indexes=[], + use_act_checkpoint=cfg.USE_ACT_CHECKPOINT, + use_rel_pos=True, + out_feature='last_feat', + ) + + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + assert out_channels == 256 or out_channels == 768 or out_channels == 1024 + backbone = ViT_FPN( + bottom_up=bottom_up, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + out_channels=out_channels, + strides=[8, 16, 32, 64, 128], + vit_out_dim=vit_out_dim) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_vit_fpn_backbone_huge(cfg, input_shape: ShapeSpec): + window_block_indexes = ( + list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + + list(range(24, 31))) + embed_dim = 1280 + vit_out_dim = embed_dim + bottom_up = ViT( # Single-scale ViT backbone + img_size=1024, + patch_size=16, + embed_dim=embed_dim, + depth=32, + num_heads=16, + drop_path_rate=0.5, + window_size=14, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + window_block_indexes=window_block_indexes, + residual_block_indexes=[], + use_act_checkpoint=cfg.USE_ACT_CHECKPOINT, + use_rel_pos=True, + out_feature='last_feat', + ) + + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + assert out_channels == 256 or out_channels == 768 or out_channels == 1024 + backbone = ViT_FPN( + bottom_up=bottom_up, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + out_channels=out_channels, + strides=[8, 16, 32, 64, 128], + vit_out_dim=vit_out_dim) + return backbone diff --git a/projects/videochat/models/grit_src/grit/modeling/meta_arch/grit.py b/projects/videochat/models/grit_src/grit/modeling/meta_arch/grit.py new file mode 100644 index 0000000000..652370d83b --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/meta_arch/grit.py @@ -0,0 +1,66 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from detectron2.config import configurable +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN +from detectron2.structures import Instances + + +@META_ARCH_REGISTRY.register() +class GRiT(GeneralizedRCNN): + + @configurable + def __init__(self, **kwargs): + super().__init__(**kwargs) + assert self.proposal_generator is not None + + @classmethod + def from_config(cls, cfg): + ret = super().from_config(cfg) + return ret + + def inference( + self, + batched_inputs: Tuple[Dict[str, torch.Tensor]], + detected_instances: Optional[List[Instances]] = None, + do_postprocess: bool = True, + ): + assert not self.training + assert detected_instances is None + + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + proposals, _ = self.proposal_generator(images, features, None) + results, _ = self.roi_heads(features, proposals) + if do_postprocess: + assert not torch.jit.is_scripting(), \ + 'Scripting is not supported for postprocess.' + return GRiT._postprocess(results, batched_inputs, + images.image_sizes) + else: + return results + + def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): + if not self.training: + return self.inference(batched_inputs) + + images = self.preprocess_image(batched_inputs) + + gt_instances = [x['instances'].to(self.device) for x in batched_inputs] + + targets_task = batched_inputs[0]['task'] + for anno_per_image in batched_inputs: + assert targets_task == anno_per_image['task'] + + features = self.backbone(images.tensor) + proposals, proposal_losses = self.proposal_generator( + images, features, gt_instances) + proposals, roihead_textdecoder_losses = self.roi_heads( + features, proposals, gt_instances, targets_task=targets_task) + + losses = {} + losses.update(roihead_textdecoder_losses) + losses.update(proposal_losses) + + return losses diff --git a/projects/videochat/models/grit_src/grit/modeling/roi_heads/grit_fast_rcnn.py b/projects/videochat/models/grit_src/grit/modeling/roi_heads/grit_fast_rcnn.py new file mode 100644 index 0000000000..654304f42a --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/roi_heads/grit_fast_rcnn.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. Modified by Jialian Wu +# from https://github.com/facebookresearch/Detic/blob/main/detic/modeling +# /roi_heads/detic_fast_rcnn.py +import fvcore.nn.weight_init as weight_init +import torch +from detectron2.config import configurable +from detectron2.layers import ShapeSpec, cat, nonzero_tuple +from detectron2.modeling.roi_heads.fast_rcnn import (FastRCNNOutputLayers, + _log_classification_stats) +from fvcore.nn import giou_loss, smooth_l1_loss +from torch import nn +from torch.nn import functional as F + +__all__ = ['GRiTFastRCNNOutputLayers'] + + +class GRiTFastRCNNOutputLayers(FastRCNNOutputLayers): + + @configurable + def __init__( + self, + input_shape: ShapeSpec, + **kwargs, + ): + super().__init__( + input_shape=input_shape, + **kwargs, + ) + + input_size = input_shape.channels * \ + (input_shape.width or 1) * (input_shape.height or 1) + + self.bbox_pred = nn.Sequential( + nn.Linear(input_size, input_size), nn.ReLU(inplace=True), + nn.Linear(input_size, 4)) + weight_init.c2_xavier_fill(self.bbox_pred[0]) + nn.init.normal_(self.bbox_pred[-1].weight, std=0.001) + nn.init.constant_(self.bbox_pred[-1].bias, 0) + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + return ret + + def losses(self, predictions, proposals): + scores, proposal_deltas = predictions + gt_classes = ( + cat([p.gt_classes for p in proposals], dim=0) + if len(proposals) else torch.empty(0)) + num_classes = self.num_classes + _log_classification_stats(scores, gt_classes) + + if len(proposals): + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], + dim=0) # Nx4 + assert not proposal_boxes.requires_grad, \ + 'Proposals should not require gradients!' + gt_boxes = cat( + [(p.gt_boxes if p.has('gt_boxes') else p.proposal_boxes).tensor + for p in proposals], + dim=0, + ) + else: + proposal_boxes = gt_boxes = torch.empty( + (0, 4), device=proposal_deltas.device) + + loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes) + return { + 'loss_cls': + loss_cls, + 'loss_box_reg': + self.box_reg_loss( + proposal_boxes, + gt_boxes, + proposal_deltas, + gt_classes, + num_classes=num_classes) + } + + def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] + + loss = F.cross_entropy(pred_class_logits, gt_classes, reduction='mean') + return loss + + def box_reg_loss(self, + proposal_boxes, + gt_boxes, + pred_deltas, + gt_classes, + num_classes=-1): + num_classes = num_classes if num_classes > 0 else self.num_classes + box_dim = proposal_boxes.shape[1] + fg_inds = nonzero_tuple((gt_classes >= 0) + & (gt_classes < num_classes))[0] + if pred_deltas.shape[1] == box_dim: + fg_pred_deltas = pred_deltas[fg_inds] + else: + fg_pred_deltas = pred_deltas.view(-1, self.num_classes, + box_dim)[fg_inds, + gt_classes[fg_inds]] + + if self.box_reg_loss_type == 'smooth_l1': + gt_pred_deltas = self.box2box_transform.get_deltas( + proposal_boxes[fg_inds], + gt_boxes[fg_inds], + ) + loss_box_reg = smooth_l1_loss( + fg_pred_deltas, + gt_pred_deltas, + self.smooth_l1_beta, + reduction='sum') + elif self.box_reg_loss_type == 'giou': + fg_pred_boxes = self.box2box_transform.apply_deltas( + fg_pred_deltas, proposal_boxes[fg_inds]) + loss_box_reg = giou_loss( + fg_pred_boxes, gt_boxes[fg_inds], reduction='sum') + else: + raise ValueError( + f"Invalid bbox reg loss type '{self.box_reg_loss_type}'") + return loss_box_reg / max(gt_classes.numel(), 1.0) + + def predict_probs(self, predictions, proposals): + scores = predictions[0] + num_inst_per_image = [len(p) for p in proposals] + probs = F.softmax(scores, dim=-1) + return probs.split(num_inst_per_image, dim=0) + + def forward(self, x): + if x.dim() > 2: + x = torch.flatten(x, start_dim=1) + scores = [] + + cls_scores = self.cls_score(x) + scores.append(cls_scores) + scores = torch.cat(scores, dim=1) + + proposal_deltas = self.bbox_pred(x) + return scores, proposal_deltas diff --git a/projects/videochat/models/grit_src/grit/modeling/roi_heads/grit_roi_heads.py b/projects/videochat/models/grit_src/grit/modeling/roi_heads/grit_roi_heads.py new file mode 100644 index 0000000000..8755f6a932 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/roi_heads/grit_roi_heads.py @@ -0,0 +1,562 @@ +import logging +import math +from typing import List, Tuple + +import torch +from detectron2.config import configurable +from detectron2.layers import batched_nms +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.poolers import ROIPooler +from detectron2.modeling.roi_heads.cascade_rcnn import (CascadeROIHeads, + _ScaleGradient) +from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY +from detectron2.structures import Boxes, Instances, pairwise_iou +from detectron2.utils.events import get_event_storage +from models.grit_src.grit.data.custom_dataset_mapper import ObjDescription +from transformers import BertTokenizer + +from ..soft_nms import batched_soft_nms +from ..text.load_text_token import LoadTextTokens +from ..text.text_decoder import (AutoRegressiveBeamSearch, GRiTTextDecoder, + TransformerDecoderTextualHead) +from .grit_fast_rcnn import GRiTFastRCNNOutputLayers + +logger = logging.getLogger(__name__) + + +@ROI_HEADS_REGISTRY.register() +class GRiTROIHeadsAndTextDecoder(CascadeROIHeads): + + @configurable + def __init__( + self, + *, + text_decoder_transformer, + train_task: list, + test_task: str, + mult_proposal_score: bool = False, + mask_weight: float = 1.0, + object_feat_pooler=None, + soft_nms_enabled=False, + beam_size=1, + **kwargs, + ): + super().__init__(**kwargs) + self.mult_proposal_score = mult_proposal_score + self.mask_weight = mask_weight + self.object_feat_pooler = object_feat_pooler + self.soft_nms_enabled = soft_nms_enabled + self.test_task = test_task + self.beam_size = beam_size + + tokenizer = BertTokenizer.from_pretrained( + '/mnt/data.coronaryct.1/ZhuYichen/Ask-Anything/model/bert-base' + '-uncased', + local_files_only=True) + self.tokenizer = tokenizer + + assert test_task in train_task, 'GRiT has not been trained on {} ' \ + 'task, ' \ + 'please verify the task name or ' \ + 'train a new ' \ + 'GRiT on {} task'.format(test_task, + test_task) + task_begin_tokens = {} + for i, task in enumerate(train_task): + if i == 0: + task_begin_tokens[task] = tokenizer.cls_token_id + else: + task_begin_tokens[task] = 103 + i + self.task_begin_tokens = task_begin_tokens + + beamsearch_decode = AutoRegressiveBeamSearch( + end_token_id=tokenizer.sep_token_id, + max_steps=40, + beam_size=beam_size, + objectdet=test_task == 'ObjectDet', + per_node_beam_size=1, + ) + self.text_decoder = GRiTTextDecoder( + text_decoder_transformer, + beamsearch_decode=beamsearch_decode, + begin_token_id=task_begin_tokens[test_task], + loss_type='smooth', + tokenizer=tokenizer, + ) + self.get_target_text_tokens = LoadTextTokens( + tokenizer, max_text_len=40, padding='do_not_pad') + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + text_decoder_transformer = TransformerDecoderTextualHead( + object_feature_size=cfg.MODEL.FPN.OUT_CHANNELS, + vocab_size=cfg.TEXT_DECODER.VOCAB_SIZE, + hidden_size=cfg.TEXT_DECODER.HIDDEN_SIZE, + num_layers=cfg.TEXT_DECODER.NUM_LAYERS, + attention_heads=cfg.TEXT_DECODER.ATTENTION_HEADS, + feedforward_size=cfg.TEXT_DECODER.FEEDFORWARD_SIZE, + mask_future_positions=True, + padding_idx=0, + decoder_type='bert_en', + use_act_checkpoint=cfg.USE_ACT_CHECKPOINT, + ) + ret.update({ + 'text_decoder_transformer': text_decoder_transformer, + 'train_task': cfg.MODEL.TRAIN_TASK, + 'test_task': cfg.MODEL.TEST_TASK, + 'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE, + 'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT, + 'soft_nms_enabled': cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED, + 'beam_size': cfg.MODEL.BEAM_SIZE, + }) + return ret + + @classmethod + def _init_box_head(self, cfg, input_shape): + ret = super()._init_box_head(cfg, input_shape) + del ret['box_predictors'] + cascade_bbox_reg_weights = \ + cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS + box_predictors = [] + for box_head, bbox_reg_weights in zip(ret['box_heads'], + cascade_bbox_reg_weights): + box_predictors.append( + GRiTFastRCNNOutputLayers( + cfg, + box_head.output_shape, + box2box_transform=Box2BoxTransform( + weights=bbox_reg_weights))) + ret['box_predictors'] = box_predictors + + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + object_feat_pooler = ROIPooler( + output_size=cfg.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + ret['object_feat_pooler'] = object_feat_pooler + return ret + + def check_if_all_background(self, proposals, targets, stage): + all_background = True + for proposals_per_image in proposals: + if not (proposals_per_image.gt_classes == self.num_classes).all(): + all_background = False + + if all_background: + logger.info( + 'all proposals are background at stage {}'.format(stage)) + proposals[0].proposal_boxes.tensor[0, :] = \ + targets[0].gt_boxes.tensor[0, :] + proposals[0].gt_boxes.tensor[0, :] = targets[0].gt_boxes.tensor[ + 0, :] + proposals[0].objectness_logits[0] = math.log( + (1.0 - 1e-10) / (1 - (1.0 - 1e-10))) + proposals[0].gt_classes[0] = targets[0].gt_classes[0] + proposals[0].gt_object_descriptions.data[0] = targets[ + 0].gt_object_descriptions.data[0] + if 'foreground' in proposals[0].get_fields().keys(): + proposals[0].foreground[0] = 1 + return proposals + + def _forward_box(self, + features, + proposals, + targets=None, + task='ObjectDet'): + if self.training: + proposals = self.check_if_all_background(proposals, targets, 0) + if (not self.training) and self.mult_proposal_score: + if len(proposals) > 0 and proposals[0].has('scores'): + proposal_scores = [p.get('scores') for p in proposals] + else: + proposal_scores = [ + p.get('objectness_logits') for p in proposals + ] + + features = [features[f] for f in self.box_in_features] + head_outputs = [] + prev_pred_boxes = None + image_sizes = [x.image_size for x in proposals] + + for k in range(self.num_cascade_stages): + if k > 0: + proposals = self._create_proposals_from_boxes( + prev_pred_boxes, + image_sizes, + logits=[p.objectness_logits for p in proposals]) + if self.training: + proposals = self._match_and_label_boxes_GRiT( + proposals, k, targets) + proposals = self.check_if_all_background( + proposals, targets, k) + predictions = self._run_stage(features, proposals, k) + prev_pred_boxes = self.box_predictor[k].predict_boxes( + (predictions[0], predictions[1]), proposals) + head_outputs.append( + (self.box_predictor[k], predictions, proposals)) + + if self.training: + object_features = self.object_feat_pooler( + features, [x.proposal_boxes for x in proposals]) + object_features = _ScaleGradient.apply( + object_features, 1.0 / self.num_cascade_stages) + foreground = torch.cat([x.foreground for x in proposals]) + object_features = object_features[foreground > 0] + + object_descriptions = [] + for x in proposals: + object_descriptions += x.gt_object_descriptions[ + x.foreground > 0].data + object_descriptions = ObjDescription(object_descriptions) + object_descriptions = object_descriptions.data + + if len(object_descriptions) > 0: + begin_token = self.task_begin_tokens[task] + text_decoder_inputs = self.get_target_text_tokens( + object_descriptions, object_features, begin_token) + object_features = object_features.view( + object_features.shape[0], object_features.shape[1], + -1).permute(0, 2, 1).contiguous() + text_decoder_inputs.update( + {'object_features': object_features}) + text_decoder_loss = self.text_decoder(text_decoder_inputs) + else: + text_decoder_loss = head_outputs[0][1][0].new_zeros([1])[0] + + losses = {} + storage = get_event_storage() + # RoI Head losses (For the proposal generator loss, please find + # it in grit.py) + for stage, (predictor, predictions, + proposals) in enumerate(head_outputs): + with storage.name_scope('stage{}'.format(stage)): + stage_losses = predictor.losses( + (predictions[0], predictions[1]), proposals) + losses.update({ + k + '_stage{}'.format(stage): v + for k, v in stage_losses.items() + }) + # Text Decoder loss + losses.update({'text_decoder_loss': text_decoder_loss}) + return losses + else: + scores_per_stage = [ + h[0].predict_probs(h[1], h[2]) for h in head_outputs + ] + logits_per_stage = [(h[1][0], ) for h in head_outputs] + scores = [ + sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) + for scores_per_image in zip(*scores_per_stage) + ] + logits = [ + sum(list(logits_per_image)) * (1.0 / self.num_cascade_stages) + for logits_per_image in zip(*logits_per_stage) + ] + if self.mult_proposal_score: + scores = [(s * ps[:, None])**0.5 + for s, ps in zip(scores, proposal_scores)] + predictor, predictions, proposals = head_outputs[-1] + boxes = predictor.predict_boxes((predictions[0], predictions[1]), + proposals) + assert len(boxes) == 1 + pred_instances, _ = self.fast_rcnn_inference_GRiT( + boxes, + scores, + logits, + image_sizes, + predictor.test_score_thresh, + predictor.test_nms_thresh, + predictor.test_topk_per_image, + self.soft_nms_enabled, + ) + + assert len(pred_instances) == 1, 'Only support one image' + for i, pred_instance in enumerate(pred_instances): + if len(pred_instance.pred_boxes) > 0: + object_features = self.object_feat_pooler( + features, [pred_instance.pred_boxes]) + object_features = object_features.view( + object_features.shape[0], object_features.shape[1], + -1).permute(0, 2, 1).contiguous() + text_decoder_output = self.text_decoder( + {'object_features': object_features}) + if self.beam_size > 1 and self.test_task == 'ObjectDet': + pred_boxes = [] + pred_scores = [] + pred_classes = [] + pred_object_descriptions = [] + + for beam_id in range(self.beam_size): + pred_boxes.append(pred_instance.pred_boxes.tensor) + # object score = sqrt(objectness score x + # description score) + pred_scores.append( + (pred_instance.scores * + torch.exp(text_decoder_output['logprobs']) + [:, beam_id])**0.5) + pred_classes.append(pred_instance.pred_classes) + for prediction in \ + text_decoder_output[ + 'predictions'][:, beam_id, :]: + # convert text tokens to words + description = self.tokenizer.decode( + prediction.tolist()[1:], + skip_special_tokens=True) + pred_object_descriptions.append(description) + + merged_instances = Instances(image_sizes[0]) + if torch.cat( + pred_scores, dim=0 + ).shape[0] <= predictor.test_topk_per_image: + merged_instances.scores = torch.cat( + pred_scores, dim=0) + merged_instances.pred_boxes = Boxes( + torch.cat(pred_boxes, dim=0)) + merged_instances.pred_classes = torch.cat( + pred_classes, dim=0) + merged_instances.pred_object_descriptions = \ + ObjDescription(pred_object_descriptions) + else: + pred_scores, top_idx = torch.topk( + torch.cat(pred_scores, dim=0), + predictor.test_topk_per_image) + merged_instances.scores = pred_scores + merged_instances.pred_boxes = Boxes( + torch.cat(pred_boxes, dim=0)[top_idx, :]) + merged_instances.pred_classes = torch.cat( + pred_classes, dim=0)[top_idx] + merged_instances.pred_object_descriptions = \ + ObjDescription( + ObjDescription(pred_object_descriptions)[ + top_idx].data) + + pred_instances[i] = merged_instances + else: + # object score = sqrt(objectness score x description + # score) + pred_instance.scores = \ + (pred_instance.scores * torch.exp( + text_decoder_output['logprobs'])) ** 0.5 + + pred_object_descriptions = [] + for prediction in text_decoder_output['predictions']: + # convert text tokens to words + description = self.tokenizer.decode( + prediction.tolist()[1:], + skip_special_tokens=True) + pred_object_descriptions.append(description) + pred_instance.pred_object_descriptions = \ + ObjDescription(pred_object_descriptions) + else: + pred_instance.pred_object_descriptions = ObjDescription([]) + + return pred_instances + + def forward(self, + features, + proposals, + targets=None, + targets_task='ObjectDet'): + if self.training: + proposals = self.label_and_sample_proposals(proposals, targets) + + losses = self._forward_box( + features, proposals, targets, task=targets_task) + if targets[0].has('gt_masks'): + mask_losses = self._forward_mask(features, proposals) + losses.update( + {k: v * self.mask_weight + for k, v in mask_losses.items()}) + else: + losses.update( + self._get_empty_mask_loss( + device=proposals[0].objectness_logits.device)) + return proposals, losses + else: + pred_instances = self._forward_box( + features, proposals, task=self.test_task) + pred_instances = self.forward_with_given_boxes( + features, pred_instances) + return pred_instances, {} + + @torch.no_grad() + def _match_and_label_boxes_GRiT(self, proposals, stage, targets): + """Add "gt_object_description" and "foreground" to detectron2's + _match_and_label_boxes.""" + num_fg_samples, num_bg_samples = [], [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + match_quality_matrix = pairwise_iou( + targets_per_image.gt_boxes, proposals_per_image.proposal_boxes) + # proposal_labels are 0 or 1 + matched_idxs, proposal_labels = self.proposal_matchers[stage]( + match_quality_matrix) + if len(targets_per_image) > 0: + gt_classes = targets_per_image.gt_classes[matched_idxs] + # Label unmatched proposals (0 label from matcher) as + # background (label=num_classes) + gt_classes[proposal_labels == 0] = self.num_classes + foreground = torch.ones_like(gt_classes) + foreground[proposal_labels == 0] = 0 + gt_boxes = targets_per_image.gt_boxes[matched_idxs] + gt_object_descriptions = \ + targets_per_image.gt_object_descriptions[ + matched_idxs] + else: + gt_classes = torch.zeros_like(matched_idxs) + self.num_classes + foreground = torch.zeros_like(gt_classes) + gt_boxes = Boxes( + targets_per_image.gt_boxes.tensor.new_zeros( + (len(proposals_per_image), 4))) + gt_object_descriptions = ObjDescription( + ['None' for i in range(len(proposals_per_image))]) + proposals_per_image.gt_classes = gt_classes + proposals_per_image.gt_boxes = gt_boxes + proposals_per_image.gt_object_descriptions = gt_object_descriptions + proposals_per_image.foreground = foreground + + num_fg_samples.append((proposal_labels == 1).sum().item()) + num_bg_samples.append(proposal_labels.numel() - num_fg_samples[-1]) + + # Log the number of fg/bg samples in each stage + storage = get_event_storage() + storage.put_scalar( + 'stage{}/roi_head/num_fg_samples'.format(stage), + sum(num_fg_samples) / len(num_fg_samples), + ) + storage.put_scalar( + 'stage{}/roi_head/num_bg_samples'.format(stage), + sum(num_bg_samples) / len(num_bg_samples), + ) + return proposals + + def fast_rcnn_inference_GRiT( + self, + boxes: List[torch.Tensor], + scores: List[torch.Tensor], + logits: List[torch.Tensor], + image_shapes: List[Tuple[int, int]], + score_thresh: float, + nms_thresh: float, + topk_per_image: int, + soft_nms_enabled: bool, + ): + result_per_image = [ + self.fast_rcnn_inference_single_image_GRiT(boxes_per_image, + scores_per_image, + logits_per_image, + image_shape, + score_thresh, + nms_thresh, + topk_per_image, + soft_nms_enabled) + for scores_per_image, boxes_per_image, image_shape, + logits_per_image in zip(scores, boxes, image_shapes, logits) + ] + return [x[0] + for x in result_per_image], [x[1] for x in result_per_image] + + def fast_rcnn_inference_single_image_GRiT( + self, + boxes, + scores, + logits, + image_shape: Tuple[int, int], + score_thresh: float, + nms_thresh: float, + topk_per_image: int, + soft_nms_enabled, + ): + """Add soft NMS to detectron2's fast_rcnn_inference_single_image.""" + valid_mask = torch.isfinite(boxes).all( + dim=1) & torch.isfinite(scores).all(dim=1) + if not valid_mask.all(): + boxes = boxes[valid_mask] + scores = scores[valid_mask] + logits = logits[valid_mask] + + scores = scores[:, :-1] + logits = logits[:, :-1] + num_bbox_reg_classes = boxes.shape[1] // 4 + # Convert to Boxes to use the `clip` function ... + boxes = Boxes(boxes.reshape(-1, 4)) + boxes.clip(image_shape) + boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4 + + # 1. Filter results based on detection scores. It can make NMS more + # efficient by filtering out low-confidence detections. + filter_mask = scores > score_thresh # R x K + # R' x 2. First column contains indices of the R predictions; + # Second column contains indices of classes. + filter_inds = filter_mask.nonzero() + if num_bbox_reg_classes == 1: + boxes = boxes[filter_inds[:, 0], 0] + else: + boxes = boxes[filter_mask] + scores = scores[filter_mask] + logits = logits[filter_mask] + + # 2. Apply NMS for each class independently. + if not soft_nms_enabled: + keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh) + else: + keep, soft_nms_scores = batched_soft_nms( + boxes, + scores, + filter_inds[:, 1], + 'linear', + 0.5, + nms_thresh, + 0.001, + ) + scores[keep] = soft_nms_scores + if topk_per_image >= 0: + keep = keep[:topk_per_image] + boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[ + keep] + logits = logits[keep] + + result = Instances(image_shape) + result.pred_boxes = Boxes(boxes) + result.scores = scores + result.pred_classes = filter_inds[:, 1] + result.logits = logits + return result, filter_inds[:, 0] + + def _get_empty_mask_loss(self, device): + if self.mask_on: + return { + 'loss_mask': + torch.zeros((1, ), device=device, dtype=torch.float32)[0] + } + else: + return {} + + def _create_proposals_from_boxes(self, boxes, image_sizes, logits): + boxes = [Boxes(b.detach()) for b in boxes] + proposals = [] + for boxes_per_image, image_size, logit in zip(boxes, image_sizes, + logits): + boxes_per_image.clip(image_size) + if self.training: + inds = boxes_per_image.nonempty() + boxes_per_image = boxes_per_image[inds] + logit = logit[inds] + prop = Instances(image_size) + prop.proposal_boxes = boxes_per_image + prop.objectness_logits = logit + proposals.append(prop) + return proposals + + def _run_stage(self, features, proposals, stage): + pool_boxes = [x.proposal_boxes for x in proposals] + box_features = self.box_pooler(features, pool_boxes) + box_features = _ScaleGradient.apply(box_features, + 1.0 / self.num_cascade_stages) + box_features = self.box_head[stage](box_features) + return self.box_predictor[stage](box_features) diff --git a/projects/videochat/models/grit_src/grit/modeling/soft_nms.py b/projects/videochat/models/grit_src/grit/modeling/soft_nms.py new file mode 100644 index 0000000000..1b1e17f642 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/soft_nms.py @@ -0,0 +1,148 @@ +import torch +from detectron2.structures import Boxes, pairwise_iou + + +def soft_nms(boxes, scores, method, gaussian_sigma, linear_threshold, + prune_threshold): + """Performs soft non-maximum suppression algorithm on axis aligned boxes. + + Args: boxes (Tensor[N, 5]): boxes where NMS will be performed. They are + expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format + scores (Tensor[N]): scores for each one of the boxes method (str): one + of ['gaussian', 'linear', 'hard'] see paper for details. users + encouraged not to use "hard", as this is the same nms available + elsewhere in detectron2 gaussian_sigma (float): parameter for Gaussian + penalty function linear_threshold (float): iou threshold for applying + linear decay. Nt from the paper re-used as threshold for standard "hard" + nms prune_threshold (float): boxes with scores below this threshold are + pruned at each iteration. Dramatically reduces computation time. Authors + use values in [10e-4, 10e-2] + + Returns: tuple(Tensor, Tensor): [0]: int64 tensor with the indices of + the elements that have been kept by Soft NMS, sorted in decreasing order + of scores [1]: float tensor with the re-scored scores of the elements + that were kept + """ + return _soft_nms( + Boxes, + pairwise_iou, + boxes, + scores, + method, + gaussian_sigma, + linear_threshold, + prune_threshold, + ) + + +def batched_soft_nms(boxes, scores, idxs, method, gaussian_sigma, + linear_threshold, prune_threshold): + """Performs soft non-maximum suppression in a batched fashion. + + Each index value correspond to a category, and NMS + will not be applied between elements of different categories. + + Args: boxes (Tensor[N, 4]): boxes where NMS will be performed. They are + expected to be in (x1, y1, x2, y2) format scores (Tensor[N]): scores for + each one of the boxes idxs (Tensor[N]): indices of the categories for + each one of the boxes. method (str): one of ['gaussian', 'linear', + 'hard'] see paper for details. users encouraged not to use "hard", + as this is the same nms available elsewhere in detectron2 gaussian_sigma + (float): parameter for Gaussian penalty function linear_threshold ( + float): iou threshold for applying linear decay. Nt from the paper + re-used as threshold for standard "hard" nms prune_threshold (float): + boxes with scores below this threshold are pruned at each iteration. + Dramatically reduces computation time. Authors use values in [10e-4, + 10e-2] Returns: tuple(Tensor, Tensor): [0]: int64 tensor with the + indices of the elements that have been kept by Soft NMS, sorted in + decreasing order of scores [1]: float tensor with the re-scored scores + of the elements that were kept + """ + if boxes.numel() == 0: + return ( + torch.empty((0, ), dtype=torch.int64, device=boxes.device), + torch.empty((0, ), dtype=torch.float32, device=scores.device), + ) + # strategy: in order to perform NMS independently per class. + # we add an offset to all the boxes. The offset is dependent + # only on the class idx, and is large enough so that boxes + # from different classes do not overlap + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + return soft_nms(boxes_for_nms, scores, method, gaussian_sigma, + linear_threshold, prune_threshold) + + +def _soft_nms( + box_class, + pairwise_iou_func, + boxes, + scores, + method, + gaussian_sigma, + linear_threshold, + prune_threshold, +): + """Soft non-max suppression algorithm. + + Implementation of [Soft-NMS -- Improving Object Detection With One Line + of Codec] (https://arxiv.org/abs/1704.04503) + + Args: box_class (cls): one of Box, RotatedBoxes pairwise_iou_func ( + func): one of pairwise_iou, pairwise_iou_rotated boxes (Tensor[N, + ?]): boxes where NMS will be performed if Boxes, in (x1, y1, x2, + y2) format if RotatedBoxes, in (x_ctr, y_ctr, width, height, + angle_degrees) format scores (Tensor[N]): scores for each one of the + boxes method (str): one of ['gaussian', 'linear', 'hard'] see paper for + details. users encouraged not to use "hard", as this is the same nms + available elsewhere in detectron2 gaussian_sigma (float): parameter for + Gaussian penalty function linear_threshold (float): iou threshold for + applying linear decay. Nt from the paper re-used as threshold for + standard "hard" nms prune_threshold (float): boxes with scores below + this threshold are pruned at each iteration. Dramatically reduces + computation time. Authors use values in [10e-4, 10e-2] + + Returns: tuple(Tensor, Tensor): [0]: int64 tensor with the indices of + the elements that have been kept by Soft NMS, sorted in decreasing order + of scores [1]: float tensor with the re-scored scores of the elements + that were kept + """ + boxes = boxes.clone() + scores = scores.clone() + idxs = torch.arange(scores.size()[0]) + + idxs_out = [] + scores_out = [] + + while scores.numel() > 0: + top_idx = torch.argmax(scores) + idxs_out.append(idxs[top_idx].item()) + scores_out.append(scores[top_idx].item()) + + top_box = boxes[top_idx] + ious = pairwise_iou_func( + box_class(top_box.unsqueeze(0)), box_class(boxes))[0] + + if method == 'linear': + decay = torch.ones_like(ious) + decay_mask = ious > linear_threshold + decay[decay_mask] = 1 - ious[decay_mask] + elif method == 'gaussian': + decay = torch.exp(-torch.pow(ious, 2) / gaussian_sigma) + elif method == 'hard': # standard NMS + decay = (ious < linear_threshold).float() + else: + raise NotImplementedError( + '{} soft nms method not implemented.'.format(method)) + + scores *= decay + keep = scores > prune_threshold + keep[top_idx] = False + + boxes = boxes[keep] + scores = scores[keep] + idxs = idxs[keep] + + return torch.tensor(idxs_out).to( + boxes.device), torch.tensor(scores_out).to(scores.device) diff --git a/projects/videochat/models/grit_src/grit/modeling/text/file_utils.py b/projects/videochat/models/grit_src/grit/modeling/text/file_utils.py new file mode 100644 index 0000000000..c5737c9c52 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/text/file_utils.py @@ -0,0 +1,264 @@ +# Utilities for working with the local dataset cache. This file is adapted +# from the AllenNLP library at https://github.com/allenai/allennlp Copyright +# by the AllenNLP authors. + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import fnmatch +import json +import logging +import os +import shutil +import sys +import tempfile +from functools import wraps +from hashlib import sha256 +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from torch.hub import _get_torch_home + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv( + 'TORCH_HOME', + os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) +default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path( + os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + default_cache_path) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """Convert `url` into a hashed filename in a repeatable way. + + If `etag` is specified, append its hash to the url's, delimited by a + period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """Return the url and etag (which may be ``None``) stored for `filename`. + + Raise ``EnvironmentError`` if `filename` or its stored metadata do not + exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError('file {} not found'.format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError('file {} not found'.format(meta_path)) + + with open(meta_path, encoding='utf-8') as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """Given something that might be a URL (or might be a local path), + determine which. + + If it's a URL, download the file and cache it, and return the path to the + cached file. If it's already a local path, make sure the file exists and + then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError('file {} not found'.format(url_or_filename)) + else: + # Something unknown + raise ValueError('unable to parse {} as a URL or as a local path'. + format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError('bad s3 path {}'.format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith('/'): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """Wrapper function for s3 requests in order to create more helpful error + messages.""" + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response['Error']['Code']) == 404: + raise EnvironmentError('file {} not found'.format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource('s3') + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource('s3') + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit='B', total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """Given a URL, look for the corresponding dataset in the local cache. + + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + if sys.version_info[0] == 2 and not isinstance(cache_dir, str): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith('s3://'): + etag = s3_etag(url) + else: + try: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get('ETag') + except EnvironmentError: + etag = None + + if sys.version_info[0] == 2 and etag is not None: + etag = etag.decode('utf-8') + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') + matching_files = list( + filter(lambda s: not s.endswith('.json'), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets + # interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info('%s not found in cache, downloading to %s', url, + temp_file.name) + + # GET file object + if url.startswith('s3://'): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid + # truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to + # the start + temp_file.seek(0) + + logger.info('copying %s to cache at %s', temp_file.name, + cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info('creating metadata file for %s', cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w') as meta_file: + output_string = json.dumps(meta) + meta_file.write(output_string) + + logger.info('removing temp file %s', temp_file.name) + + return cache_path diff --git a/projects/videochat/models/grit_src/grit/modeling/text/load_text_token.py b/projects/videochat/models/grit_src/grit/modeling/text/load_text_token.py new file mode 100644 index 0000000000..fea315cc59 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/text/load_text_token.py @@ -0,0 +1,89 @@ +import torch + + +class LoadTextTokens(object): + + def __init__(self, tokenizer, max_text_len=40, padding='do_not_pad'): + self.tokenizer = tokenizer + self.max_text_len = max_text_len + self.padding = padding + + def descriptions_to_text_tokens(self, target, begin_token): + target_encoding = self.tokenizer( + target, + padding=self.padding, + add_special_tokens=False, + truncation=True, + max_length=self.max_text_len) + + need_predict = [1] * len(target_encoding['input_ids']) + payload = target_encoding['input_ids'] + if len(payload) > self.max_text_len - 2: + payload = payload[-(self.max_text_len - 2):] + need_predict = payload[-(self.max_text_len - 2):] + + input_ids = [begin_token] + payload + [self.tokenizer.sep_token_id] + + need_predict = [0] + need_predict + [1] + data = { + 'text_tokens': torch.tensor(input_ids), + 'text_lengths': len(input_ids), + 'need_predict': torch.tensor(need_predict), + } + + return data + + def __call__(self, object_descriptions, box_features, begin_token): + text_tokens = [] + text_lengths = [] + need_predict = [] + for description in object_descriptions: + tokens = self.descriptions_to_text_tokens(description, begin_token) + text_tokens.append(tokens['text_tokens']) + text_lengths.append(tokens['text_lengths']) + need_predict.append(tokens['need_predict']) + + text_tokens = torch.cat( + self.collate(text_tokens), dim=0).to(box_features.device) + text_lengths = torch.tensor(text_lengths).to(box_features.device) + need_predict = torch.cat( + self.collate(need_predict), dim=0).to(box_features.device) + + assert text_tokens.dim() == 2 and need_predict.dim() == 2 + data = { + 'text_tokens': text_tokens, + 'text_lengths': text_lengths, + 'need_predict': need_predict + } + + return data + + def collate(self, batch): + if all(isinstance(b, torch.Tensor) for b in batch) and len(batch) > 0: + if not all(b.shape == batch[0].shape for b in batch[1:]): + assert all( + len(b.shape) == len(batch[0].shape) for b in batch[1:]) + shape = torch.tensor([b.shape for b in batch]) + max_shape = tuple(shape.max(dim=0)[0].tolist()) + batch2 = [] + for b in batch: + if any(c < m for c, m in zip(b.shape, max_shape)): + b2 = torch.zeros( + max_shape, dtype=b.dtype, device=b.device) + if b.dim() == 1: + b2[:b.shape[0]] = b + elif b.dim() == 2: + b2[:b.shape[0], :b.shape[1]] = b + elif b.dim() == 3: + b2[:b.shape[0], :b.shape[1], :b.shape[2]] = b + else: + raise NotImplementedError + b = b2 + batch2.append(b[None, ...]) + else: + batch2 = [] + for b in batch: + batch2.append(b[None, ...]) + return batch2 + else: + raise NotImplementedError diff --git a/projects/videochat/models/grit_src/grit/modeling/text/modeling_bert.py b/projects/videochat/models/grit_src/grit/modeling/text/modeling_bert.py new file mode 100644 index 0000000000..cd7976f189 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/text/modeling_bert.py @@ -0,0 +1,612 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace +# Inc. team. Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" +# Adapted from https://github.com/huggingface/transformers/blob/main/src +# /transformers/models/bert/modeling_bert.py + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import copy +import json +import logging +import math +import os +from io import open + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn + +from .file_utils import cached_path + +logger = logging.getLogger() + +BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'bert-base-uncased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased' + '-config.json', + 'bert-large-uncased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased' + '-config.json', + 'bert-base-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased' + '-config.json', + 'bert-large-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased' + '-config.json', + 'bert-base-multilingual-uncased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base' + '-multilingual-uncased-config.json', + 'bert-base-multilingual-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base' + '-multilingual-cased-config.json', + 'bert-base-chinese': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese' + '-config.json', + 'bert-base-german-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german' + '-cased-config.json', + 'bert-large-uncased-whole-word-masking': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased' + '-whole-word-masking-config.json', + 'bert-large-cased-whole-word-masking': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased' + '-whole-word-masking-config.json', + 'bert-large-uncased-whole-word-masking-finetuned-squad': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased' + '-whole-word-masking-finetuned-squad-config.json', + 'bert-large-cased-whole-word-masking-finetuned-squad': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased' + '-whole-word-masking-finetuned-squad-config.json', + 'bert-base-cased-finetuned-mrpc': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased' + '-finetuned-mrpc-config.json', +} + + +def qk2attn(query, key, attention_mask, gamma): + query = query / gamma + attention_scores = torch.matmul(query, key.transpose(-1, -2)) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in + # BertModel forward() function) + attention_scores = attention_scores + attention_mask + return attention_scores.softmax(dim=-1) + + +class QK2Attention(nn.Module): + + def forward(self, query, key, attention_mask, gamma): + return qk2attn(query, key, attention_mask, gamma) + + +LayerNormClass = torch.nn.LayerNorm + + +class BertSelfAttention(nn.Module): + + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of ' + 'attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * \ + self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.softmax = nn.Softmax(dim=-1) + self.qk2attn = QK2Attention() + + def transpose_for_scores(self, x): + if torch._C._get_tracing_state(): + # exporter is not smart enough to detect dynamic size for some + # paths + x = x.view(x.shape[0], -1, self.num_attention_heads, + self.attention_head_size) + else: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, + hidden_states, + attention_mask, + head_mask=None, + history_state=None): + if history_state is not None: + x_states = torch.cat([history_state, hidden_states], dim=1) + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(x_states) + mixed_value_layer = self.value(x_states) + else: + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_probs = self.qk2attn(query_layer, key_layer, attention_mask, + math.sqrt(self.attention_head_size)) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if self.output_attentions else ( + context_layer, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm + if not self.pre_norm: + self.LayerNorm = LayerNormClass( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + if not self.pre_norm: + hidden_states = self.LayerNorm(hidden_states + input_tensor) + else: + hidden_states = hidden_states + input_tensor + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config): + super(BertAttention, self).__init__() + self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm + if self.pre_norm: + self.LayerNorm = LayerNormClass( + config.hidden_size, eps=config.layer_norm_eps) + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, + input_tensor, + attention_mask, + head_mask=None, + history_state=None): + if self.pre_norm: + self_outputs = self.self( + self.LayerNorm(input_tensor), attention_mask, head_mask, + self.layerNorm(history_state) + if history_state else history_state) + else: + self_outputs = self.self(input_tensor, attention_mask, head_mask, + history_state) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + assert config.hidden_act == 'gelu', \ + 'Please implement other activation functions' + self.intermediate_act_fn = _gelu_python + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if not self.pre_norm: + self.LayerNorm = LayerNormClass( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + if not self.pre_norm: + hidden_states = self.LayerNorm(hidden_states + input_tensor) + else: + hidden_states = hidden_states + input_tensor + return hidden_states + + +class Mlp(nn.Module): + + def __init__(self, config): + super().__init__() + self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm + self.intermediate = BertIntermediate(config) + if self.pre_norm: + self.LayerNorm = LayerNormClass( + config.hidden_size, eps=config.layer_norm_eps) + self.output = BertOutput(config) + + def forward(self, attention_output): + if not self.pre_norm: + intermediate_output = self.intermediate(attention_output) + else: + intermediate_output = self.intermediate( + self.LayerNorm(attention_output)) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertLayer(nn.Module): + + def __init__(self, config, use_act_checkpoint=True): + super(BertLayer, self).__init__() + self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm + self.use_mlp_wrapper = hasattr( + config, 'use_mlp_wrapper') and config.use_mlp_wrapper + self.attention = BertAttention(config) + self.use_act_checkpoint = use_act_checkpoint + if self.use_mlp_wrapper: + self.mlp = Mlp(config) + else: + self.intermediate = BertIntermediate(config) + if self.pre_norm: + self.LayerNorm = LayerNormClass( + config.hidden_size, eps=config.layer_norm_eps) + self.output = BertOutput(config) + + def forward(self, + hidden_states, + attention_mask, + head_mask=None, + history_state=None): + if self.use_act_checkpoint: + attention_outputs = checkpoint.checkpoint(self.attention, + hidden_states, + attention_mask, + head_mask, history_state) + else: + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask, history_state) + attention_output = attention_outputs[0] + if self.use_mlp_wrapper: + layer_output = self.mlp(attention_output) + else: + if not self.pre_norm: + intermediate_output = self.intermediate(attention_output) + else: + intermediate_output = self.intermediate( + self.LayerNorm(attention_output)) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output, ) + attention_outputs[ + 1:] # add attentions if we output them + return outputs + + +class BertEncoder(nn.Module): + + def __init__(self, config, use_act_checkpoint=True): + super(BertEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([ + BertLayer(config, use_act_checkpoint=use_act_checkpoint) + for _ in range(config.num_hidden_layers) + ]) + self.pre_norm = hasattr(config, 'pre_norm') and config.pre_norm + if self.pre_norm: + self.LayerNorm = LayerNormClass( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, + hidden_states, + attention_mask, + head_mask=None, + encoder_history_states=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + history_state = None \ + if encoder_history_states is None \ + else encoder_history_states[i] + layer_outputs = layer_module( + hidden_states, + attention_mask, + (None if head_mask is None else head_mask[i]), + history_state, + ) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + if self.pre_norm: + hidden_states = self.LayerNorm(hidden_states) + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + return outputs + + +CONFIG_NAME = 'config.json' + + +class PretrainedConfig(object): + """Base class for all configuration classes. + + Handle a few common parameters and methods for loading/downloading/saving + configurations. + """ + pretrained_config_archive_map = {} + + def __init__(self, **kwargs): + self.finetuning_task = kwargs.pop('finetuning_task', None) + self.num_labels = kwargs.pop('num_labels', 2) + self.output_attentions = kwargs.pop('output_attentions', False) + self.output_hidden_states = kwargs.pop('output_hidden_states', False) + self.torchscript = kwargs.pop('torchscript', False) + + def save_pretrained(self, save_directory): + """Save a configuration object to a directory, so that it can be re- + loaded using the `from_pretrained(save_directory)` class method.""" + assert os.path.isdir( + save_directory + ), 'Saving path should be a directory where the model and ' \ + 'configuration can be saved ' + + # If we save using the predefined names, we can load using + # `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" Instantiate a PretrainedConfig from a pre-trained model + configuration. + + Params: **pretrained_model_name_or_path**: either: - a string with + the `shortcut name` of a pre-trained model configuration to load + from cache or download and cache if not already stored in cache ( + e.g. 'bert-base-uncased'). - a path to a `directory` containing a + configuration file saved using the `save_pretrained(save_directory)` + method. - a path or url to a saved configuration `file`. + **cache_dir**: (`optional`) string: Path to a directory in which a + downloaded pre-trained model configuration should be cached if the + standard cache should not be used. **return_unused_kwargs**: ( + `optional`) bool: - If False, then this function returns just the + final configuration object. - If True, then this functions returns a + tuple `(config, unused_kwargs)` where `unused_kwargs` is a + dictionary consisting of the key/value pairs whose keys are not + configuration attributes: ie the part of kwargs which has not been + used to update `config` and is otherwise ignored. **kwargs**: ( + `optional`) dict: Dictionary of key/value pairs with which to update + the configuration object after loading. - The values in kwargs of + any keys which are configuration attributes will be used to override + the loaded values. - Behavior concerning key/value pairs whose keys + are *not* configuration attributes is controlled by the + `return_unused_kwargs` keyword parameter. + + """ + cache_dir = kwargs.pop('cache_dir', None) + return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) + + if pretrained_model_name_or_path in cls.pretrained_config_archive_map: + config_file = cls.pretrained_config_archive_map[ + pretrained_model_name_or_path] + elif os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, + CONFIG_NAME) + else: + config_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_config_file = cached_path( + config_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in \ + cls.pretrained_config_archive_map: + logger.error( + "Couldn't reach server at '{}' to download pretrained " + 'model configuration file. '.format(config_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any " + 'file ' + 'associated to this path or url.'.format( + pretrained_model_name_or_path, + ', '.join(cls.pretrained_config_archive_map.keys()), + config_file)) + return None + if resolved_config_file == config_file: + logger.info('loading configuration file {}'.format(config_file)) + else: + logger.info( + 'loading configuration file {} from cache at {}'.format( + config_file, resolved_config_file)) + + # Load config + config = cls.from_json_file(resolved_config_file) + + # Update config with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + # add img_layer_norm_eps, use_img_layernorm + if 'img_layer_norm_eps' in kwargs: + setattr(config, 'img_layer_norm_eps', kwargs['img_layer_norm_eps']) + to_remove.append('img_layer_norm_eps') + if 'use_img_layernorm' in kwargs: + setattr(config, 'use_img_layernorm', kwargs['use_img_layernorm']) + to_remove.append('use_img_layernorm') + for key in to_remove: + kwargs.pop(key, None) + + logger.info('Model config %s', config) + if return_unused_kwargs: + return config, kwargs + else: + return config + + @classmethod + def from_dict(cls, json_object): + """Constructs a `Config` from a Python dictionary of parameters.""" + config = cls(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, 'r', encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' + + def to_json_file(self, json_file_path): + """Save this instance to a json file.""" + with open(json_file_path, 'w', encoding='utf-8') as writer: + writer.write(self.to_json_string()) + + +class BertConfig(PretrainedConfig): + r""" + :class:`~pytorch_transformers.BertConfig` is the configuration class to + store the configuration of a `BertModel`. + + + Arguments: vocab_size_or_config_json_file: Vocabulary size of + `inputs_ids` in `BertModel`. hidden_size: Size of the encoder layers + and the pooler layer. num_hidden_layers: Number of hidden layers in + the Transformer encoder. num_attention_heads: Number of attention + heads for each attention layer in the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., + feed-forward) layer in the Transformer encoder. hidden_act: The + non-linear activation function (function or string) in the encoder + and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully + connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. max_position_embeddings: The maximum sequence length + that this model might ever be used with. Typically set this to + something large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed + into `BertModel`. initializer_range: The sttdev of the + truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + """ + pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__(self, + vocab_size_or_config_json_file=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + **kwargs): + super(BertConfig, self).__init__(**kwargs) + if isinstance(vocab_size_or_config_json_file, str): + with open( + vocab_size_or_config_json_file, 'r', + encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + else: + raise ValueError( + 'First argument must be either a vocabulary size (int)' + 'or the path to a pretrained model config file (str)') + + +def _gelu_python(x): + + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) diff --git a/projects/videochat/models/grit_src/grit/modeling/text/text_decoder.py b/projects/videochat/models/grit_src/grit/modeling/text/text_decoder.py new file mode 100644 index 0000000000..a3429daff3 --- /dev/null +++ b/projects/videochat/models/grit_src/grit/modeling/text/text_decoder.py @@ -0,0 +1,707 @@ +# Modified by Jialian Wu from +# https://github.com/microsoft/GenerativeImage2Text/blob/main/generativeimage2text/layers/decoder.py +# and https://github.com/kdexd/virtex +import functools +import warnings + +import torch +from torch import nn +from torch.nn import functional as F + + +class TextualHead(nn.Module): + + def __init__(self, visual_feature_size: int, vocab_size: int, + hidden_size: int): + super().__init__() + self.visual_feature_size = visual_feature_size + self.vocab_size = vocab_size + self.hidden_size = hidden_size + + @property + def textual_feature_size(self): + return self.hidden_size + + +class WordAndPositionalEmbedding(nn.Module): + + def __init__( + self, + vocab_size: int, + hidden_size: int, + dropout: float = 0.0, + max_caption_length: int = 30, + padding_idx: int = 0, + ): + super().__init__() + self.vocab_size = vocab_size + self.padding_idx = padding_idx + + # self.words = nn.Embedding + # (vocab_size, hidden_size, padding_idx=padding_idx) + self.words = nn.Embedding(vocab_size, hidden_size) + + # We provide no "padding index" for positional embeddings. We zero out + # the positional embeddings of padded positions as a post-processing. + self.positions = nn.Embedding(max_caption_length, hidden_size) + self.layer_norm = nn.LayerNorm( + hidden_size, eps=1e-8, elementwise_affine=True) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, tokens: torch.Tensor): + position_indices = self._create_position_indices(tokens) + + # shape: (batch_size, max_caption_length, hidden_size) + word_embeddings = self.words(tokens) + position_embeddings = self.positions(position_indices) + + # shape: (batch_size, max_caption_length, hidden_size) + embeddings = self.layer_norm(word_embeddings + position_embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + @functools.lru_cache(maxsize=128) + def _create_position_indices(self, tokens: torch.Tensor): + + # Create position indices of the same size as token indices. + batch_size, max_caption_length = tokens.size() + positions = torch.arange( + max_caption_length, dtype=tokens.dtype, device=tokens.device) + # shape: (batch_size, max_caption_length) + positions = positions.unsqueeze(0).expand(batch_size, + max_caption_length) + return positions + + +class BertEncoderAsDecoder(nn.Module): + + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + + def forward( + self, + tgt, + memory, + tgt_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None, + tgt_bi_valid_mask=None, + encoder_history_states=None, + ): + assert tgt_key_padding_mask is None, 'not supported' + assert tgt_mask.dim() == 2 + assert tgt_mask.shape[0] == tgt_mask.shape[1] + # tgt_mask should always be 0/negative infinity + tgt = tgt.transpose(0, 1) + memory = memory.transpose(0, 1) + + hidden_states = torch.cat((memory, tgt), dim=1) + num_tgt = tgt.shape[1] + num_memory = memory.shape[1] + device = tgt.device + dtype = tgt.dtype + top_left = torch.zeros((num_memory, num_memory), + device=device, + dtype=dtype) + top_right = torch.full( + (num_memory, num_tgt), + float('-inf'), + device=tgt.device, + dtype=dtype, + ) + bottom_left = torch.zeros( + (num_tgt, num_memory), + dtype=dtype, + device=tgt_mask.device, + ) + left = torch.cat((top_left, bottom_left), dim=0) + right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0) + + full_attention_mask = torch.cat((left, right), dim=1)[None, :] + + if memory_key_padding_mask is None: + memory_key_padding_mask = torch.full( + (memory.shape[0], memory.shape[1]), + fill_value=False, + device=device) + # if it is False, it means valid. That is, it is not a padding + assert memory_key_padding_mask.dtype == torch.bool + zero_negative_infinity = torch.zeros_like( + memory_key_padding_mask, dtype=tgt.dtype) + zero_negative_infinity[memory_key_padding_mask] = float('-inf') + full_attention_mask = full_attention_mask.expand( + (memory_key_padding_mask.shape[0], num_memory + num_tgt, + num_memory + num_tgt)) + full_attention_mask = full_attention_mask.clone() + origin_left = full_attention_mask[:, :, :num_memory] + update = zero_negative_infinity[:, None, :] + full_attention_mask[:, :, :num_memory] = origin_left + update + + if tgt_bi_valid_mask is not None: + # verify the correctness + bs = full_attention_mask.shape[0] + # during inference, tgt_bi_valid_mask's length is not changed, but + # num_tgt can be increased + max_valid_target = tgt_bi_valid_mask.shape[1] + mask = tgt_bi_valid_mask[:, None, :].expand( + (bs, num_memory + num_tgt, max_valid_target)) + full_attention_mask[:, :, num_memory:(num_memory + + max_valid_target)][mask] = 0 + + # add axis for multi-head + full_attention_mask = full_attention_mask[:, None, :, :] + + if encoder_history_states is None: + result = self.encoder( + hidden_states=hidden_states, + attention_mask=full_attention_mask, + encoder_history_states=encoder_history_states, + ) + result = list(result) + result[0] = result[0][:, num_memory:].transpose(0, 1) + if self.encoder.output_hidden_states: + return result[0], result[1] + else: + # make it back-compatible + return result[0] + else: + encoder_out = self.encoder( + hidden_states=hidden_states[:, -1:], + attention_mask=full_attention_mask[:, :, -1:], + encoder_history_states=encoder_history_states, + ) + result = encoder_out[0].transpose(0, 1) + if self.encoder.output_hidden_states: + return result, encoder_out[1] + else: + return result + + +def create_transformer( + decoder_type, + norm_type, + textual_feature_size, + attention_heads, + feedforward_size, + dropout, + num_layers, + output_hidden_states=False, + use_mlp_wrapper=None, + use_act_checkpoint=True, +): + assert norm_type in ['post', 'pre'] + if decoder_type is None: + LayerClass = ( + nn.TransformerDecoderLayer + if norm_type == 'post' else PreNormTransformerDecoderLayer) + _layer = LayerClass( + textual_feature_size, + attention_heads, + dim_feedforward=feedforward_size, + dropout=dropout, + activation='gelu', + ) + return nn.TransformerDecoder(_layer, num_layers) + elif decoder_type == 'bert_en': + from .modeling_bert import BertConfig, BertEncoder + config = BertConfig( + vocab_size_or_config_json_file=30522, + hidden_size=textual_feature_size, + num_hidden_layers=num_layers, + num_attention_heads=attention_heads, + intermediate_size=feedforward_size, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + layer_norm_eps=1e-12, + ) + config.pre_norm = (norm_type == 'pre') + config.use_mlp_wrapper = use_mlp_wrapper + config.output_hidden_states = output_hidden_states + encoder = BertEncoder(config, use_act_checkpoint=use_act_checkpoint) + return BertEncoderAsDecoder(encoder) + + +class PreNormTransformerDecoderLayer(nn.TransformerDecoderLayer): + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + # fmt: off + # We use the members (modules) from super-class, just the order of + # operations is changed here. First layernorm, then attention. + tgt2 = self.norm1(tgt) + tgt2, _ = self.self_attn( + tgt2, + tgt2, + tgt2, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + + # Layernorm first, then decoder attention. + tgt2 = self.norm2(tgt) + tgt2, _ = self.multihead_attn( + tgt2, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout2(tgt2) + + # Layernorm first, then transformation through feedforward network. + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class TransformerDecoderTextualHead(TextualHead): + + def __init__( + self, + object_feature_size: int, + vocab_size: int, + hidden_size: int, + num_layers: int, + attention_heads: int, + feedforward_size: int, + dropout: float = 0.1, + norm_type: str = 'post', + mask_future_positions: bool = True, + max_caption_length: int = 1024, + padding_idx: int = 0, + decoder_type=None, + not_tie_weight=None, + output_hidden_states=None, + use_mlp_wrapper=None, + use_act_checkpoint=True, + ): + super().__init__(object_feature_size, vocab_size, hidden_size) + self.num_layers = num_layers + self.attention_heads = attention_heads + self.feedforward_size = feedforward_size + self.dropout = dropout + assert mask_future_positions + self.padding_idx = padding_idx + + self.object_feature_projection = nn.Sequential( + nn.Linear(object_feature_size, self.textual_feature_size), + nn.LayerNorm(self.textual_feature_size)) + + self.embedding = WordAndPositionalEmbedding( + self.vocab_size, + self.textual_feature_size, + dropout=dropout, + max_caption_length=max_caption_length, + padding_idx=padding_idx, + ) + self.transformer = create_transformer( + decoder_type=decoder_type, + norm_type=norm_type, + textual_feature_size=self.textual_feature_size, + attention_heads=self.attention_heads, + feedforward_size=self.feedforward_size, + dropout=dropout, + num_layers=self.num_layers, + output_hidden_states=output_hidden_states, + use_mlp_wrapper=use_mlp_wrapper, + use_act_checkpoint=use_act_checkpoint, + ) + self.apply(self._init_weights) + + # Create an output linear layer and tie the input and output word + # embeddings to reduce parametejs. + self.output = nn.Linear(self.textual_feature_size, vocab_size) + if not not_tie_weight: + self.output.weight = self.embedding.words.weight + + @staticmethod + def _init_weights(module): + """Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.MultiheadAttention): + module.in_proj_weight.data.normal_(mean=0.0, std=0.02) + module.out_proj.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def forward( + self, + hidden_states, + text_tokens, + ): + projected_object_features = self.object_feature_projection( + hidden_states) if hidden_states is not None else None + batch_size, max_text_length = text_tokens.size() + text_embeddings = self.embedding(text_tokens) + + # An additive mask for masking the future (one direction). + uni_mask_zero_neg = self._generate_future_mask(max_text_length, + text_embeddings.dtype, + text_embeddings.device) + + # We transpose the first two dimensions of tokens embeddings and visual + # features, as required by decoder. + text_embeddings = text_embeddings.transpose(0, 1) + + projected_object_features = projected_object_features.transpose(0, 1) + + # if transformer here is the pytorch/decoder, there is no chance, the + # output is always tensor + trans_out = self.transformer( + text_embeddings, + projected_object_features, + tgt_mask=uni_mask_zero_neg, + ) + if isinstance(trans_out, tuple): + textual_features = trans_out[0] + else: + assert isinstance(trans_out, torch.Tensor) + textual_features = trans_out + # Undo the transpose and bring batch to dim 0. + # shape: (batch_size, max_caption_length, hidden_size) + textual_features = textual_features.transpose(0, 1) + + # shape: (batch_size, max_caption_length, vocab_size) + output_logits = self.output(textual_features) + if isinstance(trans_out, tuple): + return output_logits, trans_out[1] + else: + return output_logits + + def _generate_future_mask(self, size: int, dtype: torch.dtype, + device: torch.device): + # Default mask is for forward direction. Flip for backward direction. + mask = torch.triu( + torch.ones(size, size, device=device, dtype=dtype), diagonal=1) + mask = mask.masked_fill(mask == 1, float('-inf')) + return mask + + +class AutoRegressiveBeamSearch(object): + + def __init__( + self, + end_token_id: int, + max_steps: int = 50, + beam_size: int = 5, + objectdet=True, + per_node_beam_size: int = 2, + ): + self._eos_index = end_token_id + self.max_steps = max_steps + self.beam_size = beam_size + self.objectdet = objectdet + self.per_node_beam_size = per_node_beam_size or beam_size + + def search(self, begin_tokens, step): + if self.beam_size > 1 and self.objectdet: + only_return_best = False + else: + only_return_best = True + + batch_size = begin_tokens.size()[0] + + predictions = begin_tokens.unsqueeze(1).expand( + (batch_size, self.beam_size, begin_tokens.shape[-1])) + # Calculate the first timestep. This is done outside the main loop + # because we are going from a single decoder input (the output from + # the encoder) to the top `beam_size` decoder outputs. On the other + # hand, within the main loop we are going from the `beam_size` + # elements of the beam to `beam_size`^2 candidates from which we + # will select the top `beam_size` elements for the next iteration. + # shape: (batch_size, num_classes) + start_class_logits = step(begin_tokens) + + # Convert logits to logprobs. + # shape: (batch_size * beam_size, vocab_size) + start_class_logprobs = F.log_softmax(start_class_logits, dim=1) + + num_classes = start_class_logprobs.size()[1] + + # shape: (batch_size, beam_size), (batch_size, beam_size) + start_top_logprobs, start_predicted_classes = \ + start_class_logprobs.topk(self.beam_size) + + if (self.beam_size == 1 + and (start_predicted_classes == self._eos_index).all()): + warnings.warn( + 'Empty object description predicted. You may want to ' + 'increase beam ' + 'size or ensure your step function is working properly.', + RuntimeWarning, + ) + if only_return_best: + return start_predicted_classes, start_top_logprobs + else: + return start_predicted_classes.unsqueeze( + -1), start_top_logprobs + + # The log probs for the last time step. + # shape: (batch_size, beam_size) + last_logprobs = start_top_logprobs + + # shape: (batch_size, beam_size, sequence_length) + predictions = torch.cat( + [predictions, start_predicted_classes.unsqueeze(-1)], dim=-1) + + # Log probability tensor that mandates that the end token is selected. + # shape: (batch_size * beam_size, num_classes) + logprobs_after_end = start_class_logprobs.new_full( + (batch_size * self.beam_size, num_classes), float('-inf')) + logprobs_after_end[:, self._eos_index] = 0.0 + + logits_after_end = start_class_logprobs.new_full( + (batch_size * self.beam_size, num_classes), float('-inf')) + logits_after_end[:, self._eos_index] = 0 + + while predictions.shape[-1] < self.max_steps: + # shape: (batch_size * beam_size,) + last_predictions = predictions[:, :, -1].reshape(batch_size * + self.beam_size) + + # If every predicted token from the last step is `self._eos_index`, + # then we can stop early. + if (last_predictions == self._eos_index).all(): + break + + predictions_so_far = predictions.view(batch_size * self.beam_size, + -1) + # shape: (batch_size * beam_size, num_classes) + class_logits = step(predictions_so_far) + + # Set logprobs of last predicted tokens as high negative value + # to avoid repetition in description. + class_logits = class_logits.scatter( + 1, predictions_so_far[:, -1].view((-1, 1)), -10000) + + # shape: (batch_size * beam_size, num_classes) + last_predictions_expanded = last_predictions.unsqueeze(-1).expand( + batch_size * self.beam_size, num_classes) + + # Here we are finding any beams where we predicted the end token in + # the previous timestep and replacing the distribution with a + # one-hot distribution, forcing the beam to predict the end token + # this timestep as well. + class_logits = torch.where( + last_predictions_expanded == self._eos_index, + logits_after_end, + class_logits, + ) + + # Convert logits to logprobs. + # shape: (batch_size * beam_size, vocab_size) + class_logprobs = F.log_softmax(class_logits, dim=1) + + # shape (both): (batch_size * beam_size, per_node_beam_size) + top_logprobs, predicted_classes = class_logprobs.topk( + self.per_node_beam_size) + + # Here we expand the last log probs to `(batch_size * beam_size, + # per_node_beam_size)` so that we can add them to the current log + # probs for this timestep. This lets us maintain the log + # probability of each element on the beam. + # shape: (batch_size * beam_size, per_node_beam_size) + expanded_last_logprobs = ( + last_logprobs.unsqueeze(2).expand( + batch_size, self.beam_size, + self.per_node_beam_size).reshape( + batch_size * self.beam_size, self.per_node_beam_size)) + # shape: (batch_size * beam_size, per_node_beam_size) + summed_top_logprobs = top_logprobs + expanded_last_logprobs + + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_summed = summed_top_logprobs.reshape( + batch_size, self.beam_size * self.per_node_beam_size) + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_predicted_classes = predicted_classes.reshape( + batch_size, self.beam_size * self.per_node_beam_size) + # Append the predictions to the current beam. + reshaped_beam = ( + predictions.view(batch_size * self.beam_size, 1, -1).repeat( + 1, self.per_node_beam_size, + 1).reshape(batch_size, + self.beam_size * self.per_node_beam_size, -1)) + # batch_size, (beam_size * per_node_beach_size), #token + reshaped_beam = torch.cat( + [reshaped_beam, + reshaped_predicted_classes.unsqueeze(-1)], + dim=-1) + + # Keep only the top `beam_size` beam indices. + # shape: (batch_size, beam_size), (batch_size, beam_size) + restricted_beam_logprobs, restricted_beam_indices = \ + reshaped_summed.topk(self.beam_size) + predictions = reshaped_beam.gather( + 1, + restricted_beam_indices.unsqueeze(-1).repeat( + 1, 1, reshaped_beam.shape[-1])) + + # shape: (batch_size, beam_size) + last_logprobs = restricted_beam_logprobs + + if not torch.isfinite(last_logprobs).all(): + warnings.warn( + 'Infinite log probs encountered. Some final descriptions may ' + 'not ' + 'make sense. This can happen when the beam size is larger than' + ' the number of valid (non-zero probability) transitions that ' + 'the step function produces.', + RuntimeWarning, + ) + + # Optionally select best beam and its logprobs. + if only_return_best: + # shape: (batch_size, sequence_length) + predictions = predictions[:, 0, :] + last_logprobs = last_logprobs[:, 0] + num_valid = (predictions != self._eos_index).sum(dim=-1) + num_valid += (predictions == self._eos_index).sum(dim=-1) > 0 + num_valid = num_valid - begin_tokens.shape[1] + num_valid = num_valid.clip(min=1) + + last_logprobs = last_logprobs / num_valid + + return predictions, last_logprobs + + +class GRiTTextDecoder(nn.Module): + + def __init__( + self, + transformer, + begin_token_id=101, + beamsearch_decode=None, + loss_type=None, + tokenizer=None, + ): + super().__init__() + self.textual = transformer + self.padding_idx = self.textual.padding_idx + + self.begin_token_id = begin_token_id + self.beamsearch_decode = beamsearch_decode + self.tokenizer = tokenizer + + if loss_type is None: + self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) + elif loss_type == 'smooth': + self.loss = SmoothLabelCrossEntropyLoss( + ignore_index=self.padding_idx) + else: + raise NotImplementedError(loss_type) + + def forward(self, batch): + object_features = batch['object_features'] + + if self.training: + caption_token_input = batch['text_tokens'] + + output_logits = self.textual( + object_features, + caption_token_input, + ) + + if 'need_predict' in batch: + # in place should also be good, but we do not choose that for + # safety as we may use it in prediction results in future + target = batch['text_tokens'].clone() + target[batch['need_predict'] == 0] = self.padding_idx + else: + target = batch['text_tokens'] + + feat = output_logits[:, :-1].contiguous() + target = target[:, 1:].contiguous() + feat = feat.view(-1, self.textual.vocab_size) + target = target.view(-1) + + valid_mask = target != self.padding_idx + target = target[valid_mask] + feat = feat[valid_mask] + loss = self.loss(feat, target) + + return loss + else: + output_dict = self.infer(object_features) + return output_dict + + def infer(self, object_features): + batch_size = object_features.size(0) + begin_tokens = object_features.new_full((batch_size, 1), + self.begin_token_id).long() + + decoding_step = functools.partial(self.decoding_step, object_features) + + object_description_tokens, logprobs = self.beamsearch_decode.search( + begin_tokens, decoding_step) + + output_dict = { + 'predictions': object_description_tokens, + 'logprobs': logprobs, + } + + return output_dict + + def decoding_step(self, object_features, partial_text): + batch_size = object_features.shape[0] + beam_size = int(partial_text.size(0) / batch_size) + if beam_size > 1: + batch_size, num_token, channels = object_features.size() + object_features = object_features.unsqueeze(1).repeat( + 1, beam_size, 1, 1) + object_features = object_features.view(batch_size * beam_size, + num_token, channels) + + text_lengths = torch.ones_like(partial_text) + if len(text_lengths.size()) != 2: + partial_text = partial_text.unsqueeze(1) + + # shape: (batch_size * beam_size, partial_caption_length, vocab_size) + logits = self.textual( + object_features, + partial_text, + ) + + return logits[:, -1, :].float() + + +class SmoothLabelCrossEntropyLoss(nn.Module): + + def __init__(self, eps=0.1, log_prefix='', ignore_index=None): + super().__init__() + self.eps = eps + self.log_soft = nn.LogSoftmax(dim=1) + self.kl = nn.KLDivLoss(reduction='none') + + self.iter = 0 + self.max_loss = 0 + self.min_loss = 0 + self.log_prefix = log_prefix + self.ignore_index = ignore_index + + def forward(self, feature, target): + feature = feature.float() + if self.ignore_index is not None: + valid_mask = target != self.ignore_index + target = target[valid_mask] + feature = feature[valid_mask] + assert target.numel() > 0 + self.iter += 1 + eps = self.eps + n_class = feature.size(1) + one_hot = torch.zeros_like(feature).scatter(1, target.view(-1, 1), 1) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = self.log_soft(feature) + loss = self.kl(log_prb, one_hot) + return loss.sum(dim=1).mean() diff --git a/projects/videochat/models/grit_src/grit/predictor.py b/projects/videochat/models/grit_src/grit/predictor.py new file mode 100644 index 0000000000..0e9f35000d --- /dev/null +++ b/projects/videochat/models/grit_src/grit/predictor.py @@ -0,0 +1,101 @@ +# Copyright (c) Facebook, Inc. and its affiliates. Modified by Jialian Wu +# from https://github.com/facebookresearch/detectron2/blob/main/detectron2 +# /utils/visualizer.py +import torch +from detectron2.engine.defaults import DefaultPredictor +from detectron2.utils.visualizer import ColorMode, Visualizer + + +class BatchDefaultPredictor(DefaultPredictor): + + def __call__(self, original_images): + """ + Args: + original_image (np.ndarray): an image of shape (H, W, C) + (in BGR order). + + Returns: + predictions (dict): + the output of the model for one image only. + See :doc:`/tutorials/models` for details about the format. + """ + with torch.no_grad( + ): # https://github.com/sphinx-doc/sphinx/issues/4258 + # Apply pre-processing to image. + height, width = original_images.shape[1:3] + batch_inputs = [] + for original_image in original_images: + image = self.aug.get_transform(original_image).apply_image( + original_image) + image = torch.as_tensor( + image.astype('float32').transpose(2, 0, 1)) + + inputs = {'image': image, 'height': height, 'width': width} + batch_inputs.append(inputs) + predictions = self.model(batch_inputs)[0] + return predictions + + +class Visualizer_GRiT(Visualizer): + + def __init__(self, image, instance_mode=None): + super().__init__(image, instance_mode=instance_mode) + + def draw_instance_predictions(self, predictions): + boxes = predictions.pred_boxes if predictions.has( + 'pred_boxes') else None + # scores = predictions.scores if predictions.has('scores') else None + classes = predictions.pred_classes.tolist() if predictions.has( + 'pred_classes') else None + object_description = predictions.pred_object_descriptions.data + # uncomment to output scores in visualized images + # object_description = [c + '|' + str(round(s.item(), 1)) + # for c, s in zip(object_description, scores)] + + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get( + 'thing_colors'): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in classes + ] + alpha = 0.8 + else: + colors = None + alpha = 0.5 + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image( + self._create_grayscale_image(( + predictions.pred_masks.any(dim=0) > 0 + ).numpy() if predictions.has('pred_masks') else None)) + alpha = 0.3 + + self.overlay_instances( + masks=None, + boxes=boxes, + labels=object_description, + keypoints=None, + assigned_colors=colors, + alpha=alpha, + ) + return self.output + + +class VisualizationDemo(object): + + def __init__(self, cfg, instance_mode=ColorMode.IMAGE): + self.cpu_device = torch.device('cpu') + self.instance_mode = instance_mode + + self.predictor = DefaultPredictor(cfg) + + def run_on_image(self, image): + predictions = self.predictor(image) + # Convert image from OpenCV BGR format to Matplotlib RGB format. + image = image[:, :, ::-1] + visualizer = Visualizer_GRiT(image, instance_mode=self.instance_mode) + instances = predictions['instances'].to(self.cpu_device) + vis_output = visualizer.draw_instance_predictions( + predictions=instances) + + return predictions, vis_output diff --git a/projects/videochat/models/grit_src/image_dense_captions.py b/projects/videochat/models/grit_src/image_dense_captions.py new file mode 100644 index 0000000000..6133cc89e1 --- /dev/null +++ b/projects/videochat/models/grit_src/image_dense_captions.py @@ -0,0 +1,84 @@ +from detectron2.config import get_cfg +from detectron2.data.detection_utils import read_image +from models.grit_src.grit.config import add_grit_config +from models.grit_src.grit.predictor import VisualizationDemo + +from projects.videochat.models.centernet.config import add_centernet_config + + +def dense_pred_to_caption(predictions): + boxes = predictions['instances'].pred_boxes if predictions[ + 'instances'].has('pred_boxes') else None + object_description = predictions['instances'].pred_object_descriptions.data + new_caption = '' + for i in range(len(object_description)): + new_caption += (object_description[i] + ': ' + str( + [int(a) + for a in boxes[i].tensor.cpu().detach().numpy()[0]])) + '; ' + return new_caption + + +def dense_pred_to_caption_only_name(predictions): + object_description = predictions['instances'].pred_object_descriptions.data + new_caption = ','.join(object_description) + del predictions + return new_caption + + +def setup_cfg(args): + cfg = get_cfg() + if args['cpu']: + cfg.MODEL.DEVICE = 'cpu' + add_centernet_config(cfg) + add_grit_config(cfg) + cfg.merge_from_file(args['config_file']) + cfg.merge_from_list(args['opts']) + # Set score_threshold for builtin models + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args['confidence_threshold'] + cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args[ + 'confidence_threshold'] + if args['test_task']: + cfg.MODEL.TEST_TASK = args['test_task'] + cfg.MODEL.BEAM_SIZE = 1 + cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False + cfg.USE_ACT_CHECKPOINT = False + cfg.freeze() + return cfg + + +def get_parser(device): + arg_dict = { + 'config_file': + 'models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml', + 'cpu': + False, + 'confidence_threshold': + 0.5, + 'test_task': + 'DenseCap', + 'opts': [ + 'MODEL.WEIGHTS', 'pretrained_models' + '/grit_b_densecap_objectdet.pth' + ] + } + if device.type == 'cpu': + arg_dict['cpu'] = True + return arg_dict + + +def image_caption_api(image_src, device): + args2 = get_parser(device) + cfg = setup_cfg(args2) + demo = VisualizationDemo(cfg) + if image_src: + img = read_image(image_src, format='BGR') + predictions, visualized_output = demo.run_on_image(img) + new_caption = dense_pred_to_caption(predictions) + return new_caption + + +def init_demo(device): + args2 = get_parser(device) + cfg = setup_cfg(args2) + demo = VisualizationDemo(cfg) + return demo diff --git a/projects/videochat/models/intern_action.py b/projects/videochat/models/intern_action.py new file mode 100644 index 0000000000..b440bc89eb --- /dev/null +++ b/projects/videochat/models/intern_action.py @@ -0,0 +1,656 @@ +#!/usr/bin/env python +import os +from collections import OrderedDict + +import torch +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath +from torch import nn +from torch.nn import MultiheadAttention + +MODEL_PATH = '../' +_MODELS = { + 'ViT-B/16': os.path.join(MODEL_PATH, 'vit_b16.pth'), + 'ViT-L/14': os.path.join(MODEL_PATH, 'vit_l14.pth'), + 'ViT-L/14_336': os.path.join(MODEL_PATH, 'vit_l14_336.pth'), +} + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class Local_MHRA(nn.Module): + + def __init__(self, d_model, dw_reduction=1.5, pos_kernel_size=3): + super().__init__() + + padding = pos_kernel_size // 2 + re_d_model = int(d_model // dw_reduction) + self.pos_embed = nn.Sequential( + nn.BatchNorm3d(d_model), + nn.Conv3d(d_model, re_d_model, kernel_size=1, stride=1, padding=0), + nn.Conv3d( + re_d_model, + re_d_model, + kernel_size=(pos_kernel_size, 1, 1), + stride=(1, 1, 1), + padding=(padding, 0, 0), + groups=re_d_model), + nn.Conv3d(re_d_model, d_model, kernel_size=1, stride=1, padding=0), + ) + + # init zero + print('Init zero for Conv in pos_emb') + nn.init.constant_(self.pos_embed[3].weight, 0) + nn.init.constant_(self.pos_embed[3].bias, 0) + + def forward(self, x): + return self.pos_embed(x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model, + n_head, + attn_mask=None, + drop_path=0.0, + dw_reduction=1.5, + no_lmhra=False, + double_lmhra=True): + super().__init__() + + self.n_head = n_head + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + print(f'Drop path rate: {drop_path}') + + self.no_lmhra = no_lmhra + self.double_lmhra = double_lmhra + print(f'No L_MHRA: {no_lmhra}') + print(f'Double L_MHRA: {double_lmhra}') + if not no_lmhra: + self.lmhra1 = Local_MHRA(d_model, dw_reduction=dw_reduction) + if double_lmhra: + self.lmhra2 = Local_MHRA(d_model, dw_reduction=dw_reduction) + + # spatial + self.attn = MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x, T=8, use_checkpoint=False): + # x: 1+HW, NT, C + if not self.no_lmhra: + # Local MHRA + tmp_x = x[1:, :, :] + L, NT, C = tmp_x.shape + N = NT // T + H = W = int(L**0.5) + tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, + 1).contiguous() + tmp_x = tmp_x + self.drop_path(self.lmhra1(tmp_x)) + tmp_x = tmp_x.view(N, C, T, + L).permute(3, 0, 2, + 1).contiguous().view(L, NT, C) + x = torch.cat([x[:1, :, :], tmp_x], dim=0) + # MHSA + if use_checkpoint: + attn_out = checkpoint.checkpoint(self.attention, self.ln_1(x)) + x = x + self.drop_path(attn_out) + else: + x = x + self.drop_path(self.attention(self.ln_1(x))) + # Local MHRA + if not self.no_lmhra and self.double_lmhra: + tmp_x = x[1:, :, :] + tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, + 1).contiguous() + tmp_x = tmp_x + self.drop_path(self.lmhra2(tmp_x)) + tmp_x = tmp_x.view(N, C, T, + L).permute(3, 0, 2, + 1).contiguous().view(L, NT, C) + x = torch.cat([x[:1, :, :], tmp_x], dim=0) + # FFN + if use_checkpoint: + mlp_out = checkpoint.checkpoint(self.mlp, self.ln_2(x)) + x = x + self.drop_path(mlp_out) + else: + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Extractor(nn.Module): + + def __init__( + self, + d_model, + n_head, + attn_mask=None, + mlp_factor=4.0, + dropout=0.0, + drop_path=0.0, + ): + super().__init__() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + print(f'Drop path rate: {drop_path}') + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = nn.LayerNorm(d_model) + d_mlp = round(mlp_factor * d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_mlp)), + ('gelu', QuickGELU()), + ('dropout', nn.Dropout(dropout)), + ('c_proj', nn.Linear(d_mlp, d_model))])) + self.ln_2 = nn.LayerNorm(d_model) + self.ln_3 = nn.LayerNorm(d_model) + self.attn_mask = attn_mask + + # zero init + nn.init.xavier_uniform_(self.attn.in_proj_weight) + nn.init.constant_(self.attn.out_proj.weight, 0.) + nn.init.constant_(self.attn.out_proj.bias, 0.) + nn.init.xavier_uniform_(self.mlp[0].weight) + nn.init.constant_(self.mlp[-1].weight, 0.) + nn.init.constant_(self.mlp[-1].bias, 0.) + + def attention(self, x, y): + d_model = self.ln_1.weight.size(0) + q = (x @ self.attn.in_proj_weight[:d_model].T + ) + self.attn.in_proj_bias[:d_model] + + k = (y @ self.attn.in_proj_weight[d_model:-d_model].T + ) + self.attn.in_proj_bias[d_model:-d_model] + v = (y @ self.attn.in_proj_weight[-d_model:].T + ) + self.attn.in_proj_bias[-d_model:] + Tx, Ty, N = q.size(0), k.size(0), q.size(1) + q = q.view(Tx, N, self.attn.num_heads, + self.attn.head_dim).permute(1, 2, 0, 3) + k = k.view(Ty, N, self.attn.num_heads, + self.attn.head_dim).permute(1, 2, 0, 3) + v = v.view(Ty, N, self.attn.num_heads, + self.attn.head_dim).permute(1, 2, 0, 3) + aff = (q @ k.transpose(-2, -1) / (self.attn.head_dim**0.5)) + + aff = aff.softmax(dim=-1) + out = aff @ v + out = out.permute(2, 0, 1, 3).flatten(2) + out = self.attn.out_proj(out) + return out + + def forward(self, x, y): + x = x + self.drop_path(self.attention(self.ln_1(x), self.ln_3(y))) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + + def __init__( + self, + width, + layers, + heads, + attn_mask=None, + backbone_drop_path_rate=0., + use_checkpoint=False, + checkpoint_num=[0], + t_size=8, + dw_reduction=2, + no_lmhra=False, + double_lmhra=True, + return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + n_layers=12, + n_dim=768, + n_head=12, + mlp_factor=4.0, + drop_path_rate=0., + mlp_dropout=[ + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 + ], + cls_dropout=0.5, + num_classes=400, + ): + super().__init__() + self.T = t_size + self.return_list = return_list + # backbone + b_dpr = [ + x.item() + for x in torch.linspace(0, backbone_drop_path_rate, layers) + ] + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + attn_mask, + drop_path=b_dpr[i], + dw_reduction=dw_reduction, + no_lmhra=no_lmhra, + double_lmhra=double_lmhra, + ) for i in range(layers) + ]) + # checkpoint + self.use_checkpoint = use_checkpoint + self.checkpoint_num = checkpoint_num + self.n_layers = n_layers + print(f'Use checkpoint: {self.use_checkpoint}') + print(f'Checkpoint number: {self.checkpoint_num}') + + # global block + assert n_layers == len(return_list) + if n_layers > 0: + self.temporal_cls_token = nn.Parameter(torch.zeros(1, 1, n_dim)) + self.dpe = nn.ModuleList([ + nn.Conv3d( + n_dim, + n_dim, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=n_dim) for i in range(n_layers) + ]) + for m in self.dpe: + nn.init.constant_(m.bias, 0.) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, n_layers) + ] + self.dec = nn.ModuleList([ + Extractor( + n_dim, + n_head, + mlp_factor=mlp_factor, + dropout=mlp_dropout[i], + drop_path=dpr[i], + ) for i in range(n_layers) + ]) + self.balance = nn.Parameter(torch.zeros((n_dim))) + self.sigmoid = nn.Sigmoid() + # projection + self.proj = nn.Sequential( + nn.LayerNorm(n_dim), + nn.Dropout(cls_dropout), + nn.Linear(n_dim, num_classes), + ) + + def forward(self, x): + T_down = self.T + L, NT, C = x.shape + N = NT // T_down + H = W = int((L - 1)**0.5) + + if self.n_layers > 0: + cls_token = self.temporal_cls_token.repeat(1, N, 1) + + j = -1 + for i, resblock in enumerate(self.resblocks): + if self.use_checkpoint and i < self.checkpoint_num[0]: + x = resblock(x, self.T, use_checkpoint=True) + else: + x = resblock(x, T_down) + if i in self.return_list: + j += 1 + tmp_x = x.clone() + tmp_x = tmp_x.view(L, N, T_down, C) + # dpe + _, tmp_feats = tmp_x[:1], tmp_x[1:] + tmp_feats = tmp_feats.permute(1, 3, 2, + 0).reshape(N, C, T_down, H, W) + tmp_feats = self.dpe[j](tmp_feats).view( + N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous() + tmp_x[1:] = tmp_x[1:] + tmp_feats + # global block + tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1) # T * L, N, C + cls_token = self.dec[j](cls_token, tmp_x) + + if self.n_layers > 0: + weight = self.sigmoid(self.balance) + residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C + return self.proj((1 - weight) * cls_token[0, :, :] + + weight * residual) + else: + residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C + return self.proj(residual) + + +class VisionTransformer(nn.Module): + + def __init__( + self, + # backbone + input_resolution, + patch_size, + width, + layers, + heads, + output_dim, + backbone_drop_path_rate=0., + use_checkpoint=False, + checkpoint_num=[0], + t_size=8, + kernel_size=3, + dw_reduction=1.5, + temporal_downsample=True, + no_lmhra=- False, + double_lmhra=True, + # global block + return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + n_layers=12, + n_dim=768, + n_head=12, + mlp_factor=4.0, + drop_path_rate=0., + mlp_dropout=[ + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 + ], + cls_dropout=0.5, + num_classes=400, + ): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + padding = (kernel_size - 1) // 2 + if temporal_downsample: + self.conv1 = nn.Conv3d( + 3, + width, (kernel_size, patch_size, patch_size), + (2, patch_size, patch_size), (padding, 0, 0), + bias=False) + t_size = t_size // 2 + else: + self.conv1 = nn.Conv3d( + 3, + width, (1, patch_size, patch_size), + (1, patch_size, patch_size), (0, 0, 0), + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer( + width, + layers, + heads, + dw_reduction=dw_reduction, + backbone_drop_path_rate=backbone_drop_path_rate, + use_checkpoint=use_checkpoint, + checkpoint_num=checkpoint_num, + t_size=t_size, + no_lmhra=no_lmhra, + double_lmhra=double_lmhra, + return_list=return_list, + n_layers=n_layers, + n_dim=n_dim, + n_head=n_head, + mlp_factor=mlp_factor, + drop_path_rate=drop_path_rate, + mlp_dropout=mlp_dropout, + cls_dropout=cls_dropout, + num_classes=num_classes, + ) + + def forward(self, x): + x = self.conv1(x) # shape = [*, width, grid, grid] + N, C, T, H, W = x.shape + x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C) + + x = torch.cat([ + self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + out = self.transformer(x) + return out + + +def inflate_weight(weight_2d, time_dim, center=True): + print(f'Init center: {center}') + if center: + weight_3d = torch.zeros(*weight_2d.shape) + weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) + middle_idx = time_dim // 2 + weight_3d[:, :, middle_idx, :, :] = weight_2d + else: + weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) + weight_3d = weight_3d / time_dim + return weight_3d + + +def load_state_dict(model, state_dict): + state_dict_3d = model.state_dict() + for k in state_dict.keys(): + if state_dict[k].shape != state_dict_3d[k].shape: + if len(state_dict_3d[k].shape) <= 2: + print(f'Ignore: {k}') + continue + print(f'Inflate: ' + f'{k}, {state_dict[k].shape} => {state_dict_3d[k].shape}') + time_dim = state_dict_3d[k].shape[2] + state_dict[k] = inflate_weight(state_dict[k], time_dim) + model.load_state_dict(state_dict, strict=False) + + +def intern_action_b16( + pretrained=True, + use_checkpoint=False, + checkpoint_num=[0], + t_size=16, + dw_reduction=1.5, + backbone_drop_path_rate=0., + temporal_downsample=True, + no_lmhra=False, + double_lmhra=True, + return_list=[8, 9, 10, 11], + n_layers=4, + n_dim=768, + n_head=12, + mlp_factor=4.0, + drop_path_rate=0., + mlp_dropout=[0.5, 0.5, 0.5, 0.5], + cls_dropout=0.5, + num_classes=400, +): + model = VisionTransformer( + input_resolution=224, + patch_size=16, + width=768, + layers=12, + heads=12, + output_dim=512, + use_checkpoint=use_checkpoint, + checkpoint_num=checkpoint_num, + t_size=t_size, + dw_reduction=dw_reduction, + backbone_drop_path_rate=backbone_drop_path_rate, + temporal_downsample=temporal_downsample, + no_lmhra=no_lmhra, + double_lmhra=double_lmhra, + return_list=return_list, + n_layers=n_layers, + n_dim=n_dim, + n_head=n_head, + mlp_factor=mlp_factor, + drop_path_rate=drop_path_rate, + mlp_dropout=mlp_dropout, + cls_dropout=cls_dropout, + num_classes=num_classes, + ) + + if pretrained: + print('load pretrained weights') + state_dict = torch.load(_MODELS['ViT-B/16'], map_location='cpu') + load_state_dict(model, state_dict) + return model.eval() + + +def intern_action_l14( + pretrained=True, + use_checkpoint=False, + checkpoint_num=[0], + t_size=16, + dw_reduction=1.5, + backbone_drop_path_rate=0., + temporal_downsample=True, + no_lmhra=False, + double_lmhra=True, + return_list=[20, 21, 22, 23], + n_layers=4, + n_dim=1024, + n_head=16, + mlp_factor=4.0, + drop_path_rate=0., + mlp_dropout=[0.5, 0.5, 0.5, 0.5], + cls_dropout=0.5, + num_classes=400, +): + model = VisionTransformer( + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768, + use_checkpoint=use_checkpoint, + checkpoint_num=checkpoint_num, + t_size=t_size, + dw_reduction=dw_reduction, + backbone_drop_path_rate=backbone_drop_path_rate, + temporal_downsample=temporal_downsample, + no_lmhra=no_lmhra, + double_lmhra=double_lmhra, + return_list=return_list, + n_layers=n_layers, + n_dim=n_dim, + n_head=n_head, + mlp_factor=mlp_factor, + drop_path_rate=drop_path_rate, + mlp_dropout=mlp_dropout, + cls_dropout=cls_dropout, + num_classes=num_classes, + ) + + if pretrained: + print('load pretrained weights') + state_dict = torch.load(_MODELS['ViT-L/14'], map_location='cpu') + load_state_dict(model, state_dict) + return model.eval() + + +def intern_action_l14_336( + pretrained=True, + use_checkpoint=False, + checkpoint_num=[0], + t_size=16, + dw_reduction=1.5, + backbone_drop_path_rate=0., + no_temporal_downsample=True, + no_lmhra=False, + double_lmhra=True, + return_list=[20, 21, 22, 23], + n_layers=4, + n_dim=1024, + n_head=16, + mlp_factor=4.0, + drop_path_rate=0., + mlp_dropout=[0.5, 0.5, 0.5, 0.5], + cls_dropout=0.5, + num_classes=400, +): + model = VisionTransformer( + input_resolution=336, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768, + use_checkpoint=use_checkpoint, + checkpoint_num=checkpoint_num, + t_size=t_size, + dw_reduction=dw_reduction, + backbone_drop_path_rate=backbone_drop_path_rate, + no_temporal_downsample=no_temporal_downsample, + no_lmhra=no_lmhra, + double_lmhra=double_lmhra, + return_list=return_list, + n_layers=n_layers, + n_dim=n_dim, + n_head=n_head, + mlp_factor=mlp_factor, + drop_path_rate=drop_path_rate, + mlp_dropout=mlp_dropout, + cls_dropout=cls_dropout, + num_classes=num_classes, + ) + + if pretrained: + print('load pretrained weights') + state_dict = torch.load(_MODELS['ViT-L/14_336'], map_location='cpu') + load_state_dict(model, state_dict) + return model.eval() + + +if __name__ == '__main__': + import time + + import numpy as np + from fvcore.nn import FlopCountAnalysis, flop_count_table + + seed = 4217 + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + num_frames = 16 + + model = intern_action_l14( + pretrained=False, + t_size=num_frames, + backbone_drop_path_rate=0., + drop_path_rate=0., + dw_reduction=1.5, + no_lmhra=False, + temporal_downsample=True, + return_list=[ + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 + ], + mlp_dropout=[0.5] * 16, + n_layers=16) + print(model) + + flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224)) + s = time.time() + print(flop_count_table(flops, max_depth=1)) + print(time.time() - s) diff --git a/projects/videochat/models/load_internvideo.py b/projects/videochat/models/load_internvideo.py new file mode 100644 index 0000000000..6bb53461d2 --- /dev/null +++ b/projects/videochat/models/load_internvideo.py @@ -0,0 +1,483 @@ +import os + +import numpy as np +# from kinetics_class_index import kinetics_classnames +import torch +import torch.nn as nn +import torchvision.transforms as T +from huggingface_hub import hf_hub_download +from transforms import (GroupCenterCrop, GroupNormalize, GroupScale, Stack, + ToTorchFormatTensor) + +from projects.videochat.models.intern_action import intern_action_b16 + + +class Intern_Action(nn.Module): + + def __init__(self, model): + super().__init__() + self.backbone = model + + def forward(self, x): + return self.backbone(x) + + +def get_index(num_frames, num_segments=8): + seg_size = float(num_frames - 1) / num_segments + start = int(seg_size / 2) + offsets = np.array( + [start + int(np.round(seg_size * idx)) for idx in range(num_segments)]) + return offsets + + +def transform_action(): + # transform + crop_size = 224 + scale_size = 256 + input_mean = [0.485, 0.456, 0.406] + input_std = [0.229, 0.224, 0.225] + + return T.Compose([ + # T.ToPILImage(), + GroupScale(int(scale_size)), + GroupCenterCrop(crop_size), + Stack(), + ToTorchFormatTensor(), + GroupNormalize(input_mean, input_std) + ]) + + +def load_intern_action(device, pretrained): + # Create an id to label name mapping + kinetics_id_to_classname = {} + for k, v in kinetics_classnames.items(): + kinetics_id_to_classname[k] = v + + if os.path.exists(pretrained): + model_path = pretrained + else: + model_path = hf_hub_download( + repo_id='Andy1621/uniformerv2', + filename='k400+k710_uniformerv2_b16_8x224.pyth') + # Pick a pretrained model + model = Intern_Action( + intern_action_b16( + pretrained=False, + t_size=8, + no_lmhra=True, + temporal_downsample=False)) + state_dict = torch.load(model_path, map_location=device) + model.load_state_dict(state_dict) + # Set to eval mode and move to desired device + model = model.to(device) + model = model.eval() + return model + + +def cut_frame_to_8(data): + index = np.linspace(0, len(data) - 1, 8).astype(int) + return data[index] + + +kinetics_classnames = { + '0': 'riding a bike', + '1': 'marching', + '2': 'dodgeball', + '3': 'playing cymbals', + '4': 'checking tires', + '5': 'roller skating', + '6': 'tasting beer', + '7': 'clapping', + '8': 'drawing', + '9': 'juggling fire', + '10': 'bobsledding', + '11': 'petting animal (not cat)', + '12': 'spray painting', + '13': 'training dog', + '14': 'eating watermelon', + '15': 'building cabinet', + '16': 'applauding', + '17': 'playing harp', + '18': 'balloon blowing', + '19': 'sled dog racing', + '20': 'wrestling', + '21': 'pole vault', + '22': 'hurling (sport)', + '23': 'riding scooter', + '24': 'shearing sheep', + '25': 'sweeping floor', + '26': 'eating carrots', + '27': 'skateboarding', + '28': 'dunking basketball', + '29': 'disc golfing', + '30': 'eating spaghetti', + '31': 'playing flute', + '32': 'riding mechanical bull', + '33': 'making sushi', + '34': 'trapezing', + '35': 'picking fruit', + '36': 'stretching leg', + '37': 'playing ukulele', + '38': 'tying tie', + '39': 'skydiving', + '40': 'playing cello', + '41': 'jumping into pool', + '42': 'shooting goal (soccer)', + '43': 'trimming trees', + '44': 'bookbinding', + '45': 'ski jumping', + '46': 'walking the dog', + '47': 'riding unicycle', + '48': 'shaving head', + '49': 'hopscotch', + '50': 'playing piano', + '51': 'parasailing', + '52': 'bartending', + '53': 'kicking field goal', + '54': 'finger snapping', + '55': 'dining', + '56': 'yawning', + '57': 'peeling potatoes', + '58': 'canoeing or kayaking', + '59': 'front raises', + '60': 'laughing', + '61': 'dancing macarena', + '62': 'digging', + '63': 'reading newspaper', + '64': 'hitting baseball', + '65': 'clay pottery making', + '66': 'exercising with an exercise ball', + '67': 'playing saxophone', + '68': 'shooting basketball', + '69': 'washing hair', + '70': 'lunge', + '71': 'brushing hair', + '72': 'curling hair', + '73': 'kitesurfing', + '74': 'tapping guitar', + '75': 'bending back', + '76': 'skipping rope', + '77': 'situp', + '78': 'folding paper', + '79': 'cracking neck', + '80': 'assembling computer', + '81': 'cleaning gutters', + '82': 'blowing out candles', + '83': 'shaking hands', + '84': 'dancing gangnam style', + '85': 'windsurfing', + '86': 'tap dancing', + '87': 'skiing (not slalom or crosscountry)', + '88': 'bandaging', + '89': 'push up', + '90': 'doing nails', + '91': 'punching person (boxing)', + '92': 'bouncing on trampoline', + '93': 'scrambling eggs', + '94': 'singing', + '95': 'cleaning floor', + '96': 'krumping', + '97': 'drumming fingers', + '98': 'snowmobiling', + '99': 'gymnastics tumbling', + '100': 'headbanging', + '101': 'catching or throwing frisbee', + '102': 'riding elephant', + '103': 'bee keeping', + '104': 'feeding birds', + '105': 'snatch weight lifting', + '106': 'mowing lawn', + '107': 'fixing hair', + '108': 'playing trumpet', + '109': 'flying kite', + '110': 'crossing river', + '111': 'swinging legs', + '112': 'sanding floor', + '113': 'belly dancing', + '114': 'sneezing', + '115': 'clean and jerk', + '116': 'side kick', + '117': 'filling eyebrows', + '118': 'shuffling cards', + '119': 'recording music', + '120': 'cartwheeling', + '121': 'feeding fish', + '122': 'folding clothes', + '123': 'water skiing', + '124': 'tobogganing', + '125': 'blowing leaves', + '126': 'smoking', + '127': 'unboxing', + '128': 'tai chi', + '129': 'waxing legs', + '130': 'riding camel', + '131': 'slapping', + '132': 'tossing salad', + '133': 'capoeira', + '134': 'playing cards', + '135': 'playing organ', + '136': 'playing violin', + '137': 'playing drums', + '138': 'tapping pen', + '139': 'vault', + '140': 'shoveling snow', + '141': 'playing tennis', + '142': 'getting a tattoo', + '143': 'making a sandwich', + '144': 'making tea', + '145': 'grinding meat', + '146': 'squat', + '147': 'eating doughnuts', + '148': 'ice fishing', + '149': 'snowkiting', + '150': 'kicking soccer ball', + '151': 'playing controller', + '152': 'giving or receiving award', + '153': 'welding', + '154': 'throwing discus', + '155': 'throwing axe', + '156': 'ripping paper', + '157': 'swimming butterfly stroke', + '158': 'air drumming', + '159': 'blowing nose', + '160': 'hockey stop', + '161': 'taking a shower', + '162': 'bench pressing', + '163': 'planting trees', + '164': 'pumping fist', + '165': 'climbing tree', + '166': 'tickling', + '167': 'high kick', + '168': 'waiting in line', + '169': 'slacklining', + '170': 'tango dancing', + '171': 'hurdling', + '172': 'carrying baby', + '173': 'celebrating', + '174': 'sharpening knives', + '175': 'passing American football (in game)', + '176': 'headbutting', + '177': 'playing recorder', + '178': 'brush painting', + '179': 'garbage collecting', + '180': 'robot dancing', + '181': 'shredding paper', + '182': 'pumping gas', + '183': 'rock climbing', + '184': 'hula hooping', + '185': 'braiding hair', + '186': 'opening present', + '187': 'texting', + '188': 'decorating the christmas tree', + '189': 'answering questions', + '190': 'playing keyboard', + '191': 'writing', + '192': 'bungee jumping', + '193': 'sniffing', + '194': 'eating burger', + '195': 'playing accordion', + '196': 'making pizza', + '197': 'playing volleyball', + '198': 'tasting food', + '199': 'pushing cart', + '200': 'spinning poi', + '201': 'cleaning windows', + '202': 'arm wrestling', + '203': 'changing oil', + '204': 'swimming breast stroke', + '205': 'tossing coin', + '206': 'deadlifting', + '207': 'hoverboarding', + '208': 'cutting watermelon', + '209': 'cheerleading', + '210': 'snorkeling', + '211': 'washing hands', + '212': 'eating cake', + '213': 'pull ups', + '214': 'surfing water', + '215': 'eating hotdog', + '216': 'holding snake', + '217': 'playing harmonica', + '218': 'ironing', + '219': 'cutting nails', + '220': 'golf chipping', + '221': 'shot put', + '222': 'hugging', + '223': 'playing clarinet', + '224': 'faceplanting', + '225': 'trimming or shaving beard', + '226': 'drinking shots', + '227': 'riding mountain bike', + '228': 'tying bow tie', + '229': 'swinging on something', + '230': 'skiing crosscountry', + '231': 'unloading truck', + '232': 'cleaning pool', + '233': 'jogging', + '234': 'ice climbing', + '235': 'mopping floor', + '236': 'making bed', + '237': 'diving cliff', + '238': 'washing dishes', + '239': 'grooming dog', + '240': 'weaving basket', + '241': 'frying vegetables', + '242': 'stomping grapes', + '243': 'moving furniture', + '244': 'cooking sausages', + '245': 'doing laundry', + '246': 'dying hair', + '247': 'knitting', + '248': 'reading book', + '249': 'baby waking up', + '250': 'punching bag', + '251': 'surfing crowd', + '252': 'cooking chicken', + '253': 'pushing car', + '254': 'springboard diving', + '255': 'swing dancing', + '256': 'massaging legs', + '257': 'beatboxing', + '258': 'breading or breadcrumbing', + '259': 'somersaulting', + '260': 'brushing teeth', + '261': 'stretching arm', + '262': 'juggling balls', + '263': "massaging person's head", + '264': 'eating ice cream', + '265': 'extinguishing fire', + '266': 'hammer throw', + '267': 'whistling', + '268': 'crawling baby', + '269': 'using remote controller (not gaming)', + '270': 'playing cricket', + '271': 'opening bottle', + '272': 'playing xylophone', + '273': 'motorcycling', + '274': 'driving car', + '275': 'exercising arm', + '276': 'passing American football (not in game)', + '277': 'playing kickball', + '278': 'sticking tongue out', + '279': 'flipping pancake', + '280': 'catching fish', + '281': 'eating chips', + '282': 'shaking head', + '283': 'sword fighting', + '284': 'playing poker', + '285': 'cooking on campfire', + '286': 'doing aerobics', + '287': 'paragliding', + '288': 'using segway', + '289': 'folding napkins', + '290': 'playing bagpipes', + '291': 'gargling', + '292': 'skiing slalom', + '293': 'strumming guitar', + '294': 'javelin throw', + '295': 'waxing back', + '296': 'riding or walking with horse', + '297': 'plastering', + '298': 'long jump', + '299': 'parkour', + '300': 'wrapping present', + '301': 'egg hunting', + '302': 'archery', + '303': 'cleaning toilet', + '304': 'swimming backstroke', + '305': 'snowboarding', + '306': 'catching or throwing baseball', + '307': 'massaging back', + '308': 'blowing glass', + '309': 'playing guitar', + '310': 'playing chess', + '311': 'golf driving', + '312': 'presenting weather forecast', + '313': 'rock scissors paper', + '314': 'high jump', + '315': 'baking cookies', + '316': 'using computer', + '317': 'washing feet', + '318': 'arranging flowers', + '319': 'playing bass guitar', + '320': 'spraying', + '321': 'cutting pineapple', + '322': 'waxing chest', + '323': 'auctioning', + '324': 'jetskiing', + '325': 'drinking', + '326': 'busking', + '327': 'playing monopoly', + '328': 'salsa dancing', + '329': 'waxing eyebrows', + '330': 'watering plants', + '331': 'zumba', + '332': 'chopping wood', + '333': 'pushing wheelchair', + '334': 'carving pumpkin', + '335': 'building shed', + '336': 'making jewelry', + '337': 'catching or throwing softball', + '338': 'bending metal', + '339': 'ice skating', + '340': 'dancing charleston', + '341': 'abseiling', + '342': 'climbing a rope', + '343': 'crying', + '344': 'cleaning shoes', + '345': 'dancing ballet', + '346': 'driving tractor', + '347': 'triple jump', + '348': 'throwing ball', + '349': 'getting a haircut', + '350': 'running on treadmill', + '351': 'climbing ladder', + '352': 'blasting sand', + '353': 'playing trombone', + '354': 'drop kicking', + '355': 'country line dancing', + '356': 'changing wheel', + '357': 'feeding goats', + '358': 'tying knot (not on a tie)', + '359': 'setting table', + '360': 'shaving legs', + '361': 'kissing', + '362': 'riding mule', + '363': 'counting money', + '364': 'laying bricks', + '365': 'barbequing', + '366': 'news anchoring', + '367': 'smoking hookah', + '368': 'cooking egg', + '369': 'peeling apples', + '370': 'yoga', + '371': 'sharpening pencil', + '372': 'dribbling basketball', + '373': 'petting cat', + '374': 'playing ice hockey', + '375': 'milking cow', + '376': 'shining shoes', + '377': 'juggling soccer ball', + '378': 'scuba diving', + '379': 'playing squash or racquetball', + '380': 'drinking beer', + '381': 'sign language interpreting', + '382': 'playing basketball', + '383': 'breakdancing', + '384': 'testifying', + '385': 'making snowman', + '386': 'golf putting', + '387': 'playing didgeridoo', + '388': 'biking through snow', + '389': 'sailing', + '390': 'jumpstyle dancing', + '391': 'water sliding', + '392': 'grooming horse', + '393': 'massaging feet', + '394': 'playing paintball', + '395': 'making a cake', + '396': 'bowling', + '397': 'contact juggling', + '398': 'applying cream', + '399': 'playing badminton' +} diff --git a/projects/videochat/models/med.py b/projects/videochat/models/med.py new file mode 100644 index 0000000000..79fb9fff63 --- /dev/null +++ b/projects/videochat/models/med.py @@ -0,0 +1,1137 @@ +"""* Copyright (c) 2022, salesforce.com, inc. * All rights reserved. * +SPDX-License-Identifier: BSD-3-Clause * For full license text, +see LICENSE.txt file in the repo root or +https://opensource.org/licenses/BSD-3-Clause * By Junnan Li * Based on +huggingface code base * https://github.com/huggingface/transformers/blob/v4 +.15.0/src/transformers/models/bert """ + +import math +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class BertEmbeddings_nopos(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + # if input_ids is not None: + # input_shape = input_ids.size() + # else: + # input_shape = inputs_embeds.size()[:-1] + + # seq_length = input_shape[1] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + # if self.position_embedding_type == "absolute": + # position_embeddings = self.position_embeddings(position_ids) + # # print('add position_embeddings!!!!') + # embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward(self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = \ + self.position_ids[ + :, + past_key_values_length:seq_length + past_key_values_length + ] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + # print('add position_embeddings!!!!') + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not \ + hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of ' + 'attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = \ + self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if self.position_embedding_type == 'relative_key' \ + or self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + # print(self.key.weight.shape) + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if self.position_embedding_type == 'relative_key' \ + or self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = \ + attention_scores + \ + relative_position_scores_query + \ + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = \ + self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + + if mode == 'mlr': + + assert encoder_hidden_states is not None, \ + 'encoder_hidden_states must be given for cross-attention ' \ + 'layers ' + + # print('attention_output.shape',attention_output.shape) + # print('encoder_hidden_states.shape',encoder_hidden_states.shape) + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = cross_attention_outputs[1:-1] + + present_key_value = cross_attention_outputs[-1] + + else: + self_attn_past_key_value = past_key_value[:2] \ + if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode == 'multimodal': + assert encoder_hidden_states is not None, \ + 'encoder_hidden_states must be given for ' \ + 'cross-attention layers' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with gradient ' + 'checkpointing. Setting `use_cache=False`... ') + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this + layer} See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones + indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): The shape of the input to the + model. device: (:obj:`torch.device`): The device of the input to the + model. + + Returns: :obj:`torch.Tensor` The extended attention mask, with the + same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, + # from_seq_length, to_seq_length] ourselves in which case we just + # need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition + # to the padding mask - if the model is an encoder, make the + # mask broadcastable to [batch_size, num_heads, seq_length, + # seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix + # ones mask to the causal mask causal and attention masks + # must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = \ + causal_mask[:, None, :, :] * \ + attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (' + 'shape {}) '.format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and + # 0.0 for masked positions, this operation will create a tensor + # which is 0.0 for positions we want to attend and -10000.0 for + # masked positions. Since we are adding it to the raw scores before + # the softmax, this is effectively the same as removing these + # entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`( + batch_size, sequence_length, hidden_size)`, `optional`): Sequence of + hidden-states at the output of the last layer of the encoder. Used + in the cross-attention if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`( + batch_size, sequence_length)`, `optional`): Mask to avoid performing + attention on the padding token indices of the encoder input. This + mask is used in the cross-attention if the model is configured as a + decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are + **not masked**, - 0 for tokens that are **masked**. past_key_values + (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, + embed_size_per_head)`): Contains precomputed key and value hidden + states of the attention blocks. Can be used to speed up decoding. If + :obj:`past_key_values` are used, the user can optionally input only + the last :obj:`decoder_input_ids` (those that don't have their past + key value states given to this model) of shape :obj:`(batch_size, + 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`( + batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are + returned and can be used to speed up decoding (see + :obj:`past_key_values`). + """ + output_attentions = output_attentions \ + if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict \ + if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache \ + if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the ' + 'same time ') + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or ' + 'encoder_embeds ') + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length + past_key_values_length), + device=device) + + # We can provide a self-attention mask of dimensions [batch_size, + # from_seq_length, to_seq_length] ourselves in which case we just + # need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = \ + self.get_extended_attention_mask( + attention_mask, + input_shape, + device, + is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, + # seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = \ + encoder_hidden_states[ + 0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = \ + encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed 1.0 in head_mask indicate we keep the + # head attention_probs has shape bsz x n_heads x N x N input + # head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x + # num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`( + batch_size, sequence_length, hidden_size)`, `optional`): Sequence of + hidden-states at the output of the last layer of the encoder. Used + in the cross-attention if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`( + batch_size, sequence_length)`, `optional`): Mask to avoid performing + attention on the padding token indices of the encoder input. This + mask is used in the cross-attention if the model is configured as a + decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are + **not masked**, - 0 for tokens that are **masked**. labels ( + :obj:`torch.LongTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): Labels for computing the + left-to-right language modeling loss (next word prediction). Indices + should be in ``[-100, 0, ..., config.vocab_size]`` (see + ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with + labels n ``[0, ..., config.vocab_size]`` past_key_values ( + :obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, + embed_size_per_head)`): Contains precomputed key and value hidden + states of the attention blocks. Can be used to speed up decoding. If + :obj:`past_key_values` are used, the user can optionally input only + the last :obj:`decoder_input_ids` (those that don't have their past + key value states given to this model) of shape :obj:`(batch_size, + 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`( + batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are + returned and can be used to speed up decoding (see + :obj:`past_key_values`). Returns: Example:: >>> from transformers + import BertTokenizer, BertLMHeadModel, BertConfig >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> + config = BertConfig.from_pretrained("bert-base-cased") >>> model = + BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits + """ + return_dict = return_dict \ + if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + # sequence_output.shape torch.Size([85, 30, 768]) + # prediction_scores.shape torch.Size([85, 30, 30524]) + # labels.shape torch.Size([85, 30]) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores + # and input ids by one + shifted_prediction_scores = \ + prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss( + reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if reduction == 'none': + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past diff --git a/projects/videochat/models/ocr.py b/projects/videochat/models/ocr.py new file mode 100644 index 0000000000..75a80b557e --- /dev/null +++ b/projects/videochat/models/ocr.py @@ -0,0 +1,225 @@ +import ast +import math +import os +import subprocess +from collections import defaultdict + +import openai +from paddleocr import PaddleOCR + +from projects.videochat.util import loadvideo_decord_origin + +openai.api_key = os.getenv('OPENAI_API_KEY') + + +def compute_square(pos): + return (pos[2] - pos[0]) * (pos[3] - pos[1]) + + +def remove_duplicates_by_text_key(lst, value): + seen_text_values = set() + result = [] + + for obj in lst: + text_value = obj.get(value) + + if text_value not in seen_text_values: + seen_text_values.add(text_value) + result.append(obj) + + return result + + +class ProcessOCR: + + def __init__(self): + self.video_path = None + + def inference(self, video_path, data=None): + self.video_path = video_path + ocr = PaddleOCR( + use_angle_cls=True, lang='ch' + ) # need to run only once to download and load model into memory + if data is None: + try: + data = loadvideo_decord_origin(video_path) + except Exception as e: + print(e) + return '' + text_result = list() + + for i in range(data.shape[0]): + result = ocr.ocr(data[i], cls=True) + text_per_image = list() + for idx in range(len(result)): + res = result[idx] + for line in res: + if line[1][1] < 0.8: + continue + text = dict() + text['pos'] = [ + line[0][0][0], line[0][0][1], line[0][2][0], + line[0][2][1] + ] + text['text'] = line[1][0] + text_per_image.append(text) + if len(text_per_image) > 0: + text_result.append({'begin': i, 'text': text_per_image}) + return text_result + + def merge(self, features): + # Paddleocr目前支持的多语言语种可以通过修改lang参数进行切换 + # 例如`ch`, `en`, `fr`, `german`, `korean`, `japan` + # video_path = self.video_path + video_width, video_height = self.get_video_resolution_ffmpeg() + # dense_with_pos = '' + # ocr_subtitle_result = '' + # ocr_subtitle_result2 = '' + dense_with_ocr, ocr_subtitle = self.find_text_in_dense( + features, video_width, video_height) + return dense_with_ocr, ocr_subtitle + + def get_video_resolution_ffmpeg(self): + video_path = self.video_path + try: + # 调用FFmpeg命令行获取视频宽度和高度信息 + cmd = [ + 'ffprobe', '-v', 'error', '-select_streams', 'v:0', + '-show_entries', 'stream=width,height', '-of', 'csv=p=0', + video_path + ] + result = subprocess.check_output( + cmd, stderr=subprocess.STDOUT, text=True) + + # 解析输出并转换为整数 + dimensions = [int(dim) for dim in result.strip().split(',')] + if len(dimensions) == 2: + return dimensions[0], dimensions[1] + else: + return None, None + except subprocess.CalledProcessError as e: + print(f'Error: {e.output}') + return None, None + + def chatgpt_check_subtitle(self, question): + + response = openai.ChatCompletion.create( + model='gpt-4', + messages=[ + { + 'role': + 'system', + 'content': + '用户将输入一段字幕,其中相邻的几句可能是同一句话,但是由于识别错误导致' + '了重复。你需要将其中识别错误的语句删去。\n\n例如用户输入:```\n' + 'Second 61: 你这瓜要是熟我肯定要啊\nSecond 62: 你这瓜要是熟我背定' + '要啊\n```\n你应该删去```Second 62: 你这瓜要是熟我背定要啊```,' + '保留```Second 61: 你这瓜要是熟我肯定要啊```,所以你的输出应该是:' + '\n```Second 61: 你这瓜要是熟我肯定要啊```\n注意:你只需要给出修改' + '后的字幕,不要输出任何思考过程。下面是用户输入: ' + }, + { + 'role': 'user', + 'content': question + }, + ], + temperature=1, + max_tokens=2048, + top_p=1, + frequency_penalty=0, + presence_penalty=0) + answer = response.choices[0].message.content + return answer + + def merge_ocr_with_subtitle(self, data): + if len(data['ocr_subtitle']) == 0: + return data['merged_subtitle'] + question = '"merged_subtitle": {},\n"ocr_subtitle": {}'.format( + data['merged_subtitle'], data['ocr_subtitle']) + response = openai.ChatCompletion.create( + model='gpt-4', + messages=[ + { + 'role': + 'system', + 'content': + '用户将输入一段语音识别字幕和OCR识别字幕,它们都有识别错误的地方,你需' + '要将它们相互校对,形成最符合逻辑的字幕。\n你可以删除几种语音识别字幕中' + '的错误:\n1. 中文对话中突然出现英文\n2. 突然出现没有逻辑的语气词,' + '如“啊”“吧”等\n3. 不符合语境的词语' + }, + { + 'role': 'user', + 'content': question + }, + ], + temperature=1, + max_tokens=2048, + top_p=1, + frequency_penalty=0, + presence_penalty=0) + answer = response.choices[0].message.content + second_index = answer.find('Second') + if second_index != -1: + result = answer[second_index:] + else: + result = answer + return result.strip('"') + + def find_text_in_dense(self, features, video_width, video_height): + ocr = features['ocr'] + dense = features['dense_with_pos'] + dense_with_ocr = list() + ocr_subtitle = list() + start_time = dense[0]['begin'] + for cur_ocr in ocr: + time = cur_ocr['begin'] + cur_dense_text = dense[time - start_time]['text'] + cur_dense_with_ocr = defaultdict(str) + cur_ocr_subtitle = '' + for cur_text in cur_ocr['text']: + cur_text_pos = cur_text['pos'] + cur_text_text = cur_text['text'] + if cur_text_pos[1] > video_height * 0.8: + cur_ocr_subtitle += cur_text_text + ' ' + print(f'{cur_text_text} 可能是字幕') + continue + if len(cur_dense_text) == 0: + continue + belong_obj = self.find_pos(cur_text_pos, cur_dense_text) + if belong_obj: + cur_dense_with_ocr[belong_obj] += f'{cur_text_text} ' + if len(cur_dense_with_ocr) > 0: + final_text = '' + for key, value in cur_dense_with_ocr.items(): + final_text += \ + f'{key} with the words "{value.strip()}" on it, ' + dense_with_ocr.append({ + 'begin': time, + 'text': final_text.strip(', '), + }) + if len(cur_ocr_subtitle.strip()): + ocr_subtitle.append({ + 'begin': time, + 'text': cur_ocr_subtitle.strip(), + }) + ocr_subtitle = remove_duplicates_by_text_key(ocr_subtitle, 'text') + return dense_with_ocr, ocr_subtitle + + def find_pos(self, text_pos, dense_text): + obj_list = dense_text.split(';') + min_square = math.inf + belong_obj = None + for obj in obj_list: + obj_item = obj.split(': ')[0] + if obj_item == ' ': + continue + obj_pos = ast.literal_eval(obj.split(': ')[1]) + if obj_pos[0] <= text_pos[0] and \ + obj_pos[1] <= text_pos[1] and \ + obj_pos[2] >= text_pos[2] and obj_pos[3] >= \ + text_pos[3]: + if compute_square(obj_pos) < min_square: + belong_obj = obj_item + min_square = compute_square(obj_pos) + return belong_obj diff --git a/projects/videochat/models/qwen.py b/projects/videochat/models/qwen.py new file mode 100644 index 0000000000..8c8ea16f14 --- /dev/null +++ b/projects/videochat/models/qwen.py @@ -0,0 +1,44 @@ +""" +Description: +Version: 1.0 +Author: ZhuYichen +Date: 2023-09-10 17:03:05 +LastEditors: ZhuYichen +LastEditTime: 2023-09-11 17:57:32 +""" +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + + +class Qwen: + + def __init__(self, model_path): + self.model_path = model_path + + def init_model(self): + # Note: The default behavior now has injection attack prevention off. + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, device_map='cuda', trust_remote_code=True).eval() + + # Specify hyperparameters for generation + self.model.generation_config = GenerationConfig.from_pretrained( + self.model_path, trust_remote_code=True) + + def inference(self, image, question): + # 1st dialogue turn + query = self.tokenizer.from_list_format([ + { + 'image': image + }, # Either a local path or an url + { + 'text': + '''该图片是从一个视频片段中截取的一帧,问题可能与这一帧图片无关。 + 因此如果图片中不包含问题中提到的信息,直接回答“缺失信息”。''' + f'\n问题:{question}' + }, + ]) + response, history = self.model.chat( + self.tokenizer, query=query, history=None) + print(response) + return response diff --git a/projects/videochat/models/subtitle.py b/projects/videochat/models/subtitle.py new file mode 100644 index 0000000000..7ce03f3643 --- /dev/null +++ b/projects/videochat/models/subtitle.py @@ -0,0 +1,193 @@ +import base64 +import hashlib +import hmac +import json +import math +import os +import time +import urllib + +import requests + + +class RequestApi(object): + + def __init__(self, appid, secret_key, upload_file_path): + self.lfasr_host = 'https://raasr.xfyun.cn/v2/api' + self.api_upload = '/upload' + self.api_get_result = '/getResult' + self.result = None + self.appid = appid + self.secret_key = secret_key + self.upload_file_path = upload_file_path + self.ts = str(int(time.time())) + self.signa = self.get_signa() + + def get_signa(self): + appid = self.appid + secret_key = self.secret_key + m2 = hashlib.md5() + m2.update((appid + self.ts).encode('utf-8')) + md5 = m2.hexdigest() + md5 = bytes(md5, encoding='utf-8') + # 以secret_key为key, 上面的md5为msg, 使用hashlib.sha1加密结果为signa + signa = hmac.new(secret_key.encode('utf-8'), md5, + hashlib.sha1).digest() + signa = base64.b64encode(signa) + signa = str(signa, 'utf-8') + return signa + + def upload(self): + print('上传部分:') + upload_file_path = self.upload_file_path + file_len = os.path.getsize(upload_file_path) + file_name = os.path.basename(upload_file_path) + + param_dict = {} + param_dict['appId'] = self.appid + param_dict['signa'] = self.signa + param_dict['ts'] = self.ts + param_dict['fileSize'] = file_len + param_dict['fileName'] = file_name + param_dict['duration'] = '200' + param_dict['roleType'] = 1 + # param_dict["roleNum"] = 2 + print('upload参数:', param_dict) + data = open(upload_file_path, 'rb').read(file_len) + + response = requests.post( + url=self.lfasr_host + self.api_upload + '?' + + urllib.parse.urlencode(param_dict), + headers={'Content-type': 'application/json'}, + data=data) + print('upload_url:', response.request.url) + result = json.loads(response.text) + print('upload resp:', result) + return result + + def get_result(self): + uploadresp = self.upload() + orderId = uploadresp['content']['orderId'] + param_dict = { + 'appId': self.appid, + 'signa': self.signa, + 'ts': self.ts, + 'orderId': orderId, + 'resultType': 'transfer' + } + print('') + print('查询部分:') + print('get result参数:', param_dict) + status = 3 + # 建议使用回调的方式查询结果,查询接口有请求频率限制 + while status == 3: + response = requests.post( + url=self.lfasr_host + self.api_get_result + '?' + + urllib.parse.urlencode(param_dict), + headers={'Content-type': 'application/json'}) + # print("get_result_url:",response.request.url) + result = json.loads(response.text) + print(result) + status = result['content']['orderInfo']['status'] + print('status=', status) + if status == 4: + break + time.sleep(5) + print('get_result resp:', result) + self.result = result + return result + + def result2text(self, return_list=False): + order_result = json.loads(self.result['content']['orderResult']) + lattice = order_result['lattice'] + sentence_list = list() + for item in lattice: + sentence = dict() + text = '' + for item2 in json.loads(item['json_1best'])['st']['rt'][0]['ws']: + text += item2['cw'][0]['w'] + sentence['text'] = text + sentence['speaker'] = json.loads(item['json_1best'])['st']['rl'] + sentence['begin'] = int( + int(json.loads(item['json_1best'])['st']['bg']) / 1000) + sentence['end'] = int( + int(json.loads(item['json_1best'])['st']['ed']) / 1000) + sentence_list.append(sentence) + result = '' + for sentence in sentence_list: + result += 'Second {} to Second {}, Speaker {}: {}\n'.format( + sentence['begin'], sentence['end'], sentence['speaker'], + sentence['text']) + if return_list: + return sentence_list + else: + return result + + +class ProcessSubtitle: + + def __init__(self, features): + self.features = features + + def merge_whisper_and_xunfei(self): + whisper = self.features['whisper'] + xunfei = self.features['subtitle'] + result = [] + for whisper_item in whisper: + # 第一个字是中文,使用讯飞的结果 + if '\u4e00' <= whisper_item['text'][0] <= '\u9fff': + find_item = self.find_match_subtitle(whisper_item['begin'], + xunfei) + if not find_item: + result.append(whisper_item) + else: + result.append(find_item) + # 不是中文,使用whisper的结果,但使用讯飞的speaker + else: + find_item = self.find_match_subtitle(whisper_item['begin'], + xunfei) + if not find_item: + whisper_item['speaker'] = '1' + result.append(whisper_item) + else: + whisper_item['speaker'] = find_item['speaker'] + result.append(whisper_item) + result = self.remove_duplicates_by_text_key(result, 'text') + return result + + def find_match_subtitle(self, time, subtitle): + time_gap = math.inf + temp = None + for item in subtitle: + if item['begin'] == time: + return item + elif 0 < time - item['begin'] < time_gap: + time_gap = time - item['begin'] + temp = item + elif time - item['begin'] < 0: + if item['begin'] - time < time_gap: + return item + else: + return temp + return temp + + def remove_duplicates_by_text_key(self, lst, value): + seen_text_values = set() + result = [] + + for obj in lst: + text_value = obj.get(value) + + if text_value not in seen_text_values: + seen_text_values.add(text_value) + result.append(obj) + + return result + + +# 输入讯飞开放平台的appid,secret_key和待转写的文件路径 +if __name__ == '__main__': + api = RequestApi(appid='', secret_key='', upload_file_path=r'') + + api.get_result() + api.result2text() diff --git a/projects/videochat/models/swin_transformer.py b/projects/videochat/models/swin_transformer.py new file mode 100644 index 0000000000..1f02428d49 --- /dev/null +++ b/projects/videochat/models/swin_transformer.py @@ -0,0 +1,782 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from scipy import interpolate +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative + position bias. It supports both of shifted and non-shifted window. + + Args: dim (int): Number of input channels. window_size (tuple[int]): The + height and width of the window. num_heads (int): Number of attention + heads. qkv_bias (bool, optional): If True, add a learnable bias to + query, key, value. Default: True qk_scale (float | None, optional): + Override default qk scale of head_dim ** -0.5 if set attn_drop (float, + optional): Dropout ratio of attention weight. Default: 0.0 proj_drop ( + float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the + # window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - \ + coords_flatten[:, None, :] + # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += \ + self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: x: input features with shape of (num_windows*B, N, C) mask: ( + 0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: dim (int): Number of input channels. input_resolution (tuple[ + int]): Input resolution. num_heads (int): Number of attention heads. + window_size (int): Window size. shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias ( + bool, optional): If True, add a learnable bias to query, key, value. + Default: True qk_scale (float | None, optional): Override default qk + scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. + Default: 0.0 attn_drop (float, optional): Attention dropout rate. + Default: 0.0 drop_path (float, optional): Stochastic depth rate. + Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: + nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: + nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't + # partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, \ + 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, + W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, ' \ + f'input_resolution={self.input_resolution}, ' \ + f'num_heads={self.num_heads}, '\ + f'window_size={self.window_size}, ' \ + f'shift_size={self.shift_size}, ' \ + f'mlp_ratio={self.mlp_ratio} ' + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: input_resolution (tuple[int]): Resolution of input feature. dim ( + int): Number of input channels. norm_layer (nn.Module, optional): + Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + assert H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.' + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f'input_resolution={self.input_resolution}, dim={self.dim}' + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + + Args: dim (int): Number of input channels. input_resolution (tuple[ + int]): Input resolution. depth (int): Number of blocks. num_heads (int): + Number of attention heads. window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias ( + bool, optional): If True, add a learnable bias to query, key, value. + Default: True qk_scale (float | None, optional): Override default qk + scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. + Default: 0.0 attn_drop (float, optional): Attention dropout rate. + Default: 0.0 drop_path (float | tuple[float], optional): Stochastic + depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization + layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): + Downsample layer at the end of the layer. Default: None use_checkpoint ( + bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, ' \ + f'input_resolution={self.input_resolution}, depth={self.depth}' + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: img_size (int): Image size. Default: 224. patch_size (int): Patch + token size. Default: 4. in_chans (int): Number of input image channels. + Default: 3. embed_dim (int): Number of linear projection output + channels. Default: 96. norm_layer (nn.Module, optional): Normalization + layer. Default: None + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model " \ + f'({self.img_size[0]}*{self.img_size[1]}). ' + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * ( + self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer A PyTorch impl of : `Swin Transformer: + Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): + Number of input image channels. Default: 3 num_classes (int): Number of + classes for classification head. Default: 1000 embed_dim (int): Patch + embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin + Transformer layer. num_heads (tuple(int)): Number of attention heads in + different layers. window_size (int): Window size. Default: 7 mlp_ratio ( + float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias ( + bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate + (float): Attention dropout rate. Default: 0 drop_path_rate (float): + Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): + Normalization layer. Default: nn.LayerNorm. ape (bool): If True, + add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. + Default: True use_checkpoint (bool): Whether to use checkpointing to + save memory. Default: False + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=(patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if + # num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + + x_cls = self.avgpool(x.transpose(1, 2)) # B C 1 + + if idx_to_group_img is None: + return torch.cat([x_cls.transpose(1, 2), x], dim=1) + else: + x_bs = torch.gather( + x, + dim=0, + index=idx_to_group_img.view(-1, 1, + 1).expand(-1, x.shape[1], + x.shape[2])) + weights = image_atts[:, 1:].unsqueeze(2) # B L 1 + x_bs_cls = torch.sum( + (weights * x_bs).transpose(1, 2), dim=-1, + keepdim=True) # B C 1 + x_bs_cls = x_bs_cls / torch.sum( + weights.transpose(1, 2), dim=-1, keepdim=True) # avgpool + + return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), \ + torch.cat([x_cls.transpose(1, 2), x], dim=1) + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[ + 0] * self.patches_resolution[1] // (2**self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''): + # from: https://github.com/microsoft/unilm/blob + # /8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py + # #L348 + + # rel_pos_bias: relative_position_bias_table + src_num_pos, num_attn_heads = rel_pos_bias.size() + + num_extra_tokens = 0 + src_size = int((src_num_pos - num_extra_tokens)**0.5) + dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + if src_size != dst_size: + print('Position interpolate %s from %dx%d to %dx%d' % + (param_name, src_size, src_size, dst_size, dst_size)) + + # extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q**(i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + # print("Original positions = %s" % str(x)) + # print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to( + rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + return rel_pos_bias diff --git a/projects/videochat/models/tag2text.py b/projects/videochat/models/tag2text.py new file mode 100644 index 0000000000..ecf5a572e4 --- /dev/null +++ b/projects/videochat/models/tag2text.py @@ -0,0 +1,533 @@ +""" + * Tag2Text + * Written by Xinyu Huang +""" +import json +import math +import os +import warnings +from typing import List +from urllib.parse import urlparse + +import numpy as np +import torch +from timm.models.hub import download_cached_file +from torch import nn +from transformers import BertTokenizer + +from projects.videochat.models.med import (BertConfig, BertLMHeadModel, + BertModel, logger) +from projects.videochat.models.swin_transformer import ( + SwinTransformer, interpolate_relative_pos_embed) +from projects.videochat.models.vit import (VisionTransformer, + interpolate_pos_embed) +from projects.videochat.tag_class.tag_class import tra_array + +warnings.filterwarnings('ignore') + + +def read_json(rpath): + with open(rpath, 'r') as f: + return json.load(f) + + +# delete some tags that may disturb captioning +delete_tag_index = [127, 2961, 3351, 3265, 3338, 3355, 3359] + +# adjust thresholds for some tags +# default threshold: 0.68 +# 2701: "person"; 2828: "man"; 1167: "woman"; +tag_thrshold = {2701: 0.7, 2828: 0.7, 1167: 0.7} + + +class Tag2Text_Caption(nn.Module): + + def __init__( + self, + med_config='configs/med_config.json', + image_size=384, + vit='base', + vit_grad_ckpt=False, + vit_ckpt_layer=0, + prompt='a picture of ', + threshold=0.68, + ): + """ + Args: med_config (str): path for the mixture of encoder-decoder + model's configuration file image_size (int): input image size vit ( + str): model size of vision transformer + """ + super().__init__() + + if vit == 'swin_b': + if image_size == 224: + vision_config_path = 'configs/swin/config_swinB_224.json' + elif image_size == 384: + vision_config_path = 'configs/swin/config_swinB_384.json' + vision_config = read_json(vision_config_path) + assert image_size == vision_config['image_res'] + # assert config['patch_size'] == 32 + vision_width = vision_config['vision_width'] + + self.visual_encoder = SwinTransformer( + img_size=vision_config['image_res'], + patch_size=4, + in_chans=3, + embed_dim=vision_config['embed_dim'], + depths=vision_config['depths'], + num_heads=vision_config['num_heads'], + window_size=vision_config['window_size'], + mlp_ratio=4., + qkv_bias=True, + drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False) + + else: + self.visual_encoder, vision_width = create_vit( + vit, image_size, vit_grad_ckpt, vit_ckpt_layer) + + self.tokenizer = init_tokenizer() + + # create the decoder + decoder_config = BertConfig.from_json_file(med_config) + decoder_config.encoder_width = 768 + self.text_decoder = BertLMHeadModel(config=decoder_config) + + # create encoder + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.tag_encoder = BertModel( + config=encoder_config, add_pooling_layer=False) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + self.threshold = threshold + num_features = 768 + self.num_class = 3429 + + q2l_config = BertConfig.from_json_file('configs/q2l_config.json') + q2l_config.encoder_width = vision_width + self.vision_multi = BertModel( + config=q2l_config, add_pooling_layer=False) + self.vision_multi.resize_token_embeddings(len(self.tokenizer)) + self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) + self.fc = GroupWiseLinear(self.num_class, num_features, bias=True) + self.del_selfattention() + + tie_encoder_decoder_weights(self.tag_encoder, self.vision_multi, '', + ' ') + self.tag_array = tra_array + + self.class_threshold = torch.ones(self.num_class) * self.threshold + for key, value in tag_thrshold.items(): + self.class_threshold[key] = value + + def del_selfattention(self): + del self.vision_multi.embeddings + for layer in self.vision_multi.encoder.layer: + del layer.attention + + def generate(self, + image, + sample=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + tag_input=None, + return_tag_predict=False): + image_embeds = self.visual_encoder(image) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + # ==============generate tag==============# + if tag_input is None: + image_spatial_embeds = image_embeds[:, 1:, :] + # image_cls_embeds = image_embeds[:, 0, :] + + bs = image_spatial_embeds.shape[0] + label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) + mlr_tagembedding = self.vision_multi( + encoder_embeds=label_embed, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + mode='mlr', + ) + + logits = self.fc(mlr_tagembedding[0]) + + # targets = torch.where(torch.sigmoid(logits) > self.threshold , + # torch.tensor(1.0).to(image.device), torch.zeros( + # self.num_class).to(image.device)) + targets = torch.where( + torch.sigmoid(logits) > self.class_threshold.to(image.device), + torch.tensor(1.0).to(image.device), + torch.zeros(self.num_class).to(image.device)) + + tag = targets.cpu().numpy() + tag[:, delete_tag_index] = 0 + bs = image.size(0) + tag_input = [] + for b in range(bs): + index = np.argwhere(tag[b] == 1) + token = self.tag_array[index].squeeze(axis=1) + tag_input.append(' | '.join(token)) + # ========================================# + + if not sample: + image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) + image_atts = image_atts.repeat_interleave(num_beams, dim=0) + tag_input_temp = [] + for tag in tag_input: + for i in range(num_beams): + tag_input_temp.append(tag) + tag_input = tag_input_temp + + tag_input_tokenzier = self.tokenizer( + tag_input, + padding='max_length', + truncation=True, + max_length=40, + return_tensors='pt').to(image.device) + encoder_input_ids = tag_input_tokenzier.input_ids + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id + + output_tagembedding = self.tag_encoder( + encoder_input_ids, + attention_mask=tag_input_tokenzier.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + prompt = [self.prompt] * image.size(0) + input_ids = self.tokenizer( + prompt, return_tensors='pt').input_ids.to(image.device) + input_ids[:, 0] = self.tokenizer.bos_token_id + input_ids = input_ids[:, :-1] + + if sample: + # nucleus sampling + model_kwargs = { + 'encoder_hidden_states': output_tagembedding.last_hidden_state, + 'encoder_attention_mask': None + } + outputs = self.text_decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + # beam search + model_kwargs = { + 'encoder_hidden_states': output_tagembedding.last_hidden_state, + 'encoder_attention_mask': None + } + outputs = self.text_decoder.generate( + input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + captions = [] + for output in outputs: + caption = self.tokenizer.decode(output, skip_special_tokens=True) + captions.append(caption[len(self.prompt):]) + if return_tag_predict: + if sample: + return captions, tag_input + else: + return captions, tag_input[0:int(len(tag_input) / num_beams)] + return captions + + def generate_sublists(self, + image, + sample=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + tag_input=None, + return_tag_predict=False): + n = image.shape[0] # total number of images + captions = [] + tags = [] + # iterate over image sub-tensors of size 10 or less + for i in range(0, n, 50): + image_subtensor = image[i:i + 50] + # get subtensor of size 10 or less + # call original generate function on image subtensor + sublist_captions, sublist_tags = self.generate( + image_subtensor, + sample=sample, + num_beams=num_beams, + max_length=max_length, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + tag_input=tag_input, + return_tag_predict=return_tag_predict) + + # append sublist captions to overall captions + captions.extend(sublist_captions) + tags.extend(sublist_tags) + + return captions, tags + + +def tag2text_caption(pretrained='', **kwargs): + model = Tag2Text_Caption(**kwargs) + if pretrained: + if kwargs['vit'] == 'swin_b': + model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) + else: + model, msg = load_checkpoint(model, pretrained) + # print('vit:',kwargs['vit']) + # print('msg_v2',msg) + return model + + +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, + base_model_prefix: str, skip_key: str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f'{decoder.__class__} and {encoder.__class__} are not equal. In ' + f'this case make sure that all encoder weights are correctly ' + f'initialized. ') + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f'{decoder_pointer} and {encoder_pointer} have to be of type ' \ + f'torch.nn.Module ' + if hasattr(decoder_pointer, 'weight') and skip_key not in module_name: + assert hasattr(encoder_pointer, 'weight') + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, 'bias'): + assert hasattr(encoder_pointer, 'bias') + encoder_pointer.bias = decoder_pointer.bias + # print(module_name+' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f'Encoder module {encoder_pointer} does not match decoder ' \ + f'module {decoder_pointer} ' + + all_encoder_weights = set([ + module_name + '/' + sub_name + for sub_name in encoder_modules.keys() + ]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance( + decoder_modules[decoder_name], + type(encoder_modules[encoder_name])) and \ + len(encoder_modules) != len(decoder_modules): + # this can happen if the name corresponds to the + # position in a list module list of layers in this + # case the decoder has added a cross-attention that + # the encoder does not have thus skip this step and + # subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + 'Max depth of recursive function ' + '`tie_encoder_to_decoder` reached. It seems that ' + 'there is a circular dependency between two or more ' + '`nn.Modules` of your model. ') + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + '/' + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + '/' + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, + uninitialized_encoder_weights, skip_key) + + +class GroupWiseLinear(nn.Module): + # could be changed to: + # output = torch.einsum('ijk,zjk->ij', x, self.W) + # or output = torch.einsum('ijk,jk->ij', x, self.W[0]) + def __init__(self, num_class, hidden_dim, bias=True): + super().__init__() + self.num_class = num_class + self.hidden_dim = hidden_dim + self.bias = bias + + self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) + if bias: + self.b = nn.Parameter(torch.Tensor(1, num_class)) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.W.size(2)) + for i in range(self.num_class): + self.W[0][i].data.uniform_(-stdv, stdv) + if self.bias: + for i in range(self.num_class): + self.b[0][i].data.uniform_(-stdv, stdv) + + def forward(self, x): + # x: B,K,d + x = (self.W * x).sum(-1) + if self.bias: + x = x + self.b + return x + + +def init_tokenizer(): + tokenizer = BertTokenizer.from_pretrained( + '/mnt/data.coronaryct.1/ZhuYichen/Ask-Anything/model/bert-base' + '-uncased/', + local_files_only=True) + tokenizer.add_special_tokens({'bos_token': '[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def create_vit(vit, + image_size, + use_grad_checkpointing=False, + ckpt_layer=0, + drop_path_rate=0): + assert vit in ['base', 'large'], 'vit parameter must be base or large' + if vit == 'base': + vision_width = 768 + visual_encoder = VisionTransformer( + img_size=image_size, + patch_size=16, + embed_dim=vision_width, + depth=12, + num_heads=12, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate) + elif vit == 'large': + vision_width = 1024 + visual_encoder = VisionTransformer( + img_size=image_size, + patch_size=16, + embed_dim=vision_width, + depth=24, + num_heads=16, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate) + return visual_encoder, vision_width + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ('http', 'https') + + +def load_checkpoint(model, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed( + state_dict['visual_encoder.pos_embed'], model.visual_encoder) + if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): + state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed( + state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m) + for key in model.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape != model.state_dict()[key].shape: + del state_dict[key] + + msg = model.load_state_dict(state_dict, strict=False) + print('load checkpoint from %s' % url_or_filename) + return model, msg + + +def load_checkpoint_swinbase(model, url_or_filename, kwargs): + if kwargs['image_size'] == 224: + vision_config_path = 'configs/swin/config_swinB_224.json' + elif kwargs['image_size'] == 384: + vision_config_path = 'configs/swin/config_swinB_384.json' + elif kwargs['image_size'] == 480: + vision_config_path = 'configs/swin/config_swinB_480.json' + elif kwargs['image_size'] == 576: + vision_config_path = 'configs/swin/config_swinB_576.json' + elif kwargs['image_size'] == 608: + vision_config_path = 'configs/swin/config_swinB_608.json' + window_size = read_json(vision_config_path)['window_size'] + # print('--------------') + # print(url_or_filename) + # print('--------------') + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + for k in list(state_dict.keys()): + if 'relative_position_bias_table' in k: + dst_num_pos = (2 * window_size - 1)**2 + state_dict[k] = interpolate_relative_pos_embed( + state_dict[k], dst_num_pos, param_name=k) + elif ('relative_position_index' in k) or ('attn_mask' in k): + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + print('load checkpoint from %s' % url_or_filename) + return model, msg diff --git a/projects/videochat/models/transnetv2.py b/projects/videochat/models/transnetv2.py new file mode 100644 index 0000000000..99d436c440 --- /dev/null +++ b/projects/videochat/models/transnetv2.py @@ -0,0 +1,477 @@ +import random +from fractions import Fraction + +import ffmpeg +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as functional + + +class ShotProcessor: + + def shot(self, video_path, shot): + time_intervals = [0] + + frame_rate = self.get_frame_rate(video_path) + + for line in shot: + start_frame, end_frame = line + # start_time = self.frames_to_seconds(start_frame, frame_rate) + end_time = self.frames_to_seconds(end_frame, frame_rate) + time_intervals.append(int(end_time)) + + return time_intervals + + def frames_to_seconds(self, frame_number, frame_rate): + return frame_number / frame_rate + + def get_frame_rate(self, video_path): + info = ffmpeg.probe(video_path) + vs = next(c for c in info['streams'] if c['codec_type'] == 'video') + fps = float(Fraction(vs['r_frame_rate'])) + return fps + + +class Shot: + + def __init__(self, pretrained): + self.model = TransNetV2() + state_dict = torch.load(pretrained) + self.model.load_state_dict(state_dict) + self.model.eval() + + def predictions_to_scenes(self, + predictions: np.ndarray, + threshold: float = 0.5): + predictions = (predictions > threshold).astype(np.uint8) + scenes = [] + t, t_prev, start = -1, 0, 0 + for i, t in enumerate(predictions): + if t_prev == 1 and t == 0: + start = i + if t_prev == 0 and t == 1 and i != 0: + scenes.append([start, i]) + t_prev = t + if t == 0: + scenes.append([start, i]) + + # just fix if all predictions are 1 + if len(scenes) == 0: + return np.array([[0, len(predictions) - 1]], dtype=np.int32) + + return np.array(scenes, dtype=np.int32) + + def inference(self, video_path): + print('[TransNetV2] Extracting frames from {}'.format(video_path)) + video_stream, err = ffmpeg.input(video_path).output( + 'pipe:', format='rawvideo', pix_fmt='rgb24', s='48x27').run( + capture_stdout=True, capture_stderr=True) + + video = np.frombuffer(video_stream, + np.uint8).reshape([1, -1, 27, 48, 3]) + video_tensor = torch.from_numpy(video) + single_frame_pred, all_frame_pred = self.model(video_tensor) + + single_frame_pred = torch.sigmoid( + single_frame_pred).cpu().detach().numpy() + scenes = self.predictions_to_scenes(single_frame_pred[0]) + return scenes + + +class TransNetV2(nn.Module): + + def __init__( + self, + F=16, + L=3, + S=2, + D=1024, + use_many_hot_targets=True, + use_frame_similarity=True, + use_color_histograms=True, + use_mean_pooling=False, + dropout_rate=0.5, + use_convex_comb_reg=False, # not supported + use_resnet_features=False, # not supported + use_resnet_like_top=False, # not supported + frame_similarity_on_last_layer=False): # not supported + super(TransNetV2, self).__init__() + + if use_resnet_features or use_resnet_like_top or use_convex_comb_reg \ + or frame_similarity_on_last_layer: + raise NotImplementedError( + 'Some options not implemented in Pytorch version of Transnet!') + + self.SDDCNN = nn.ModuleList([ + StackedDDCNNV2( + in_filters=3, + n_blocks=S, + filters=F, + stochastic_depth_drop_prob=0.) + ] + [ + StackedDDCNNV2( + in_filters=(F * 2**(i - 1)) * 4, n_blocks=S, filters=F * 2**i) + for i in range(1, L) + ]) + + self.frame_sim_layer = FrameSimilarity( + sum([(F * 2**i) * 4 for i in range(L)]), + lookup_window=101, + output_dim=128, + similarity_dim=128, + use_bias=True) if use_frame_similarity else None + self.color_hist_layer = ColorHistograms( + lookup_window=101, + output_dim=128) if use_color_histograms else None + + self.dropout = nn.Dropout( + dropout_rate) if dropout_rate is not None else None + + output_dim = ((F * 2**(L - 1)) * 4) * 3 * 6 + if use_frame_similarity: + output_dim += 128 + if use_color_histograms: + output_dim += 128 + + self.fc1 = nn.Linear(output_dim, D) + self.cls_layer1 = nn.Linear(D, 1) + self.cls_layer2 = nn.Linear(D, 1) if use_many_hot_targets else None + + self.use_mean_pooling = use_mean_pooling + self.eval() + + def forward(self, inputs): + assert isinstance(inputs, torch.Tensor) \ + and list(inputs.shape[2:]) == [27, 48, 3] \ + and inputs.dtype == torch.uint8, \ + 'incorrect input type and/or shape' + # uint8 of shape [B, T, H, W, 3] to float of shape [B, 3, T, H, W] + x = inputs.permute([0, 4, 1, 2, 3]).float() + x = x.div_(255.) + + block_features = [] + for block in self.SDDCNN: + x = block(x) + block_features.append(x) + + if self.use_mean_pooling: + x = torch.mean(x, dim=[3, 4]) + x = x.permute(0, 2, 1) + else: + x = x.permute(0, 2, 3, 4, 1) + x = x.reshape(x.shape[0], x.shape[1], -1) + + if self.frame_sim_layer is not None: + x = torch.cat([self.frame_sim_layer(block_features), x], 2) + + if self.color_hist_layer is not None: + x = torch.cat([self.color_hist_layer(inputs), x], 2) + + x = self.fc1(x) + x = functional.relu(x) + + if self.dropout is not None: + x = self.dropout(x) + + one_hot = self.cls_layer1(x) + + if self.cls_layer2 is not None: + return one_hot, {'many_hot': self.cls_layer2(x)} + + return one_hot + + +class StackedDDCNNV2(nn.Module): + + def __init__( + self, + in_filters, + n_blocks, + filters, + shortcut=True, + use_octave_conv=False, # not supported + pool_type='avg', + stochastic_depth_drop_prob=0.0): + super(StackedDDCNNV2, self).__init__() + + if use_octave_conv: + raise NotImplementedError( + 'Octave convolution not implemented in Pytorch version of ' + 'Transnet! ') + + assert pool_type == 'max' or pool_type == 'avg' + if use_octave_conv and pool_type == 'max': + print( + 'WARN: Octave convolution was designed with average pooling, ' + 'not max pooling. ') + + self.shortcut = shortcut + self.DDCNN = nn.ModuleList([ + DilatedDCNNV2( + in_filters if i == 1 else filters * 4, + filters, + octave_conv=use_octave_conv, + activation=functional.relu if i != n_blocks else None) + for i in range(1, n_blocks + 1) + ]) + self.pool = nn.MaxPool3d( + kernel_size=(1, 2, 2)) if pool_type == 'max' else nn.AvgPool3d( + kernel_size=(1, 2, 2)) + self.stochastic_depth_drop_prob = stochastic_depth_drop_prob + + def forward(self, inputs): + x = inputs + shortcut = None + + for block in self.DDCNN: + x = block(x) + if shortcut is None: + shortcut = x + + x = functional.relu(x) + + if self.shortcut is not None: + if self.stochastic_depth_drop_prob != 0.: + if self.training: + if random.random() < self.stochastic_depth_drop_prob: + x = shortcut + else: + x = x + shortcut + else: + x = (1 - self.stochastic_depth_drop_prob) * x + shortcut + else: + x += shortcut + + x = self.pool(x) + return x + + +class DilatedDCNNV2(nn.Module): + + def __init__(self, + in_filters, + filters, + batch_norm=True, + activation=None, + octave_conv=False): # not supported + super(DilatedDCNNV2, self).__init__() + + if octave_conv: + raise NotImplementedError( + 'Octave convolution not implemented in Pytorch version of ' + 'Transnet! ') + + assert not (octave_conv and batch_norm) + + self.Conv3D_1 = Conv3DConfigurable( + in_filters, filters, 1, use_bias=not batch_norm) + self.Conv3D_2 = Conv3DConfigurable( + in_filters, filters, 2, use_bias=not batch_norm) + self.Conv3D_4 = Conv3DConfigurable( + in_filters, filters, 4, use_bias=not batch_norm) + self.Conv3D_8 = Conv3DConfigurable( + in_filters, filters, 8, use_bias=not batch_norm) + + self.bn = nn.BatchNorm3d(filters * 4, eps=1e-3) if batch_norm else None + self.activation = activation + + def forward(self, inputs): + conv1 = self.Conv3D_1(inputs) + conv2 = self.Conv3D_2(inputs) + conv3 = self.Conv3D_4(inputs) + conv4 = self.Conv3D_8(inputs) + + x = torch.cat([conv1, conv2, conv3, conv4], dim=1) + + if self.bn is not None: + x = self.bn(x) + + if self.activation is not None: + x = self.activation(x) + + return x + + +class Conv3DConfigurable(nn.Module): + + def __init__( + self, + in_filters, + filters, + dilation_rate, + separable=True, + octave=False, # not supported + use_bias=True, + kernel_initializer=None): # not supported + super(Conv3DConfigurable, self).__init__() + + if octave: + raise NotImplementedError( + 'Octave convolution not implemented in Pytorch version of ' + 'Transnet! ') + if kernel_initializer is not None: + raise NotImplementedError( + 'Kernel initializers are not implemented in Pytorch version ' + 'of Transnet! ') + + assert not (separable and octave) + + if separable: + # (2+1)D convolution https://arxiv.org/pdf/1711.11248.pdf + conv1 = nn.Conv3d( + in_filters, + 2 * filters, + kernel_size=(1, 3, 3), + dilation=(1, 1, 1), + padding=(0, 1, 1), + bias=False) + conv2 = nn.Conv3d( + 2 * filters, + filters, + kernel_size=(3, 1, 1), + dilation=(dilation_rate, 1, 1), + padding=(dilation_rate, 0, 0), + bias=use_bias) + self.layers = nn.ModuleList([conv1, conv2]) + else: + conv = nn.Conv3d( + in_filters, + filters, + kernel_size=3, + dilation=(dilation_rate, 1, 1), + padding=(dilation_rate, 1, 1), + bias=use_bias) + self.layers = nn.ModuleList([conv]) + + def forward(self, inputs): + x = inputs + for layer in self.layers: + x = layer(x) + return x + + +class FrameSimilarity(nn.Module): + + def __init__( + self, + in_filters, + similarity_dim=128, + lookup_window=101, + output_dim=128, + stop_gradient=False, # not supported + use_bias=False): + super(FrameSimilarity, self).__init__() + + if stop_gradient: + raise NotImplementedError( + 'Stop gradient not implemented in Pytorch version of Transnet!' + ) + + self.projection = nn.Linear(in_filters, similarity_dim, bias=use_bias) + self.fc = nn.Linear(lookup_window, output_dim) + + self.lookup_window = lookup_window + assert lookup_window % 2 == 1, '`lookup_window` must be odd integer' + + def forward(self, inputs): + x = torch.cat([torch.mean(x, dim=[3, 4]) for x in inputs], dim=1) + x = torch.transpose(x, 1, 2) + + x = self.projection(x) + x = functional.normalize(x, p=2, dim=2) + + batch_size, time_window = x.shape[0], x.shape[1] + similarities = torch.bmm(x, x.transpose( + 1, 2)) # [batch_size, time_window, time_window] + similarities_padded = functional.pad(similarities, + [(self.lookup_window - 1) // 2, + (self.lookup_window - 1) // 2]) + + batch_indices = torch.arange( + 0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat( + [1, time_window, self.lookup_window]) + time_indices = torch.arange( + 0, time_window, device=x.device).view([1, time_window, 1]).repeat( + [batch_size, 1, self.lookup_window]) + lookup_indices = torch.arange( + 0, self.lookup_window, device=x.device).view([ + 1, 1, self.lookup_window + ]).repeat([batch_size, time_window, 1]) + time_indices + + similarities = similarities_padded[batch_indices, time_indices, + lookup_indices] + return functional.relu(self.fc(similarities)) + + +class ColorHistograms(nn.Module): + + def __init__(self, lookup_window=101, output_dim=None): + super(ColorHistograms, self).__init__() + + self.fc = nn.Linear(lookup_window, + output_dim) if output_dim is not None else None + self.lookup_window = lookup_window + assert lookup_window % 2 == 1, '`lookup_window` must be odd integer' + + @staticmethod + def compute_color_histograms(frames): + frames = frames.int() + + def get_bin(frames): + # returns 0 .. 511 + R, G, B = frames[:, :, 0], frames[:, :, 1], frames[:, :, 2] + R, G, B = R >> 5, G >> 5, B >> 5 + return (R << 6) + (G << 3) + B + + batch_size, time_window, height, width, no_channels = frames.shape + assert no_channels == 3 + frames_flatten = frames.view(batch_size * time_window, height * width, + 3) + + binned_values = get_bin(frames_flatten) + frame_bin_prefix = (torch.arange( + 0, batch_size * time_window, device=frames.device) << 9).view( + -1, 1) + binned_values = (binned_values + frame_bin_prefix).view(-1) + + histograms = torch.zeros( + batch_size * time_window * 512, + dtype=torch.int32, + device=frames.device) + histograms.scatter_add_( + 0, binned_values, + torch.ones( + len(binned_values), dtype=torch.int32, device=frames.device)) + + histograms = histograms.view(batch_size, time_window, 512).float() + histograms_normalized = functional.normalize(histograms, p=2, dim=2) + return histograms_normalized + + def forward(self, inputs): + x = self.compute_color_histograms(inputs) + + batch_size, time_window = x.shape[0], x.shape[1] + similarities = torch.bmm(x, x.transpose( + 1, 2)) # [batch_size, time_window, time_window] + similarities_padded = functional.pad(similarities, + [(self.lookup_window - 1) // 2, + (self.lookup_window - 1) // 2]) + + batch_indices = torch.arange( + 0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat( + [1, time_window, self.lookup_window]) + time_indices = torch.arange( + 0, time_window, device=x.device).view([1, time_window, 1]).repeat( + [batch_size, 1, self.lookup_window]) + lookup_indices = torch.arange( + 0, self.lookup_window, device=x.device).view([ + 1, 1, self.lookup_window + ]).repeat([batch_size, time_window, 1]) + time_indices + + similarities = similarities_padded[batch_indices, time_indices, + lookup_indices] + + if self.fc is not None: + return functional.relu(self.fc(similarities)) + return similarities diff --git a/projects/videochat/models/vit.py b/projects/videochat/models/vit.py new file mode 100644 index 0000000000..80bd8d5c23 --- /dev/null +++ b/projects/videochat/models/vit.py @@ -0,0 +1,379 @@ +"""* Copyright (c) 2022, salesforce.com, inc. * All rights reserved. * +SPDX-License-Identifier: BSD-3-Clause * For full license text, +see LICENSE.txt file in the repo root or +https://opensource.org/licenses/BSD-3-Clause * By Junnan Li * Based on timm +code base * https://github.com/rwightman/pytorch-image-models/tree/master +/timm """ + +from functools import partial + +import torch +import torch.nn as nn +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper +from timm.models.helpers import adapt_input_conv +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, resize_pos_embed + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_grad_checkpointing=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path( + self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 + Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + representation_size=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=None, + use_grad_checkpointing=False, + ckpt_layer=0): + """ + Args: img_size (int, tuple): input image size patch_size (int, + tuple): patch size in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension depth (int): depth of + transformer num_heads (int): number of attention heads mlp_ratio ( + int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): + enable bias for qkv if True qk_scale (float): override default qk + scale of head_dim ** -0.5 if set representation_size (Optional[ + int]): enable and set representation layer (pre-logits) to this + value if set drop_rate (float): dropout rate attn_drop_rate (float): + attention dropout rate drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + use_grad_checkpointing=(use_grad_checkpointing + and i >= depth - ckpt_layer)) + for i in range(depth) + ]) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:, :x.size(1), :] + x = self.pos_drop(x) + + for i, blk in enumerate(self.blocks): + x = blk(x, register_blk == i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, + checkpoint_path: str, + prefix: str = ''): + """Load weights from .npz checkpoints for official Google Brain Flax + implementation.""" + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_( + adapt_input_conv(stem.conv.weight.shape[1], + _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_( + _n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_( + _n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_( + _n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_( + _n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_( + _n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_( + _n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], + _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p( + w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed(pos_embed_w, model.pos_embed, + getattr(model, 'num_tokens', 1), + model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_( + torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T + for n in ('query', 'key', 'value') + ])) + block.attn.qkv.bias.copy_( + torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) + for n in ('query', 'key', 'value') + ])) + block.attn.proj.weight.copy_( + _n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_( + _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_( + _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + + if orig_size != new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode='bicubic', + align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d' % + (orig_size**2, new_size**2)) + + return new_pos_embed + else: + return pos_embed_checkpoint diff --git a/projects/videochat/requirements.txt b/projects/videochat/requirements.txt new file mode 100644 index 0000000000..928e6e1932 --- /dev/null +++ b/projects/videochat/requirements.txt @@ -0,0 +1,52 @@ +accelerate +bitsandbytes +boto3~=1.28.9 +botocore~=1.31.9 +decord~=0.6.0 +decord +einops +entrypoints +fairscale==0.4.4 +ffmpeg-python +git+https://github.com/openai/whisper.git +gradio~=3.38.0 + +https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz +imageio +imageio-ffmpeg +jsonschema +langchain==0.0.292 +lvis~=0.5.3 +matplotlib~=3.7.2 +mmcv +nltk +numpy~=1.24.4 +omegaconf~=2.3.0 +openai~=0.27.8 + +opencv-python~=4.6.0.66 +openmim +paddleocr~=2.6.1.3 +paddlepaddle~=2.5.0 +Pillow~=10.0.0 +psutil~=5.9.5 +pycocoevalcap +pycocotools~=2.0.6 +pydantic==1.10.4 +PyYAML~=6.0.1 +requests~=2.31.0 +scipy~=1.10.1 +scipy +setuptools==59.5.0 +simplet5~=0.1.4 +simplet5 +spacy +tabulate~=0.9.0 +termcolor~=2.3.0 +timm==0.4.12 +torch==1.13.1 +torchvision==0.14.1 +tqdm~=4.65.0 +transformers~=4.16.2 +webdataset +wget==3.2 diff --git a/projects/videochat/tag_class/tag_class.py b/projects/videochat/tag_class/tag_class.py new file mode 100644 index 0000000000..eeaafcba45 --- /dev/null +++ b/projects/videochat/tag_class/tag_class.py @@ -0,0 +1,563 @@ +import numpy as np + +tra_array = [ + 'tennis', 'bear cub', 'observatory', 'bicycle', 'hillside', 'judge', + 'watercolor illustration', 'granite', 'lobster', 'livery', 'stone', + 'ceramic', 'ranch', 'cloth', 'smile', 'building', 'tattoo', 'cricketer', + 'cheek', 'pear', 'source', 'winter', 'surface', 'spray', 'ceremony', + 'magic', 'curve', 'container', 'fair', 'medicine', 'baby', + 'tennis racquet', 'ornament', 'bamboo', 'duckling', 'song', 'safari', + 'team presentation', 'daffodil', 'cross', 'toothpaste', 'shield', + 'fashion model', 'capsule', 'map', 'creek', 'glass house', 'glass plate', + 'siding', 'corner', 'water buffalo', 'bison', 'figure skater', 'diploma', + 'tire', 'race', 'cable car', 'brain', 'gas stove', 'soap bubble', + 'palette', 'snowboard', 'school child', 'trench coat', 'monk', 'fiber', + 'kitchen window', 'sunglass', 'coffee', 'security', 'strawberry', + 'penguin', 'tree root', 'loaf', 'engagement ring', 'lamb', + 'vector cartoon illustration', 'sandwich', 'mountain village', 'shape', + 'charm', 'fiction', 'knot', 'greenhouse', 'sushi', 'text', 'disaster', + 'trophy', 'gang', 'strap', 'soccer game', 'cardinal', 'tee', 'turtle', + 'water surface', 'grassland', 'dolphin', 'store', 'dirt', 'iceberg', + 'pergola', 'farmer market', 'publicity portrait', 'tote bag', + 'teenage girl', 'view mirror', 'session', 'commuter', 'dressing room', + 'tricycle', 'christmas ball', 'headlight', 'police', 'armchair', 'chart', + 'yacht', 'saw', 'printer', 'rock band', 'gingerbread house', 'tag', + 'table lamp', 'hockey game', 'slope', 'font', 'wicker basket', 'jewelry', + 'quarter', 'software', 'weapon', 'pin', 'worship', 'painter', 'goal', + 'morning light', 'bike', 'baseball bat', 'elevator', 'cuisine', 'sausage', + 'stunt', 'wrestler', 'statue', 'landing', 'pillar', 'willow tree', + 'sea wave', 'chicken', 'peanut', 'muscle', 'bob', 'tv genre', + 'bathroom window', 'radish', 'textile', 'pelican', 'marketplace', 'crest', + 'elevation map', 'gift', 'parish', 'traffic light', 'campfire', 'fog', + 'award winner', 'beach ball', 'mat', 'white house', 'plaster', 'moped', + 'football team', 'solution', 'bicyclist', 'bit', 'playground', 'darkness', + 'cake', 'maple leave', 'mold', 'cracker', 'blueberry', 'rubble', + 'container ship', 'pedestrian bridge', 'snail', 'parrot', 'form', + 'circuit', 'highlight', 'pickup truck', 'koala', 'rain', 'system', + 'weather', 'raincoat', 'soccer team', 'windshield', 'thunderstorm', 'mike', + 'bird house', 'bridge', 'grandfather', 'restroom', 'animation', + 'wilderness', 'clown', 'banana', 'brown', 'braid', 'dining room', + 'kindergarten', 'launch event', 'purple', 'school', 'stairwell', 'brooch', + 'movie poster image', 'mountain river', 'shelf', 'wicket', 'headboard', + 'buddha', 'flower field', 'dugout', 'cd', 'bald eagle', 'lagoon', + 'seaweed', 'agriculture', 'emergency service', 'maple tree', 'parachute', + 'continent', 'amusement park', 'remote', 'bun', 'tackle', 'hospital', + 'garage door', 'birthday party', 'friendship', 'go', 'mausoleum', 'jeep', + 'raccoon', 'step', 'ice hockey team', 'cigarette', 'lace dress', + 'forest floor', 'mall', 'captain', 'milk', 'golf course', 'meal', + 'picnic table', 'sail', 'volleyball', 'canal', 'terrace', 'computer desk', + 'caravan', 'hotel', 'cheerleader', 'nurse', 'museum', 'marsh', 'fox', + 'plateau', 'night', 'twin', 'letter logo', 'autumn tree', 'powder', + 'convention', 'creature', 'lighthouse', 'shop window', 'jacket', 'stork', + 'taxi', 'trade', 'blackboard', 'olive', 'road sign', 'resort', 'snowflake', + 'cemetery', 'travel', 'evening dress', 'picnic', 'drink', 'winter morning', + 'football player', 'snack', 'boxing glove', 'dinner party', 'airline', + 'swing', 'port', 'wheelbarrow', 'bathroom sink', 'sweater', 'ambulance', + 'gear', 'oil', 'wii controller', 'array', 'home office', 'car show', + 'mixture', 'profession', 'tree frog', 'square', 'facility', 'coral reef', + 'sea wall', 'pizza', 'exhibit', 'demolition', 'trout', 'ring', + 'coffee shop', 'bracelet', 'bean', 'lip', 'fencing', 'landscape', + 'sitting', 'package', 'metal', 'bust', 'king', 'hair', 'window seat', + 'wildlife', 'trunk', 'greenery', 'stencil', 'fire hydrant', 'bridesmaid', + 'plaza', 'alps', 'tower bridge', 'crop top', 'crossing', 'cinema', + 'pedestrian crossing', 'family', 'shopping cart', 'stomach', + 'church building', 'screen door', 'skater', 'soccer field', 'kettle', + 'mussel', 'raindrop', 'candy cane', 'water lily', 'flower girl', 'desert', + 'enclosure', 'christmas light', 'kitchen', 'caterpillar', 'plaid', 'bath', + 'bush', 'mud', 'ballet', 'knee', 'adult', 'raft', 'sea view', 'cactus', + 'office chair', 'overall', 'rim', 'scaffolding', 'pig', 'cover', + 'poster page', 'sprinkle', 'chandelier', 'algae', 'traffic', 'surfboard', + 'book', 'filming', 'flash', 'mansion', 'camouflage', 'trouser', 'ticket', + 'weed', 'cab', 'trench', 'elephant', 'huddle', 'sphere', + 'christmas decoration', 'city', 'launch', 'doll', 'christmas ornament', + 'fabric', 'bikini', 'biplane', 'breakfast', 'neighbourhood', 'race track', + 'foliage', 'avocado', 'school bus', 'footwear', 'highway', 'ocean view', + 'art vector illustration', 'wall clock', 'curtain', 'teenager', + 'kitchen area', 'robot', 'tusk', 'lounge chair', 'beam', 'paddle', 'camel', + 'lid', 'world map', 'city view', 'newlywed', 'cargo ship', 'yellow', + 'exhibition', 'bend', 'novel', 'wool', 'ontario', 'bread', 'campus', + 'coastline', 'cutting board', 'booth', 'table top', 'carpet', + 'beach chair', 'workout', 'street food', 'fun', 'costumer film designer', + 'gadget', 'artist', 'fishing village', 'builder', 'violinist', 'iphone', + 'spider web', 'traffic sign', 'ruin', 'rescue', 'clipboard', 'seal', + 'film director', 'paw', 'nursery', 'intersection', 'tomato sauce', 'taste', + 'paddy field', 'christmas tree', 'wave', 'stool', 'watering can', 'rug', + 'daytime', 'subway station', 'craft', 'pine forest', 'black', 'planet', + 'motif', 'christmas market', 'glass window', 'college', 'wheat', 'damage', + 'rectangle', 'picture frame', 'chess', 'guest room', 'street corner', + 'religion', 'seed', 'puzzle', 'freeway', 'beauty', 'ocean', 'watch', + 'mother', 'garage', 'quote', 'dj', 'supporter', 'hip hop artist', 'muffin', + 'eiffel tower', 'cash', 'firefighter', 'cauliflower', 'bunker', 'sled', + 'manicure', 'shark', 'stall', 'jungle', 'family home', 'tour bus', + 'chimney', 'touchdown', 'roundabout', 'coyote', 'street scene', 'tank', + 'wedding dress', 'mantle', 'bedroom window', 'coconut', 'chapel', 'goat', + 'living space', 'rock wall', 'polka dot', 'railway', 'mandala', 'mango', + 'lesson', 'mountain landscape', 'team photo', 'bookshelf', 'meter', + 'bulldog', 'evening sun', 'stick', 'card', 'pink', 'fish pond', 'paint', + 'pill', 'cart', 'pea', 'van', 'album', 'football college game', + 'mountain pass', 'doughnut', 'ski slope', 'match', 'official', 'shadow', + 'organ', 'celebration', 'coin', 'log cabin', 'firework display', 'present', + 'twig', 'chef', 'confetti', 'footpath', 'tour', 'ponytail', 'artwork', + 'race car', 'club', 'season', 'hose', 'pencil', 'aircraft', + 'rock formation', 'wardrobe', 'participant', 'politician', 'engineer', + 'peace', 'filter', 'sailing boat', 'water bottle', 'service dog', 'poodle', + 'loki', 'statesman', 'sleeping bag', 'outskirt', 'clock', 'factory', + 'oak tree', 'physician', 'color', 'room', 'stairway', 'company', 'lady', + 'graph', 'faucet', 'tablecloth', 'subway train', 'chocolate chip cookie', + 'headquarters', 'screw', 'goggle', 'halloween', 'city street', 'swirl', + 'cord', 'forward', 'bone', 'bedding', 'archway', 'wig', 'lobby', 'mask', + 'attic', 'kitchen table', 'skylight', 'fire', 'exit', 'oil painting', + 'passenger', 'meditation', 'salmon', 'fedora', 'rubber stamp', + 'orange juice', 'arch', 'scientist', 'stroll', 'manhattan', 'float', + 'baseball uniform', 'circle', 'church', 'decker bus', 'competitor', 'zoo', + 'basketball team', 'tourist', 'daughter', 'silverware', 'ceiling fan', + 'birth', 'vase', 'jack', 'mushroom', 'spiral', 'cage', 'limb', 'salad', + 'ad', 'control', 'earth', 'party', 'bolt', 'tractor', 'barley', + 'wedding photo', 'hawk', 'warehouse', 'vegetable garden', 'chocolate cake', + 'cabbage', 'floor window', 'baby shower', 'magnifying glass', 'table', + 'stethoscope', 'reading', 'mission', 'croissant', 'gift box', 'rocket', + 'forest road', 'cooking', 'suite', 'hill country', 'motorcycle', + 'baseball player', 'angle', 'drug', 'sport association', 'championship', + 'family portrait', 'florist', 'softball', 'egret', 'office', 'plywood', + 'jockey', 'mosque', 'brunch', 'beanie', 'office building', 'pattern', + 'calendar', 'indoor', 'pepper', 'ledge', 'trail', 'fuel', + 'laptop computer', 'tennis shoe', 'deck chair', 'guitarist', 'barn', + 'surgery', 'cartoon illustration', 'nebula', 'railroad', 'mountain goat', + 'goose', 'car door', 'cheer', 'liquid', 'hardwood floor', 'pathway', + 'acorn', 'gull', 'airliner', 'couch', 'lake house', 'spaghetti', + 'promenade', 'collection', 'garden', 'bank', 'robin', 'tennis ball', + 'peony', 'gymnast', 'lavender', 'deck', 'test', 'riverside', 'rapper', + 'domino', 'bride', 'mouse', 'basil', 'wedding couple', 'ocean wave', 'arm', + 'kitchen floor', 'grove', 'family member', 'backyard', 'raspberry', + 'forest fire', 'officer', 'hibiscus', 'canyon', 'composer', 'signature', + 'olive oil', 'hibiscus flower', 'rose', 'vector icon', 'sunrise', + 'horseback', 'motor scooter', 'office worker', 'tradition', 'ingredient', + 'washing machine', 'lighting', 'bagel', 'sailboat', 'policeman', 'mare', + 'graphic', 'halloween pumpkin', 'stock', 'pilot', 'education', 'team', + 'body', 'horse', 'kimono', 'bazaar', 'bag', 'recording studio', 'parsley', + 'entrance', 'denim', 'vet', 'horse farm', 'charcoal', 'architecture', + 'glass vase', 'puppy', 'estuary', 'television show host', 'city bus', + 'shoulder', 'beast', 'balance', 'golfer', 'roadside', 'denim jacket', + 'stone wall', 'counter top', 'app icon', 'toast', 'head coach', 'ham', + 'warrior', 'gem', 'refrigerator', 'snowman', 'construction worker', 'coal', + 'website', 'morning fog', 'mustard', 'human', 'owl', 'puppy dog', + 'piggy bank', 'vegetation', 'pirate', 'action film', 'marshmallow', + 'thanksgiving', 'business', 'disease', 'signage', 'greeting', 'skate park', + 'tile', 'mouth', 'spinach', 'vacation', 'leader', 'shrine', 'walker', + 'science fiction film', 'bill', 'rabbit', 'motor boat', 'bar', 'radio', + 'barge', 'tail', 'chainsaw', 'gallery', 'rainbow', 'pasta', 'padlock', + 'web', 'pastry', 'ink', 'reef', 'school uniform', 'shawl', 'treasure', + 'peach', 'dinner table', 'injury', 'harbor', 'witch', 'car dealership', + 'litter', 'gesture', 'documentary', 'marriage', 'sea shell', 'priest', + 'dome', 'kit', 'icon', 'seaside', 'bucket', 'entertainment', 'stable', + 'hat', 'puddle', 'sock', 'shopper', 'technology', 'harbour', 'orbit', + 'antler', 'tube', 'flag waving', 'cook', 'tight', 'commander', 'farmland', + 'switch', 'hiker', 'wedding ceremony', 'award ceremony', 'champion', + 'chopstick', 'farmhouse', 'performer', 'spike', 'accident', 'cruise ship', + 'passenger train', 'attraction', 'entertainer', 'rear view', 'sidewalk', + 'parade', 'racing', 'plane', 'ritual', 'peacock', 'pocket', 'plum', 'drop', + 'carrot', 'floor', 'sunset', 'troop', 'architect', 'coffee table', 'dust', + 'outline', 'leather', 'charity event', 'heat', 'whale', 'laundry', + 'coconut tree', 'crosswalk', 'pony', 'ant', 'pipe', 'string', 'coat', + 'angel', 'beef', 'church tower', 'dish', 'pitch', 'cupboard', + 'thermometer', 'dirt field', 'fireworks', 'minute', 'cane', 'pajama', + 'flower garden', 'autumn', 'trash can', 'dachshund', 'banana tree', 'tray', + 'moose', 'roadway', 'carnival', 'antenna', 'pole', 'castle wall', 'ram', + 'cattle', 'hay', 'cookie', 'swimmer', 'baseball team', 'strait', 'hedge', + 'jet', 'fire pit', 'octopus', 'calf', 'cube', 'opera', 'cardboard box', + 'tiara', 'kitchen sink', 'prairie', 'bowl', 'galaxy', 'straw hat', 'linen', + 'ski resort', 'stitch', 'street lamp', 'motorist', 'icicle', 'stain', + 'flora', 'drain', 'kitchen cabinet', 'decor', 'bouquet', 'pound', + 'interior design', 'nail polish', 'figurine', 'tomb', 'disc', 'twist', + 'blouse', 'ribbon', 'figure', 'burger', 'cork', 'soccer goalkeeper', + 'train bridge', 'drinking water', 'dew', 'baker', 'storm cloud', 'tarmac', + 'tv drama', 'sponge', 'magnet', 'sailor', 'entry', 'swan', 'exercise', + 'sloth', 'jewel', 'scuba diver', 'bite', 'cat tree', 'tent', 'can', + 'tennis match', 'ecosystem', 'picket fence', 'palm', 'train car', + 'frying pan', 'rally', 'tablet pc', 'reindeer', 'image', 'wolf', 'chin', + 'conservatory', 'flood water', 'cityscape', 'beach sand', 'car park', + 'pavement', 'farm field', 'swimming', 'winter storm', 'stem', 'pillow', + 'inning', 'gorilla', 'desk', 'avenue', 'fern', 'money', 'pearl', + 'train station', 'skillet', 'nap', 'barber', 'library', 'freezer', 'label', + 'rainforest', 'parking sign', 'mirror', 'wing', 'noodle', 'press room', + 'sculpture', 'tablet', 'viewer', 'prayer', 'mini', 'mechanic', 'laugh', + 'rice field', 'hand', 'mustache', 'mountain road', 'catwalk', 'conference', + 'cape', 'installation', 'musician', 'stream', 'machine', 'speech', + 'crocodile', 'soccer match', 'town square', 'passport', 'post box', + 'point', 'stone building', 'motorway', 'mix', 'dentist', 'businessperson', + 'happiness', 'boat', 'vineyard', 'treadmill', 'glass wall', + 'water droplet', 'coffee mug', 'graduate', 'sunflower', 'parliament', + 'shepherd', 'movie', 'wine', 'orchard', 'tulip', 'motherboard', 'cup', + 'broom', 'spot', 'drawing', 'polo shirt', 'graduation', 'film producer', + 'moonlight', 'glow', 'film format', 't shirt', 'rock face', 'sword', + 'clinic', 'festival day', 'meadow', 'staple', 'pupil', 'training ground', + 'rider', 'flower', 'foal', 'wharf', 'foot bridge', 'shooting', 'top', + 'mast', 'police car', 'robe', 'wedding bouquet', 'stop sign', + 'birthday cake', 'glitter', 'butter', 'scooter', 'tundra', 'superhero', + 'pocket watch', 'inscription', 'youngster', 'fruit tree', 'movie poster', + 'engine', 'foundation', 'motorcyclist', 'take', 'woman', 'antelope', + 'country artist', 'road trip', 'typewriter', 'tuxedo', 'brand', 'pine', + 'bathroom', 'paradise', 'texture', 'balloon', 'dining table', 'home', + 'computer screen', 'actor', 'clip', 'tv tower', 'panorama', 'summit', + 'cat', 'plot', 'eagle', 'dancer', 'pup', 'studio shot', 'tear', + 'bird bath', 'classroom', 'bookstore', 'city wall', 'tv programme', + 'blade', 'easel', 'buttercream', 'sweet', 'designer', 'diamond', + 'handshake', 'herb', 'corn field', 'seafront', 'concrete', 'street artist', + 'gas', 'stamp', 'window display', 'paper', 'note', 'pint', 'quarry', + 'research', 'fixture', 'manager', 'soil', 'leopard', 'board game', + 'ladder', 'stop light', 'island', 'ramp', 'football match', 'icing', + 'drill', 'currency', 'summer evening', 'topping', 'pyramid', 'pomegranate', + 'cell', 'ivy', 'squad', 'scenery', 'computer', 'locomotive', 'surf', + 'mascot', 'dune', 'path', 'duck', 'twilight', 'wire', 'bow tie', 'strike', + 'cormorant', 'car wash', 'crane', 'market', 'philosopher', 'alarm clock', + 'camera', 'birch', 'greeting card', 'plain', 'clay', 'donut', 'lock', + 'moth', 'laboratory', 'fan', 'violin', 'jazz fusion artist', + 'mountain biker', 'terrain', 'magazine', 'pickup', 'comedy film', + 'smartphone', 'film', 'bed', 'microwave oven', 'tournament', 'lawn', + 'car window', 'alligator', 'screen', 'jetty', 'shopping bag', + 'landscape view', 'cabinetry', 'friendly match', 'thing', 'petal', + 'shopping center', 'transport', 'ballet dancer', 'shoreline', 'princess', + 'car seat', 'parking meter', 'green', 'vodka', 'band', 'rock', 'costume', + 'warning sign', 'strip', 'plaque', 'wheelchair', 'headband', 'ginger', + 'dice', 'media', 'hairdresser', 'press', 'living room', 'stove', 'player', + 'cherry', 'workshop', 'carving', 'embroidery', 'doodle', 'adventure', + 'rugby player', 'monument', 'brush', 'marker', 'loft', 'postcard', + 'collage', 'ball', 'professor', 'dresser', 'gig', 'festival', 'blackbird', + 'makeup artist', 'video camera', 'sticker', 'peak', 'wildflower', + 'santa hat', 'rodeo', 'wedding photographer', 'guy', 'staff', 'waterfall', + 'operation', 'defender', 'falcon', 'haze', 'individual', 'gentleman', + 'greyhound', 'rocking chair', 'rice', 'garbage', 'platter', 'chocolate', + 'splash', 'business suit', 'cheetah', 'valley', 'maze', 'trampoline', + 'garland', 'slalom', 'unicorn', 'tree stump', 'painting', 'romance', + 'fight', 'alcohol', 'ghost', 'fondant', 'spa', 'shutter', 'death', + 'demonstration', 'cotton', 'pier', 'flea market', 'history', 'savannah', + 'fist', 'aisle', 'crew', 'jug', 'pose', 'anchor', 'teapot', 'boat house', + 'business team', 'tripod', 'bee', 'pebble', 'mattress', 'canvas', + 'hallway', 'campaign', 'pod', 'lake district', 'article', 'white', 'sofa', + 'honey', 'marathon', 'pancake', 'tourist attraction', 'wedding gown', + 'battle', 'shelving', 'sea', 'sheet music', 'pie', 'yarn', + 'construction site', 'flyer', 'tie', 'star', 'lettuce', 'martial artist', + 'dart', 'straw', 'reflection', 'conference room', 'temperature', 'rugby', + 'mosquito', 'physicist', 'rock climber', 'crash', 'backdrop', + 'toilet seat', 'sand castle', 'water park', 'toy car', 'waste', 'luxury', + 'hangar', 'rv', 'tree trunk', 'board', 'gold', 'project picture', 'cap', + 'cottage', 'relief', 'attire', 'microscope', 'battery', 'roll', 'line', + 'parking garage', 'crystal', 'broadcasting', 'brick wall', 'lab', + 'flooring', 'meeting', '3d cg rendering', 'desktop computer', 'cowboy', + 'sailing ship', 'junction', 'hairstyle', 'homework', 'profile', 'model', + 'flower pot', 'street light', 'salt lake', 'maple', 'space', 'blizzard', + 'throw', 'zebras', 'brochure', 'constellation', 'beak', 'kilt', 'pond', + 'blue sky', 'sneaker', 'sand dune', 'morning sun', 'almond', 'grill', + 'curl', 'basketball girl game', 'chameleon', 'toilet bowl', 'prince', + 'keyboard', 'queen', 'computer monitor', 'writing', 'crown', 'basilica', + 'kiss', 'house', 'parking', 'football competition', 'shell', + 'sport equipment', 'comedy', 'baboon', 'vendor', 'rise building', 'wrap', + 'food truck', 'cat bed', 'rickshaw', 'flare', 'teal', 'nectar', 'eclipse', + 'vehicle', 'steam locomotive', 'gorge', 'cow', 'christmas card', + 'demonstrator', 'memorial', 'towel', 'jewellery', 'train', 'frisbee', + 'baseball game', 'fur', 'afternoon sun', 'community', 'sparkler', + 'bandage', 'firework', 'dollar', 'pasture', 'video', 'bus', 'tree house', + 'seashore', 'field', 'hamburger', 'souvenir', 'hedgehog', 'worm', + 'pine cone', 'osprey', 'dinosaur', 'vegetable', 'junk', 'poster', 'army', + 'winger', 'bundle', 'stage', 'growth', 'wedding party', 'service', + 'blanket', 'ruler', 'eye', 'credit card', 'castle', 'diner', 'hut', 'elk', + 'hard rock artist', 'nun', 'dog breed', 'nest', 'drma film', 'number icon', + 'water tank', 'giraffe', 'altar', 'pavilion', 'tv personality', 'suv', + 'street vendor', 'street sign', 'ditch', 'debris', 'foam', 'takeoff', + 'spice', 'mountain lake', 'tea', 'orchestra', 'spacecraft', 'counter', + 'abbey', 'mountain', 'hydrangea', 'racer', 'orange tree', 'tide', + 'cowboy hat', 'rapid', 'town', 'wild', 'herd', 'vein', 'driveway', 'jar', + 'bark', 'illustration', 'horror film', 'corn', 'stroller', 'industry', + 'mountain stream', 'gym', 'neckline', 'pan', 'client', 'spectator', + 'eggplant', 'camper', 'fawn', 'hoodie', 'meat', 'lemonade', 'food market', + 'slum', 'comic book character', 'flower market', 'love', 'palace', 'gun', + 'heel', 'shopping street', 'shooting basketball guard', 'family photo', + 'rooftop', 'laundry basket', 'airport runway', 'horn', 'face mask', + 'flight', 'appetizer', 'violet', 'country lane', 'cement', 'instrument', + 'tv actor', 'spark', 'celebrity', 'award', 'country house', 'standing', + 'auction', 'date', 'engagement', 'puck', 'advertisement', 'chair', 'zebra', + 'driftwood', 'bumblebee', 'maple leaf', 'bonnet', 'orange', 'water tower', + 'door', 'singer', 'floor plan', 'discussion', 'theatre', 'pilgrim', 'mug', + 'branch', 'window sill', 'baseball pitcher', 'bakery', 'lollipop', + 'basketball player', 'toilet paper', 'chalkboard', 'cabin', 'sign', + 'night sky', 'cannon', 'fishing net', 'submarine', 'suit', 'fur coat', + 'wine bottle', 'folder', 'street art', 'suspension bridge', 'evening sky', + 'billboard', 'postage stamp', 'newspaper', 'transportation', 'surgeon', + 'light', 'park', 'horizon', 'road', 'sand bar', 'trumpet', 'lounge', + 'cloud forest', 'birthday celebration', 'balcony', 'anime', 'beehive', + 'umbrella', 'goldfish', 'baseball cap', 'waterhole', 'ceiling', 'carousel', + 'backpack', 'plant pot', 'atmosphere', 'sunflower field', 'spire', + 'vision', 'woodpecker', 'chip', 'pool table', 'lotus flower', 'cone', + 'humpback whale', 'reservoir', 'hunt', 'piano', 'plate', 'dining area', + 'luggage', 'skier', 'dance floor', 'crow', 'stair', 'overpass', + 'opera house', 'bear', 'jazz artist', 'water', 'vessel', 'cast', 'yard', + 'cathedral', 'basketball hoop', 'graveyard', 'sound', 'berry', 'onlooker', + 'fauna', 'birch tree', 'retail', 'hill', 'skeleton', 'journalist', 'frost', + 'basket', 'nail', 'dusk', 'trash', 'dawn', 'clover', 'hen', 'volcano', + 'basketball coach', 'home decor', 'charge', 'haircut', 'sense', + 'university', 'lizard', 'daisy', 'tablet computer', 'grass field', + 'prison', 'metal artist', 'bathroom mirror', 'window frame', 'chest', + 'flavor', 'pop country artist', 'market square', 'monkey', 'blog', 'deer', + 'speech bubble', 'dog', 'independence day', 'girl', 'boy', 'tartan', + 'furniture', 'appliance', 'office window', 'fish boat', 'sand box', + 'tv sitcom', 'drama', 'sleigh', 'depression', 'paper towel', 'baseball', + 'protestor', 'grape', 'wedding cake', 'invitation', 'accessory', 'pick', + 'grandparent', 'racket', 'tea plantation', 'outdoors', 'egg', 'glass bowl', + 'sun', 'organization', 'lion', 'panel', 'station', 'wallpaper', + 'helicopter', 'salt', 'vanity', 'patio', 'lunch', 'street performer', + 'mountain range', 'soup', 'bacon', 'power station', 'cantilever bridge', + 'hummingbird', 'shirt', 'rope', 'hip', 'chalk', 'pendant', 'choir', 'tv', + 'lichen', 'railway bridge', 'art gallery', 'bartender', 'wagon', + 'baby elephant', 'accordion', 'horseshoe', 'building site', 'clutch', + 'harvest', 'savanna', 'geranium', 'business woman', 'paddock', 'patch', + 'beech tree', 'war', 'suburbs', 'hospital bed', 'motorcycle racer', 'moss', + 'gravel', 'government agency', 'dollar bill', 'father', 'fjord', 'concert', + 'nut', 'wedding photography', 'finish line', 'home plate', 'food', 'nose', + 'thumb', 'village', 'dining room table', 'bumper', 'monster', 'blackberry', + 'lime', 'conflict', 'gala', 'wallet', 'wrist', 'hug', 'mermaid', 'lava', + 'lawyer', 'folk rock artist', 'arena', 'onion', 'toothbrush', 'fashion', + 'perfume', 'flip', 'triangle', 'woodland', 'mail', 'grasshopper', 'studio', + 'wood floor', 'den', 'racquet', 'cello', 'lemur', 'astronaut', + 'glass table', 'blood', 'dvd', 'planter', 'silver', 'leash', + 'master bedroom', 'forest', 'batter', 'shoe', 'engraving', 'opening', + 'product', 'toe', 'cocktail', 'mallard duck', 'bike ride', 'oasis', + 'wedding ring', 'cinematographer', 'holly', 'autograph', 'fence', + 'ice cube', 'cove', 'pineapple', 'aurora', 'glass bead', 'produce', + 'apartment building', 'cob', 'miniature', 'cockpit', 'flashlight', 'frog', + 'sheep', 'groom', 'steel', 'watermelon', 'clip art', 'paper plate', + 'ostrich', 'contour', 'mural', 'cub', 'paisley bandanna', 'winery', 'turn', + 'handle', 'satellite', 'post', 'pork', 'child', 'asphalt', 'grocery store', + 'vulture', 'trolley', 'nightclub', 'brick', 'trailer', 'compass', 'cereal', + 'cafe', 'cartoon character', 'sugar', 'fiction book', 'glass floor', + 'umpire', 'guitar', 'hamster', 'protester', 'airplane', 'garment', + 'blazer', 'railway line', 'wedding', 'shoe box', 'parking lot', + 'construction', 'graduation ceremony', 'tram', 'telescope', 'copper', + 'pain', 'autumn forest', 'guest house', 'partner', 'crayon', 'dip', 'boot', + 'corridor', 'computer keyboard', 'hockey player', 'chicken coop', + 'bus station', 'gathering', 'ankle', 'bunk bed', 'wood table', + 'football coach', 'monarch', 'pharmacy', 'legging', 'mannequin', 'female', + 'train track', 'stack', 'canopy', 'design element', 'grandmother', + 'symbol', 'beach hut', 'zucchini', 'bomb', 'businessman', 'skyscraper', + 'tongue', 'case', 'sparkle', 'highland', 'ballroom', 'prom', 'estate', + 'customer', 'archipelago', 'cheese', 'debate', 'carriage', 'bulldozer', + 'pumpkin', 'sitting room', 'gas station', 'wedding reception', 'camp', + 'dog bed', 'tower', 'property', 'river bed', 'pop latin artist', 'fridge', + 'wine glass', 'coast', 'beer', 'tow truck', 'fire truck', 'mountain bike', + 'thigh', 'heron', 'boat ride', 'gondola', 'turquoise', 'lake', 'llama', + 'kitty', 'tin', 'waiting room', 'coffee cup', 'socialite', 'guard', 'tap', + 'waterway', 'forehead', 'list', 'erosion', 'box', 'sea lion', 'pollen', + 'dam', 'wasp', 'salon', 'tennis tournament', 'flower box', 'aquarium', + 'rain cloud', 'clothing store', 'lead singer', 'cupcake', 'tortoise', + 'lettering', 'sport facility', 'dance', 'dog house', 'nature', 'football', + 'rooster', 'footballer', 'railway track', 'crowd', 'fishing rod', + 'silhouette', 'wind turbine', 'sari', 'bus window', 'cloud', 'charity', + 'medal', 'yoga', 'event', 'veil', 'fashion menswear milan week', 'news', + 'knife', 'print', 'screen tv', 'walnut', 'fungus', 'ice cream', + 'computer mouse', 'play', 'tribe', 'picture', 'video game', + 'business card', 'music festival', 'rack', 'envelope', 'shower', + 'dirt road', 'mine', 'oyster', 'monarch butterfly', 'dude', 'fruit salad', + 'podium', 'fork', 'lace', 'test match', 'boulder', 'cricket player', + 'staircase', 'peninsula', 'shopping', 'popcorn', 'oak', 'market stall', + 'pine tree', 'mountaineer', 'student', 'closet', 'hood', 'handstand', + 'centerpiece', 'insect', 'patient', 'makeover', 'tennis player', 'sheet', + 'park bench', 'apple', 'organism', 'hook', 'turkey', 'tangerine', + 'sibling', 'shopping mall', 'bird', 'scarf', 'smoothie', 'net', 'grass', + 'napkin', 'ray', 'eyebrow', 'laptop keyboard', 'motorbike', 'woman hand', + 'oven', 'book cover', 'easter egg', 'microwave', 'sand', 'snapshot', + 'soccer ball', 'makeup', 'knight', 'bowling ball', 'shower curtain', + 'flame', 'lightning', 'running', 'power plant', 'crib', 'cartoon', 'moat', + 'fashion girl', 'wedding invitation', 'bottle', 'cliff', 'monastery', + 'file photo', 'apartment', 'casino', 'cream', 'sweatshirt', 'storm', + 'cruise', 'teddy bear', 'shovel', 'wind farm', 'writer', 'dock', + 'professional', 'hotel room', 'job', 'monitor', 'donkey', 'pass', + 'interview', 'duchess', 'mark', 'plank', 'beard', 'zombie', 'trio', + 'channel', 'cricket team', 'windmill', 'vest', 'diagram', 'cable', + 'winter scene', 'golden gate bridge', 'buffalo', 'studio portrait', + 'pagoda', 'whiskey', 'freight train', 'kite', 'future', 'steam train', + 'phone box', 'headset', 'wood', 'snowboarder', 'paper bag', 'slide', + 'grapefruit', 'seating', 'morning', 'bronze sculpture', 'theatre actor', + 'stump', 'jean', 'landmark', 'jam', 'waist', 'watercolor', 'hammock', + 'light fixture', 'ice', 'basin', 'beverage', 'shelter', 'premiere', + 'mound', 'ear', 'bronze', 'sunlight', 'street', 'energy', 'barn door', + 'hike', 'fleet', 'claw', 'beach', 'pepperoni', 'bin', 'trainer', 'buffet', + 'archive', 'todler', 'referee', 'bay window', 'dove', 'production company', + 'evening light', 'gate', 'farm', 'reed', 'fruit stand', 'explorer', + 'snow storm', 'throw pillow', 'button', 'display case', 'bookcase', 'lead', + 'lipstick', 'basketball court', 'cargo', 'ensemble', 'pope', 'clock tower', + 'teen', 'speaker', 'rat', 'laptop', 'ski', 'mess', 'stadium', 'ferry boat', + 'bunny', 'waterfront', 'downtown', 'sink', 'press conference', 'dinner', + 'condiment', 'thread', 'audience', 'grid', 'car', 'plastic', 'people', + 'barbecue', 'pigeon', 'urinal', 'seagull', 'volunteer', 'hockey', + 'fir tree', 'pollution', 'trial', 'collar', 'area', 'meeting room', + 'circus', 'yogurt', 'orangutan', 'viaduct', 'comedian', 'drone', 'scissor', + 'pop rock artist', 'biscuit', 'panda', 'water feature', 'air balloon', + 'remote control', 'watercolor painting', 'show', 'walk', 'post office', + 'bike path', 'rap gangsta artist', 'microphone', 'crack', 'sunset sky', + 'glass', 'tv show', 'cartoon style', 'stripe', 'foyer', 'signal', + 'calligraphy', 'bulb', 'gardener', 'coffee bean', 'spider', 'tapestry', + 'city skyline', 'necklace', 'kitten', 'traveler', 'veteran', 'frosting', + 'fry', 'tennis court', 'tank top', 'butterfly house', 'mist', 'drummer', + 'water level', 'scale', 'baseball glove', 'music video performer', + 'champagne', 'camping', 'clothing', 'water drop', 'telephone box', 'pen', + 'morning mist', 'fire engine', 'porch', 'opening ceremony', 'style', + 'palm tree', 'fashion show', 'universe', 'scratch', 'axe', 'ottoman', + 'explosion', 'rib', 'boutique', 'game', 'cucumber', 'fruit', + 'stone bridge', 'nature reserve', 'track', 'train window', 'punch', + 'telephone pole', 'velvet', 'sauce', 'moon', 'contrast', 'flamingo', 'bat', + 'vending machine', 'ship', 'equestrian', 'shade', 'comforter', 'pallet', + 'sparrow', 'wii', 'glaze', 'grocery', 'steeple', 'soccer player', + 'contract', 'advertising', 'runner', 'chimpanzee', 'world', 'seat', + 'project', 'chihuahua', 'bubble', 'willow', 'pedestal', + 'soul hip hop artist', 'curb', 'drawer', 'leaf', 'banner', 'launch party', + 'coach', 'government', 'snowball', 'toy', 'portrait', 'doctor', + 'whiteboard', 'electronic', 'tiger', 'graffiti', 'column', 'nightstand', + 'whistle', 'maxi dress', 'bench', 'wetsuit', 'bird feeder', + 'football game', 'basketball', 'class', 'bathroom door', 'store window', + 'text message', 'wreath', 'street view', 'binocular', 'pet', 'facade', + 'drought', 'lemon', 'new year', 'night view', 'airplane window', 'specie', + 'rule', 'jaw', 'wheat field', 'diet', 'pop artist', 'habitat', + 'screenshot', 'scoreboard', 'shore', 'mane', 'quilt', 'ski lift', 'orchid', + 'turban', 'christmas', 'airport', 'marina', 'glass door', 'glass bottle', + 'restaurant', 'conductor', 'logo', 'sleep', 'tape', 'tomato', 'river bank', + 'lilac', 'tooth', 'training', 'pottery', 'shop', 'steam engine', + 'mason jar', 'base', 'procession', 'border', 'shoot', 'footprint', + 'hotdog', 'bull', 'stocking', 'recreation', 'automobile model', 'design', + 'country pop artist', 'river', 'retriever', 'department store', + 'auditorium', 'sport car', 'supermarket', 'belt', 'cricket', 'window box', + 'dress shirt', 'letter', 'residence', 'megaphone', 'pant', 'wildfire', + 'bird nest', 'crab', 'swimsuit', 'candle', 'funeral', 'mill', + 'national park', 'plant', 'cop', 'power line', 'perch', 'blue', 'finger', + 'ferris wheel', 'globe', 'skateboard', 'helmet', 'movie theater', + 'uniform', 'hammer', 'material', 'kid', 'well', 'butterfly', 'sideline', + 'fashion fall show', 'planet earth', 'lift', 'male', 'sauna', 'gray', + 'flour', 'sand sculpture', 'program', 'cabinet', 'infant', 'wheel', + 'aircraft model', 'dough', 'garlic', 'skate', 'arrow', 'wrapping paper', + 'ripple', 'lamp', 'iron', 'banknote', 'beaver', 'ferry', 'courtyard', + 'bassist', 'countryside', 'steak', 'comfort', 'boxer', 'laundry room', + 'campsite', 'brick building', 'golf', 'subway', 'headphone', 'fort', + 'handbag', 'drum', 'flood', 'saddle', 'bass', 'labyrinth', 'needle', + 'sun ray', 'app', 'menu', 'president', 'cardigan', 'dandelion', 'wetland', + 'ice hockey player', 'number', 'city hall', 'fishing', 'portrait session', + 'pug', 'key', 'art print', 'minister', 'hurdle', 'emergency', + 'painting artist', 'flag pole', 'evening', 'purse', 'recipe', 'golf ball', + 'coloring book', 'mountain peak', 'senior', 'holiday', 'bud', 'cousin', + 'pantry', 'lap', 'skin', 'flag', 'tissue paper', 'ridge', 'wire fence', + 'surfer', 'climber', 'photograph', 'sewing machine', 'cooler', 'actress', + 'apple tree', 'cancer', 'starfish', 'automobile make', 'dumbbell', 'brace', + 'tunnel', 'window', 'paint artist', 'composition', 'school student', + 'condo', 'convertible', 'cushion', 'selfie', 'territory', 'guide', 'tree', + 'court', 'shrimp', 'stone house', 'dress', 'eyelash', 'juice', 'broccoli', + 'chain', 'tourism', 'mountain top', 'concept car', 'film premiere', + 'light bulb', 'cafeteria', 'badge', 'flower bed', 'theater', 'root', + 'racecar driver', 'basketball boy game', 'glove', 'skyline', 'wall', + 'glacier', 'airport terminal', 'bug', 'trim', 'railway station', + 'briefcase', 'flat', 'fountain', 'person', 'lane', 'asparagus', 'art', + 'lantern', 'dishwasher', 'director', 'snake', 'lecture', 'game controller', + 'tree branch', 'pub', 'bathing suit', 'queue', 'belly', 'poppy', 'bow', + 'pitcher', 'ice cream cone', 'cave', 'candy', 'road bridge', 'host', + 'traffic jam', 'earring', 'file', 'foot', 'watermark overlay stamp', + 'mailbox', 'supercar', 'railing', 'bedroom', 'seafood', 'waffle', + 'bronze statue', 'plan', 'flow', 'marble', 'basketball game', 'automobile', + 'scene', 'cypress tree', 'soldier', 'skateboarder', 'glass building', + 'cherry tree', 'pump', 'grain', 'wildebeest', 'loop', 'frame', 'bathtub', + 'saxophone', 'diver', 'stalk', 'lily', 'bead', 'alley', 'flock', + 'family room', 'manufacturing', 'pointer', 'worker', 'navy', 'potato', + 'teacher', 'photography', 'dolly', 'boardwalk', 'water fountain', + 'athlete', 'side dish', 'bay', 'ice hockey', 'phone', 'hero', 'face', + 'gold medal', 'blind', 'swamp', 'researcher', 'swim', 'meatball', 'iguana', + 'leather jacket', 'jellyfish', 'site', 'smoke', 'traffic signal', 'melon', + 'beetle', 'calculator', 'skirt', 'plantation', 'sculptor', 'barrier', + 'catcher', 'security guard', 'sketch', 'awning', 'steering wheel', + 'mountain view', 'bus stop', 'pool', 'leg', 'spotlight', 'apron', + 'mineral', 'inlet', 'sleeve', 'torch', 'emotion', 'march', 'police oficer', + 'performance', 'lamp post', 'fishing boat', 'summer', 'presentation', + 'saucer', 'suitcase', 'supermodel', 'goalkeeper', 'shrub', 'rock artist', + 'document', 'beach house', 'man', 'blue artist', 'cigar', 'railroad track', + 'gown', 'mosaic', 'bungalow', 'alphabet', 'baseball field', 'shed', + 'pedestrian', 'rail', 'soap', 'kitchen counter', 'dessert', 'dunk', + 'blossom', 'conversation', 'fruit market', 'glass jar', 'military', + 'beer bottle', 'photographer', 'tennis racket', 'competition', 'escalator', + 'bell tower', 'stilt', 'ballerina', 'television', 'feather', 'fence post', + 'rear', 'dahlia', 'red carpet', 'tub', 'hole', 'fortress', 'pack', + 'telephone', 'cardboard', 'city park', 'platform', 'college student', + 'arch bridge', 'wind', 'blender', 'bloom', 'ice rink', 'birthday', 'raven', + 'fairy', 'embankment', 'hall', 'flower shop', 'suburb', 'barrel', 'biker', + 'steam', 'dragonfly', 'formation', 'electricity', 'business people', + 'symmetry', 'walkway', 'fisherman', 'gas mask', 'loch', 'youth', 'hanger', + 'dot', 'fish', 'street market', 'animation film', 'crime fiction film', + 'boar', 'emblem', 'halloween costume', 'kangaroo', 'couple', 'spoon', + 'squirrel', 'neon sign', 'sky', 'office desk', 'beauty salon', + 'breakwater', 'fashion look', 'toaster', 'author', 'news conference', + 'outdoor', 'canoe', 'dragon', 'tool', 'shopping centre', 'ladybug', + 'swimming pool', 'landscaping', 'ski pole', 'red', 'truck', 'fly', + 'temple', 'level', 'sunday', 'railroad bridge', 'car mirror', 'lawn mower', + 'flute', 'aircraft carrier', 'fashion menswear london week', 'sunshine', + 'tile floor', 'skull', 'fossil', 'flower arrangement', 'diaper', + 'sea turtle', 'cherry blossom', 'fireman', 'shack', 'lens', 'waiter', + 'animal', 'basement', 'snow', 'autumn park', 'glass box', 'kick', 'head', + 'anniversary', 'vine', 'back', 'paper lantern', 'fish tank', 'cellphone', + 'silk', 'coral', 'notebook', 'photo', 'gazebo', 'ketchup', 'driver', + 'farmer', 'bonfire', 'chestnut', 'photoshoot', 'football field', + 'olive tre', 'pheasant', 'sandal', 'toilet', 'fireplace', 'music', 'deity', + 'fish market', 'fig', 'bell', 'neck', 'grave', 'villa', 'cyclist', 'crate', + 'grey', 'asphalt road', 'soccer', 'hostel', 'municipality', 'courthouse', + 'roof', 'end table', 'pot', 'sedan', 'structure', 'folk artist', 'sport', + 'sport team', 'protest', 'syringe', 'fashion designer', 'jersey', + 'heart shape', 'kayak', 'stare', 'sit with', 'direct', 'read', + 'photograph', 'spin', 'teach', 'laugh', 'carve', 'grow on', 'warm', + 'watch', 'stretch', 'smell', 'decorate', 'shine', 'light', 'dance', 'send', + 'park', 'chase', 'collect', 'lead', 'kiss', 'lead to', 'lick', 'smile', + 'cheer', 'sit', 'point', 'block', 'rock', 'drop', 'cut', 'ski', 'wrap', + 'lose', 'serve', 'provide', 'sleep', 'dress', 'embrace', 'burn', 'pack', + 'stir', 'create', 'touch', 'wash', 'stick', 'reveal', 'shop', 'train', + 'paint', 'groom', 'hunt', 'bloom', 'play', 'pay', 'brush', 'shoot', 'hold', + 'picture', 'carry', 'sip', 'contain', 'turn', 'pour', 'pitch', 'give', + 'add', 'blow', 'look in', 'show', 'walk', 'illuminate', 'kneel', 'cover', + 'drag', 'post', 'present', 'fit', 'operate', 'fish', 'race', 'write', + 'deliver', 'peel', 'push', 'run', 'sit around', 'buy', 'jump', 'walk on', + 'attend', 'clean', 'sell', 'ride on', 'mount', 'host', 'dry', 'plant', + 'sing', 'row', 'shake', 'perch', 'ride', 'fight', 'skateboard', 'live', + 'call', 'surround', 'practice', 'play on', 'work on', 'step', 'relax', + 'hit', 'fall in', 'flow', 'greet', 'launch', 'wear', 'hang on', 'drive', + 'sit in', 'break', 'learn', 'fly', 'connect', 'display', 'locate', + 'compete', 'go for', 'sail', 'lift', 'toast', 'help', 'run on', 'reflect', + 'pose', 'scratch', 'frame', 'dribble', 'herd', 'enter', 'exit', 'place', + 'inspect', 'build', 'pick', 'fill', 'grind', 'skate', 'offer', 'float', + 'sit by', 'stand', 'release', 'rest', 'singe', 'climb', 'tie', 'mark', + 'lay', 'stand around', 'capture', 'set', 'land', 'swinge', 'run in', + 'kick', 'lean', 'head', 'sign', 'approach', 'swim', 'close', 'crash', + 'control', 'fall', 'remove', 'repair', 'open', 'appear', 'travel', 'load', + 'miss', 'check', 'surf', 'moor', 'smoke', 'drink', 'board', 'seat', 'feed', + 'rise', 'sit on', 'swing', 'grow', 'strike', 'date', 'slide', 'share', + 'graze', 'jump in', 'lie', 'extrude', 'roll', 'move', 'gather', 'eat', + 'pull', 'run through', 'squeeze', 'lay on', 'draw', 'play with', 'wave', + 'assemble', 'perform', 'march', 'score', 'attach', 'adjust', 'hang', 'hug', + 'sleep on', 'throw', 'live in', 'talk', 'pet', 'work', 'run with', 'see', + 'flip', 'catch', 'cook', 'receive', 'celebrate', 'look', 'classic', + 'bridal', 'indoor', 'industrial', 'teenage', 'mini', 'grassy', 'aged', + 'long', 'warm', 'light', 'handsome', 'happy', 'three', 'pregnant', + 'circular', 'urban', 'silver', 'ceramic', '3d', 'green', 'blonde', + 'golden', 'dark', 'tropical', 'ripe', 'deep', 'fat', 'musical', 'giant', + 'medical', 'medieval', 'bare', 'stunning', 'bold', 'geographical', 'huge', + 'plastic', 'foggy', 'stormy', 'gothic', 'biological', 'empty', 'clear', + 'antique', 'pink', 'steep', 'brown', 'striped', 'aerial', 'rainy', 'cool', + 'flying', 'commercial', 'purple', 'trendy', 'blank', 'haired', 'dead', + 'wooden', 'flat', 'high', 'beige', 'panoramic', 'angry', 'dozen', 'rural', + 'solar', 'big', 'small', 'stained', 'thick', 'many', 'fresh', 'clean', + 'strong', 'abstract', 'crowded', 'retro', 'dry', 'gorgeous', 'martial', + 'modern', 'blue', 'cloudy', 'low', 'four', 'outdoor', 'single', 'much', + 'beautiful', 'snowy', 'pretty', 'new', 'short', 'sunny', 'closed', 'rocky', + 'red', 'two', 'double', 'male', 'gray', 'five', 'colorful', 'automotive', + 'various', 'one', 'old', 'rusty', 'tall', 'wild', 'narrow', 'natural', + 'several', 'frozen', 'textured', 'lush', 'young', 'hot', 'mixed', 'white', + 'float', 'quiet', 'round', 'bright', 'religious', 'female', 'historical', + 'shiny', 'traditional', 'tourist', 'yellow', 'bald', 'coastal', 'lovely', + 'little', 'broken', 'romantic', 'wide', 'royal', 'rich', 'open', 'cute', + 'ancient', 'cold', 'political', 'elderly', 'gold', 'full', 'rustic', + 'metallic', 'floral', 'sad', 'wet', 'fancy', 'senior', 'tiny', 'stylish', + 'large', 'frosty', 'orange', 'transparent', 'electronic', 'shallow', + 'scared', 'armed', 'dirty', 'historic', 'black', 'few', 'windy', 'some', + 'square', 'ornamental', 'sandy', 'thin' +] + +tra_array = np.array(tra_array) diff --git a/projects/videochat/test.py b/projects/videochat/test.py new file mode 100644 index 0000000000..cc53cca1fc --- /dev/null +++ b/projects/videochat/test.py @@ -0,0 +1,593 @@ +import configparser +import json +import os +# from models.qwen import Qwen +import subprocess +# from models.mPlug_owl import Owl +from datetime import datetime + +import numpy as np +import torch +import torchvision.transforms as T +import whisper +from chatbot import ChainOfThought, ConversationBot +from models.grit_model import DenseCaptioning +from models.ocr import ProcessOCR +from models.tag2text import tag2text_caption +from models.transnetv2 import Shot, ShotProcessor +from simplet5 import SimpleT5 +from torchvision import transforms +from util import loadvideo_decord_origin + +from projects.videochat.models.load_internvideo import (load_intern_action, + transform_action) +from projects.videochat.models.subtitle import ProcessSubtitle, RequestApi + +config = configparser.ConfigParser() +config.read('configs.ini') +args = { + 'videos_path': config.get('Arguments', 'videos_path'), + 'openai_api_key': os.environ['OPENAI_API_KEY'], + 'output_path': config.get('Arguments', 'output_path'), + 'images_path': config.get('Arguments', 'images_path'), + 'evaluate_path': config.get('Arguments', 'evaluate_path'), + 'appid': config.get('Arguments', 'appid'), + 'secret_key': config.get('Arguments', 'secret_key'), + 'segment_length': int(config.get('Arguments', 'segment_length')), + 'remarks': config.get('Arguments', 'remarks'), + 'llm': config.get('Arguments', 'llm'), + 'predict': config.get('Arguments', 'predict') == 'True', + 'evaluate': config.get('Arguments', 'evaluate') == 'True', + 'mode': config.get('Arguments', 'mode'), + 'qa_mode': config.get('Arguments', 'qa_mode'), +} +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +torch.cuda.set_device(int(config.get('Arguments', 'device'))) + + +class VideoChat: + + def __init__(self): + self.ocr_model = None + self.whisper_model = None + self.dense_caption_model = None + self.topil = None + self.intern_action = None + self.model_T5 = None + self.model = None + self.trans_action = None + self.shot_model = None + self.transform = None + self.bot = ConversationBot() + + def load_model(self): + image_size = 384 + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), normalize + ]) + + # define model + self.shot_model = Shot('pretrained_models/' + 'transnetv2-pytorch-weights.pth') + + self.model = tag2text_caption( + pretrained='pretrained_models/tag2text_swin_14m.pth', + image_size=image_size, + vit='swin_b') + self.model.eval() + self.model = self.model.to(device) + print('[INFO] initialize caption model success!') + + self.model_T5 = SimpleT5() + if torch.cuda.is_available(): + self.model_T5.load_model( + 't5', 'pretrained_models/' + 'flan-t5-large-finetuned-openai-summarize_from_feedback', + use_gpu=True) + else: + self.model_T5.load_model( + 't5', 'pretrained_models/' + 'flan-t5-large-finetuned-openai-summarize_from_feedback', + use_gpu=False) + print('[INFO] initialize summarize model success!') + # action recognition + self.intern_action = load_intern_action( + device, + pretrained='pretrained_models/uniformerv2/' + 'k400+k710_uniformerv2_b16_8x224.pyth') + self.trans_action = transform_action() + self.topil = T.ToPILImage() + print('[INFO] initialize InternVideo model success!') + + self.dense_caption_model = DenseCaptioning(device) + self.dense_caption_model.initialize_model() + print('[INFO] initialize dense caption model success!') + + self.whisper_model = whisper.load_model('large') + print('[INFO] initialize whisper model success!') + + self.ocr_model = ProcessOCR() + print('[INFO] initialize ocr model success!') + + def inference_second(self, video_path, input_tag): + # shot + shot_result = self.shot_model.inference(video_path) + # Whisper + whisper = list() + try: + whisper_result = self.whisper_model.transcribe(video_path) + for segment in whisper_result['segments']: + whisper.append({ + 'begin': int(segment['start']), + 'end': int(segment['end']), + 'text': segment['text'], + }) + except Exception as e: + print(e) + + # 讯飞API + subtitle = list() + try: + api = RequestApi( + appid=args['appid'], + secret_key=args['secret_key'], + upload_file_path=video_path) + + api.get_result() + subtitle = api.result2text(return_list=True) + except Exception as e: + print(e) + try: + data = loadvideo_decord_origin(video_path) + except Exception as e: + print(e) + return None + tmp = [] + for i, img in enumerate(data): + tmp.append(self.transform(img).to(device).unsqueeze(0)) + # dense caption + dense_caption = list() + dense_caption_with_pos = list() + dense_foot = 1 + dense_index = np.arange(0, len(data), dense_foot) + original_images = data[dense_index, :, :, ::-1] + with torch.no_grad(): + for index, original_image in zip(dense_index, original_images): + new_caption_only_name, new_caption = \ + self.dense_caption_model.run_caption_tensor( + original_image, + video_path, + index, + args['images_path'] + ) + dense_caption.append({ + 'begin': index, + 'text': new_caption_only_name, + }) + + dense_caption_with_pos.append({ + 'begin': index, + 'text': new_caption, + }) + # ocr + ocr_result = self.ocr_model.inference(video_path, data) + del data, original_images + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + # Frame Caption + image = torch.cat(tmp).to(device) + + self.model.threshold = 0.68 + if input_tag == '' or input_tag == 'none' or input_tag == 'None': + input_tag_list = None + else: + input_tag_list = [input_tag.replace(',', ' | ')] + with torch.no_grad(): + caption, tag_predict = self.model.generate_sublists( + image, + tag_input=input_tag_list, + max_length=50, + return_tag_predict=True) + frame_caption = list() + for i, j in enumerate(caption): + frame_caption.append({ + 'begin': i, + 'text': j, + }) + tag_1 = set(tag_predict) + synth_caption = self.model_T5.predict('. '.join(caption)) + + del image, tmp + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + features = dict() + features['shot'] = shot_result.tolist() + features['subtitle'] = subtitle + features['whisper'] = whisper + features['dense'] = dense_caption + features['dense_with_pos'] = dense_caption_with_pos + features['frame'] = frame_caption + features['synth_caption'] = synth_caption + features['tag'] = list(tag_1) + + # ocr + features['ocr'] = ocr_result + dense_with_ocr, ocr_subtitle = self.ocr_model.merge(features) + features['dense_with_ocr'] = dense_with_ocr + features['ocr_subtitle'] = ocr_subtitle + + return features + + +class InputVideo: + + def __init__(self, videos_path): + self.output_path = None + self.cur_time = None + self.evaluate_path = '' + self.videos_path = videos_path + self.video_chat = VideoChat() + self.videos = list() + self.questions = list() + self.features = list() + self.exist_features = True + self.video_num = 0 + self.owl = None + self.qwen = None + for item in os.listdir(self.videos_path): + try: + num = int(item) + if num > 0: + continue + except ValueError: + continue + json_path = os.path.join(self.videos_path, item, 'data.json') + video_path = os.path.join(self.videos_path, item, 'video.mp4') + if not os.path.exists(video_path): + continue + if not os.path.exists(json_path): + continue + features_path = os.path.join(self.videos_path, item, + 'features.json') + if not os.path.exists(features_path): + self.exist_features = False + self.features.append(features_path) + self.questions.append(json_path) + self.videos.append(video_path) + self.video_num += 1 + print('Test video numbers:', self.video_num) + self.extract_features() + + def extract_features(self): + if not self.exist_features: + self.video_chat.load_model() + for index, video_path in enumerate(self.videos): + if not os.path.exists(self.features[index]): + features = self.video_chat.inference_second(video_path, '') + if features is None: + print(f'Can not extract features from {video_path}') + features = {'frame': ''} + else: + # Merge Whisper and Xunfei + process_subtitle = ProcessSubtitle(features) + features['merged_subtitle'] = \ + process_subtitle.merge_whisper_and_xunfei() + # Shot + shot_processor = ShotProcessor() + features['time_intervals'] = shot_processor.shot( + video_path, features['shot']) + with open(self.features[index], 'w') as file: + json.dump( + features, + file, + indent=4, + ensure_ascii=False, + cls=CustomEncoder) + + def start_test(self): + self.cur_time = datetime.now() + self.output_path = os.path.join(args['output_path'], + self.cur_time.strftime('%Y%m%d%H%M%S')) + for index, video_path in enumerate(self.videos): + if os.path.exists(self.features[index]): + with open(self.features[index], 'r') as file: + data = json.load(file) + features = data + if len(features['frame']) == 0: + continue + else: + print('features.json not found') + continue + if args['mode'] == 'normal': + self.video_chat.bot.init_agent(args['openai_api_key'], + features) + self.qa_test(index, features, video_path) + elif args['mode'] == 'shot': + begin_time = 0 + summary_list = list() + for shot_time in features['time_intervals']: + end_time = shot_time + if end_time == 0: + continue + shot_features = dict() + for feature_type, feature_content in features.items(): + if len(feature_content) == 0 or not isinstance( + feature_content[0], dict) or 'begin' not in \ + feature_content[0]: + continue + shot_features[feature_type] = list() + for item in feature_content: + if begin_time <= item['begin'] < end_time: + shot_features[feature_type].append(item) + if len(shot_features['frame']) == 0: + continue + # dense_with_ocr = find_text_in_dense(shot_features) + # shot_features['dense_with_ocr'] = dense_with_ocr + prompt, question = self.video_chat.bot.init_agent_shot( + args['openai_api_key'], shot_features) + summary = self.video_chat.bot.run_text( + question, args['llm'], None, t=1) + summary_list.append({ + 'begin': begin_time, + 'end': end_time, + 'text': summary, + }) + begin_time = shot_time + output = dict() + output['summary'] = summary_list + output['subtitle'] = features['merged_subtitle'] + if self.output_path: + folder = os.path.basename(os.path.dirname(video_path)) + folder_path = os.path.join(self.output_path, folder) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + save_path = os.path.join(folder_path, 'summary.json') + with open(save_path, 'w') as json_file: + json.dump( + summary_list, + json_file, + indent=4, + ensure_ascii=False) + prompt = self.video_chat.bot.init_agent_with_summary( + args['openai_api_key'], output) + self.qa_test(index, prompt, video_path) + + def qa_test(self, index, features, video_path): + output = dict() + qa_output = list() + + with open(self.questions[index]) as file: + data = json.load(file) + output['video_name'] = data['video_name'] + output['test_time'] = self.cur_time.strftime('%Y-%m-%d %H:%M:%S') + output['remarks'] = args['remarks'] + for qa in data['qa']: + answer = qa['a'] + question = qa['q'] + # infer_answer = self.video_chat.bot.run_text(question, + # args['llm'], t=1) + chain_of_thought = ChainOfThought( + self.video_chat.bot.system_prompt) + question_type = chain_of_thought.get_question_type(question) + # if int(question_type['question_type']) == 2: + # if not self.qwen: + # self.qwen = Qwen('/mnt/data.coronaryct.1/ + # ZhuYichen/Qwen-VL/model/Qwen-VL-Chat/') + # self.qwen.init_model() + # image_paths = extract_frames(video_path, + # features['time_intervals']) + # ans_list = list() + # for path in image_paths: + # answer = self.qwen.inference(path, question) + # if '缺失信息' in answer or '无法' in answer or '没有' in answer: + # continue + # ans_list.append(answer) + # infer_answer = ans_list + + if question_type['is_summary']: + infer_answer = chain_of_thought.summarize() + else: + infer_answer = self.video_chat.bot.run_text( + question, args['llm']) + answer_type = chain_of_thought.get_answer_type( + question, infer_answer) + if not answer_type['is_solved']: + # if int(question_type['question_type']) == 2: + # if not self.owl: + # self.owl = Owl( + # '/mnt/data.coronaryct.1/ZhuYichen/ + # mPLUG-Owl/model/ + # mplug-owl-bloomz-7b-multilingual/') + # self.owl.init_model() + # question = '根据用户提出的问题,在视频描述中找到与之对应 + # 的时间点,\n问题:“{}”\n'.format( + # qa['q']) + '以如下的json格式返回{"time":int}。' + # try: + # time_point = int( + # self.video_chat.bot.run_text(question, + # args['llm']).strip('{}').split(':')[ + # 1].strip()) + # image_paths = extract_frames(video_path, + # features['time_intervals']) + # infer_answer = self.owl.inference( + # image_paths, qa['q']) + # print(infer_answer) + # except Exception as e: + # print(e) + # else: + question = '你可以按照以下流程思考:' \ + '1. 找出用户提问的关键词;' \ + '2. 在描述中搜索和关键词相关的描述;' \ + '3. 根据片段中的内容推理答案。\n' \ + '请你给出你的思考过程:推理' + \ + qa['q'] + infer_answer = self.video_chat.bot.run_text( + question, args['llm']) + if len(infer_answer) == 0: + continue + qa_output.append({ + 'q': qa['q'], + 'a': answer, + 'predict': infer_answer + }) + if len(qa_output) == 0: + return + output['qa'] = qa_output + output['features'] = features + if self.output_path: + folder = os.path.basename(os.path.dirname(video_path)) + folder_path = os.path.join(self.output_path, folder) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + save_path = os.path.join(folder_path, 'output.json') + with open(save_path, 'w') as json_file: + json.dump( + output, + json_file, + indent=4, + ensure_ascii=False, + cls=CustomEncoder) + + def evaluate_by_chatgpt(self): + if args['predict'] and args['evaluate']: + all_predict_results = os.listdir(args['output_path']) + int_list = [int(s) for s in all_predict_results] + self.evaluate_path = os.path.join(args['output_path'], + str(max(int_list))) + elif not args['predict'] and args['evaluate']: + self.evaluate_path = args['evaluate_path'] + + evaluate_result = dict() + total_score = 0 + qa_number = 0 + evaluate_save_path = os.path.join(self.evaluate_path, + 'evaluate_result.json') + answer_list = [] + # num_reasoning = 0 + # num_visual = 0 + # num_other = 0 + # reasoning_score = 0 + # visual_score = 0 + # other_score = 0 + for output_folder in os.listdir(self.evaluate_path): + # if int(output_folder) != 14: + # continue + if output_folder.endswith('.json'): + continue + output_file = os.path.join(self.evaluate_path, output_folder, + 'output.json') + if not os.path.exists(output_file): + continue + with open(output_file) as file: + data = json.load(file) + cur_score = 0 + cur_number = 0 + cur_result = dict() + cur_answer = list() + for qa in data['qa']: + try: + answer = self.video_chat.bot.evaluate_qa( + args['openai_api_key'], str(qa), args['llm']) + total_score += int(answer['score']) + qa_number += 1 + cur_score += int(answer['score']) + cur_number += 1 + cur_answer.append(answer) + + # chain_of_thought = + # ChainOfThought(self.video_chat.bot.system_prompt) + # question_type = + # chain_of_thought.get_question_type(qa['q']) + # if question_type['question_type'] == 1: + # print('reasoning') + # num_reasoning += 1 + # reasoning_score += int(answer['score']) + # elif question_type['question_type'] == 2: + # print('visual') + # num_visual += 1 + # visual_score += int(answer['score']) + # elif question_type['question_type'] == 0: + # print('other') + # num_other += 1 + # other_score += int(answer['score']) + except Exception as e: + print(e) + cur_result['video'] = output_folder + cur_result['mean_score'] = cur_score / cur_number + cur_result['answer'] = cur_answer + answer_list.append(cur_result) + evaluate_result['remarks'] = args['remarks'] + evaluate_result['mean_score'] = total_score / qa_number + evaluate_result['total_number'] = qa_number + # evaluate_result['num_reasoning'] = num_reasoning + # evaluate_result['num_visual'] = num_visual + # evaluate_result['num_other'] = num_other + # evaluate_result['reasoning_score'] = reasoning_score / num_reasoning + # evaluate_result['visual_score'] = visual_score / num_visual + # evaluate_result['other_score'] = other_score / num_other + evaluate_result['answer_list'] = answer_list + with open(evaluate_save_path, 'w') as json_file: + json.dump(evaluate_result, json_file, indent=4, ensure_ascii=False) + + +class CustomEncoder(json.JSONEncoder): + + def default(self, o): + if isinstance(o, np.int64): + return int(o) # Convert int64 to Python int + return super().default(o) + + +def extract_frames(video_path, seconds): + folder = os.path.dirname(video_path) + save_path = os.path.join(folder, 'tmp') + if not os.path.exists(save_path): + os.mkdir(save_path) + output_paths = list() + for second in seconds: + output_path = os.path.join(save_path, f'{second}.png') + if os.path.exists(output_path): + output_paths.append(output_path) + continue + ffmpeg_cmd = [ + 'ffmpeg', + '-i', + video_path, + '-ss', + str(second), # 指定要提取的时间点 + '-vframes', + '1', # 仅提取一个帧 + output_path + ] + subprocess.run(ffmpeg_cmd) + output_paths.append(output_path) + return output_paths + + +def get_duration(video_path): + result = subprocess.run([ + 'ffprobe', '-v', 'error', '-show_entries', 'format=duration', '-of', + 'default=noprint_wrappers=1:nokey=1', video_path + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + return float(result.stdout) + + +def main(): + input_video = InputVideo(args['videos_path']) + if args['predict']: + input_video.start_test() + if args['evaluate']: + input_video.evaluate_by_chatgpt() + + +if __name__ == '__main__': + for i in range(1): + main() diff --git a/projects/videochat/transforms.py b/projects/videochat/transforms.py new file mode 100644 index 0000000000..2a6805002e --- /dev/null +++ b/projects/videochat/transforms.py @@ -0,0 +1,450 @@ +import math +import numbers +import random + +import numpy as np +import torch +import torchvision +from PIL import Image, ImageOps + + +class GroupRandomCrop(object): + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img_group): + + w, h = img_group[0].size + th, tw = self.size + + out_images = list() + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + + for img in img_group: + assert (img.size[0] == w and img.size[1] == h) + if w == tw and h == th: + out_images.append(img) + else: + out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) + + return out_images + + +class MultiGroupRandomCrop(object): + + def __init__(self, size, groups=1): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.groups = groups + + def __call__(self, img_group): + + w, h = img_group[0].size + th, tw = self.size + + out_images = list() + + for i in range(self.groups): + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + + for img in img_group: + assert (img.size[0] == w and img.size[1] == h) + if w == tw and h == th: + out_images.append(img) + else: + out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) + + return out_images + + +class GroupCenterCrop(object): + + def __init__(self, size): + self.worker = torchvision.transforms.CenterCrop(size) + + def __call__(self, img_group): + return [self.worker(img) for img in img_group] + + +class GroupRandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of + 0.5.""" + + def __init__(self, is_flow=False): + self.is_flow = is_flow + + def __call__(self, img_group, is_flow=False): + v = random.random() + if v < 0.5: + ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] + if self.is_flow: + for i in range(0, len(ret), 2): + # invert flow pixel values when flipping + ret[i] = ImageOps.invert(ret[i]) + return ret + else: + return img_group + + +class GroupNormalize(object): + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) + rep_std = self.std * (tensor.size()[0] // len(self.std)) + + for t, m, s in zip(tensor, rep_mean, rep_std): + t.sub_(m).div_(s) + + return tensor + + +class GroupScale(object): + """Rescales the input PIL.Image to the given 'size'. 'size' will be the + size of the smaller edge. For example, if height > width, then image will + be. + + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.worker = torchvision.transforms.Resize(size, interpolation) + + def __call__(self, img_group): + return [self.worker(img) for img in img_group] + + +class GroupOverSample(object): + + def __init__(self, crop_size, scale_size=None, flip=True): + self.crop_size = crop_size if not isinstance(crop_size, int) else ( + crop_size, crop_size) + + if scale_size is not None: + self.scale_worker = GroupScale(scale_size) + else: + self.scale_worker = None + self.flip = flip + + def __call__(self, img_group): + + if self.scale_worker is not None: + img_group = self.scale_worker(img_group) + + image_w, image_h = img_group[0].size + crop_w, crop_h = self.crop_size + + offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, + crop_w, crop_h) + oversample_group = list() + for o_w, o_h in offsets: + normal_group = list() + flip_group = list() + for i, img in enumerate(img_group): + crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) + normal_group.append(crop) + flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode == 'L' and i % 2 == 0: + flip_group.append(ImageOps.invert(flip_crop)) + else: + flip_group.append(flip_crop) + + oversample_group.extend(normal_group) + if self.flip: + oversample_group.extend(flip_group) + return oversample_group + + +class GroupFullResSample(object): + + def __init__(self, crop_size, scale_size=None, flip=True): + self.crop_size = crop_size if not isinstance(crop_size, int) else ( + crop_size, crop_size) + + if scale_size is not None: + self.scale_worker = GroupScale(scale_size) + else: + self.scale_worker = None + self.flip = flip + + def __call__(self, img_group): + + if self.scale_worker is not None: + img_group = self.scale_worker(img_group) + + image_w, image_h = img_group[0].size + crop_w, crop_h = self.crop_size + + w_step = (image_w - crop_w) // 4 + h_step = (image_h - crop_h) // 4 + + offsets = list() + offsets.append((0 * w_step, 2 * h_step)) # left + offsets.append((4 * w_step, 2 * h_step)) # right + offsets.append((2 * w_step, 2 * h_step)) # center + + oversample_group = list() + for o_w, o_h in offsets: + normal_group = list() + flip_group = list() + for i, img in enumerate(img_group): + crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) + normal_group.append(crop) + if self.flip: + flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode == 'L' and i % 2 == 0: + flip_group.append(ImageOps.invert(flip_crop)) + else: + flip_group.append(flip_crop) + + oversample_group.extend(normal_group) + oversample_group.extend(flip_group) + return oversample_group + + +class GroupMultiScaleCrop(object): + + def __init__(self, + input_size, + scales=None, + max_distort=1, + fix_crop=True, + more_fix_crop=True): + self.scales = scales if scales is not None else [1, .875, .75, .66] + self.max_distort = max_distort + self.fix_crop = fix_crop + self.more_fix_crop = more_fix_crop + self.input_size = input_size if not isinstance(input_size, int) else [ + input_size, input_size + ] + self.interpolation = Image.BILINEAR + + def __call__(self, img_group): + + im_size = img_group[0].size + + crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) + crop_img_group = [ + img.crop( + (offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) + for img in img_group + ] + ret_img_group = [ + img.resize((self.input_size[0], self.input_size[1]), + self.interpolation) for img in crop_img_group + ] + return ret_img_group + + def _sample_crop_size(self, im_size): + image_w, image_h = im_size[0], im_size[1] + + # find a crop size + base_size = min(image_w, image_h) + crop_sizes = [int(base_size * x) for x in self.scales] + crop_h = [ + self.input_size[1] if abs(x - self.input_size[1]) < 3 else x + for x in crop_sizes + ] + crop_w = [ + self.input_size[0] if abs(x - self.input_size[0]) < 3 else x + for x in crop_sizes + ] + + pairs = [] + for i, h in enumerate(crop_h): + for j, w in enumerate(crop_w): + if abs(i - j) <= self.max_distort: + pairs.append((w, h)) + + crop_pair = random.choice(pairs) + if not self.fix_crop: + w_offset = random.randint(0, image_w - crop_pair[0]) + h_offset = random.randint(0, image_h - crop_pair[1]) + else: + w_offset, h_offset = self._sample_fix_offset( + image_w, image_h, crop_pair[0], crop_pair[1]) + + return crop_pair[0], crop_pair[1], w_offset, h_offset + + def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): + offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, + crop_w, crop_h) + return random.choice(offsets) + + @staticmethod + def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): + w_step = (image_w - crop_w) // 4 + h_step = (image_h - crop_h) // 4 + + ret = list() + ret.append((0, 0)) # upper left + ret.append((4 * w_step, 0)) # upper right + ret.append((0, 4 * h_step)) # lower left + ret.append((4 * w_step, 4 * h_step)) # lower right + ret.append((2 * w_step, 2 * h_step)) # center + + if more_fix_crop: + ret.append((0, 2 * h_step)) # center left + ret.append((4 * w_step, 2 * h_step)) # center right + ret.append((2 * w_step, 4 * h_step)) # lower center + ret.append((2 * w_step, 0 * h_step)) # upper center + + ret.append((1 * w_step, 1 * h_step)) # upper left quarter + ret.append((3 * w_step, 1 * h_step)) # upper right quarter + ret.append((1 * w_step, 3 * h_step)) # lower left quarter + ret.append((3 * w_step, 3 * h_step)) # lower righ quarter + + return ret + + +class GroupRandomSizedCrop(object): + """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the + original size and and a random aspect ratio of 3/4 to 4/3 of the original + aspect ratio This is popularly used to train the Inception networks size: + size of the smaller edge interpolation: Default: + + PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, img_group): + for attempt in range(10): + area = img_group[0].size[0] * img_group[0].size[1] + target_area = random.uniform(0.08, 1.0) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= img_group[0].size[0] and h <= img_group[0].size[1]: + x1 = random.randint(0, img_group[0].size[0] - w) + y1 = random.randint(0, img_group[0].size[1] - h) + found = True + break + else: + found = False + x1 = 0 + y1 = 0 + + if found: + out_group = list() + for img in img_group: + img = img.crop((x1, y1, x1 + w, y1 + h)) + assert (img.size == (w, h)) + out_group.append( + img.resize((self.size, self.size), self.interpolation)) + return out_group + else: + # Fallback + scale = GroupScale(self.size, interpolation=self.interpolation) + crop = GroupRandomCrop(self.size) + return crop(scale(img_group)) + + +class ConvertDataFormat(object): + + def __init__(self, model_type): + self.model_type = model_type + + def __call__(self, images): + if self.model_type == '2D': + return images + tc, h, w = images.size() + t = tc // 3 + images = images.view(t, 3, h, w) + images = images.permute(1, 0, 2, 3) + return images + + +class Stack(object): + + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + if img_group[0].mode == 'L': + return np.concatenate([np.expand_dims(x, 2) for x in img_group], + axis=2) + elif img_group[0].mode == 'RGB': + if self.roll: + return np.concatenate( + [np.array(x)[:, :, ::-1] for x in img_group], axis=2) + else: + # print(img_group[0].shape) + return np.concatenate(img_group, axis=2) + + +class ToTorchFormatTensor(object): + """Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, + 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]""" + + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + return img.float().div(255) if self.div else img.float() + + +class IdentityTransform(object): + + def __call__(self, data): + return data + + +if __name__ == '__main__': + trans = torchvision.transforms.Compose([ + GroupScale(256), + GroupRandomCrop(224), + Stack(), + ToTorchFormatTensor(), + GroupNormalize(mean=[.485, .456, .406], std=[.229, .224, .225]) + ]) + + im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') + + color_group = [im] * 3 + rst = trans(color_group) + + gray_group = [im.convert('L')] * 9 + gray_rst = trans(gray_group) + + trans2 = torchvision.transforms.Compose([ + GroupRandomSizedCrop(256), + Stack(), + ToTorchFormatTensor(), + GroupNormalize(mean=[.485, .456, .406], std=[.229, .224, .225]) + ]) + print(trans2(color_group)) diff --git a/projects/videochat/util.py b/projects/videochat/util.py new file mode 100644 index 0000000000..311629fdd9 --- /dev/null +++ b/projects/videochat/util.py @@ -0,0 +1,98 @@ +import numpy as np +from decord import VideoReader, cpu + + +def loadvideo_decord(sample, + sample_rate_scale=1, + new_width=384, + new_height=384, + clip_len=8, + frame_sample_rate=2, + num_segment=1): + fname = sample + vr = VideoReader( + fname, width=new_width, height=new_height, num_threads=1, ctx=cpu(0)) + # handle temporal segments + # converted_len = int(clip_len * frame_sample_rate) + seg_len = len(vr) // num_segment + duration = max(len(vr) // vr.get_avg_fps(), 8) + + all_index = [] + for i in range(num_segment): + index = np.linspace(0, seg_len, num=int(duration)) + index = np.clip(index, 0, seg_len - 1).astype(np.int64) + index = index + i * seg_len + all_index.extend(list(index)) + + all_index = all_index[::int(sample_rate_scale)] + vr.seek(0) + buffer = vr.get_batch(all_index).asnumpy() + return buffer + + +def loadvideo_decord_origin(sample, + sample_rate_scale=1, + new_width=384, + new_height=384, + clip_len=8, + frame_sample_rate=2, + num_segment=1): + fname = sample + vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) + # handle temporal segments + # converted_len = int(clip_len * frame_sample_rate) + seg_len = len(vr) // num_segment + duration = max(len(vr) // vr.get_avg_fps(), 8) + + all_index = [] + for i in range(num_segment): + index = np.linspace(0, seg_len, num=int(duration)) + index = np.clip(index, 0, seg_len - 1).astype(np.int64) + index = index + i * seg_len + all_index.extend(list(index)) + + all_index = all_index[::int(sample_rate_scale)] + vr.seek(0) + buffer = vr.get_batch(all_index).asnumpy() + return buffer + + +def loadvideo_decord_time_segment(sample, + segment_length=10, + sample_rate_scale=1, + new_width=384, + new_height=384, + frame_sample_rate=2): + fname = sample + vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) + total_duration = len(vr) // vr.get_avg_fps() # 计算视频的总时长(秒) + num_segment = int(total_duration // segment_length) # 计算可以分割的段数 + remaining_time = total_duration % segment_length # 计算剩余的时间(秒) + + all_index = [] + buffer = [] + time_index = [] + for i in range(num_segment): + start_time = i * segment_length + end_time = start_time + segment_length + index = np.arange( + start_time * vr.get_avg_fps(), + end_time * vr.get_avg_fps(), + step=int(sample_rate_scale)) + all_index.append(list(index)) + vr.seek(0) + buffer.append(vr.get_batch(list(index)).asnumpy()) + time_index.append(i * segment_length) + time_index.append(i * segment_length) + + if remaining_time > 0: # 如果有剩余的时间,则将其作为一个单独的片段处理 + start_time = num_segment * segment_length + index = np.arange( + start_time * vr.get_avg_fps(), + len(vr), + step=int(sample_rate_scale)) + all_index.append(list(index)) + vr.seek(0) + buffer.append(vr.get_batch(list(index)).asnumpy()) + + return buffer, time_index