رفتن به نوشته‌ها

tf.hub و دسترسی به مدل‌های از پیش آموزش داده شده

بنا به گفته سایت رسمی تنسورفلو:

TensorFlow Hub is a library for reusable machine learning modules

که خب تقریبا به طور کامل تنسورفلو هاب رو تعریف کرده است. تنسورفلو هاب جایی است که محققان یا توسعه‌دهنده‌ها مدل‌های یادگیری ماشین اشون رو آپلود میکنن تا بقیه بتونن به راحتی ازش استفاده کنند. این مدل‌ها در مواقعی که میخواهیم انتقال یادگیری (Transfer learning ) انجام بدیم یا از مدلهای قبلی برای استخراج ویژگی استفاده کنیم میتونه بسیار مفید باشه. مدل‌های مختلفی در حوزه بینایی کامپیوتر،‌ متن و … آموزش داده شده است و در تنسورفلو هاب قرار داده شده است که میتونید از آنها استفاده کنید.

در این پست میخواهیم با استفاده از این ماژول (تنسورفلوهاب) یک مدل ساده برای تشخیص اسپم در Kaggle پیاده‌سازی کنیم که دقت نسبتا قابل قبولی رو میده (۹۸٪-۹۹٪).

در قدم اول دیتاست Kaggle رو از اینجا دانلود کنید و روت پروژه قرار بدید ( یادتون نره از حالت فشرده zip خارج کنید).

در قدم بعدی باید تنسورفلو و تنسورفلوهاب رو نصب کنیم:

برای اینکه یک مدل از پیش آموزش داده شده رو لوود کنید یا باید خودتون قبل مدل رو روی tfhub.dev آپلود کرده باشید یا از مدل‌هایی که قبلا آپلود شده استفاده کنید. در این پست من از مدل Universial Setence Encoder گوگل استفاده میکنم.

در این مقاله دو مدل عمیق برای زبان انگلیسی آموزش داده شده است. یک مدل بر پایه مدل Transformer و دیگری بر اساس مدل DAN است. مدل Transformer پیچیدگی و زمان اجرای بیشتری دارد ولی دقت و عملکرد بهتری به ارمغان میآورد. این دو مدل تقریبا همه منظوره آموزش داده شده‌اند (طبق گفته مقاله روی داده‌های ویکی‌پدیا، اخبار و… آموزش داده شده است) و بخاطر همین میتونید تقریبا در هر تسکی به عنوان یک بازنمایی (representation) مناسب برای جملاتتون ازشون استفاده کنید.

در مقاله نحوه استفاده از مدل در یک تیکه کد آورده شده است که خب ماهم از همین استفاده میکنیم.

که این رو اگر بخواهیم در پایتون پیاده‌سازی کنیم چیزی شبیه زیر خواهد شد:

که جمله ورودی (The quick brown fox jumps over the lazy dogs.) را به یک بردار با اندازه 500 تبدیل میکند که از این بردار میتونید در تسک‌های مختلف استفاده کنید.

خب بیایید به پیاده‌سازی تشخیص اسپم بپردازیم. ابتدا کمی دیتاست رو پیش‌پردازش میکنیم تا به اون چیزی که میخواهیم تبدیل شود. اگر فایل spam.csv رو مشاهده کنید، میبینید که چندتا ستون اضافی داره ولی اون چیزی که ما برای آموزش مدل لازم داریم ستون‌های اول و دوم (v1 , v2) است. این ستون‌ها اسمشون رو تغییر میدهیم. همچنین اگر دقت کنید برچسب‌ها به صورت spam/ham است که اینها رو هم تبدیل به صفر و یک میکنیم و در یک ستون جدید قرار میدهیم. برای اینکه بتونیم ارزیابی منصفانه‌ای داشته باشیم دیتامون رو به آموزش (۵۰۰۰ نمونه) و آزمون (۵۵۷ نمونه) تقسیم میکنیم. پس به طور کلی چیزی مشابه زیر خواهیم داشت:

مرحله بعدی به تعریف مدل میرسه. مدل‌مون رو با استفاده از کراس پیاده‌سازی میکنیم. همچنین برای اینکه بتونیم از tfhub استفاده کنیم باید یک لایه Lambda استفاده تا بتونیم جملات ورودی رو به بردارهای نهفته (sentence vector embedding) تبدیل کنیم. پس از اینکه بردارهای نهفته رو بدست آوردیم جملاتمون رو به یک MLP classifier با سایز ۲۵۶ و ۲ میدهیم تا دسته‌بندی درست را برای ما تشخیص دهد:

تنها نکته‌ای که احتمالا در بالا مشاهده میکنید این است که برخلاف اکثرا مدل‌هایی که تا به حال پیاده‌سازی کردید در اینجا به جای اینکه ورودی یک توالی از اعداد باشه (sequence of integers) یک رشته (string) است. دلیل این موضوع هم این است که Universal Encoder اینجوری پیاده‌سازی شده است که یک رشته به عنوان ورودی میگیرد.

مرحله آخر نوبت به آموزش و ارزیابی مدل میرسه. در حال حاضر tfhub نمیشه مستقیم از session پیش فرض استفاده کرد و باید براش یک session تعریف کنید. پس از اینکه session رو تعریف کردیم، مدل رو آموزش دادیم و در نهایت با کمک تابع predict روی داده تست ارزیابی رو انجام میدهیم. در پایان هم ماتریس درهمریختگی رو براش میکشیم که دقتی در حدود ۹۸-۹۹ به ما خواهد داد:

کد کامل این پیاده‌سازی رو میتونید در اینجا ملاحظه کنید.

به طور کلی TfHub میتونه در بسیاری از موارد که یک بازنمایی سریع از جملات/تصاویر یا … میخواهید بهتون کمک کنه. دیگه لازم نیست یک مدل خیلی بزرگ رو از ابتدا آموزش بدید. در اینجا میتونید برخی مدل‌های دیگر رو مشاهده کنید.

منتشر شده در آموزشتنسورفلو

اولین باشید که نظر می دهید

دیدگاهتان را بنویسید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *